6 from .._jit_internal 
import List, BroadcastingList1, BroadcastingList2, \
     7     BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict
     8 from torch._C import TensorType, TupleType, FloatType, IntType, \
     9     ListType, StringType, DictType
    10 from textwrap 
import dedent
    13 PY35 = sys.version_info >= (3, 5)
    17     def __init__(self, name, members):
    21     def __getattr__(self, name):
    25             raise RuntimeError(
"Module {} has no member called {}".format(self.
name, name))
    29     'torch': 
Module(
'torch', {
'Tensor': torch.Tensor}),
    30     'Tensor': torch.Tensor,
    31     'typing': 
Module(
'typing', {
'Tuple': Tuple}),
    38 def get_signature(fn):
    41         sig = try_real_annotations(fn)
    45     type_line, source = 
None, 
None    47         source = dedent(inspect.getsource(fn))
    48         type_line = get_type_line(source)
    56     return parse_type_line(type_line)
    62 def get_num_params(fn):
    64         source = dedent(inspect.getsource(fn))
    65     except (TypeError, IOError):
    69     py_ast = ast.parse(source)
    70     if len(py_ast.body) == 1 
and isinstance(py_ast.body[0], ast.ClassDef):
    71         raise RuntimeError(
"cannot instantiate class object ({}) inside jit.script".format(py_ast.body[0].name))
    72     if len(py_ast.body) != 1 
or not isinstance(py_ast.body[0], ast.FunctionDef):
    73         raise RuntimeError(
"expected a single top-level function")
    74     py_def = py_ast.body[0]
    75     if py_def.args.vararg 
is not None:
    77     elif hasattr(py_def.args, 
'kwonlyargs') 
and len(py_def.args.kwonlyargs) > 0:
    80         num_params = len(py_def.args.args)
    81         if inspect.ismethod(fn):
    82             num_params = num_params - 1
    86 def parse_type_line(type_line):
    87     """Parses a type annotation specified as a comment.    90         # type: (Tensor, torch.Tensor) -> Tuple[Tensor]    91         # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor    93     arg_ann_str, ret_ann_str = split_type_line(type_line)
    96         arg_ann = eval(arg_ann_str, _eval_env)
    97     except (NameError, SyntaxError) 
as e:
    98         raise RuntimeError(
"Failed to parse the argument list of a type annotation: {}".format(str(e)))
   100     if not isinstance(arg_ann, tuple):
   104         ret_ann = eval(ret_ann_str, _eval_env)
   105     except (NameError, SyntaxError) 
as e:
   106         raise RuntimeError(
"Failed to parse the return type of a type annotation: {}".format(str(e)))
   108     arg_types = [ann_to_type(ann) 
for ann 
in arg_ann]
   109     return arg_types, ann_to_type(ret_ann)
   112 def get_type_line(source):
   113     """Tries to find the line containing a comment with the type annotation."""   114     lines = source.split(
'\n')
   118         if '# type:' in line:
   119             type_line = line.strip()
   125 def split_type_line(type_line):
   126     """Splits the comment with the type annotation into parts for argument and return types.   128     For example, for an input of:   129         # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]   131     This function will return:   132         ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")   135     start_offset = len(
'# type:')
   137         arrow_pos = type_line.index(
'->')
   139         raise RuntimeError(
"Syntax error in type annotation (cound't find `->`)")
   140     return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip()
   143 def try_real_annotations(fn):
   144     """Tries to use the Py3.5+ annotation syntax to get the type."""   146         sig = inspect.signature(fn)
   150     all_annots = [sig.return_annotation] + [p.annotation 
for p 
in sig.parameters.values()]
   151     if all(ann 
is sig.empty 
for ann 
in all_annots):
   156         return ann 
if ann 
is not sig.empty 
else None   158     arg_types = [ann_to_type(as_ann(p.annotation))
   159                  for p 
in sig.parameters.values()]
   160     return_type = ann_to_type(as_ann(sig.return_annotation))
   161     return arg_types, return_type
   164 def ann_to_type(ann):
   166         return TensorType.get()
   167     elif ann 
is torch.Tensor:
   168         return TensorType.get()
   170         return TupleType([ann_to_type(a) 
for a 
in ann.__args__])
   172         return ListType(ann_to_type(ann.__args__[0]))
   174         key = ann_to_type(ann.__args__[0])
   175         value = ann_to_type(ann.__args__[1])
   176         return DictType(key, value)
   178         return FloatType.get()
   182         return StringType.get()
   183     raise ValueError(
"Unknown type annotation: '{}'".format(ann.__name__))