Caffe2 - Python API
A deep learning, cross platform ML framework
gen_pyi.py
1 from __future__ import print_function
2 import multiprocessing
3 import sys
4 import os
5 import inspect
6 import collections
7 import yaml
8 import types
9 import re
10 import argparse
11 
12 from ..autograd.utils import YamlLoader, CodeTemplate, write
13 from ..autograd.gen_python_functions import get_py_torch_functions, get_py_variable_methods
14 from ..autograd.gen_autograd import load_aten_declarations
15 
16 """
17 This module implements generation of type stubs for PyTorch,
18 enabling use of autocomplete in IDEs like PyCharm, which otherwise
19 don't understand C extension modules.
20 
21 At the moment, this module only handles type stubs for torch and
22 torch.Tensor. It should eventually be expanded to cover all functions
23 which come are autogenerated.
24 
25 Here's our general strategy:
26 
27 - We start off with a hand-written __init__.pyi.in file. This
28  file contains type definitions for everything we cannot automatically
29  generate, including pure Python definitions directly in __init__.py
30  (the latter case should be pretty rare).
31 
32 - We go through automatically bound functions based on the
33  type information recorded in Declarations.yaml and
34  generate type hints for them (generate_type_hints)
35 
36 There are a number of type hints which we've special-cased;
37 read gen_pyi for the gory details.
38 """
39 
40 # TODO: Consider defining some aliases for our Union[...] types, to make
41 # the stubs to read on the human eye.
42 
43 needed_modules = set()
44 
45 FACTORY_PARAMS = "dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: bool=False"
46 
47 # this could be more precise w.r.t list contents etc. How to do Ellipsis?
48 INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]"
49 
50 blacklist = [
51  '__init_subclass__',
52  '__new__',
53  '__subclasshook__',
54  'clamp',
55  'clamp_',
56  'device',
57  'grad',
58  'requires_grad',
59  'range',
60  # defined in functional
61  'einsum',
62  # reduction argument; these bindings don't make sense
63  'ctc_loss',
64  'cosine_embedding_loss',
65  'hinge_embedding_loss',
66  'kl_div',
67  'margin_ranking_loss',
68  'triplet_margin_loss',
69  # Somehow, these are defined in both _C and in functional. Ick!
70  'broadcast_tensors',
71  'meshgrid',
72  'cartesian_prod',
73  'norm',
74  'chain_matmul',
75  'stft',
76  'tensordot',
77  'norm',
78  'split',
79  # These are handled specially by python_arg_parser.cpp
80  'add',
81  'add_',
82  'add_out',
83  'sub',
84  'sub_',
85  'sub_out',
86  'mul',
87  'mul_',
88  'mul_out',
89  'div',
90  'div_',
91  'div_out',
92 ]
93 
94 
95 def type_to_python(typename, size=None):
96  """type_to_python(typename: str, size: str) -> str
97 
98  Transforms a Declarations.yaml type name into a Python type specification
99  as used for type hints.
100  """
101  typename = typename.replace(' ', '') # normalize spaces, e.g., 'Generator *'
102 
103  # Disambiguate explicitly sized int/tensor lists from implicitly
104  # sized ones. These permit non-list inputs too. (IntArrayRef[] and
105  # TensorList[] are not real types; this is just for convenience.)
106  if typename in {'IntArrayRef', 'TensorList'} and size is not None:
107  typename += '[]'
108 
109  typename = {
110  'Device': 'Union[_device, str, None]',
111  'Generator*': 'Generator',
112  'IntegerTensor': 'Tensor',
113  'Scalar': 'Number',
114  'ScalarType': '_dtype',
115  'Storage': 'Storage',
116  'BoolTensor': 'Tensor',
117  'IndexTensor': 'Tensor',
118  'SparseTensorRef': 'Tensor',
119  'Tensor': 'Tensor',
120  'IntArrayRef': '_size',
121  'IntArrayRef[]': 'Union[_int, _size]',
122  'TensorList': 'Union[Tuple[Tensor, ...], List[Tensor]]',
123  'TensorList[]': 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]',
124  'bool': 'bool',
125  'double': '_float',
126  'int64_t': '_int',
127  'accreal': 'Number',
128  'real': 'Number',
129  'void*': '_int', # data_ptr
130  'void': 'None',
131  'std::string': 'str',
132  }[typename]
133 
134  return typename
135 
136 
137 def arg_to_type_hint(arg):
138  """arg_to_type_hint(arg) -> str
139 
140  This takes one argument in a Declarations and returns a string
141  representing this argument in a type hint signature.
142  """
143  name = arg['name']
144  if name == 'from': # from is a Python keyword...
145  name += '_'
146  typename = type_to_python(arg['dynamic_type'], arg.get('size'))
147  if arg.get('is_nullable'):
148  typename = 'Optional[' + typename + ']'
149  if 'default' in arg:
150  default = arg['default']
151  if default == 'nullptr':
152  default = None
153  elif default == 'c10::nullopt':
154  default = None
155  elif isinstance(default, str) and default.startswith('{') and default.endswith('}'):
156  if arg['dynamic_type'] == 'Tensor' and default == '{}':
157  default = None
158  elif arg['dynamic_type'] == 'IntArrayRef':
159  default = '(' + default[1:-1] + ')'
160  else:
161  raise Exception("Unexpected default constructor argument of type {}".format(arg['dynamic_type']))
162  default = '={}'.format(default)
163  else:
164  default = ''
165  return name + ': ' + typename + default
166 
167 
168 binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv',
169  'matmul',
170  'radd', 'rmul', # reverse arithmetic
171  'and', 'or', 'xor', # logic
172  'iadd', 'iand', 'idiv', 'ilshift', 'imul',
173  'ior', 'irshift', 'isub', 'itruediv', 'ixor', # inplace ops
174  )
175 comparison_ops = ('eq', 'ne', 'ge', 'gt', 'lt', 'le')
176 unary_ops = ('neg', 'abs', 'invert')
177 to_py_type_ops = ('bool', 'float', 'long', 'index', 'int', 'nonzero')
178 all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
179 
180 
181 def sig_for_ops(opname):
182  """sig_for_ops(opname : str) -> List[str]
183 
184  Returns signatures for operator special functions (__add__ etc.)"""
185 
186  # we have to do this by hand, because they are hand-bound in Python
187 
188  assert opname.endswith('__') and opname.startswith('__'), "Unexpected op {}".format(opname)
189 
190  name = opname[2:-2]
191  if name in binary_ops:
192  return ['def {}(self, other: Any) -> Tensor: ...'.format(opname)]
193  elif name in comparison_ops:
194  # unsafe override https://github.com/python/mypy/issues/5704
195  return ['def {}(self, other: Any) -> Tensor: ... # type: ignore'.format(opname)]
196  elif name in unary_ops:
197  return ['def {}(self) -> Tensor: ...'.format(opname)]
198  elif name in to_py_type_ops:
199  if name in {'bool', 'float'}:
200  tname = name
201  elif name == 'nonzero':
202  tname = 'bool'
203  else:
204  tname = 'int'
205  if tname in {'float', 'int'}:
206  tname = 'builtins.' + tname
207  return ['def {}(self) -> {}: ...'.format(opname, tname)]
208  else:
209  raise Exception("unknown op", opname)
210 
211 
212 def generate_type_hints(fname, decls, is_tensor=False):
213  """generate_type_hints(fname, decls, is_tensor=False)
214 
215  Generates type hints for the declarations pertaining to the function
216  :attr:`fname`. attr:`decls` are the declarations from the parsed
217  Declarations.yaml.
218  The :attr:`is_tensor` flag indicates whether we are parsing
219  members of the Tensor class (true) or functions in the
220  `torch` namespace (default, false).
221 
222  This function currently encodes quite a bit about the semantics of
223  the translation C++ -> Python.
224  """
225  if fname in blacklist:
226  return []
227 
228  type_hints = []
229  dnames = ([d['name'] for d in decls])
230  has_out = fname + '_out' in dnames
231 
232  if has_out:
233  decls = [d for d in decls if d['name'] != fname + '_out']
234 
235  for decl in decls:
236  render_kw_only_separator = True # whether we add a '*' if we see a keyword only argument
237  python_args = []
238 
239  has_tensor_options = 'TensorOptions' in [a['dynamic_type'] for a in decl['arguments']]
240 
241  for a in decl['arguments']:
242  if a['dynamic_type'] != 'TensorOptions':
243  if a.get('kwarg_only', False) and render_kw_only_separator:
244  python_args.append('*')
245  render_kw_only_separator = False
246  python_args.append(arg_to_type_hint(a))
247 
248  if is_tensor:
249  if 'self: Tensor' in python_args:
250  python_args.remove('self: Tensor')
251  python_args = ['self'] + python_args
252  else:
253  raise Exception("method without self is unexpected")
254 
255  if has_out:
256  if render_kw_only_separator:
257  python_args.append('*')
258  render_kw_only_separator = False
259  python_args.append('out: Optional[Tensor]=None')
260 
261  if has_tensor_options:
262  if render_kw_only_separator:
263  python_args.append('*')
264  render_kw_only_separator = False
265  python_args += ["dtype: _dtype=None",
266  "layout: layout=strided",
267  "device: Union[_device, str, None]=None",
268  "requires_grad:bool=False"]
269 
270  python_args_s = ', '.join(python_args)
271  python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']]
272 
273  if len(python_returns) > 1:
274  python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']'
275  else:
276  python_returns_s = python_returns[0]
277 
278  type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
279  numargs = len(decl['arguments'])
280  vararg_pos = int(is_tensor)
281  have_vararg_version = (numargs > vararg_pos and
282  decl['arguments'][vararg_pos]['dynamic_type'] in {'IntArrayRef', 'TensorList'} and
283  (numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and
284  (not is_tensor or decl['arguments'][0]['name'] == 'self'))
285 
286  type_hints.append(type_hint)
287 
288  if have_vararg_version:
289  # Two things come into play here: PyTorch has the "magic" that if the first and only positional argument
290  # is an IntArrayRef or TensorList, it will be used as a vararg variant.
291  # The following outputs the vararg variant, the "pass a list variant" is output above.
292  # The other thing is that in Python, the varargs are annotated with the element type, not the list type.
293  typelist = decl['arguments'][vararg_pos]['dynamic_type']
294  if typelist == 'IntArrayRef':
295  vararg_type = '_int'
296  else:
297  vararg_type = 'Tensor'
298  # replace first argument and eliminate '*' if present
299  python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] +
300  ': ' + vararg_type] + python_args[vararg_pos + 2:])
301  python_args_s = ', '.join(python_args)
302  type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
303  type_hints.append(type_hint)
304 
305  return type_hints
306 
307 
308 def gen_pyi(declarations_path, out):
309  """gen_pyi()
310 
311  This function generates a pyi file for torch.
312  """
313 
314  # Some of this logic overlaps with generate_python_signature in
315  # tools/autograd/gen_python_functions.py; however, this
316  # function is all about generating mypy type signatures, whereas
317  # the other function generates are custom format for argument
318  # checking. If you are update this, consider if your change
319  # also needs to update the other file.
320 
321  # Load information from YAML
322  declarations = load_aten_declarations(declarations_path)
323 
324  # Generate type signatures for top-level functions
325  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
326 
327  unsorted_function_hints = collections.defaultdict(list)
328  unsorted_function_hints.update({
329  'set_flush_denormal': ['def set_flush_denormal(mode: bool) -> bool: ...'],
330  'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
331  'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
332  'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
333  " *, out: Optional[Tensor]=None) -> Tensor: ..."],
334  'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."],
335  'get_num_threads': ['def get_num_threads() -> _int: ...'],
336  'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
337  # These functions are explicitly disabled by
338  # SKIP_PYTHON_BINDINGS because they are hand bound.
339  # Correspondingly, we must hand-write their signatures.
340  'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
341  'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
342  ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
343  ' device: Union[_device, str, None]=None, requires_grad:bool=False) -> Tensor: ...'],
344  'range': ['def range(start: Number, end: Number,'
345  ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
346  .format(FACTORY_PARAMS)],
347  'arange': ['def arange(start: Number, end: Number, step: Number, *,'
348  ' out: Optional[Tensor]=None, {}) -> Tensor: ...'
349  .format(FACTORY_PARAMS),
350  'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
351  .format(FACTORY_PARAMS),
352  'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
353  .format(FACTORY_PARAMS)],
354  'randint': ['def randint(low: _int, high: _int, size: _size, *, {}) -> Tensor: ...'
355  .format(FACTORY_PARAMS),
356  'def randint(high: _int, size: _size, *, {}) -> Tensor: ...'
357  .format(FACTORY_PARAMS)],
358  })
359  for binop in ['add', 'sub', 'mul', 'div']:
360  unsorted_function_hints[binop].append(
361  'def {}(input: Union[Tensor, Number],'
362  ' other: Union[Tensor, Number],'
363  ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
364  unsorted_function_hints[binop].append(
365  'def {}(input: Union[Tensor, Number],'
366  ' value: Number,'
367  ' other: Union[Tensor, Number],'
368  ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
369 
370  function_declarations = get_py_torch_functions(declarations)
371  for name in sorted(function_declarations.keys()):
372  unsorted_function_hints[name] += generate_type_hints(name, function_declarations[name])
373 
374  # Generate type signatures for deprecated functions
375 
376  # TODO: Maybe we shouldn't generate type hints for deprecated
377  # functions :) However, examples like those addcdiv rely on these.
378  with open('tools/autograd/deprecated.yaml', 'r') as f:
379  deprecated = yaml.load(f, Loader=YamlLoader)
380  for d in deprecated:
381  name, sig = re.match(r"^([^\(]+)\(([^\)]*)", d['name']).groups()
382  sig = ['*' if p.strip() == '*' else p.split() for p in sig.split(',')]
383  sig = ['*' if p == '*' else (p[1] + ': ' + type_to_python(p[0])) for p in sig]
384  unsorted_function_hints[name].append("def {}({}) -> Tensor: ...".format(name, ', '.join(sig)))
385 
386  function_hints = []
387  for name, hints in sorted(unsorted_function_hints.items()):
388  if len(hints) > 1:
389  hints = ['@overload\n' + h for h in hints]
390  function_hints += hints
391 
392  # Generate type signatures for Tensor methods
393  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
394 
395  unsorted_tensor_method_hints = collections.defaultdict(list)
396  unsorted_tensor_method_hints.update({
397  'size': ['def size(self) -> Size: ...',
398  'def size(self, _int) -> _int: ...'],
399  'stride': ['def stride(self) -> Tuple[_int]: ...',
400  'def stride(self, _int) -> _int: ...'],
401  'new_empty': ['def new_empty(self, size: {}, {}) -> Tensor: ...'.
402  format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
403  'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'.
404  format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
405  'new_zeros': ['def new_zeros(self, size: {}, {}) -> Tensor: ...'.
406  format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
407  'new_full': ['def new_full(self, size: {}, value: {}, {}) -> Tensor: ...'.
408  format(type_to_python('IntArrayRef'), type_to_python('Scalar'), FACTORY_PARAMS)],
409  'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
410  # clamp has no default values in the Declarations
411  'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
412  " *, out: Optional[Tensor]=None) -> Tensor: ..."],
413  'clamp_': ["def clamp_(self, min: _float=-inf, max: _float=inf) -> Tensor: ..."],
414  '__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
415  '__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])"
416  " -> None: ...".format(INDICES)],
417  'tolist': ['def tolist(self) -> List: ...'],
418  'requires_grad_': ['def requires_grad_(self, mode: bool=True) -> Tensor: ...'],
419  'element_size': ['def element_size(self) -> _int: ...'],
420  'dim': ['def dim(self) -> _int: ...'],
421  'ndimension': ['def ndimension(self) -> _int: ...'],
422  'nelement': ['def nelement(self) -> _int: ...'],
423  'cuda': ['def cuda(self, device: Optional[_device]=None, non_blocking: bool=False) -> Tensor: ...'],
424  'numpy': ['def numpy(self) -> Any: ...'],
425  'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
426  'map_': ['def map_(tensor: Tensor, callable: Callable) -> Tensor: ...'],
427  'copy_': ['def copy_(self, src: Tensor, non_blocking: bool=False) -> Tensor: ...'],
428  'storage': ['def storage(self) -> Storage: ...'],
429  'type': ['def type(self, dtype: Union[None, str, _dtype]=None, non_blocking: bool=False)'
430  ' -> Union[str, Tensor]: ...'],
431  'get_device': ['def get_device(self) -> _int: ...'],
432  'is_contiguous': ['def is_contiguous(self) -> bool: ...'],
433  'is_cuda': ['def is_cuda(self) -> bool: ...'],
434  'is_leaf': ['def is_leaf(self) -> bool: ...'],
435  'storage_offset': ['def storage_offset(self) -> _int: ...'],
436  'to': ['def to(self, dtype: _dtype, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
437  'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
438  'non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
439  'def to(self, other: Tensor, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
440  ],
441  'item': ["def item(self) -> Number: ..."],
442  })
443  for binop in ['add', 'sub', 'mul', 'div']:
444  for inplace in [True, False]:
445  out_suffix = ', *, out: Optional[Tensor]=None'
446  if inplace:
447  name += '_'
448  out_suffix = ''
449  unsorted_tensor_method_hints[name].append(
450  'def {}(self, other: Union[Tensor, Number]{})'
451  ' -> Tensor: ...'.format(name, out_suffix))
452  unsorted_tensor_method_hints[name].append(
453  'def {}(self, value: Number,'
454  ' other: Union[Tensor, Number]{})'
455  ' -> Tensor: ...'.format(name, out_suffix))
456  simple_conversions = ['byte', 'char', 'cpu', 'double', 'float', 'half', 'int', 'long', 'short']
457  for name in simple_conversions:
458  unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))
459 
460  tensor_method_declarations = get_py_variable_methods(declarations)
461  for name in sorted(tensor_method_declarations.keys()):
462  unsorted_tensor_method_hints[name] += \
463  generate_type_hints(name, tensor_method_declarations[name], is_tensor=True)
464 
465  for op in all_ops:
466  name = '__{}__'.format(op)
467  unsorted_tensor_method_hints[name] += sig_for_ops(name)
468 
469  tensor_method_hints = []
470  for name, hints in sorted(unsorted_tensor_method_hints.items()):
471  if len(hints) > 1:
472  hints = ['@overload\n' + h for h in hints]
473  tensor_method_hints += hints
474 
475  # TODO: Missing type hints for nn
476 
477  # Generate type signatures for legacy classes
478  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
479 
480  # TODO: These are deprecated, maybe we shouldn't type hint them
481  legacy_class_hints = []
482  for c in ('DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
483  'ShortStorage', 'CharStorage', 'ByteStorage'):
484  legacy_class_hints.append('class {}(Storage): ...'.format(c))
485 
486  for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
487  'ShortTensor', 'CharTensor', 'ByteTensor'):
488  legacy_class_hints.append('class {}(Tensor): ...'.format(c))
489 
490  # Generate type signatures for dtype classes
491  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
492 
493  # TODO: don't explicitly list dtypes here; get it from canonical
494  # source
495  dtype_class_hints = ['{}: dtype = ...'.format(n)
496  for n in
497  ['float32', 'float', 'float64', 'double', 'float16', 'half',
498  'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
499  'complex32', 'complex64', 'complex128']]
500 
501  # Write out the stub
502  # ~~~~~~~~~~~~~~~~~~
503 
504  env = {
505  'function_hints': function_hints,
506  'tensor_method_hints': tensor_method_hints,
507  'legacy_class_hints': legacy_class_hints,
508  'dtype_class_hints': dtype_class_hints,
509  }
510  TORCH_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '__init__.pyi.in'))
511 
512  write(out, 'torch/__init__.pyi', TORCH_TYPE_STUBS, env)
513 
514 
515 def main():
516  parser = argparse.ArgumentParser(
517  description='Generate type stubs for PyTorch')
518  parser.add_argument('--declarations-path', metavar='DECL',
519  default='torch/share/ATen/Declarations.yaml',
520  help='path to Declarations.yaml')
521  parser.add_argument('--out', metavar='OUT',
522  default='.',
523  help='path to output directory')
524  args = parser.parse_args()
525  gen_pyi(args.declarations_path, args.out)
526 
527 
528 if __name__ == '__main__':
529  main()