1 from __future__
import print_function
3 from common_utils
import TestCase, run_tests, download_file
19 def get_examples_from_docstring(docstr):
21 Extracts all runnable python code from the examples 22 in docstrings; returns a list of lines. 26 example_file_lines = []
30 exampleline_re = re.compile(
r"^\s+(?:>>>|\.\.\.) (.*)$")
32 for l
in docstr.split(
'\n'):
34 m = exampleline_re.match(l)
36 beginning += m.group(1)
40 m = exampleline_re.match(l)
42 beginning += m.group(1)
46 compile(beginning,
"",
"exec")
51 example_file_lines += beginning.split(
'\n')
55 return [
' ' + l
for l
in example_file_lines]
58 def get_all_examples():
59 """get_all_examples() -> str 61 This function grabs (hopefully all) examples from the torch documentation 62 strings and puts them in one nonsensical module returned as a string. 67 example_file_lines = [
69 "import torch.nn.functional as F",
70 "import math # type: ignore",
71 "import numpy # type: ignore",
72 "import io # type: ignore",
73 "import itertools # type: ignore",
78 "def preprocess(inp):",
79 " # type: (torch.Tensor) -> torch.Tensor",
83 for fname
in dir(torch):
84 fn = getattr(torch, fname)
85 docstr = inspect.getdoc(fn)
86 if docstr
and fname
not in blacklist:
87 e = get_examples_from_docstring(docstr)
89 example_file_lines.append(
"\n\ndef example_torch_{}():".format(fname))
90 example_file_lines += e
92 for fname
in dir(torch.Tensor):
93 fn = getattr(torch.Tensor, fname)
94 docstr = inspect.getdoc(fn)
95 if docstr
and fname
not in blacklist:
96 e = get_examples_from_docstring(docstr)
98 example_file_lines.append(
"\n\ndef example_torch_tensor_{}():".format(fname))
99 example_file_lines += e
101 return "\n".join(example_file_lines)
105 @unittest.skipIf(sys.version_info[0] == 2,
"no type hints for Python 2")
106 @unittest.skipIf(
not HAVE_MYPY,
"need mypy")
109 Run documentation examples through mypy. 111 fn = os.path.join(os.path.dirname(__file__),
'generated_type_hints_smoketest.py')
112 with open(fn,
"w")
as f:
113 print(get_all_examples(), file=f)
144 with tempfile.TemporaryDirectory()
as tmp_dir:
147 os.path.dirname(torch.__file__),
148 os.path.join(tmp_dir,
'torch'),
149 target_is_directory=
True 152 raise unittest.SkipTest(
'cannot symlink')
157 '--follow-imports',
'silent',
158 '--check-untyped-defs',
159 os.path.abspath(fn)],
162 except subprocess.CalledProcessError
as e:
163 raise AssertionError(
"mypy failed. Look above this error for mypy's output.")
166 if __name__ ==
'__main__':
def test_doc_examples(self)