6 from .._jit_internal
import List, BroadcastingList1, BroadcastingList2, \
7 BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict
8 from torch._C import TensorType, TupleType, FloatType, IntType, \
9 ListType, StringType, DictType
10 from textwrap
import dedent
13 PY35 = sys.version_info >= (3, 5)
17 def __init__(self, name, members):
21 def __getattr__(self, name):
25 raise RuntimeError(
"Module {} has no member called {}".format(self.
name, name))
29 'torch':
Module(
'torch', {
'Tensor': torch.Tensor}),
30 'Tensor': torch.Tensor,
31 'typing':
Module(
'typing', {
'Tuple': Tuple}),
38 def get_signature(fn):
41 sig = try_real_annotations(fn)
45 type_line, source =
None,
None 47 source = dedent(inspect.getsource(fn))
48 type_line = get_type_line(source)
56 return parse_type_line(type_line)
62 def get_num_params(fn):
64 source = dedent(inspect.getsource(fn))
65 except (TypeError, IOError):
69 py_ast = ast.parse(source)
70 if len(py_ast.body) == 1
and isinstance(py_ast.body[0], ast.ClassDef):
71 raise RuntimeError(
"cannot instantiate class object ({}) inside jit.script".format(py_ast.body[0].name))
72 if len(py_ast.body) != 1
or not isinstance(py_ast.body[0], ast.FunctionDef):
73 raise RuntimeError(
"expected a single top-level function")
74 py_def = py_ast.body[0]
75 if py_def.args.vararg
is not None:
77 elif hasattr(py_def.args,
'kwonlyargs')
and len(py_def.args.kwonlyargs) > 0:
80 num_params = len(py_def.args.args)
81 if inspect.ismethod(fn):
82 num_params = num_params - 1
86 def parse_type_line(type_line):
87 """Parses a type annotation specified as a comment. 90 # type: (Tensor, torch.Tensor) -> Tuple[Tensor] 91 # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor 93 arg_ann_str, ret_ann_str = split_type_line(type_line)
96 arg_ann = eval(arg_ann_str, _eval_env)
97 except (NameError, SyntaxError)
as e:
98 raise RuntimeError(
"Failed to parse the argument list of a type annotation: {}".format(str(e)))
100 if not isinstance(arg_ann, tuple):
104 ret_ann = eval(ret_ann_str, _eval_env)
105 except (NameError, SyntaxError)
as e:
106 raise RuntimeError(
"Failed to parse the return type of a type annotation: {}".format(str(e)))
108 arg_types = [ann_to_type(ann)
for ann
in arg_ann]
109 return arg_types, ann_to_type(ret_ann)
112 def get_type_line(source):
113 """Tries to find the line containing a comment with the type annotation.""" 114 lines = source.split(
'\n')
118 if '# type:' in line:
119 type_line = line.strip()
125 def split_type_line(type_line):
126 """Splits the comment with the type annotation into parts for argument and return types. 128 For example, for an input of: 129 # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor] 131 This function will return: 132 ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]") 135 start_offset = len(
'# type:')
137 arrow_pos = type_line.index(
'->')
139 raise RuntimeError(
"Syntax error in type annotation (cound't find `->`)")
140 return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip()
143 def try_real_annotations(fn):
144 """Tries to use the Py3.5+ annotation syntax to get the type.""" 146 sig = inspect.signature(fn)
150 all_annots = [sig.return_annotation] + [p.annotation
for p
in sig.parameters.values()]
151 if all(ann
is sig.empty
for ann
in all_annots):
156 return ann
if ann
is not sig.empty
else None 158 arg_types = [ann_to_type(as_ann(p.annotation))
159 for p
in sig.parameters.values()]
160 return_type = ann_to_type(as_ann(sig.return_annotation))
161 return arg_types, return_type
164 def ann_to_type(ann):
166 return TensorType.get()
167 elif ann
is torch.Tensor:
168 return TensorType.get()
170 return TupleType([ann_to_type(a)
for a
in ann.__args__])
172 return ListType(ann_to_type(ann.__args__[0]))
174 key = ann_to_type(ann.__args__[0])
175 value = ann_to_type(ann.__args__[1])
176 return DictType(key, value)
178 return FloatType.get()
182 return StringType.get()
183 raise ValueError(
"Unknown type annotation: '{}'".format(ann.__name__))