Caffe2 - Python API
A deep learning, cross platform ML framework
annotations.py
1 import re
2 import sys
3 import ast
4 import inspect
5 import torch
6 from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
7  BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict
8 from torch._C import TensorType, TupleType, FloatType, IntType, \
9  ListType, StringType, DictType
10 from textwrap import dedent
11 
12 
13 PY35 = sys.version_info >= (3, 5)
14 
15 
16 class Module(object):
17  def __init__(self, name, members):
18  self.name = name
19  self.members = members
20 
21  def __getattr__(self, name):
22  try:
23  return self.members[name]
24  except KeyError:
25  raise RuntimeError("Module {} has no member called {}".format(self.name, name))
26 
27 
28 _eval_env = {
29  'torch': Module('torch', {'Tensor': torch.Tensor}),
30  'Tensor': torch.Tensor,
31  'typing': Module('typing', {'Tuple': Tuple}),
32  'Tuple': Tuple,
33  'List': List,
34  'Dict': Dict,
35 }
36 
37 
38 def get_signature(fn):
39  # Python 3.5 adds support for the nice annotation syntax, so try that first.
40  if PY35:
41  sig = try_real_annotations(fn)
42  if sig is not None:
43  return sig
44 
45  type_line, source = None, None
46  try:
47  source = dedent(inspect.getsource(fn))
48  type_line = get_type_line(source)
49  except TypeError:
50  pass
51  # This might happen both because we failed to get the source of fn, or
52  # because it didn't have any annotations.
53  if type_line is None:
54  return None
55 
56  return parse_type_line(type_line)
57 
58 
59 # This is essentially a weaker form of get_signature(), where we don't care if
60 # we have the types, we just care that we can figure out how many parameters
61 # a function takes.
62 def get_num_params(fn):
63  try:
64  source = dedent(inspect.getsource(fn))
65  except (TypeError, IOError):
66  return None
67  if source is None:
68  return None
69  py_ast = ast.parse(source)
70  if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
71  raise RuntimeError("cannot instantiate class object ({}) inside jit.script".format(py_ast.body[0].name))
72  if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
73  raise RuntimeError("expected a single top-level function")
74  py_def = py_ast.body[0]
75  if py_def.args.vararg is not None:
76  return None
77  elif hasattr(py_def.args, 'kwonlyargs') and len(py_def.args.kwonlyargs) > 0:
78  return None
79  else:
80  num_params = len(py_def.args.args)
81  if inspect.ismethod(fn):
82  num_params = num_params - 1
83  return num_params
84 
85 
86 def parse_type_line(type_line):
87  """Parses a type annotation specified as a comment.
88 
89  Example inputs:
90  # type: (Tensor, torch.Tensor) -> Tuple[Tensor]
91  # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
92  """
93  arg_ann_str, ret_ann_str = split_type_line(type_line)
94 
95  try:
96  arg_ann = eval(arg_ann_str, _eval_env)
97  except (NameError, SyntaxError) as e:
98  raise RuntimeError("Failed to parse the argument list of a type annotation: {}".format(str(e)))
99 
100  if not isinstance(arg_ann, tuple):
101  arg_ann = (arg_ann,)
102 
103  try:
104  ret_ann = eval(ret_ann_str, _eval_env)
105  except (NameError, SyntaxError) as e:
106  raise RuntimeError("Failed to parse the return type of a type annotation: {}".format(str(e)))
107 
108  arg_types = [ann_to_type(ann) for ann in arg_ann]
109  return arg_types, ann_to_type(ret_ann)
110 
111 
112 def get_type_line(source):
113  """Tries to find the line containing a comment with the type annotation."""
114  lines = source.split('\n')
115 
116  type_line = None
117  for line in lines:
118  if '# type:' in line:
119  type_line = line.strip()
120  break
121 
122  return type_line
123 
124 
125 def split_type_line(type_line):
126  """Splits the comment with the type annotation into parts for argument and return types.
127 
128  For example, for an input of:
129  # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]
130 
131  This function will return:
132  ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")
133 
134  """
135  start_offset = len('# type:')
136  try:
137  arrow_pos = type_line.index('->')
138  except ValueError:
139  raise RuntimeError("Syntax error in type annotation (cound't find `->`)")
140  return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip()
141 
142 
143 def try_real_annotations(fn):
144  """Tries to use the Py3.5+ annotation syntax to get the type."""
145  try:
146  sig = inspect.signature(fn)
147  except ValueError:
148  return None
149 
150  all_annots = [sig.return_annotation] + [p.annotation for p in sig.parameters.values()]
151  if all(ann is sig.empty for ann in all_annots):
152  return None
153 
154  def as_ann(ann):
155  # sig.empty is really annoying so convert it to None
156  return ann if ann is not sig.empty else None
157 
158  arg_types = [ann_to_type(as_ann(p.annotation))
159  for p in sig.parameters.values()]
160  return_type = ann_to_type(as_ann(sig.return_annotation))
161  return arg_types, return_type
162 
163 
164 def ann_to_type(ann):
165  if ann is None:
166  return TensorType.get()
167  elif ann is torch.Tensor:
168  return TensorType.get()
169  elif is_tuple(ann):
170  return TupleType([ann_to_type(a) for a in ann.__args__])
171  elif is_list(ann):
172  return ListType(ann_to_type(ann.__args__[0]))
173  elif is_dict(ann):
174  key = ann_to_type(ann.__args__[0])
175  value = ann_to_type(ann.__args__[1])
176  return DictType(key, value)
177  elif ann is float:
178  return FloatType.get()
179  elif ann is int:
180  return IntType.get()
181  elif ann is str:
182  return StringType.get()
183  raise ValueError("Unknown type annotation: '{}'".format(ann.__name__))