Caffe2 - Python API
A deep learning, cross platform ML framework
function_wrapper.py
1 # HEY! Trying to understand what this file does? Read
2 # "what has to be done to add a Operation ..." first!
3 
4 import re
5 from code_template import CodeTemplate
6 
7 try:
8  import typing # noqa: F401
9 except ImportError:
10  raise RuntimeError(
11  'Missing build dependency: Unable to import the `typing` module. '
12  'Please install it via `conda install typing` or `pip install typing`')
13 
14 # flake8 doesn't take into account usages in type annotations.
15 from typing import Union, Set # noqa: F401
16 from typing import Any, Dict, List, Optional, Tuple, NamedTuple
17 
18 try:
19  from mypy_extensions import TypedDict
20 except ImportError:
21  # Avoid the dependency on the mypy_extensions package.
22  # It is required, however, for type checking.
23  def TypedDict(name, attrs, total=True): # type: ignore
24  return Dict[Any, Any]
25 
26 import sys
27 if sys.version_info[0] == 3:
28  string_type = str
29 else:
30  string_type = basestring
31 
32 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
33 #
34 # what has to be done to add a Operation ...
35 #
36 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37 #
38 # 1. if broadcasting or without the full list of arguments, add a non-virtual
39 # declaration under Type.h (right now, we call this template
40 # BROADCAST but it also handles default arguments)
41 TYPE_METHOD_DECLARATION_BROADCAST = CodeTemplate("""\
42 ${return_type} ${api_name}(${type_method_formals}) const override;
43 """)
44 # 2. broadcasting functions are implemented in Type.cpp
45 TYPE_METHOD_DEFINITION_BROADCAST = CodeTemplate("""\
46 ${return_type} TypeDefault::${api_name}(${type_method_formals}) const {
47  ${device_guard_declaration}
48  Tensor ${broadcast_returns};
49  std::tie(${broadcast_returns}) = ${broadcast_function}(${broadcast_actuals}, "${api_name}");
50  return ${method_prefix_derived}${api_name}(${broadcast_modified_actuals});
51 }
52 """)
53 # 3. add virtual dispatch declaration to Type.h and impl to Type.cpp; method_prefix_derived
54 # is present for providing a base-class definition for a derived-type method with a prefix.
55 #
56 # If the declaration is abstract, then the actual implementation will
57 # be in a derived type; we put in a simple default "not implemented"
58 # stub. However, if the declaration is concrete, we dispatch to the
59 # actual implementation. At the moment, this situation *only* occurs
60 # for 'native' declarations (so the native dispatch is hardcoded into
61 # the template here.)
62 PURE_VIRTUAL_TYPE_METHOD_DECLARATION = CodeTemplate("""\
63 virtual ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const = 0;
64 """)
65 DEPRECATED_PURE_VIRTUAL_TYPE_METHOD_DECLARATION = CodeTemplate("""\
66 C10_DEPRECATED virtual ${return_type} \
67 ${method_prefix_derived}${api_name}(${type_method_formals}) const = 0;
68 """)
69 PURE_VIRTUAL_TYPE_METHOD_DECLARATION_BROADCAST = CodeTemplate("""\
70 virtual ${return_type} ${api_name}(${type_method_formals}) const = 0;
71 """)
72 
73 TYPE_METHOD_DECLARATION_ABSTRACT = CodeTemplate("""\
74 ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override;
75 """)
76 TYPE_METHOD_DEFINITION_ABSTRACT = CodeTemplate("""\
77 ${return_type} TypeDefault::${method_prefix_derived}${api_name}(${type_method_formals}) const {
78  AT_ERROR("${method_prefix_derived}${api_name} is not implemented for type ", toString());
79 }
80 """)
81 TYPE_METHOD_DECLARATION_CONCRETE = CodeTemplate("""\
82 ${return_type} ${api_name}(${type_method_formals}) const override;
83 """)
84 TYPE_METHOD_DEFINITION_CONCRETE = CodeTemplate("""\
85 ${return_type} TypeDefault::${api_name}(${type_method_formals}) const {
86  ${device_guard_declaration}
87  ${type_definition_body}
88 }
89 """)
90 # 4. add override to TypeDerived.h
91 TYPE_DERIVED_DECLARATION = CodeTemplate("""\
92 ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override;
93 """)
94 # 5. add override definition to TypeDerived.cpp
95 TYPE_DERIVED_DEFINITION = CodeTemplate("""\
96 ${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const {
97  ${device_guard_declaration}
98  ${type_definition_body}
99 }
100 """)
101 # NB: As far as ezyang can tell, we don't *have* to codegen this,
102 # because we will inherit it from the TYPE_METHOD_DEFINITION_CONCRETE in
103 # the superclass. But it doesn't seem to be harmful.
104 TYPE_DERIVED_DEFINITION_NATIVE = CodeTemplate("""\
105 ${return_type} ${Type}::${api_name}(${type_method_formals}) const {
106  ${device_guard_declaration}
107  ${return_call} at::native::${native_type_method_dispatch}(/* actuals */ ${actuals});
108 }
109 """)
110 TYPE_DERIVED_DEFINITION_NATIVE_MISSING = CodeTemplate("""\
111 ${return_type} ${Type}::${api_name}(${type_method_formals}) const {
112  AT_ERROR("${api_name} not supported on ${Type}");
113 }
114 """)
115 TYPE_DEFINITION_BODY_NATIVE = CodeTemplate("""\
116 ${return_call} at::native::${native_type_method_dispatch}(/* native_actuals */ ${native_actuals});
117 """)
118 
119 # Overrideable stubs to be used in user-extendable backends
120 TYPE_DEFINITION_EXTENSION_BACKEND = CodeTemplate("""\
121 ${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const {
122  return ${Type}Dispatch::get_function<${return_type} (*)(${formals_types})>("${schema}")(${native_actuals});
123 }
124 """)
125 
126 # add non-virtual declaration to Tensor.h
127 TENSOR_METHOD_DECLARATION = CodeTemplate("""\
128 ${return_type} ${api_name}(${method_formals_with_defaults})${const_mark};
129 """)
130 # add non-virtual declaration to Tensor.cpp
131 TENSOR_METHOD_DEFINITION = CodeTemplate("""\
132 inline ${return_type} Tensor::${api_name}(${method_formals})${const_mark} {
133  return type().${api_name}(${method_actuals});
134 }
135 """)
136 # add a method declaration in Functions.h
137 FUNCTION_DECLARATION = CodeTemplate("""\
138 static inline ${return_type} ${api_name}(${formals_with_defaults});
139 """)
140 # add a method declaration in Functions.h
141 DEPRECATED_FUNCTION_DECLARATION = CodeTemplate("""\
142 C10_DEPRECATED static inline ${return_type} ${api_name}(${formals_with_defaults});
143 """)
144 # add method definition in Functions.h
145 FUNCTION_DEFINITION = CodeTemplate("""\
146 static inline ${return_type} ${api_name}(${formals}) {
147  return ${inferred_type}.${api_name}(${type_method_actuals});
148 }
149 """)
150 # add a native declaration for a native function
151 NATIVE_DECLARATION = CodeTemplate("""\
152 CAFFE2_API ${return_type} ${native_type_method_dispatch}(${formals_with_defaults});
153 """)
154 
155 # special method definition for factory functions in Functions.h
156 FACTORY_DEFINITION = CodeTemplate("""\
157 static inline ${return_type} ${api_name}(${formals}) {
158  const DeviceGuard guard(options.device());
159  return at::native::${api_name}(${type_method_actuals});
160 }
161 """)
162 
163 # We need to cast to the base type because C++ may hide the base class
164 # implementation of ${api_name} if we have overloaded a function with
165 # the same name (but different signature) already
166 ZERO_DIM_CHECK = CodeTemplate("""\
167 if (${check_name}.dim() == 0) {
168  return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${zero_dim_actuals});
169 }""")
170 
171 ZERO_DIM_ONLY = CodeTemplate("""\
172 AT_ERROR("${api_name} only supports a 0-dimensional ${check_name} tensor, but got tensor "
173  "with ", ${check_name}.dim(), " dimension(s).");
174 """)
175 
176 SPARSE_CHECK = CodeTemplate("""\
177 if(${check_name}.is_sparse()) {
178  return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${sparse_actuals});
179 }""")
180 
181 BUFFER_DEFINITION = CodeTemplate("""\
182 auto ${name}_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
183  ${Backend}TensorId(), caffe2::TypeMeta::Make<${ScalarType}>(), ${THTensor}_new(), false).release();
184 auto ${name} = Tensor(${name}_, false);""")
185 
186 CONDITIONAL_INITIALIZER = CodeTemplate("""\
187 if (${name}.defined()) {
188  ${initializer}
189 }""")
190 
191 CALL_TEMPLATE = CodeTemplate("${cname}(${actuals})")
192 
193 
194 class NYIError(Exception):
195  """Indicates we don't support this declaration yet"""
196 
197  def __init__(self, reason):
198  self.reason = reason
199 
200 
201 TYPE_FORMAL_GENERIC = {
202  'THTensor*': 'Tensor &',
203  'THSTensor*': 'SparseTensorRef',
204  'THBoolTensor*': 'Tensor &',
205  'THIndexTensor*': 'Tensor &',
206  'THIntegerTensor*': 'Tensor &',
207  'THDenseTensor*': 'Tensor &',
208  'THDenseIndexTensor*': 'Tensor &',
209  'THStorage*': 'Storage',
210  'THGenerator*': 'Generator *',
211  'IntArrayRefSize': 'IntArrayRef',
212  'accreal': 'Scalar',
213  'real': 'Scalar',
214  'long': 'int64_t',
215 }
216 
217 DYNAMIC_TYPE = {
218  'THTensor*': 'Tensor',
219  'THSTensor*': 'SparseTensorRef',
220  'THBoolTensor*': 'BoolTensor',
221  'THIndexTensor*': 'IndexTensor',
222  'THIntegerTensor*': 'IntegerTensor',
223  'THDenseTensor*': 'Tensor',
224  'THDenseIndexTensor*': 'IndexTensor',
225  'THStorage*': 'Storage',
226  'THGenerator*': 'Generator*',
227  'IntArrayRefSize': 'IntArrayRef',
228  'accreal': 'accreal',
229  'real': 'real',
230  'long': 'int64_t',
231 }
232 
233 NATIVE_DYNAMIC_TYPE = {
234  'Tensor &': 'Tensor',
235  'const Tensor &': 'Tensor',
236 }
237 
238 TYPE_RETURN = {
239  'THTensor*': 'Tensor',
240  'THIndexTensor*': 'Tensor',
241  'THBoolTensor*': 'Tensor',
242  'THIntegerTensor*': 'Tensor',
243  'THSTensor*': 'Tensor',
244  'THDenseTensor*': 'Tensor',
245  'THDenseIndexTensor*': 'Tensor',
246  'real': 'Tensor',
247  'accreal': 'Tensor',
248  'long': 'int64_t',
249 }
250 
251 CHECKED_CAST = {
252  'THTensor*':
253  CodeTemplate(
254  'checked_tensor_unwrap('
255  '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
256  'Backend::${Backend}, ScalarType::${ScalarName})'),
257  'THSTensor*':
258  CodeTemplate(
259  'checked_tensor_unwrap('
260  '${arg_name}.tref,"${arg_name}",${arg_pos},false, '
261  'Backend::${Backend}, ScalarType::${ScalarName})'),
262  'THBoolTensor*':
263  CodeTemplate(
264  'checked_tensor_unwrap('
265  '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
266  'Backend::${Backend}, ScalarType::Byte)'),
267  'THIndexTensor*':
268  CodeTemplate(
269  'checked_tensor_unwrap('
270  '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
271  'Backend::${Backend}, ScalarType::Long)'),
272  'THIntegerTensor*':
273  CodeTemplate(
274  'checked_tensor_unwrap('
275  '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
276  'Backend::${Backend}, ScalarType::Int)'),
277  'THDenseTensor*':
278  CodeTemplate(
279  'checked_tensor_unwrap('
280  '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
281  'Backend::${DenseBackend}, ScalarType::${ScalarName})'),
282  'THDenseIndexTensor*':
283  CodeTemplate(
284  'checked_tensor_unwrap('
285  '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
286  'Backend::${DenseBackend}, ScalarType::Long)'),
287  'THStorage*':
288  CodeTemplate(
289  'checked_storage('
290  '${arg_name},"${arg_name}",${arg_pos}, '
291  # We're punning here (Backend and DeviceType constructors coincide)
292  # but DeviceType is the correct way to classify storages
293  'DeviceType::${Backend}, at::scalarTypeToTypeMeta(ScalarType::${ScalarName}))'),
294  'THGenerator*':
295  CodeTemplate(
296  'check_generator<${Backend}Generator>(${arg_name}, &globalContext().defaultGenerator(device_type()))'),
297  # This is a cast done via direct-construction
298  'IntArrayRefStride': CodeTemplate('at::IntArrayRef ${result_name} = get_intlist_stride_th(${arg_name});'),
299  'real': CodeTemplate('${arg_name}.to${ScalarName}()'),
300  'accreal': CodeTemplate('${arg_name}.to${AccScalarName}()'),
301  'TensorList': CodeTemplate(
302  'checked_tensor_list_unwrap(${arg_name},"${arg_name}",${arg_pos}, '
303  'Backend::${Backend}, ScalarType::${ScalarName})'),
304  'IntArrayRef': CodeTemplate('check_intlist<${size}>(${arg_name}, "${arg_name}", ${arg_pos}${,default_init})')
305 }
306 
307 CHECKED_USE = {
308  'THTensor*': '{}_',
309  'THSTensor*': '{}_',
310  'THIndexTensor*': '{}_',
311  'THBoolTensor*': '{}_',
312  'THIntegerTensor*': '{}_',
313  'THDenseTensor*': '{}_',
314  'THDenseIndexTensor*': '{}_',
315  'THStorage*': '{}_.unsafeGetStorageImpl()',
316  'THGenerator*': '{}_->generator',
317  'TensorList': "{0}_.data(), {0}_.size()",
318 }
319 
320 CHECKED_USE_NULLABLE = CodeTemplate('${arg_name}_ ? ${usage} : NULL')
321 
322 ALLOC_NOARGS_WRAP = {
323  'THTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
324  '(${Backend}TensorId(), caffe2::TypeMeta::Make<${ScalarType}>(), allocator(), false).release()',
325  'THBoolTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
326  '(${Backend}TensorId(), scalarTypeToTypeMeta(ScalarType::Byte), allocator(), false).release()',
327  'THIndexTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
328  '(${Backend}TensorId(), scalarTypeToTypeMeta(ScalarType::Long), allocator(), false).release()',
329  'THIntegerTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
330  '(${Backend}TensorId(), scalarTypeToTypeMeta(ScalarType::Int), allocator(), false).release()',
331  'THDenseTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
332  '(${Backend}TensorId(), caffe2::TypeMeta::Make<${ScalarType}>(), allocator(), false).release()',
333  'THDenseIndexTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
334  '(${Backend}TensorId(), scalarTypeToTypeMeta(ScalarType::Long), '
335  'allocator(), false).release()'
336 }
337 
338 ALLOC_WRAP = {
339  'THTensor*': '${arguments}',
340  'THBoolTensor*': '${arguments}',
341  'THIndexTensor*': '${arguments}',
342  'THIntegerTensor*': '${arguments}',
343  'THDenseTensor*': '${arguments}',
344  'THDenseIndexTensor*': '${arguments}',
345 }
346 
347 # Replacements for constants when calling into TH
348 CONSTANT_REPLACEMENTS = [
349  ('AS_REAL', '${AS_REAL}'),
350  ('__last_dim', 'self.ndimension()-1'),
351 ]
352 
353 # Replacements for constants in header file function definitions
354 HEADER_CONSTANT_REPLACEMENTS = [
355  (r'AS_REAL\((.*)\)', r'\1'),
356  ('__last_dim', '-1'),
357 ]
358 
359 
360 class nested_dict(object):
361  def __init__(self, base, parent):
362  self.base, self.parent = base, parent
363 
364  def __getitem__(self, x):
365  r = self.base.get(x)
366  if r is not None:
367  return r
368  return self.parent[x]
369 
370 
371 Environment = TypedDict('Environment', {
372  'ScalarName': str,
373  'THTensor': str,
374  'THType': str,
375  'THTensor': str,
376  'Backend': str,
377  'AccScalarName': str,
378 })
379 
380 TopEnvironment = TypedDict('TopEnvironment', {
381  'type_registrations': List[str],
382  'type_headers': List[str],
383  'pure_virtual_type_method_declarations': List[str],
384  'pure_virtual_extended_type_method_declarations': List[str],
385  'type_method_declarations': List[str],
386  'type_method_definitions': List[str],
387  'tensor_method_declarations': List[str],
388  'tensor_method_definitions': List[str],
389  'function_declarations': List[str],
390  'function_definitions': List[str],
391  'type_ids': List[str],
392  'native_function_declarations': List[str],
393 })
394 
395 # A Declarations.cwrap formal argument
396 # type can contain THTensor* types
397 THFormal = TypedDict('THFormal', {
398  'name': str,
399  'type': str,
400  'dynamic_type': str,
401  'kwarg_only': bool,
402  'is_nullable': bool,
403  'default': str,
404  'default_init': str,
405  'output': bool,
406  'size': int,
407  'declared_type': str,
408  'ignore_check': bool,
409  'allocate': bool,
410  'mask': bool,
411  'if_true': bool,
412  'if_false': bool,
413  'wrap_dim': str,
414  # Broadcast is originally a str but gets unwrapped to a List or Dict in-place
415  'broadcast': Any,
416  'resize': str,
417  'cpu_zero': bool,
418  'zero': bool,
419 }, total=False)
420 
421 # Generic ATen formal or native_functions.yaml formal argument.
422 # type can contain Tensor& reference types.
423 AtFormal = TypedDict('AtFormal', {
424  'name': str,
425  'type': str,
426  'dynamic_type': str,
427  'kwarg_only': bool,
428  'is_nullable': bool,
429  'default': str,
430  'default_init': str,
431  'output': bool,
432  'size': int,
433 }, total=False)
434 
435 # Note [field_name versus name]
436 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
437 # What is the difference between "field_name" and "name"?
438 #
439 # Return values of ATen operators always have a name: if it is not
440 # explicitly assigned a name inside native_functions.yaml like func:
441 # myop() -> (Tensor indices, Tensor value), then the codegen will
442 # automatically assign it a name like result0, or name might be
443 # specified inside Declarations.cwrap. We don't want these assigned
444 # names to become part of the public API when we return a namedtuple for
445 # any such multiple-return function.
446 #
447 # Thus field_name is like name, but it is defined only when there is a
448 # name specified in native_functions.yaml. If field_name is defined,
449 # then the codegen would generate code to return namedtuple. Otherwise,
450 # it would just return tuple.
451 
452 ReturnType = TypedDict('ReturnType', {
453  'name': str,
454  # See Note [field_name versus name]
455  'field_name': str,
456  'type': str,
457  'dynamic_type': str,
458 }, total=False)
459 
460 ReturnDecl = TypedDict('ReturnDecl', {
461  'kind': str,
462  'type': str,
463  'arguments': List[int],
464 }, total=False)
465 
466 # Represents a buffer in nn.yaml
467 NNBuffer = TypedDict('NNBuffer', {
468  'name': str,
469 })
470 
471 FunctionOption = TypedDict('FunctionOption', {
472  'actuals': List[str],
473  'api_name': str,
474  'arguments': List[THFormal],
475  'aten_custom_call': str,
476  'aten_dense_sparse': bool,
477  'backend_type_pairs': List[Tuple[str, str]],
478  'backends': List[str],
479  'broadcast_actuals': List[str],
480  'broadcast_function': str,
481  'broadcast_modified_actuals': List[str],
482  'broadcast_returns': List[str],
483  'buffers': List[NNBuffer],
484  # cimpls is really a List[FunctionOption]
485  'cimpls': List[Any],
486  'cname': str,
487  'condition': str,
488  'const_mark': str,
489  'device_guard': bool,
490  'device_guard_declaration': str,
491  'with_gil': bool,
492  'cpu_half': bool,
493  'deprecated': bool,
494  'cpu_bool': bool,
495  # See Note [field_name versus name]
496  'field_name': str,
497  'formals_list': List[AtFormal],
498  'formals_with_defaults': List[str],
499  'formals': List[str],
500  'formals_types': List[str],
501  'inferred_type': str,
502  'inplace': bool,
503  'matches_jit_signature': bool,
504  # This controls whether or not we generate the interface in Type or
505  # TypeExtendedInterface
506  'extended_method': bool,
507  'method_actuals': List[str],
508  'method_formals_with_defaults': List[str],
509  'method_formals': List[str],
510  'method_prefix_derived': str,
511  'mode': str,
512  'python_module': str,
513  'name': str,
514  'native_actuals': List[str],
515  'native_type_method_dispatch': str,
516  # options should be List[FunctionOption]
517  'options': Any,
518  'schema_string': str,
519  'requires_tensor': bool,
520  'return_call': str,
521  'return_type': str,
522  'return': ReturnDecl,
523  'returns': List[ReturnType],
524  'scalar_check': str,
525  # schema used for extension backend operator registration
526  'schema': str,
527  'sparse': bool,
528  'type_definition_body': List[str],
529  'type_method_actuals': List[str],
530  'type_method_definition_dispatch': str,
531  'type_method_formals': List[str],
532  'variants': str,
533  'when_spares_dispatch': str,
534  'when_sparse_dispatch': str,
535  'with_gil': bool,
536  'zero_dim_dispatch_when_scalar': str,
537  'zero_dim_tensor_only': bool,
538 })
539 
540 OutputDeclaration = NamedTuple('OutputDeclaration', [
541  ('name', str),
542  ('matches_jit_signature', bool),
543  ('schema_string', str),
544  ('method_prefix_derived', str),
545  ('arguments', List[AtFormal]),
546  ('method_of', List[str]),
547  ('mode', str),
548  ('python_module', str),
549  ('buffers', Optional[List[str]]),
550  ('returns', List[ReturnType]),
551  ('inplace', bool),
552  ('is_factory_method', bool),
553  ('abstract', bool),
554  ('requires_tensor', bool),
555  ('device_guard', bool),
556  ('with_gil', bool),
557  ('deprecated', bool),
558 ])
559 
560 
561 def device_guard(option, formals, dispatch_options, dispatch_tensor):
562  # For factory methods the `DeviceGuard` is already in the template.
563  if option.get('device_guard', True):
564  if dispatch_options:
565  return 'const DeviceGuard device_guard({}.device());'.format(dispatch_options['name'])
566  if dispatch_tensor:
567  return 'const OptionalDeviceGuard device_guard(device_of({}));'.format(dispatch_tensor)
568  return '// DeviceGuard omitted'
569 
570 
571 def is_real_argument_to_wrapper(argument):
572  # type: (THFormal) -> bool
573  return not argument.get('output', False) and\
574  argument['type'] != 'CONSTANT' and\
575  argument['type'] != 'argument'
576 
577 
578 def is_mutable_formal_argument(argument, option):
579  # type: (THFormal, FunctionOption) -> bool
580  return argument.get('output') or option['inplace'] and argument['name'] == 'self'
581 
582 
583 def check_methods_do_not_start_with_underscore(name, is_method):
584  if name in {'_values', '_indices', '_nnz', '_dimI', '_dimV', '_coalesced_'}:
585  return
586  if is_method and name.startswith('_') and not name.startswith('__') and not name.startswith('_th_'):
587  message = "Function '{}' starts with a single underscore and is ".format(name)
588  message += "configured to have a method on Tensor. Functions that start with "
589  message += " a single underscore should only be functions in the at:: "
590  message += "namespace and not methods on Tensor!"
591  raise RuntimeError(message)
592 
593 
594 def to_return_type(arg, option):
595  # type: (THFormal, FunctionOption) -> ReturnType
596  t = arg['type']
597  rt = TYPE_RETURN.get(t, t)
598  if rt == 'Tensor' and not arg.get('allocate'):
599  rt = rt + ' &'
600  if not is_mutable_formal_argument(arg, option):
601  rt = 'const ' + rt
602  return {
603  'name': arg['name'],
604  'type': rt,
605  'dynamic_type': DYNAMIC_TYPE.get(arg['type'], arg['type']),
606  }
607 
608 
609 def create_generic(top_env, declarations):
610  # type: (TopEnvironment, List[FunctionOption]) -> List[OutputDeclaration]
611  # translates defaults from cwrap types to C++ values
612  def translate_default(argument, type_str, default):
613  # type: (THFormal, str, Any) -> Any
614  if default is None:
615  # cause the default constructor for the object to run
616  return '{}'
617  if 'if_true' in argument:
618  return argument['default'] == argument['if_true']
619  for pattern, replacement in HEADER_CONSTANT_REPLACEMENTS:
620  default = re.sub(pattern, replacement, str(default))
621  if type_str in {'Scalar', 'int64_t', 'double'}:
622  try:
623  return int(default)
624  except Exception:
625  try:
626  return float(default)
627  except Exception:
628  return default
629  elif type_str == 'bool':
630  assert default.lower() in ['true', 'false']
631  return default.lower() == 'true'
632  else:
633  return default
634 
635  # change from THTensor* to Tensor & so we get how it will appear
636  # in the aten argument list...
637  def translate_formal(argument, option):
638  # type: (THFormal, FunctionOption) -> AtFormal
639  type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type'])
640  if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option):
641  type_str = 'const ' + type_str
642  translated = {
643  'name': argument['name'],
644  'type': type_str,
645  'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']),
646  } # type: AtFormal
647  if 'kwarg_only' in argument:
648  translated['kwarg_only'] = argument['kwarg_only']
649  if 'default' in argument:
650  default = translate_default(argument, type_str, argument['default'])
651  translated['default'] = default
652  translated['default_init'] = argument.get('default_init', default)
653  if argument.get('output'):
654  translated['output'] = True
655  if argument.get('size'):
656  translated['size'] = argument['size']
657  if argument.get('is_nullable') is not None:
658  translated['is_nullable'] = argument['is_nullable']
659  return translated
660 
661  def get_formals(option, include_constants=False):
662  # type: (FunctionOption, bool) -> List[AtFormal]
663  seen = set() # type: Set[str]
664  pos_args = [] # type: List[THFormal]
665  kwd_args = [] # type: List[THFormal]
666 
667  def insert(argument):
668  # type: (THFormal) -> None
669  if argument['name'] not in seen:
670  seen.add(argument['name'])
671  if argument.get('kwarg_only', False):
672  kwd_args.append(argument)
673  else:
674  pos_args.append(argument)
675 
676  def has_output_mask(argument):
677  # type: (THFormal) -> bool
678  return argument.get('allocate', False) and argument.get('mask', False)
679 
680  for argument in option['arguments']:
681  if argument.get('output') and not argument.get('allocate', False):
682  insert(argument)
683  for argument in option['arguments']:
684  if argument['type'] == 'THSTensor*':
685  # only enable for a subset of Dense/Sparse ops
686  if not (option.get('aten_dense_sparse', False)):
687  raise NYIError("Sparse Tensor")
688 
689  if include_constants and argument['type'] == 'CONSTANT':
690  insert(argument)
691  elif is_real_argument_to_wrapper(argument):
692  insert(argument)
693  if any(has_output_mask(arg) for arg in option['arguments']):
694  mask_size = sum(has_output_mask(arg) for arg in option['arguments'])
695  insert({
696  'name': 'output_mask',
697  # NB: Lack of space in comma works around parsing
698  # problem in gen_variable_type.py
699  'type': 'std::array<bool,{}>'.format(mask_size),
700  'default': '{{' + ', '.join(['true'] * mask_size) + '}}',
701  })
702 
703  result = pos_args + kwd_args
704  return [translate_formal(argument, option) for argument in result]
705 
706  def get_return_types(option):
707  # type: (FunctionOption) -> List[ReturnType]
708  ret = option['return']
709  if ret['kind'] == 'arguments':
710  argument_indices = ret['arguments']
711  if len(argument_indices) == 1:
712  the_arg = option['arguments'][argument_indices[0]]
713  return [to_return_type(the_arg, option)]
714  else:
715  return [to_return_type(option['arguments'][idx], option)
716  for idx in argument_indices]
717  elif ret['kind'] == 'type':
718  return [{
719  'type': TYPE_RETURN.get(ret['type'], ret['type']),
720  'dynamic_type': DYNAMIC_TYPE.get(ret['type'], ret['type']),
721  }]
722  else:
723  raise Exception("format_return_type")
724 
725  def format_return_type(return_types):
726  # type: (List[ReturnType]) -> str
727  if len(return_types) == 1:
728  return return_types[0]['type']
729  return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
730 
731  def find_dispatch_tensor(formals):
732  # type: (List[AtFormal]) -> Optional[str]
733  # dispatch to self if it's a parameter
734  def is_any_tensor_type(formal):
735  return (formal['dynamic_type'] == 'Tensor' or formal['dynamic_type'] == 'BoolTensor'
736  or formal['dynamic_type'] == 'IndexTensor')
737 
738  for formal in formals:
739  if formal['name'] == 'self' and is_any_tensor_type(formal) and not formal.get('is_nullable', False):
740  return formal['name']
741  # otherwise dispatch to the first Tensor or TensorList
742  for formal in formals:
743  if 'TensorList' == formal['dynamic_type'] or is_any_tensor_type(formal) and \
744  not formal.get('is_nullable', False):
745  return formal['name']
746 
747  return None
748 
749  def format_formal(f):
750  # type: (AtFormal) -> str
751  return '{} {}'.format(f['type'], f['name'])
752 
753  def formal_with_default(f):
754  # type: (AtFormal) -> str
755  s = format_formal(f)
756  v = f.get('default')
757  if v is None:
758  return s
759  if isinstance(v, bool):
760  v = str(v).lower()
761  return '{}={}'.format(s, v)
762 
763  def get_broadcast_argument(option):
764  # type: (FunctionOption) -> Optional[THFormal]
765  for argument in option['arguments']:
766  if argument.get('broadcast'):
767  return argument
768  return None
769 
770  def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims):
771  # type: (THFormal, bool, bool) -> List[str]
772  # Note: broadcast_dims can change type...
773  # return the actuals that will be passed to the broadcast function.
774  # 1) in the common case, this is the broadcasted argument (e.g. "self") followed by the tensors
775  # that it is broadcasted against (comma-separated) (e.g. "self, tensor1, tensor2").
776  # 2) in the broadcast_dims case, this is the broadcasted argument (e.g. "self") followed by the sizes
777  # it is broadcasted to (as an initializer list), so e.g. the specification
778  # "mat1.dim0,mat2.dim1" gets transformed to "self, {mat1.size(0),mat2.size(1)}"
779  if not broadcast_dims:
780  broadcast_actuals = [broadcast_arg['name']] + broadcast_arg['broadcast'].split()[0].split(",")
781  else:
782  broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',')
783  # generate size call for each dimension
784  broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')' # type: ignore
785  for x in broadcast_dims_spec])
786  broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}' # type: ignore
787  broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list]
788 
789  return broadcast_actuals
790 
791  def emit_nn_body(option):
792  # type: (FunctionOption) -> Union[str, List[str]]
793  # Concrete definition on Type.cpp for NN functions. Delegates to the
794  # xxx_forward variant variant after creating any necessary buffers.
795  actuals = option['actuals']
796  base_name = option['name'][:-1] if option['inplace'] else option['name']
797  fwd_name = option['api_name'].replace(base_name, base_name + '_forward')
798 
799  if len(option['buffers']) == 0:
800  return 'return {}({});'.format(fwd_name, ', '.join(actuals))
801 
802  body = [] # type: List[str]
803  if option['api_name'].endswith('_out'):
804  # _out variants must create buffers and insert them in the
805  # arguments list between output and input arguments
806  for buffer in option['buffers']:
807  body.append('Tensor {} = at::empty({{0}}, this->options());'.format(buffer['name']))
808  actuals = [arg['name'] for arg in option['arguments'] if arg.get('output')]
809  actuals += [buffer['name'] for buffer in option['buffers']]
810  actuals += [arg['name'] for arg in option['arguments'] if not arg.get('output')]
811 
812  body.append('return std::get<0>({}({}));'.format(fwd_name, ', '.join(actuals)))
813  return body
814 
815  def process_option(option, output_options):
816  # type: (FunctionOption, List[OutputDeclaration]) -> None
817  option['inplace'] = re.search(
818  '(^__i|[^_]_$)', option['api_name']) is not None
819 
820  # print(yaml.dump(option))
821  formals = get_formals(option)
822  option['formals_list'] = formals
823  option['formals'] = [format_formal(f) for f in formals]
824  option['formals_with_defaults'] = [formal_with_default(f) for f in formals]
825  option['returns'] = get_return_types(option)
826  option['return_type'] = format_return_type(option['returns'])
827  option['return_call'] = 'return ' if option['return_type'] != 'void' else ''
828  option['actuals'] = [f['name'] for f in formals]
829 
830  option['method_formals'] = [format_formal(f) for f in formals
831  if f['name'] != 'self']
832  option['method_formals_with_defaults'] = (
833  [formal_with_default(f) for f in formals if f['name'] != 'self'])
834  option['method_actuals'] = [
835  f['name'] if f['name'] != 'self' else '*this' for f in formals]
836 
837  # There are no cases where these differ, but they do in native_functions
838  option['type_method_formals'] = option['formals']
839  option['type_method_actuals'] = option['actuals']
840 
841  option['const_mark'] = '' if option['inplace'] else ' const'
842 
843  assert 'method' not in option['variants'], 'TH functions cannot be methods'
844  is_function = 'function' in option['variants']
845  dispatch_tensor = find_dispatch_tensor(formals)
846  is_namespace_function = is_function and dispatch_tensor is not None
847 
848  broadcast_arg = get_broadcast_argument(option)
849  # "s_" for "same size".
850  option['method_prefix_derived'] = '' if broadcast_arg is None else 's_'
851  if option['mode'] == 'TH':
852  option['device_guard'] = False
853  option['device_guard_declaration'] = device_guard(option, formals, False, dispatch_tensor)
854 
855  env = nested_dict(option, top_env)
856 
857  mode = option['mode']
858  abstract = True
859  assert option['extended_method'], 'Expected legacy operator to be an extended method'
860 
861  if mode == 'NN' and option.get('cimpls') is None:
862  # NN function with no _forward/_backward suffix don't have cimpls.
863  # They call the _forward function and discard any buffer returns
864  abstract = False
865  top_env['pure_virtual_extended_type_method_declarations'].append(
866  PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
867  top_env['type_method_declarations'].append(
868  TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
869  body = emit_nn_body(option)
870  top_env['type_method_definitions'].append(
871  TYPE_METHOD_DEFINITION_CONCRETE.substitute(
872  env, type_definition_body=body))
873  elif broadcast_arg is None:
874  top_env['pure_virtual_extended_type_method_declarations'].append(
875  PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
876  top_env['type_method_declarations'].append(
877  TYPE_METHOD_DECLARATION_ABSTRACT.substitute(env))
878  top_env['type_method_definitions'].append(
879  TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
880  else:
881  top_env['pure_virtual_extended_type_method_declarations'].append(
882  PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
883  top_env['pure_virtual_extended_type_method_declarations'].append(
884  PURE_VIRTUAL_TYPE_METHOD_DECLARATION_BROADCAST.substitute(env))
885  top_env['type_method_declarations'].append(
886  TYPE_METHOD_DECLARATION_BROADCAST.substitute(env))
887  top_env['type_method_declarations'].append(
888  TYPE_METHOD_DECLARATION_ABSTRACT.substitute(env))
889  top_env['type_method_definitions'].append(
890  TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
891 
892  broadcast_inplace = 'inplace' in broadcast_arg['broadcast']
893  broadcast_dims = 'dims:' in broadcast_arg['broadcast']
894  option['broadcast_actuals'] = get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims)
895  if not broadcast_dims:
896  option['broadcast_returns'] = (["b_" + x for x in option['broadcast_actuals']
897  if x != broadcast_arg['name'] or not broadcast_inplace])
898  else:
899  option['broadcast_returns'] = ["b_" + broadcast_arg['name']]
900 
901  option['broadcast_function'] = 'expand_' + ('inplace' if broadcast_inplace
902  else 'size' if broadcast_dims else 'outplace')
903  option['broadcast_modified_actuals'] = ['b_' + y if 'b_' + y in option['broadcast_returns'] else y
904  for y in option['actuals']]
905  top_env['type_method_definitions'].append(
906  TYPE_METHOD_DEFINITION_BROADCAST.substitute(env))
907 
908  method_of = ['Type']
909  if is_namespace_function:
910  option['inferred_type'] = 'detail::infer_type({})'.format(dispatch_tensor)
911  top_env['function_declarations'].append(
912  FUNCTION_DECLARATION.substitute(env))
913  top_env['function_definitions'].append(
914  FUNCTION_DEFINITION.substitute(env))
915  method_of.append('namespace')
916 
917  buffer_names = [buffer['name'] for buffer in option.get('buffers', [])]
918 
919  output_options.append(OutputDeclaration(
920  name=option['api_name'],
921  matches_jit_signature=option['matches_jit_signature'],
922  schema_string=option['schema_string'],
923  method_prefix_derived=option['method_prefix_derived'],
924  arguments=formals,
925  method_of=method_of,
926  mode=mode,
927  python_module=option.get('python_module', ''),
928  buffers=buffer_names,
929  returns=option['returns'],
930  inplace=option['inplace'],
931  is_factory_method=False,
932  # See Note [Abstract ATen methods]
933  abstract=abstract,
934  requires_tensor=option.get('requires_tensor', False),
935  device_guard=option.get('device_guard', True),
936  with_gil=option.get('with_gil', False),
937  deprecated=option.get('deprecated', False)
938  ))
939 
940  def native_get_formals(option, include_constants=False):
941  # type: (FunctionOption, bool) -> List[AtFormal]
942  seen = set() # type: Set[str]
943  pos_args = []
944  kwd_args = []
945 
946  def insert(argument):
947  # type: (AtFormal) -> None
948  if argument['name'] not in seen:
949  seen.add(argument['name'])
950  if argument.get('kwarg_only', False):
951  kwd_args.append(argument)
952  else:
953  pos_args.append(argument)
954 
955  for argument in option['arguments']:
956  insert(argument)
957 
958  # not clear we need dynamic_type translation as we can specify the correct type
959  # directly in native functions
960  def add_dynamic_type(argument, option):
961  # type: (AtFormal, FunctionOption) -> AtFormal
962  argument['dynamic_type'] = NATIVE_DYNAMIC_TYPE.get(argument['type'], argument['type'])
963  return argument
964 
965  result = pos_args + kwd_args
966  result = [add_dynamic_type(argument, option) for argument in result]
967 
968  # ensure we get reference-type formals when appropriate
969  def native_translate_formals(argument, option):
970  # type: (AtFormal, FunctionOption) -> AtFormal
971  def translate_map(const):
972  # type: (bool) -> Dict[str, str]
973  return {
974  'Tensor': 'const Tensor &' if const else 'Tensor &',
975  'BoolTensor': 'const Tensor &' if const else 'Tensor &',
976  'IndexTensor': 'const Tensor &' if const else 'Tensor &',
977  'Type': 'const Type &' if const else 'Type &',
978  'TensorOptions': 'const TensorOptions &' if const else 'TensorOptions &',
979  'TensorList': 'TensorList',
980  }
981 
982  if argument.get('is_nullable') and argument['type'] not in translate_map(False).keys():
983  argument['type'] = "c10::optional<{}>".format(argument['type'])
984 
985  if (option['inplace'] and argument['name'] == 'self') or argument.get('output', False):
986  argument['type'] = translate_map(False).get(argument['type'], argument['type'])
987  else:
988  argument['type'] = translate_map(True).get(argument['type'], argument['type'])
989 
990  return argument
991 
992  result = [native_translate_formals(argument, option) for argument in result]
993  return result
994 
995  # this can return multiple return types in a list, e.g. ['Tensor', 'Tensor']
996  def native_get_return_types(option):
997  # type: (FunctionOption) -> List[ReturnType]
998  ret = option['return']
999 
1000  return_types = [] # List[ReturnType]
1001  for t_raw in ret:
1002  # See Note [field_name versus name]
1003  field_name = None
1004  if isinstance(t_raw, string_type):
1005  t = t_raw
1006  name = None
1007  elif t_raw is None:
1008  t = 'void'
1009  name = None
1010  else:
1011  t = t_raw['type']
1012  name = t_raw['name']
1013  if 'field_name' in t_raw:
1014  field_name = t_raw['field_name']
1015 
1016  # can't actually return a TensorList (since it's a reference object)
1017  actual_return_type = {'TensorList': 'std::vector<Tensor>'}.get(t, t)
1018 
1019  if actual_return_type == 'Tensor' and (option['inplace'] or option['api_name'].endswith('_out')):
1020  # follow normal ATen convention of returning Tensor & for inplace functions.
1021  actual_return_type = 'Tensor &'
1022 
1023  rtype = {
1024  'type': actual_return_type,
1025  'dynamic_type': NATIVE_DYNAMIC_TYPE.get(t, t),
1026  } # type: ReturnType
1027  if name is not None:
1028  rtype['name'] = name
1029  if field_name is not None:
1030  rtype['field_name'] = field_name
1031  return_types.append(rtype)
1032 
1033  return return_types
1034 
1035  def process_native(option, output_options):
1036  # type: (FunctionOption, List[OutputDeclaration]) -> None
1037  assert option['python_module'] == '' or option['python_module'] == 'nn', \
1038  "Found python_module of {} for decl {}, but only \'\' string or \'nn\' are supported".format(
1039  option['python_module'], option['name'])
1040 
1041  formals = native_get_formals(option)
1042  option['formals_list'] = formals
1043  option['formals'] = [format_formal(f) for f in formals]
1044  option['formals_with_defaults'] = [formal_with_default(f) for f in formals]
1045  option['returns'] = native_get_return_types(option)
1046  option['return_type'] = format_return_type(option['returns'])
1047  option['return_call'] = 'return ' if option['return_type'] != 'void' else ''
1048  option['actuals'] = [f['name'] for f in formals]
1049 
1050  option['method_formals'] = [format_formal(f) for f in formals
1051  if f['name'] != 'self']
1052  option['method_formals_with_defaults'] = (
1053  [formal_with_default(f) for f in formals if f['name'] != 'self'])
1054  option['method_actuals'] = [
1055  f['name'] if f['name'] != 'self' else '*this' for f in formals]
1056 
1057  def find_formal(formal_name, formals):
1058  for formal in formals:
1059  if formal_name == formal['dynamic_type']:
1060  return formal
1061  return None
1062 
1063  assert find_formal('Type', formals) is None, \
1064  "Found Type argument in {}({}). Use TensorOptions instead.".format(
1065  option['name'], ", ".join(option['method_formals_with_defaults']))
1066 
1067  type_method_dispatch = option['type_method_definition_dispatch']
1068 
1069  dispatch_options = find_formal('TensorOptions', formals)
1070  # Only dispatch via tensor if there is no Options argument
1071  dispatch_tensor = None if dispatch_options else find_dispatch_tensor(formals)
1072 
1073  option['type_method_formals'] = [format_formal(f) for f in formals]
1074  option['type_method_actuals'] = [f['name'] for f in formals]
1075  option['native_actuals'] = [f['name'] for f in formals]
1076 
1077  option['const_mark'] = '' if option['inplace'] else ' const'
1078 
1079  is_method = 'method' in option['variants']
1080  is_namespace_function = 'function' in option['variants']
1081  is_factory_method = find_formal('TensorOptions', formals) and \
1082  not dispatch_options and 'method' not in option['variants']
1083 
1084  check_methods_do_not_start_with_underscore(option['name'], is_method)
1085 
1086  option['method_prefix_derived'] = ''
1087  option['device_guard_declaration'] = device_guard(option, formals, dispatch_options, dispatch_tensor)
1088 
1089  env = nested_dict(option, top_env)
1090 
1091  broadcast_arg = get_broadcast_argument(option)
1092  if broadcast_arg is not None:
1093  raise Exception("broadcasting is not yet supported for native functions, "
1094  "but specified for function {}", option['name'])
1095 
1096  if option['extended_method']:
1097  top_env['pure_virtual_extended_type_method_declarations'].append(
1098  PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
1099  else:
1100  top_env['pure_virtual_type_method_declarations'].append(
1101  PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
1102  top_env['type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
1103  option['native_type_method_dispatch'] = type_method_dispatch
1104 
1105  # Note [Abstract ATen methods]
1106  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1107  # An abstract ATen method is one whose dispatch differs between
1108  # types. These are implemented in derived types (with a
1109  # standard (throwing) definition in Type). A concrete ATen
1110  # method is one which has the same dispatch for all types;
1111  # we just implement it in the base Type. This is exposed
1112  # in Declarations.yaml via a field named 'abstract'.
1113  abstract = False
1114  if isinstance(type_method_dispatch, dict):
1115  abstract = True
1116  top_env['type_method_definitions'].append(
1117  TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
1118  else:
1119  body = TYPE_DEFINITION_BODY_NATIVE.substitute(env)
1120  top_env['type_method_definitions'].append(
1121  TYPE_METHOD_DEFINITION_CONCRETE.substitute(
1122  env, type_definition_body=body))
1123 
1124  # generate the at::native function declarations (i.e. what the user will implement)
1125  if isinstance(type_method_dispatch, dict):
1126  generated_native_functions = [] # type: List[str]
1127  for key in sorted(type_method_dispatch.keys()):
1128  value = type_method_dispatch[key]
1129  if value not in generated_native_functions:
1130  option['native_type_method_dispatch'] = value
1131  top_env['native_function_declarations'].append(
1132  NATIVE_DECLARATION.substitute(env))
1133  generated_native_functions.append(value)
1134  else:
1135  top_env['native_function_declarations'].append(
1136  NATIVE_DECLARATION.substitute(env))
1137 
1138  method_of = ['Type']
1139  if is_method:
1140  top_env['tensor_method_declarations'].append(
1141  TENSOR_METHOD_DECLARATION.substitute(env))
1142  top_env['tensor_method_definitions'].append(
1143  TENSOR_METHOD_DEFINITION.substitute(env))
1144  method_of.append('Tensor')
1145 
1146  if is_namespace_function:
1147  if dispatch_tensor:
1148  option['inferred_type'] = 'detail::infer_type({})'.format(dispatch_tensor)
1149  elif dispatch_options:
1150  option['inferred_type'] = 'at::getType({})'.format(dispatch_options['name'])
1151  else:
1152  # doesn't depend on a specific type, use undefined float
1153  option['inferred_type'] = 'at::getNonVariableType(at::Backend::Undefined, at::ScalarType::Float)'
1154  declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION
1155  top_env['function_declarations'].append(declaration.substitute(env))
1156  top_env['function_definitions'].append(FUNCTION_DEFINITION.substitute(env))
1157  method_of.append('namespace')
1158 
1159  output_options.append(OutputDeclaration(
1160  name=option['api_name'],
1161  matches_jit_signature=option["matches_jit_signature"],
1162  schema_string=option["schema_string"],
1163  method_prefix_derived=option['method_prefix_derived'],
1164  arguments=formals,
1165  method_of=method_of,
1166  mode=option['mode'],
1167  python_module=option['python_module'],
1168  buffers=None,
1169  returns=option['returns'],
1170  inplace=option['inplace'],
1171  is_factory_method=is_factory_method,
1172  # See Note [Abstract ATen methods]
1173  abstract=abstract,
1174  requires_tensor=option.get('requires_tensor', False),
1175  device_guard=option.get('device_guard', True),
1176  with_gil=option.get('with_gil', False),
1177  deprecated=option['deprecated'],
1178  ))
1179 
1180  output_declarations = [] # type: List[OutputDeclaration]
1181  for declaration in declarations:
1182  output_options = [] # type: List[OutputDeclaration]
1183  for option in declaration['options']:
1184  option["matches_jit_signature"] = declaration["matches_jit_signature"]
1185  option["schema_string"] = declaration["schema_string"]
1186  try:
1187  if option['mode'] != 'native':
1188  process_option(option, output_options)
1189  else:
1190  process_native(option, output_options)
1191  except NYIError:
1192  option['skip'] = True
1193  output_declarations.extend(output_options)
1194 
1195  return output_declarations
1196 
1197 
1198 def create_derived(backend_type_env, declarations):
1199  # type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]]
1200  type_object_declarations = []
1201  type_object_definitions = []
1202 
1203  is_cuda = 'CUDA' in backend_type_env['Backend']
1204 
1205  def replace_with_null(argument):
1206  # type: (THFormal) -> bool
1207  return (argument['type'] == 'THGenerator*' and
1208  backend_type_env['Backend'] == 'CUDA')
1209 
1210  def requires_checked_cast(argument):
1211  # type: (THFormal) -> bool
1212  if argument['type'] == 'IntArrayRef':
1213  return 'size' in argument
1214  return argument['type'] in CHECKED_CAST
1215 
1216  def nullable_argument(argument):
1217  # type: (THFormal) -> bool
1218  return argument.get('is_nullable', False)
1219 
1220  def bool_option_is_string(argument):
1221  # type: (THFormal) -> bool
1222  return 'if_true' in argument and isinstance(argument['if_true'], string_type)
1223 
1224  def get_argument(argument, option):
1225  # type: (THFormal, FunctionOption) -> str
1226  if replace_with_null(argument):
1227  return 'NULL'
1228  elif requires_checked_cast(argument):
1229  checked_use = CHECKED_USE.get(
1230  argument['type'], '{}_').format(argument['name'])
1231  if nullable_argument(argument):
1232  checked_use = CHECKED_USE_NULLABLE.substitute(
1233  env={}, arg_name=argument['name'], usage=checked_use)
1234  return checked_use
1235  elif argument['type'] == 'bool' and 'if_true' in argument:
1236  if bool_option_is_string(argument):
1237  tpl = '({}) ? "{}" : "{}"'
1238  else:
1239  tpl = '({}) ? {} : {}'
1240  return tpl.format(argument['name'],
1241  argument['if_true'], argument['if_false'])
1242  elif argument['type'] == 'CONSTANT':
1243  # this is a bool that is actually a string...
1244  if bool_option_is_string(argument):
1245  return '"{}"'.format(argument['name'])
1246  v = str(argument.get('default', argument['name']))
1247  for pattern, replacement in CONSTANT_REPLACEMENTS:
1248  v = re.sub(pattern, replacement, v)
1249  return CodeTemplate(v).substitute(backend_type_env)
1250  # e.g. argument 0, i.e. repeat the 0th argument in this position...
1251  elif argument['type'] == 'argument':
1252  index = int(argument['name'])
1253  return get_argument(option['arguments'][index], option)
1254  else:
1255  return argument['name']
1256 
1257  def drop_argument(argument, option):
1258  # type: (THFormal, FunctionOption) -> bool
1259  # Devices are handled in the body of the function.
1260  if argument['name'] == 'device':
1261  return True
1262  return 'CUDA' in backend_type_env['Backend'] and (
1263  option['mode'] == 'TH' and argument['type'] == 'THGenerator*')
1264 
1265  def get_arguments(arguments, option):
1266  # type: (List[THFormal], FunctionOption) -> List[str]
1267  return [get_argument(argument, option)
1268  for argument in arguments if not drop_argument(argument, option)]
1269 
1270  def is_actual_return_long(ret):
1271  # type: (ReturnDecl) -> bool
1272  if ret['type'] == 'long':
1273  return True
1274  if ret['type'] == 'real':
1275  return backend_type_env['ScalarName'] == 'Long'
1276  if ret['type'] == 'accreal':
1277  return backend_type_env['AccScalarName'] == 'Long'
1278  return False
1279 
1280  def handle_zero_dim(env, option):
1281  # type: (Environment, FunctionOption) -> List[str]
1282  zero_dim_dispatch = option.get('zero_dim_dispatch_when_scalar', '')
1283  if not zero_dim_dispatch:
1284  return []
1285  broadcasts_arg = zero_dim_dispatch in option.get('broadcast_actuals', '')
1286  zero_dim_only = option.get('zero_dim_tensor_only', False)
1287  # this combination doesn't seem to make sense
1288  assert not (broadcasts_arg and zero_dim_only)
1289  # if the argument broadcasts, then this would only affect cases where all broadcasted
1290  # tensors were zero-dim, which is inconsistent with the scalar handling.
1291  if broadcasts_arg:
1292  return []
1293  zero_dim_actuals = [arg['name']
1294  if arg['name'] != zero_dim_dispatch else "{}.item()".format(arg['name'])
1295  for arg in option['formals_list']]
1296  return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)]
1297 
1298  def handle_only_zero_dim(env, option):
1299  # type: (Environment, FunctionOption) -> Optional[List[str]]
1300  if option.get('zero_dim_tensor_only', False):
1301  check_name = option['zero_dim_dispatch_when_scalar']
1302  return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)]
1303  else:
1304  return None
1305 
1306  def handle_sparse(env, option):
1307  # type: (Environment, FunctionOption) -> List[str]
1308  if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']:
1309  return []
1310  check_name = option['when_sparse_dispatch']
1311  sparse_actuals = [arg['name']
1312  if arg['name'] != check_name else "SparseTensorRef({})".format(arg['name'])
1313  for arg in option['formals_list']]
1314  return [SPARSE_CHECK.substitute(env, check_name=check_name, sparse_actuals=sparse_actuals)]
1315 
1316  def allocate_arg(env, arg, output_count):
1317  # type: (Environment, THFormal, int) -> List[str]
1318  name = arg['name']
1319  state = ''
1320  if is_cuda:
1321  state = 'globalContext().getTHCState()'
1322  allocation = CodeTemplate(ALLOC_NOARGS_WRAP[arg['type']]).substitute(env)
1323  tensor_arg = '{}_'.format(name)
1324  if arg.get('mask', False):
1325  allocation = 'output_mask[{}] ? {} : nullptr'.format(output_count, allocation)
1326  tensor_arg = ('{}_ == nullptr ? (TensorImpl*)UndefinedTensorImpl::singleton() : (TensorImpl*){}_'
1327  .format(name, name))
1328  intrusive_ptr_type = 'c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>'
1329  return [
1330  'auto {}_ = {};'.format(name, allocation),
1331  'auto {} = Tensor({}::reclaim({}));'.format(name, intrusive_ptr_type, tensor_arg),
1332  ]
1333 
1334  def resize_arg(arg):
1335  # type: (THFormal) -> str
1336  resize = arg['resize']
1337  if isinstance(resize, str):
1338  return "{}.resize_({}.sizes());".format(arg['name'], resize)
1339  else:
1340  resize_scalar = arg.get('resize_scalar', False)
1341  if resize_scalar:
1342  dims = ['{}.dim() == 0 ? 1 : {}.size({})'.format(name, name, dim) for name, dim in resize]
1343  else:
1344  dims = ['{}.size({})'.format(name, dim) for name, dim in resize]
1345  return "{}.resize_({{ {} }});".format(arg['name'], ','.join(dims))
1346 
1347  def handle_call(env, option, cimpl):
1348  # type: (Environment, FunctionOption, FunctionOption) -> str
1349  is_nn = option['mode'] == 'NN'
1350  actuals = get_arguments(cimpl['arguments'], option)
1351  if is_cuda or is_nn:
1352  actuals = ['globalContext().getTHCState()'] + actuals
1353 
1354  cname = cimpl['cname']
1355  if option.get('sparse', False):
1356  if is_cuda:
1357  cname = 'THCS' + env['ScalarName'] + "Tensor_" + cname
1358  else:
1359  cname = env['THTensor'].replace('TH', 'THS') + '_' + cname
1360  elif is_nn:
1361  cname = 'THNN_{}'.format(env['THType']) + cname
1362  else:
1363  cname = env['THTensor'] + '_' + cname
1364 
1365  call = CALL_TEMPLATE.substitute(actuals=actuals, cname=cname)
1366  if cimpl.get('condition') is not None:
1367  call = 'if ({}) {}'.format(cimpl['condition'], call)
1368  return call
1369 
1370  def emit_body(env, option):
1371  # type: (Environment, FunctionOption) -> List[str]
1372  body = [] # type: List[str]
1373  body += handle_sparse(env, option)
1374  body += handle_zero_dim(env, option)
1375  only_zero_dim_check = handle_only_zero_dim(env, option)
1376  if only_zero_dim_check is not None:
1377  # code below only_zero_dim_check is unreachable so we do not need to generate the rest.
1378  body += only_zero_dim_check
1379  return body
1380 
1381  # arguments are potentially duplicated because of one argument
1382  # referencing another
1383  seen_names = set() # type: Set[str]
1384  seen_tensorlists = set() # type: Set[str]
1385  count = 0
1386  output_count = 0
1387 
1388  # scalar_check is the heuristic conditions when a result may be a scalar_check
1389  # if there is a IntArrayRefSize argument, then its dimensions are used to determine scalar.
1390  # otherwise, it is true if all the input tensors are scalars,
1391  scalar_check_is_from_size = False
1392  scalar_check_is_from_option = False
1393  scalar_check = None
1394  scalar_check_opt = option.get('scalar_check')
1395  if scalar_check_opt is not None:
1396  if isinstance(scalar_check_opt, bool):
1397  scalar_check = str(scalar_check_opt).lower()
1398  else:
1399  scalar_check = scalar_check_opt
1400  scalar_check_is_from_option = True
1401 
1402  for arg in option['arguments']:
1403  if is_real_argument_to_wrapper(arg):
1404  count += 1
1405  if arg['type'] == 'IntArrayRefSize' and not scalar_check_is_from_option:
1406  scalar_check_is_from_size = True
1407  scalar_check = '{}.size() == 0'.format(arg['name'])
1408  if arg['type'] == 'TensorList':
1409  seen_tensorlists.add(arg['name'])
1410 
1411  wrap_dim_target = arg.get('wrap_dim', None)
1412  if wrap_dim_target is not None:
1413  # for Tensors, "name_" is the TensorImpl, but for TensorLists, it is an
1414  # std::vector of TH*s. Since TH*s have different dimension rules, we used
1415  # "name" instead, but keep "name_" for tensor to avoid an extra function call.
1416  if wrap_dim_target not in seen_tensorlists:
1417  wrap_dim_target = wrap_dim_target + "_"
1418  body.append("{} = maybe_wrap_dim({}, {});"
1419  .format(arg['name'], arg['name'], wrap_dim_target))
1420 
1421  # only generated checked casts the first time we see it
1422  if arg['name'] not in seen_names and requires_checked_cast(arg):
1423  seen_names.add(arg['name'])
1424 
1425  # make a new allocation of TensorImpl, then wrap a Tensor around it.
1426  if arg.get('allocate', False):
1427  body += allocate_arg(env, arg, output_count)
1428  output_count += 1
1429  # extract the TensorImpl from an existing tensor (or Storage, etc.)
1430  else:
1431  # special case where we allow undefined Tensors, and thus
1432  # the checked cast succeeds even if the Tensor is not
1433  # defined
1434  null_okay = 'true' if nullable_argument(arg) else 'false'
1435  default_init = []
1436  if 'default_init' in arg:
1437  default_init.append(arg['default_init'])
1438 
1439  check_cast = CHECKED_CAST[arg['type']].substitute(
1440  env, arg_name=arg['name'], arg_pos=count,
1441  null_okay=null_okay, default_init=default_init,
1442  size=arg.get('size'))
1443  body.append("auto {}_ = {};".format(
1444  arg['name'], check_cast))
1445  if drop_argument(arg, option) or replace_with_null(arg):
1446  body.append(
1447  "(void) {}_; //silence unused warning".format(arg['name']))
1448 
1449  initializers = []
1450 
1451  # resize tensors for special ops that require it
1452  if 'resize' in arg:
1453  initializers.append(resize_arg(arg))
1454 
1455  # also special handling where we zero some outputs.
1456  if arg.get('zero', False) or (arg.get('cpu_zero', False) and not is_cuda):
1457  initializers.append("{}.zero_();".format(arg['name']))
1458 
1459  # only initialize non-null arguments
1460  if nullable_argument(arg) and len(initializers) > 0:
1461  body.append(CONDITIONAL_INITIALIZER.substitute({
1462  'name': arg['name'],
1463  'initializer': initializers
1464  }))
1465  else:
1466  body += initializers
1467 
1468  # for out-of-place: dim() == 0 for all input tensors is and'd to form
1469  # the test for whether the output is also a scalar
1470  # for in-place: dim() == 0 shouldn't change as a result of the operation
1471  if (not arg.get('output') and 'Tensor' in arg['type'] and
1472  'TensorList' not in arg['type'] and
1473  'THS' not in arg['type'] and
1474  not scalar_check_is_from_size and
1475  not scalar_check_is_from_option and
1476  not option['inplace']):
1477  check = '{}->dim() == 0'.format(arg['name'] + '_')
1478  if nullable_argument(arg):
1479  check = '(!{} || {})'.format(arg['name'] + '_', check)
1480  scalar_check = (check if scalar_check is None
1481  else scalar_check + ' && ' + check)
1482 
1483  # cimpls, if it exists, contains the underlying C function names and
1484  # arguments. Otherwise use option
1485  cimpls = option.get('cimpls', [option])
1486  calls = [handle_call(env, option, cimpl) for cimpl in cimpls]
1487 
1488  ret = option['return']
1489 
1490  if ret['kind'] == 'arguments':
1491  if 'aten_custom_call' in option:
1492  # all aten_custom_call bodies handle settings on their own.
1493  scalar_check = None
1494  body.append(CodeTemplate(
1495  option['aten_custom_call']).substitute(env))
1496  else:
1497  body.extend([call + ';' for call in calls])
1498  arguments_indices = ret['arguments']
1499  arguments = [option['arguments'][argi]
1500  for argi in arguments_indices]
1501  if scalar_check is not None:
1502  if not isinstance(scalar_check, dict):
1503  if len(arguments) > 1:
1504  body.append("bool maybe_scalar = {};".format(scalar_check))
1505  scalar_check = 'maybe_scalar'
1506  for arg in arguments:
1507  scalar_check_arg = (scalar_check if not isinstance(scalar_check, dict)
1508  else scalar_check.get(arg['name'])) # type: ignore
1509  if scalar_check_arg is not None:
1510  stmt = "{}_->maybe_zero_dim({});".format(arg['name'], scalar_check_arg)
1511  if nullable_argument(arg):
1512  stmt = "if ({}_) {}".format(arg['name'], stmt)
1513  body.append(stmt)
1514  if len(arguments_indices) == 1:
1515  arg = arguments[0]
1516  body.append("return {};".format(arg['name']))
1517  else:
1518  types = [to_return_type(arg, option)['type']
1519  for arg in arguments]
1520  # TODO: check for move semantics...
1521  names = [arg['name'] for arg in arguments]
1522  body.append(CodeTemplate("return std::tuple<${types}>(${names});").substitute(
1523  types=types, names=names))
1524  elif ret['kind'] == 'type':
1525  assert len(calls) == 1
1526  call = calls[0]
1527  if 'aten_custom_call' in option:
1528  # all aten_custom_call bodies handle settings on their own.
1529  scalar_check = None
1530  body.append(CodeTemplate(
1531  option['aten_custom_call']).substitute(env))
1532 
1533  if ret['type'] in ALLOC_WRAP.keys():
1534  maybe_scalar = "->maybe_zero_dim({})".format(scalar_check) \
1535  if scalar_check is not None \
1536  else ""
1537  wrapped_tensor = CodeTemplate(ALLOC_WRAP[ret['type']]).substitute(
1538  env, arguments=[call])
1539  return_tensor = (
1540  "return Tensor(" +
1541  "c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(" +
1542  "(${wrapped_tensor})${maybe_scalar}));")
1543  body.append(CodeTemplate(return_tensor).substitute(
1544  env, wrapped_tensor=wrapped_tensor, maybe_scalar=maybe_scalar))
1545  # return the same underlying Tensor type for both real and accreal; this ensures
1546  # e.g. x.sum(0) and x.sum() return the same type. We explicitly cast to the
1547  # ScalarType before constructing the scalar_tensor to avoid overflow checking.
1548  elif ret['type'] == 'accreal' or ret['type'] == 'real':
1549  return_scalar = 'return at::scalar_tensor(convert<${ScalarType}>(${call}), options());'
1550  body.append(CodeTemplate(return_scalar).substitute(env, call=call))
1551  else:
1552  # we using int64_t for long in the API, so correct it here...
1553  if is_actual_return_long(ret):
1554  call = "static_cast<int64_t>({})".format(call)
1555  body.append("return {};".format(call))
1556  else:
1557  raise Exception("NYI - return handling")
1558  return body
1559 
1560  def process_option(option):
1561  # type: (FunctionOption) -> None
1562  pair = (backend_type_env['Backend'],
1563  backend_type_env['ScalarName'])
1564  if pair in option['backend_type_pairs']:
1565  env = nested_dict(option, backend_type_env)
1566  body = emit_body(env, option) # type: ignore
1567  option['type_definition_body'] = body
1568  type_object_declarations.append(
1569  TYPE_DERIVED_DECLARATION.substitute(env))
1570  type_object_definitions.append(
1571  TYPE_DERIVED_DEFINITION.substitute(env))
1572 
1573  def process_native(option):
1574  # type: (FunctionOption) -> None
1575  dispatch = option['type_method_definition_dispatch']
1576  env = nested_dict(option, backend_type_env)
1577 
1578  if isinstance(dispatch, dict):
1579  pair = (backend_type_env['Backend'],
1580  backend_type_env['ScalarName'])
1581  if pair in option['backend_type_pairs']:
1582  native_dispatch = dispatch.get(pair[0])
1583  type_object_declarations.append(
1584  TYPE_DERIVED_DECLARATION.substitute(env))
1585  if native_dispatch is None:
1586  type_object_definitions.append(
1587  TYPE_DERIVED_DEFINITION_NATIVE_MISSING.substitute(env))
1588  else:
1589  option['native_type_method_dispatch'] = native_dispatch
1590  type_object_definitions.append(
1591  TYPE_DERIVED_DEFINITION_NATIVE.substitute(env))
1592 
1593  for declaration in declarations:
1594  for option in declaration['options']:
1595  if not option.get('skip', False):
1596  try:
1597  if option['mode'] == 'NN' and option.get('cimpls') is None:
1598  continue
1599  if option['mode'] != 'native':
1600  process_option(option)
1601  else:
1602  process_native(option)
1603  except NYIError:
1604  pass
1605  return type_object_declarations, type_object_definitions
1606 
1607 
1608 def create_extension_backend(backend_type_env, declarations):
1609  # type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]]
1610  type_object_declarations = []
1611  type_object_definitions = []
1612 
1613  for declaration in declarations:
1614  for option in declaration['options']:
1615  if not option.get('skip', False):
1616  try:
1617  option['formals_types'] = [f['type'] for f in option['formals_list']]
1618  option['native_actuals'] = [f['name'] for f in option['formals_list']]
1619  schema_args = ", ".join(
1620  ["{} {}".format(f['dynamic_type'], f['name']) for f in option['formals_list']])
1621  return_type = NATIVE_DYNAMIC_TYPE.get(option['return_type'], option['return_type'])
1622  option['schema'] = "{}({}) -> {}".format(option['api_name'], schema_args, return_type)
1623  env = nested_dict(option, backend_type_env)
1624  type_object_declarations.append(
1625  TYPE_DERIVED_DECLARATION.substitute(env))
1626  type_object_definitions.append(
1627  TYPE_DEFINITION_EXTENSION_BACKEND.substitute(env))
1628  except NYIError:
1629  pass
1630  return type_object_declarations, type_object_definitions
Module caffe2.python.layers.split.