1 from __future__
import print_function
12 from ..autograd.utils
import YamlLoader, CodeTemplate, write
13 from ..autograd.gen_python_functions
import get_py_torch_functions, get_py_variable_methods
14 from ..autograd.gen_autograd
import load_aten_declarations
17 This module implements generation of type stubs for PyTorch, 18 enabling use of autocomplete in IDEs like PyCharm, which otherwise 19 don't understand C extension modules. 21 At the moment, this module only handles type stubs for torch and 22 torch.Tensor. It should eventually be expanded to cover all functions 23 which come are autogenerated. 25 Here's our general strategy: 27 - We start off with a hand-written __init__.pyi.in file. This 28 file contains type definitions for everything we cannot automatically 29 generate, including pure Python definitions directly in __init__.py 30 (the latter case should be pretty rare). 32 - We go through automatically bound functions based on the 33 type information recorded in Declarations.yaml and 34 generate type hints for them (generate_type_hints) 36 There are a number of type hints which we've special-cased; 37 read gen_pyi for the gory details. 43 needed_modules = set()
45 FACTORY_PARAMS =
"dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: bool=False" 48 INDICES =
"indices: Union[None, _int, slice, Tensor, List, Tuple]" 64 'cosine_embedding_loss',
65 'hinge_embedding_loss',
67 'margin_ranking_loss',
68 'triplet_margin_loss',
95 def type_to_python(typename, size=None):
96 """type_to_python(typename: str, size: str) -> str 98 Transforms a Declarations.yaml type name into a Python type specification 99 as used for type hints. 101 typename = typename.replace(
' ',
'')
106 if typename
in {
'IntArrayRef',
'TensorList'}
and size
is not None:
110 'Device':
'Union[_device, str, None]',
111 'Generator*':
'Generator',
112 'IntegerTensor':
'Tensor',
114 'ScalarType':
'_dtype',
115 'Storage':
'Storage',
116 'BoolTensor':
'Tensor',
117 'IndexTensor':
'Tensor',
118 'SparseTensorRef':
'Tensor',
120 'IntArrayRef':
'_size',
121 'IntArrayRef[]':
'Union[_int, _size]',
122 'TensorList':
'Union[Tuple[Tensor, ...], List[Tensor]]',
123 'TensorList[]':
'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]',
131 'std::string':
'str',
137 def arg_to_type_hint(arg):
138 """arg_to_type_hint(arg) -> str 140 This takes one argument in a Declarations and returns a string 141 representing this argument in a type hint signature. 146 typename = type_to_python(arg[
'dynamic_type'], arg.get(
'size'))
147 if arg.get(
'is_nullable'):
148 typename =
'Optional[' + typename +
']' 150 default = arg[
'default']
151 if default ==
'nullptr':
153 elif default ==
'c10::nullopt':
155 elif isinstance(default, str)
and default.startswith(
'{')
and default.endswith(
'}'):
156 if arg[
'dynamic_type'] ==
'Tensor' and default ==
'{}':
158 elif arg[
'dynamic_type'] ==
'IntArrayRef':
159 default =
'(' + default[1:-1] +
')' 161 raise Exception(
"Unexpected default constructor argument of type {}".format(arg[
'dynamic_type']))
162 default =
'={}'.format(default)
165 return name +
': ' + typename + default
168 binary_ops = (
'add',
'sub',
'mul',
'div',
'pow',
'lshift',
'rshift',
'mod',
'truediv',
172 'iadd',
'iand',
'idiv',
'ilshift',
'imul',
173 'ior',
'irshift',
'isub',
'itruediv',
'ixor',
175 comparison_ops = (
'eq',
'ne',
'ge',
'gt',
'lt',
'le')
176 unary_ops = (
'neg',
'abs',
'invert')
177 to_py_type_ops = (
'bool',
'float',
'long',
'index',
'int',
'nonzero')
178 all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
181 def sig_for_ops(opname):
182 """sig_for_ops(opname : str) -> List[str] 184 Returns signatures for operator special functions (__add__ etc.)""" 188 assert opname.endswith(
'__')
and opname.startswith(
'__'),
"Unexpected op {}".format(opname)
191 if name
in binary_ops:
192 return [
'def {}(self, other: Any) -> Tensor: ...'.format(opname)]
193 elif name
in comparison_ops:
195 return [
'def {}(self, other: Any) -> Tensor: ... # type: ignore'.format(opname)]
196 elif name
in unary_ops:
197 return [
'def {}(self) -> Tensor: ...'.format(opname)]
198 elif name
in to_py_type_ops:
199 if name
in {
'bool',
'float'}:
201 elif name ==
'nonzero':
205 if tname
in {
'float',
'int'}:
206 tname =
'builtins.' + tname
207 return [
'def {}(self) -> {}: ...'.format(opname, tname)]
209 raise Exception(
"unknown op", opname)
212 def generate_type_hints(fname, decls, is_tensor=False):
213 """generate_type_hints(fname, decls, is_tensor=False) 215 Generates type hints for the declarations pertaining to the function 216 :attr:`fname`. attr:`decls` are the declarations from the parsed 218 The :attr:`is_tensor` flag indicates whether we are parsing 219 members of the Tensor class (true) or functions in the 220 `torch` namespace (default, false). 222 This function currently encodes quite a bit about the semantics of 223 the translation C++ -> Python. 225 if fname
in blacklist:
229 dnames = ([d[
'name']
for d
in decls])
230 has_out = fname +
'_out' in dnames
233 decls = [d
for d
in decls
if d[
'name'] != fname +
'_out']
236 render_kw_only_separator =
True 239 has_tensor_options =
'TensorOptions' in [a[
'dynamic_type']
for a
in decl[
'arguments']]
241 for a
in decl[
'arguments']:
242 if a[
'dynamic_type'] !=
'TensorOptions':
243 if a.get(
'kwarg_only',
False)
and render_kw_only_separator:
244 python_args.append(
'*')
245 render_kw_only_separator =
False 246 python_args.append(arg_to_type_hint(a))
249 if 'self: Tensor' in python_args:
250 python_args.remove(
'self: Tensor')
251 python_args = [
'self'] + python_args
253 raise Exception(
"method without self is unexpected")
256 if render_kw_only_separator:
257 python_args.append(
'*')
258 render_kw_only_separator =
False 259 python_args.append(
'out: Optional[Tensor]=None')
261 if has_tensor_options:
262 if render_kw_only_separator:
263 python_args.append(
'*')
264 render_kw_only_separator =
False 265 python_args += [
"dtype: _dtype=None",
266 "layout: layout=strided",
267 "device: Union[_device, str, None]=None",
268 "requires_grad:bool=False"]
270 python_args_s =
', '.join(python_args)
271 python_returns = [type_to_python(r[
'dynamic_type'])
for r
in decl[
'returns']]
273 if len(python_returns) > 1:
274 python_returns_s =
'Tuple[' +
', '.join(python_returns) +
']' 276 python_returns_s = python_returns[0]
278 type_hint =
"def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
279 numargs = len(decl[
'arguments'])
280 vararg_pos = int(is_tensor)
281 have_vararg_version = (numargs > vararg_pos
and 282 decl[
'arguments'][vararg_pos][
'dynamic_type']
in {
'IntArrayRef',
'TensorList'}
and 283 (numargs == vararg_pos + 1
or python_args[vararg_pos + 1] ==
'*')
and 284 (
not is_tensor
or decl[
'arguments'][0][
'name'] ==
'self'))
286 type_hints.append(type_hint)
288 if have_vararg_version:
293 typelist = decl[
'arguments'][vararg_pos][
'dynamic_type']
294 if typelist ==
'IntArrayRef':
297 vararg_type =
'Tensor' 299 python_args = (([
'self']
if is_tensor
else []) + [
'*' + decl[
'arguments'][vararg_pos][
'name'] +
300 ': ' + vararg_type] + python_args[vararg_pos + 2:])
301 python_args_s =
', '.join(python_args)
302 type_hint =
"def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
303 type_hints.append(type_hint)
308 def gen_pyi(declarations_path, out):
311 This function generates a pyi file for torch. 322 declarations = load_aten_declarations(declarations_path)
327 unsorted_function_hints = collections.defaultdict(list)
328 unsorted_function_hints.update({
329 'set_flush_denormal': [
'def set_flush_denormal(mode: bool) -> bool: ...'],
330 'get_default_dtype': [
'def get_default_dtype() -> _dtype: ...'],
331 'from_numpy': [
'def from_numpy(ndarray) -> Tensor: ...'],
332 'clamp': [
"def clamp(self, min: _float=-inf, max: _float=inf," 333 " *, out: Optional[Tensor]=None) -> Tensor: ..."],
334 'as_tensor': [
"def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."],
335 'get_num_threads': [
'def get_num_threads() -> _int: ...'],
336 'set_num_threads': [
'def set_num_threads(num: _int) -> None: ...'],
340 'tensor': [
"def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
341 'sparse_coo_tensor': [
'def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],' 342 ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,' 343 ' device: Union[_device, str, None]=None, requires_grad:bool=False) -> Tensor: ...'],
344 'range': [
'def range(start: Number, end: Number,' 345 ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...' 346 .format(FACTORY_PARAMS)],
347 'arange': [
'def arange(start: Number, end: Number, step: Number, *,' 348 ' out: Optional[Tensor]=None, {}) -> Tensor: ...' 349 .format(FACTORY_PARAMS),
350 'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...' 351 .format(FACTORY_PARAMS),
352 'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...' 353 .format(FACTORY_PARAMS)],
354 'randint': [
'def randint(low: _int, high: _int, size: _size, *, {}) -> Tensor: ...' 355 .format(FACTORY_PARAMS),
356 'def randint(high: _int, size: _size, *, {}) -> Tensor: ...' 357 .format(FACTORY_PARAMS)],
359 for binop
in [
'add',
'sub',
'mul',
'div']:
360 unsorted_function_hints[binop].append(
361 'def {}(input: Union[Tensor, Number],' 362 ' other: Union[Tensor, Number],' 363 ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
364 unsorted_function_hints[binop].append(
365 'def {}(input: Union[Tensor, Number],' 367 ' other: Union[Tensor, Number],' 368 ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
370 function_declarations = get_py_torch_functions(declarations)
371 for name
in sorted(function_declarations.keys()):
372 unsorted_function_hints[name] += generate_type_hints(name, function_declarations[name])
378 with open(
'tools/autograd/deprecated.yaml',
'r') as f: 379 deprecated = yaml.load(f, Loader=YamlLoader) 381 name, sig = re.match(
r"^([^\(]+)\(([^\)]*)", d[
'name']).groups()
382 sig = [
'*' if p.strip() ==
'*' else p.split()
for p
in sig.split(
',')]
383 sig = [
'*' if p ==
'*' else (p[1] +
': ' + type_to_python(p[0]))
for p
in sig]
384 unsorted_function_hints[name].append(
"def {}({}) -> Tensor: ...".format(name,
', '.join(sig)))
387 for name, hints
in sorted(unsorted_function_hints.items()):
389 hints = [
'@overload\n' + h
for h
in hints]
390 function_hints += hints
395 unsorted_tensor_method_hints = collections.defaultdict(list)
396 unsorted_tensor_method_hints.update({
397 'size': [
'def size(self) -> Size: ...',
398 'def size(self, _int) -> _int: ...'],
399 'stride': [
'def stride(self) -> Tuple[_int]: ...',
400 'def stride(self, _int) -> _int: ...'],
401 'new_empty': [
'def new_empty(self, size: {}, {}) -> Tensor: ...'.
402 format(type_to_python(
'IntArrayRef'), FACTORY_PARAMS)],
403 'new_ones': [
'def new_ones(self, size: {}, {}) -> Tensor: ...'.
404 format(type_to_python(
'IntArrayRef'), FACTORY_PARAMS)],
405 'new_zeros': [
'def new_zeros(self, size: {}, {}) -> Tensor: ...'.
406 format(type_to_python(
'IntArrayRef'), FACTORY_PARAMS)],
407 'new_full': [
'def new_full(self, size: {}, value: {}, {}) -> Tensor: ...'.
408 format(type_to_python(
'IntArrayRef'), type_to_python(
'Scalar'), FACTORY_PARAMS)],
409 'new_tensor': [
"def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
411 'clamp': [
"def clamp(self, min: _float=-inf, max: _float=inf," 412 " *, out: Optional[Tensor]=None) -> Tensor: ..."],
413 'clamp_': [
"def clamp_(self, min: _float=-inf, max: _float=inf) -> Tensor: ..."],
414 '__getitem__': [
"def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
415 '__setitem__': [
"def __setitem__(self, {}, val: Union[Tensor, Number])" 416 " -> None: ...".format(INDICES)],
417 'tolist': [
'def tolist(self) -> List: ...'],
418 'requires_grad_': [
'def requires_grad_(self, mode: bool=True) -> Tensor: ...'],
419 'element_size': [
'def element_size(self) -> _int: ...'],
420 'dim': [
'def dim(self) -> _int: ...'],
421 'ndimension': [
'def ndimension(self) -> _int: ...'],
422 'nelement': [
'def nelement(self) -> _int: ...'],
423 'cuda': [
'def cuda(self, device: Optional[_device]=None, non_blocking: bool=False) -> Tensor: ...'],
424 'numpy': [
'def numpy(self) -> Any: ...'],
425 'apply_': [
'def apply_(self, callable: Callable) -> Tensor: ...'],
426 'map_': [
'def map_(tensor: Tensor, callable: Callable) -> Tensor: ...'],
427 'copy_': [
'def copy_(self, src: Tensor, non_blocking: bool=False) -> Tensor: ...'],
428 'storage': [
'def storage(self) -> Storage: ...'],
429 'type': [
'def type(self, dtype: Union[None, str, _dtype]=None, non_blocking: bool=False)' 430 ' -> Union[str, Tensor]: ...'],
431 'get_device': [
'def get_device(self) -> _int: ...'],
432 'is_contiguous': [
'def is_contiguous(self) -> bool: ...'],
433 'is_cuda': [
'def is_cuda(self) -> bool: ...'],
434 'is_leaf': [
'def is_leaf(self) -> bool: ...'],
435 'storage_offset': [
'def storage_offset(self) -> _int: ...'],
436 'to': [
'def to(self, dtype: _dtype, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
437 'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, ' 438 'non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
439 'def to(self, other: Tensor, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
441 'item': [
"def item(self) -> Number: ..."],
443 for binop
in [
'add',
'sub',
'mul',
'div']:
444 for inplace
in [
True,
False]:
445 out_suffix =
', *, out: Optional[Tensor]=None' 449 unsorted_tensor_method_hints[name].append(
450 'def {}(self, other: Union[Tensor, Number]{})' 451 ' -> Tensor: ...'.format(name, out_suffix))
452 unsorted_tensor_method_hints[name].append(
453 'def {}(self, value: Number,' 454 ' other: Union[Tensor, Number]{})' 455 ' -> Tensor: ...'.format(name, out_suffix))
456 simple_conversions = [
'byte',
'char',
'cpu',
'double',
'float',
'half',
'int',
'long',
'short']
457 for name
in simple_conversions:
458 unsorted_tensor_method_hints[name].append(
'def {}(self) -> Tensor: ...'.format(name))
460 tensor_method_declarations = get_py_variable_methods(declarations)
461 for name
in sorted(tensor_method_declarations.keys()):
462 unsorted_tensor_method_hints[name] += \
463 generate_type_hints(name, tensor_method_declarations[name], is_tensor=
True)
466 name =
'__{}__'.format(op)
467 unsorted_tensor_method_hints[name] += sig_for_ops(name)
469 tensor_method_hints = []
470 for name, hints
in sorted(unsorted_tensor_method_hints.items()):
472 hints = [
'@overload\n' + h
for h
in hints]
473 tensor_method_hints += hints
481 legacy_class_hints = []
482 for c
in (
'DoubleStorage',
'FloatStorage',
'LongStorage',
'IntStorage',
483 'ShortStorage',
'CharStorage',
'ByteStorage'):
484 legacy_class_hints.append(
'class {}(Storage): ...'.format(c))
486 for c
in (
'DoubleTensor',
'FloatTensor',
'LongTensor',
'IntTensor',
487 'ShortTensor',
'CharTensor',
'ByteTensor'):
488 legacy_class_hints.append(
'class {}(Tensor): ...'.format(c))
495 dtype_class_hints = [
'{}: dtype = ...'.format(n)
497 [
'float32',
'float',
'float64',
'double',
'float16',
'half',
498 'uint8',
'int8',
'int16',
'short',
'int32',
'int',
'int64',
'long',
499 'complex32',
'complex64',
'complex128']]
505 'function_hints': function_hints,
506 'tensor_method_hints': tensor_method_hints,
507 'legacy_class_hints': legacy_class_hints,
508 'dtype_class_hints': dtype_class_hints,
510 TORCH_TYPE_STUBS = CodeTemplate.from_file(os.path.join(
'torch',
'__init__.pyi.in'))
512 write(out,
'torch/__init__.pyi', TORCH_TYPE_STUBS, env)
516 parser = argparse.ArgumentParser(
517 description=
'Generate type stubs for PyTorch')
518 parser.add_argument(
'--declarations-path', metavar=
'DECL',
519 default=
'torch/share/ATen/Declarations.yaml',
520 help=
'path to Declarations.yaml')
521 parser.add_argument(
'--out', metavar=
'OUT',
523 help=
'path to output directory')
524 args = parser.parse_args()
525 gen_pyi(args.declarations_path, args.out)
528 if __name__ ==
'__main__':