5 from code_template
import CodeTemplate
11 'Missing build dependency: Unable to import the `typing` module. ' 12 'Please install it via `conda install typing` or `pip install typing`')
15 from typing
import Union, Set
16 from typing
import Any, Dict, List, Optional, Tuple, NamedTuple
19 from mypy_extensions
import TypedDict
23 def TypedDict(name, attrs, total=True):
27 if sys.version_info[0] == 3:
30 string_type = basestring
41 TYPE_METHOD_DECLARATION_BROADCAST = CodeTemplate(
"""\ 42 ${return_type} ${api_name}(${type_method_formals}) const override; 45 TYPE_METHOD_DEFINITION_BROADCAST = CodeTemplate(
"""\ 46 ${return_type} TypeDefault::${api_name}(${type_method_formals}) const { 47 ${device_guard_declaration} 48 Tensor ${broadcast_returns}; 49 std::tie(${broadcast_returns}) = ${broadcast_function}(${broadcast_actuals}, "${api_name}"); 50 return ${method_prefix_derived}${api_name}(${broadcast_modified_actuals}); 62 PURE_VIRTUAL_TYPE_METHOD_DECLARATION = CodeTemplate(
"""\ 63 virtual ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const = 0; 65 DEPRECATED_PURE_VIRTUAL_TYPE_METHOD_DECLARATION = CodeTemplate(
"""\ 66 C10_DEPRECATED virtual ${return_type} \ 67 ${method_prefix_derived}${api_name}(${type_method_formals}) const = 0; 69 PURE_VIRTUAL_TYPE_METHOD_DECLARATION_BROADCAST = CodeTemplate(
"""\ 70 virtual ${return_type} ${api_name}(${type_method_formals}) const = 0; 73 TYPE_METHOD_DECLARATION_ABSTRACT = CodeTemplate(
"""\ 74 ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override; 76 TYPE_METHOD_DEFINITION_ABSTRACT = CodeTemplate(
"""\ 77 ${return_type} TypeDefault::${method_prefix_derived}${api_name}(${type_method_formals}) const { 78 AT_ERROR("${method_prefix_derived}${api_name} is not implemented for type ", toString()); 81 TYPE_METHOD_DECLARATION_CONCRETE = CodeTemplate(
"""\ 82 ${return_type} ${api_name}(${type_method_formals}) const override; 84 TYPE_METHOD_DEFINITION_CONCRETE = CodeTemplate(
"""\ 85 ${return_type} TypeDefault::${api_name}(${type_method_formals}) const { 86 ${device_guard_declaration} 87 ${type_definition_body} 91 TYPE_DERIVED_DECLARATION = CodeTemplate(
"""\ 92 ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override; 95 TYPE_DERIVED_DEFINITION = CodeTemplate(
"""\ 96 ${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const { 97 ${device_guard_declaration} 98 ${type_definition_body} 104 TYPE_DERIVED_DEFINITION_NATIVE = CodeTemplate(
"""\ 105 ${return_type} ${Type}::${api_name}(${type_method_formals}) const { 106 ${device_guard_declaration} 107 ${return_call} at::native::${native_type_method_dispatch}(/* actuals */ ${actuals}); 110 TYPE_DERIVED_DEFINITION_NATIVE_MISSING = CodeTemplate(
"""\ 111 ${return_type} ${Type}::${api_name}(${type_method_formals}) const { 112 AT_ERROR("${api_name} not supported on ${Type}"); 115 TYPE_DEFINITION_BODY_NATIVE = CodeTemplate(
"""\ 116 ${return_call} at::native::${native_type_method_dispatch}(/* native_actuals */ ${native_actuals}); 120 TYPE_DEFINITION_EXTENSION_BACKEND = CodeTemplate(
"""\ 121 ${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const { 122 return ${Type}Dispatch::get_function<${return_type} (*)(${formals_types})>("${schema}")(${native_actuals}); 127 TENSOR_METHOD_DECLARATION = CodeTemplate(
"""\ 128 ${return_type} ${api_name}(${method_formals_with_defaults})${const_mark}; 131 TENSOR_METHOD_DEFINITION = CodeTemplate(
"""\ 132 inline ${return_type} Tensor::${api_name}(${method_formals})${const_mark} { 133 return type().${api_name}(${method_actuals}); 137 FUNCTION_DECLARATION = CodeTemplate(
"""\ 138 static inline ${return_type} ${api_name}(${formals_with_defaults}); 141 DEPRECATED_FUNCTION_DECLARATION = CodeTemplate(
"""\ 142 C10_DEPRECATED static inline ${return_type} ${api_name}(${formals_with_defaults}); 145 FUNCTION_DEFINITION = CodeTemplate(
"""\ 146 static inline ${return_type} ${api_name}(${formals}) { 147 return ${inferred_type}.${api_name}(${type_method_actuals}); 151 NATIVE_DECLARATION = CodeTemplate(
"""\ 152 CAFFE2_API ${return_type} ${native_type_method_dispatch}(${formals_with_defaults}); 156 FACTORY_DEFINITION = CodeTemplate(
"""\ 157 static inline ${return_type} ${api_name}(${formals}) { 158 const DeviceGuard guard(options.device()); 159 return at::native::${api_name}(${type_method_actuals}); 166 ZERO_DIM_CHECK = CodeTemplate(
"""\ 167 if (${check_name}.dim() == 0) { 168 return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${zero_dim_actuals}); 171 ZERO_DIM_ONLY = CodeTemplate(
"""\ 172 AT_ERROR("${api_name} only supports a 0-dimensional ${check_name} tensor, but got tensor " 173 "with ", ${check_name}.dim(), " dimension(s)."); 176 SPARSE_CHECK = CodeTemplate(
"""\ 177 if(${check_name}.is_sparse()) { 178 return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${sparse_actuals}); 181 BUFFER_DEFINITION = CodeTemplate(
"""\ 182 auto ${name}_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>( 183 ${Backend}TensorId(), caffe2::TypeMeta::Make<${ScalarType}>(), ${THTensor}_new(), false).release(); 184 auto ${name} = Tensor(${name}_, false);""")
186 CONDITIONAL_INITIALIZER = CodeTemplate(
"""\ 187 if (${name}.defined()) { 191 CALL_TEMPLATE = CodeTemplate(
"${cname}(${actuals})")
195 """Indicates we don't support this declaration yet""" 197 def __init__(self, reason):
201 TYPE_FORMAL_GENERIC = {
202 'THTensor*':
'Tensor &',
203 'THSTensor*':
'SparseTensorRef',
204 'THBoolTensor*':
'Tensor &',
205 'THIndexTensor*':
'Tensor &',
206 'THIntegerTensor*':
'Tensor &',
207 'THDenseTensor*':
'Tensor &',
208 'THDenseIndexTensor*':
'Tensor &',
209 'THStorage*':
'Storage',
210 'THGenerator*':
'Generator *',
211 'IntArrayRefSize':
'IntArrayRef',
218 'THTensor*':
'Tensor',
219 'THSTensor*':
'SparseTensorRef',
220 'THBoolTensor*':
'BoolTensor',
221 'THIndexTensor*':
'IndexTensor',
222 'THIntegerTensor*':
'IntegerTensor',
223 'THDenseTensor*':
'Tensor',
224 'THDenseIndexTensor*':
'IndexTensor',
225 'THStorage*':
'Storage',
226 'THGenerator*':
'Generator*',
227 'IntArrayRefSize':
'IntArrayRef',
228 'accreal':
'accreal',
233 NATIVE_DYNAMIC_TYPE = {
234 'Tensor &':
'Tensor',
235 'const Tensor &':
'Tensor',
239 'THTensor*':
'Tensor',
240 'THIndexTensor*':
'Tensor',
241 'THBoolTensor*':
'Tensor',
242 'THIntegerTensor*':
'Tensor',
243 'THSTensor*':
'Tensor',
244 'THDenseTensor*':
'Tensor',
245 'THDenseIndexTensor*':
'Tensor',
254 'checked_tensor_unwrap(' 255 '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, ' 256 'Backend::${Backend}, ScalarType::${ScalarName})'),
259 'checked_tensor_unwrap(' 260 '${arg_name}.tref,"${arg_name}",${arg_pos},false, ' 261 'Backend::${Backend}, ScalarType::${ScalarName})'),
264 'checked_tensor_unwrap(' 265 '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, ' 266 'Backend::${Backend}, ScalarType::Byte)'),
269 'checked_tensor_unwrap(' 270 '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, ' 271 'Backend::${Backend}, ScalarType::Long)'),
274 'checked_tensor_unwrap(' 275 '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, ' 276 'Backend::${Backend}, ScalarType::Int)'),
279 'checked_tensor_unwrap(' 280 '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, ' 281 'Backend::${DenseBackend}, ScalarType::${ScalarName})'),
282 'THDenseIndexTensor*':
284 'checked_tensor_unwrap(' 285 '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, ' 286 'Backend::${DenseBackend}, ScalarType::Long)'),
290 '${arg_name},"${arg_name}",${arg_pos}, ' 293 'DeviceType::${Backend}, at::scalarTypeToTypeMeta(ScalarType::${ScalarName}))'),
296 'check_generator<${Backend}Generator>(${arg_name}, &globalContext().defaultGenerator(device_type()))'),
298 'IntArrayRefStride':
CodeTemplate(
'at::IntArrayRef ${result_name} = get_intlist_stride_th(${arg_name});'),
300 'accreal':
CodeTemplate(
'${arg_name}.to${AccScalarName}()'),
302 'checked_tensor_list_unwrap(${arg_name},"${arg_name}",${arg_pos}, ' 303 'Backend::${Backend}, ScalarType::${ScalarName})'),
304 'IntArrayRef':
CodeTemplate(
'check_intlist<${size}>(${arg_name}, "${arg_name}", ${arg_pos}${,default_init})')
310 'THIndexTensor*':
'{}_',
311 'THBoolTensor*':
'{}_',
312 'THIntegerTensor*':
'{}_',
313 'THDenseTensor*':
'{}_',
314 'THDenseIndexTensor*':
'{}_',
315 'THStorage*':
'{}_.unsafeGetStorageImpl()',
316 'THGenerator*':
'{}_->generator',
317 'TensorList':
"{0}_.data(), {0}_.size()",
320 CHECKED_USE_NULLABLE =
CodeTemplate(
'${arg_name}_ ? ${usage} : NULL')
322 ALLOC_NOARGS_WRAP = {
323 'THTensor*':
'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>' 324 '(${Backend}TensorId(), caffe2::TypeMeta::Make<${ScalarType}>(), allocator(), false).release()',
325 'THBoolTensor*':
'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>' 326 '(${Backend}TensorId(), scalarTypeToTypeMeta(ScalarType::Byte), allocator(), false).release()',
327 'THIndexTensor*':
'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>' 328 '(${Backend}TensorId(), scalarTypeToTypeMeta(ScalarType::Long), allocator(), false).release()',
329 'THIntegerTensor*':
'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>' 330 '(${Backend}TensorId(), scalarTypeToTypeMeta(ScalarType::Int), allocator(), false).release()',
331 'THDenseTensor*':
'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>' 332 '(${Backend}TensorId(), caffe2::TypeMeta::Make<${ScalarType}>(), allocator(), false).release()',
333 'THDenseIndexTensor*':
'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>' 334 '(${Backend}TensorId(), scalarTypeToTypeMeta(ScalarType::Long), ' 335 'allocator(), false).release()' 339 'THTensor*':
'${arguments}',
340 'THBoolTensor*':
'${arguments}',
341 'THIndexTensor*':
'${arguments}',
342 'THIntegerTensor*':
'${arguments}',
343 'THDenseTensor*':
'${arguments}',
344 'THDenseIndexTensor*':
'${arguments}',
348 CONSTANT_REPLACEMENTS = [
349 (
'AS_REAL',
'${AS_REAL}'),
350 (
'__last_dim',
'self.ndimension()-1'),
354 HEADER_CONSTANT_REPLACEMENTS = [
355 (
r'AS_REAL\((.*)\)',
r'\1'),
356 (
'__last_dim',
'-1'),
360 class nested_dict(object):
361 def __init__(self, base, parent):
362 self.base, self.parent = base, parent
364 def __getitem__(self, x):
368 return self.parent[x]
371 Environment = TypedDict(
'Environment', {
377 'AccScalarName': str,
380 TopEnvironment = TypedDict(
'TopEnvironment', {
381 'type_registrations': List[str],
382 'type_headers': List[str],
383 'pure_virtual_type_method_declarations': List[str],
384 'pure_virtual_extended_type_method_declarations': List[str],
385 'type_method_declarations': List[str],
386 'type_method_definitions': List[str],
387 'tensor_method_declarations': List[str],
388 'tensor_method_definitions': List[str],
389 'function_declarations': List[str],
390 'function_definitions': List[str],
391 'type_ids': List[str],
392 'native_function_declarations': List[str],
397 THFormal = TypedDict(
'THFormal', {
407 'declared_type': str,
408 'ignore_check': bool,
423 AtFormal = TypedDict(
'AtFormal', {
452 ReturnType = TypedDict(
'ReturnType', {
460 ReturnDecl = TypedDict(
'ReturnDecl', {
463 'arguments': List[int],
467 NNBuffer = TypedDict(
'NNBuffer', {
471 FunctionOption = TypedDict(
'FunctionOption', {
472 'actuals': List[str],
474 'arguments': List[THFormal],
475 'aten_custom_call': str,
476 'aten_dense_sparse': bool,
477 'backend_type_pairs': List[Tuple[str, str]],
478 'backends': List[str],
479 'broadcast_actuals': List[str],
480 'broadcast_function': str,
481 'broadcast_modified_actuals': List[str],
482 'broadcast_returns': List[str],
483 'buffers': List[NNBuffer],
489 'device_guard': bool,
490 'device_guard_declaration': str,
497 'formals_list': List[AtFormal],
498 'formals_with_defaults': List[str],
499 'formals': List[str],
500 'formals_types': List[str],
501 'inferred_type': str,
503 'matches_jit_signature': bool,
506 'extended_method': bool,
507 'method_actuals': List[str],
508 'method_formals_with_defaults': List[str],
509 'method_formals': List[str],
510 'method_prefix_derived': str,
512 'python_module': str,
514 'native_actuals': List[str],
515 'native_type_method_dispatch': str,
518 'schema_string': str,
519 'requires_tensor': bool,
522 'return': ReturnDecl,
523 'returns': List[ReturnType],
528 'type_definition_body': List[str],
529 'type_method_actuals': List[str],
530 'type_method_definition_dispatch': str,
531 'type_method_formals': List[str],
533 'when_spares_dispatch': str,
534 'when_sparse_dispatch': str,
536 'zero_dim_dispatch_when_scalar': str,
537 'zero_dim_tensor_only': bool,
540 OutputDeclaration = NamedTuple(
'OutputDeclaration', [
542 (
'matches_jit_signature', bool),
543 (
'schema_string', str),
544 (
'method_prefix_derived', str),
545 (
'arguments', List[AtFormal]),
546 (
'method_of', List[str]),
548 (
'python_module', str),
549 (
'buffers', Optional[List[str]]),
550 (
'returns', List[ReturnType]),
552 (
'is_factory_method', bool),
554 (
'requires_tensor', bool),
555 (
'device_guard', bool),
557 (
'deprecated', bool),
561 def device_guard(option, formals, dispatch_options, dispatch_tensor):
563 if option.get(
'device_guard',
True):
565 return 'const DeviceGuard device_guard({}.device());'.format(dispatch_options[
'name'])
567 return 'const OptionalDeviceGuard device_guard(device_of({}));'.format(dispatch_tensor)
568 return '// DeviceGuard omitted' 571 def is_real_argument_to_wrapper(argument):
573 return not argument.get(
'output',
False)
and\
574 argument[
'type'] !=
'CONSTANT' and\
575 argument[
'type'] !=
'argument' 578 def is_mutable_formal_argument(argument, option):
580 return argument.get(
'output')
or option[
'inplace']
and argument[
'name'] ==
'self' 583 def check_methods_do_not_start_with_underscore(name, is_method):
584 if name
in {
'_values',
'_indices',
'_nnz',
'_dimI',
'_dimV',
'_coalesced_'}:
586 if is_method
and name.startswith(
'_')
and not name.startswith(
'__')
and not name.startswith(
'_th_'):
587 message =
"Function '{}' starts with a single underscore and is ".format(name)
588 message +=
"configured to have a method on Tensor. Functions that start with " 589 message +=
" a single underscore should only be functions in the at:: " 590 message +=
"namespace and not methods on Tensor!" 591 raise RuntimeError(message)
594 def to_return_type(arg, option):
597 rt = TYPE_RETURN.get(t, t)
598 if rt ==
'Tensor' and not arg.get(
'allocate'):
600 if not is_mutable_formal_argument(arg, option):
605 'dynamic_type': DYNAMIC_TYPE.get(arg[
'type'], arg[
'type']),
609 def create_generic(top_env, declarations):
612 def translate_default(argument, type_str, default):
617 if 'if_true' in argument:
618 return argument[
'default'] == argument[
'if_true']
619 for pattern, replacement
in HEADER_CONSTANT_REPLACEMENTS:
620 default = re.sub(pattern, replacement, str(default))
621 if type_str
in {
'Scalar',
'int64_t',
'double'}:
626 return float(default)
629 elif type_str ==
'bool':
630 assert default.lower()
in [
'true',
'false']
631 return default.lower() ==
'true' 637 def translate_formal(argument, option):
639 type_str = TYPE_FORMAL_GENERIC.get(argument[
'type'], argument[
'type'])
640 if type_str ==
'Tensor &' and not is_mutable_formal_argument(argument, option):
641 type_str =
'const ' + type_str
643 'name': argument[
'name'],
645 'dynamic_type': DYNAMIC_TYPE.get(argument[
'type'], argument[
'type']),
647 if 'kwarg_only' in argument:
648 translated[
'kwarg_only'] = argument[
'kwarg_only']
649 if 'default' in argument:
650 default = translate_default(argument, type_str, argument[
'default'])
651 translated[
'default'] = default
652 translated[
'default_init'] = argument.get(
'default_init', default)
653 if argument.get(
'output'):
654 translated[
'output'] =
True 655 if argument.get(
'size'):
656 translated[
'size'] = argument[
'size']
657 if argument.get(
'is_nullable')
is not None:
658 translated[
'is_nullable'] = argument[
'is_nullable']
661 def get_formals(option, include_constants=False):
667 def insert(argument):
669 if argument[
'name']
not in seen:
670 seen.add(argument[
'name'])
671 if argument.get(
'kwarg_only',
False):
672 kwd_args.append(argument)
674 pos_args.append(argument)
676 def has_output_mask(argument):
678 return argument.get(
'allocate',
False)
and argument.get(
'mask',
False)
680 for argument
in option[
'arguments']:
681 if argument.get(
'output')
and not argument.get(
'allocate',
False):
683 for argument
in option[
'arguments']:
684 if argument[
'type'] ==
'THSTensor*':
686 if not (option.get(
'aten_dense_sparse',
False)):
689 if include_constants
and argument[
'type'] ==
'CONSTANT':
691 elif is_real_argument_to_wrapper(argument):
693 if any(has_output_mask(arg)
for arg
in option[
'arguments']):
694 mask_size = sum(has_output_mask(arg)
for arg
in option[
'arguments'])
696 'name':
'output_mask',
699 'type':
'std::array<bool,{}>'.format(mask_size),
700 'default':
'{{' +
', '.join([
'true'] * mask_size) +
'}}',
703 result = pos_args + kwd_args
704 return [translate_formal(argument, option)
for argument
in result]
706 def get_return_types(option):
708 ret = option[
'return']
709 if ret[
'kind'] ==
'arguments':
710 argument_indices = ret[
'arguments']
711 if len(argument_indices) == 1:
712 the_arg = option[
'arguments'][argument_indices[0]]
713 return [to_return_type(the_arg, option)]
715 return [to_return_type(option[
'arguments'][idx], option)
716 for idx
in argument_indices]
717 elif ret[
'kind'] ==
'type':
719 'type': TYPE_RETURN.get(ret[
'type'], ret[
'type']),
720 'dynamic_type': DYNAMIC_TYPE.get(ret[
'type'], ret[
'type']),
723 raise Exception(
"format_return_type")
725 def format_return_type(return_types):
727 if len(return_types) == 1:
728 return return_types[0][
'type']
729 return "std::tuple<{}>".format(
','.join(r[
'type']
for r
in return_types))
731 def find_dispatch_tensor(formals):
734 def is_any_tensor_type(formal):
735 return (formal[
'dynamic_type'] ==
'Tensor' or formal[
'dynamic_type'] ==
'BoolTensor' 736 or formal[
'dynamic_type'] ==
'IndexTensor')
738 for formal
in formals:
739 if formal[
'name'] ==
'self' and is_any_tensor_type(formal)
and not formal.get(
'is_nullable',
False):
740 return formal[
'name']
742 for formal
in formals:
743 if 'TensorList' == formal[
'dynamic_type']
or is_any_tensor_type(formal)
and \
744 not formal.get(
'is_nullable',
False):
745 return formal[
'name']
749 def format_formal(f):
751 return '{} {}'.format(f[
'type'], f[
'name'])
753 def formal_with_default(f):
759 if isinstance(v, bool):
761 return '{}={}'.format(s, v)
763 def get_broadcast_argument(option):
765 for argument
in option[
'arguments']:
766 if argument.get(
'broadcast'):
770 def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims):
779 if not broadcast_dims:
780 broadcast_actuals = [broadcast_arg[
'name']] + broadcast_arg[
'broadcast'].
split()[0].
split(
",")
782 broadcast_dims_spec = broadcast_arg[
'broadcast'].
split()[1].
split(
':')[1].
split(
',')
784 broadcast_dims = ([x.split(
'.')[0] +
'.size(' + x.split(
'.')[1].replace(
'dim',
'') +
')' 785 for x
in broadcast_dims_spec])
786 broadcast_dims_init_list =
'{' +
','.join(broadcast_dims) +
'}' 787 broadcast_actuals = [broadcast_arg[
'name'], broadcast_dims_init_list]
789 return broadcast_actuals
791 def emit_nn_body(option):
795 actuals = option[
'actuals']
796 base_name = option[
'name'][:-1]
if option[
'inplace']
else option[
'name']
797 fwd_name = option[
'api_name'].replace(base_name, base_name +
'_forward')
799 if len(option[
'buffers']) == 0:
800 return 'return {}({});'.format(fwd_name,
', '.join(actuals))
803 if option[
'api_name'].endswith(
'_out'):
806 for buffer
in option[
'buffers']:
807 body.append(
'Tensor {} = at::empty({{0}}, this->options());'.format(buffer[
'name']))
808 actuals = [arg[
'name']
for arg
in option[
'arguments']
if arg.get(
'output')]
809 actuals += [buffer[
'name']
for buffer
in option[
'buffers']]
810 actuals += [arg[
'name']
for arg
in option[
'arguments']
if not arg.get(
'output')]
812 body.append(
'return std::get<0>({}({}));'.format(fwd_name,
', '.join(actuals)))
815 def process_option(option, output_options):
817 option[
'inplace'] = re.search(
818 '(^__i|[^_]_$)', option[
'api_name'])
is not None 821 formals = get_formals(option)
822 option[
'formals_list'] = formals
823 option[
'formals'] = [format_formal(f)
for f
in formals]
824 option[
'formals_with_defaults'] = [formal_with_default(f)
for f
in formals]
825 option[
'returns'] = get_return_types(option)
826 option[
'return_type'] = format_return_type(option[
'returns'])
827 option[
'return_call'] =
'return ' if option[
'return_type'] !=
'void' else '' 828 option[
'actuals'] = [f[
'name']
for f
in formals]
830 option[
'method_formals'] = [format_formal(f)
for f
in formals
831 if f[
'name'] !=
'self']
832 option[
'method_formals_with_defaults'] = (
833 [formal_with_default(f)
for f
in formals
if f[
'name'] !=
'self'])
834 option[
'method_actuals'] = [
835 f[
'name']
if f[
'name'] !=
'self' else '*this' for f
in formals]
838 option[
'type_method_formals'] = option[
'formals']
839 option[
'type_method_actuals'] = option[
'actuals']
841 option[
'const_mark'] =
'' if option[
'inplace']
else ' const' 843 assert 'method' not in option[
'variants'],
'TH functions cannot be methods' 844 is_function =
'function' in option[
'variants']
845 dispatch_tensor = find_dispatch_tensor(formals)
846 is_namespace_function = is_function
and dispatch_tensor
is not None 848 broadcast_arg = get_broadcast_argument(option)
850 option[
'method_prefix_derived'] =
'' if broadcast_arg
is None else 's_' 851 if option[
'mode'] ==
'TH':
852 option[
'device_guard'] =
False 853 option[
'device_guard_declaration'] = device_guard(option, formals,
False, dispatch_tensor)
855 env = nested_dict(option, top_env)
857 mode = option[
'mode']
859 assert option[
'extended_method'],
'Expected legacy operator to be an extended method' 861 if mode ==
'NN' and option.get(
'cimpls')
is None:
865 top_env[
'pure_virtual_extended_type_method_declarations'].append(
866 PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
867 top_env[
'type_method_declarations'].append(
868 TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
869 body = emit_nn_body(option)
870 top_env[
'type_method_definitions'].append(
871 TYPE_METHOD_DEFINITION_CONCRETE.substitute(
872 env, type_definition_body=body))
873 elif broadcast_arg
is None:
874 top_env[
'pure_virtual_extended_type_method_declarations'].append(
875 PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
876 top_env[
'type_method_declarations'].append(
877 TYPE_METHOD_DECLARATION_ABSTRACT.substitute(env))
878 top_env[
'type_method_definitions'].append(
879 TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
881 top_env[
'pure_virtual_extended_type_method_declarations'].append(
882 PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
883 top_env[
'pure_virtual_extended_type_method_declarations'].append(
884 PURE_VIRTUAL_TYPE_METHOD_DECLARATION_BROADCAST.substitute(env))
885 top_env[
'type_method_declarations'].append(
886 TYPE_METHOD_DECLARATION_BROADCAST.substitute(env))
887 top_env[
'type_method_declarations'].append(
888 TYPE_METHOD_DECLARATION_ABSTRACT.substitute(env))
889 top_env[
'type_method_definitions'].append(
890 TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
892 broadcast_inplace =
'inplace' in broadcast_arg[
'broadcast']
893 broadcast_dims =
'dims:' in broadcast_arg[
'broadcast']
894 option[
'broadcast_actuals'] = get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims)
895 if not broadcast_dims:
896 option[
'broadcast_returns'] = ([
"b_" + x
for x
in option[
'broadcast_actuals']
897 if x != broadcast_arg[
'name']
or not broadcast_inplace])
899 option[
'broadcast_returns'] = [
"b_" + broadcast_arg[
'name']]
901 option[
'broadcast_function'] =
'expand_' + (
'inplace' if broadcast_inplace
902 else 'size' if broadcast_dims
else 'outplace')
903 option[
'broadcast_modified_actuals'] = [
'b_' + y
if 'b_' + y
in option[
'broadcast_returns']
else y
904 for y
in option[
'actuals']]
905 top_env[
'type_method_definitions'].append(
906 TYPE_METHOD_DEFINITION_BROADCAST.substitute(env))
909 if is_namespace_function:
910 option[
'inferred_type'] =
'detail::infer_type({})'.format(dispatch_tensor)
911 top_env[
'function_declarations'].append(
912 FUNCTION_DECLARATION.substitute(env))
913 top_env[
'function_definitions'].append(
914 FUNCTION_DEFINITION.substitute(env))
915 method_of.append(
'namespace')
917 buffer_names = [buffer[
'name']
for buffer
in option.get(
'buffers', [])]
919 output_options.append(OutputDeclaration(
920 name=option[
'api_name'],
921 matches_jit_signature=option[
'matches_jit_signature'],
922 schema_string=option[
'schema_string'],
923 method_prefix_derived=option[
'method_prefix_derived'],
927 python_module=option.get(
'python_module',
''),
928 buffers=buffer_names,
929 returns=option[
'returns'],
930 inplace=option[
'inplace'],
931 is_factory_method=
False,
934 requires_tensor=option.get(
'requires_tensor',
False),
935 device_guard=option.get(
'device_guard',
True),
936 with_gil=option.get(
'with_gil',
False),
937 deprecated=option.get(
'deprecated',
False)
940 def native_get_formals(option, include_constants=False):
946 def insert(argument):
948 if argument[
'name']
not in seen:
949 seen.add(argument[
'name'])
950 if argument.get(
'kwarg_only',
False):
951 kwd_args.append(argument)
953 pos_args.append(argument)
955 for argument
in option[
'arguments']:
960 def add_dynamic_type(argument, option):
962 argument[
'dynamic_type'] = NATIVE_DYNAMIC_TYPE.get(argument[
'type'], argument[
'type'])
965 result = pos_args + kwd_args
966 result = [add_dynamic_type(argument, option)
for argument
in result]
969 def native_translate_formals(argument, option):
971 def translate_map(const):
974 'Tensor':
'const Tensor &' if const
else 'Tensor &',
975 'BoolTensor':
'const Tensor &' if const
else 'Tensor &',
976 'IndexTensor':
'const Tensor &' if const
else 'Tensor &',
977 'Type':
'const Type &' if const
else 'Type &',
978 'TensorOptions':
'const TensorOptions &' if const
else 'TensorOptions &',
979 'TensorList':
'TensorList',
982 if argument.get(
'is_nullable')
and argument[
'type']
not in translate_map(
False).keys():
983 argument[
'type'] =
"c10::optional<{}>".format(argument[
'type'])
985 if (option[
'inplace']
and argument[
'name'] ==
'self')
or argument.get(
'output',
False):
986 argument[
'type'] = translate_map(
False).get(argument[
'type'], argument[
'type'])
988 argument[
'type'] = translate_map(
True).get(argument[
'type'], argument[
'type'])
992 result = [native_translate_formals(argument, option)
for argument
in result]
996 def native_get_return_types(option):
998 ret = option[
'return']
1004 if isinstance(t_raw, string_type):
1012 name = t_raw[
'name']
1013 if 'field_name' in t_raw:
1014 field_name = t_raw[
'field_name']
1017 actual_return_type = {
'TensorList':
'std::vector<Tensor>'}.get(t, t)
1019 if actual_return_type ==
'Tensor' and (option[
'inplace']
or option[
'api_name'].endswith(
'_out')):
1021 actual_return_type =
'Tensor &' 1024 'type': actual_return_type,
1025 'dynamic_type': NATIVE_DYNAMIC_TYPE.get(t, t),
1027 if name
is not None:
1028 rtype[
'name'] = name
1029 if field_name
is not None:
1030 rtype[
'field_name'] = field_name
1031 return_types.append(rtype)
1035 def process_native(option, output_options):
1037 assert option[
'python_module'] ==
'' or option[
'python_module'] ==
'nn', \
1038 "Found python_module of {} for decl {}, but only \'\' string or \'nn\' are supported".format(
1039 option[
'python_module'], option[
'name'])
1041 formals = native_get_formals(option)
1042 option[
'formals_list'] = formals
1043 option[
'formals'] = [format_formal(f)
for f
in formals]
1044 option[
'formals_with_defaults'] = [formal_with_default(f)
for f
in formals]
1045 option[
'returns'] = native_get_return_types(option)
1046 option[
'return_type'] = format_return_type(option[
'returns'])
1047 option[
'return_call'] =
'return ' if option[
'return_type'] !=
'void' else '' 1048 option[
'actuals'] = [f[
'name']
for f
in formals]
1050 option[
'method_formals'] = [format_formal(f)
for f
in formals
1051 if f[
'name'] !=
'self']
1052 option[
'method_formals_with_defaults'] = (
1053 [formal_with_default(f)
for f
in formals
if f[
'name'] !=
'self'])
1054 option[
'method_actuals'] = [
1055 f[
'name']
if f[
'name'] !=
'self' else '*this' for f
in formals]
1057 def find_formal(formal_name, formals):
1058 for formal
in formals:
1059 if formal_name == formal[
'dynamic_type']:
1063 assert find_formal(
'Type', formals)
is None, \
1064 "Found Type argument in {}({}). Use TensorOptions instead.".format(
1065 option[
'name'],
", ".join(option[
'method_formals_with_defaults']))
1067 type_method_dispatch = option[
'type_method_definition_dispatch']
1069 dispatch_options = find_formal(
'TensorOptions', formals)
1071 dispatch_tensor =
None if dispatch_options
else find_dispatch_tensor(formals)
1073 option[
'type_method_formals'] = [format_formal(f)
for f
in formals]
1074 option[
'type_method_actuals'] = [f[
'name']
for f
in formals]
1075 option[
'native_actuals'] = [f[
'name']
for f
in formals]
1077 option[
'const_mark'] =
'' if option[
'inplace']
else ' const' 1079 is_method =
'method' in option[
'variants']
1080 is_namespace_function =
'function' in option[
'variants']
1081 is_factory_method = find_formal(
'TensorOptions', formals)
and \
1082 not dispatch_options
and 'method' not in option[
'variants']
1084 check_methods_do_not_start_with_underscore(option[
'name'], is_method)
1086 option[
'method_prefix_derived'] =
'' 1087 option[
'device_guard_declaration'] = device_guard(option, formals, dispatch_options, dispatch_tensor)
1089 env = nested_dict(option, top_env)
1091 broadcast_arg = get_broadcast_argument(option)
1092 if broadcast_arg
is not None:
1093 raise Exception(
"broadcasting is not yet supported for native functions, " 1094 "but specified for function {}", option[
'name'])
1096 if option[
'extended_method']:
1097 top_env[
'pure_virtual_extended_type_method_declarations'].append(
1098 PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
1100 top_env[
'pure_virtual_type_method_declarations'].append(
1101 PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
1102 top_env[
'type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
1103 option[
'native_type_method_dispatch'] = type_method_dispatch
1114 if isinstance(type_method_dispatch, dict):
1116 top_env[
'type_method_definitions'].append(
1117 TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
1119 body = TYPE_DEFINITION_BODY_NATIVE.substitute(env)
1120 top_env[
'type_method_definitions'].append(
1121 TYPE_METHOD_DEFINITION_CONCRETE.substitute(
1122 env, type_definition_body=body))
1125 if isinstance(type_method_dispatch, dict):
1126 generated_native_functions = []
1127 for key
in sorted(type_method_dispatch.keys()):
1128 value = type_method_dispatch[key]
1129 if value
not in generated_native_functions:
1130 option[
'native_type_method_dispatch'] = value
1131 top_env[
'native_function_declarations'].append(
1132 NATIVE_DECLARATION.substitute(env))
1133 generated_native_functions.append(value)
1135 top_env[
'native_function_declarations'].append(
1136 NATIVE_DECLARATION.substitute(env))
1138 method_of = [
'Type']
1140 top_env[
'tensor_method_declarations'].append(
1141 TENSOR_METHOD_DECLARATION.substitute(env))
1142 top_env[
'tensor_method_definitions'].append(
1143 TENSOR_METHOD_DEFINITION.substitute(env))
1144 method_of.append(
'Tensor')
1146 if is_namespace_function:
1148 option[
'inferred_type'] =
'detail::infer_type({})'.format(dispatch_tensor)
1149 elif dispatch_options:
1150 option[
'inferred_type'] =
'at::getType({})'.format(dispatch_options[
'name'])
1153 option[
'inferred_type'] =
'at::getNonVariableType(at::Backend::Undefined, at::ScalarType::Float)' 1154 declaration = DEPRECATED_FUNCTION_DECLARATION
if option[
'deprecated']
else FUNCTION_DECLARATION
1155 top_env[
'function_declarations'].append(declaration.substitute(env))
1156 top_env[
'function_definitions'].append(FUNCTION_DEFINITION.substitute(env))
1157 method_of.append(
'namespace')
1159 output_options.append(OutputDeclaration(
1160 name=option[
'api_name'],
1161 matches_jit_signature=option[
"matches_jit_signature"],
1162 schema_string=option[
"schema_string"],
1163 method_prefix_derived=option[
'method_prefix_derived'],
1165 method_of=method_of,
1166 mode=option[
'mode'],
1167 python_module=option[
'python_module'],
1169 returns=option[
'returns'],
1170 inplace=option[
'inplace'],
1171 is_factory_method=is_factory_method,
1174 requires_tensor=option.get(
'requires_tensor',
False),
1175 device_guard=option.get(
'device_guard',
True),
1176 with_gil=option.get(
'with_gil',
False),
1177 deprecated=option[
'deprecated'],
1180 output_declarations = []
1181 for declaration
in declarations:
1183 for option
in declaration[
'options']:
1184 option[
"matches_jit_signature"] = declaration[
"matches_jit_signature"]
1185 option[
"schema_string"] = declaration[
"schema_string"]
1187 if option[
'mode'] !=
'native':
1188 process_option(option, output_options)
1190 process_native(option, output_options)
1192 option[
'skip'] =
True 1193 output_declarations.extend(output_options)
1195 return output_declarations
1198 def create_derived(backend_type_env, declarations):
1200 type_object_declarations = []
1201 type_object_definitions = []
1203 is_cuda =
'CUDA' in backend_type_env[
'Backend']
1205 def replace_with_null(argument):
1207 return (argument[
'type'] ==
'THGenerator*' and 1208 backend_type_env[
'Backend'] ==
'CUDA')
1210 def requires_checked_cast(argument):
1212 if argument[
'type'] ==
'IntArrayRef':
1213 return 'size' in argument
1214 return argument[
'type']
in CHECKED_CAST
1216 def nullable_argument(argument):
1218 return argument.get(
'is_nullable',
False)
1220 def bool_option_is_string(argument):
1222 return 'if_true' in argument
and isinstance(argument[
'if_true'], string_type)
1224 def get_argument(argument, option):
1226 if replace_with_null(argument):
1228 elif requires_checked_cast(argument):
1229 checked_use = CHECKED_USE.get(
1230 argument[
'type'],
'{}_').format(argument[
'name'])
1231 if nullable_argument(argument):
1232 checked_use = CHECKED_USE_NULLABLE.substitute(
1233 env={}, arg_name=argument[
'name'], usage=checked_use)
1235 elif argument[
'type'] ==
'bool' and 'if_true' in argument:
1236 if bool_option_is_string(argument):
1237 tpl =
'({}) ? "{}" : "{}"' 1239 tpl =
'({}) ? {} : {}' 1240 return tpl.format(argument[
'name'],
1241 argument[
'if_true'], argument[
'if_false'])
1242 elif argument[
'type'] ==
'CONSTANT':
1244 if bool_option_is_string(argument):
1245 return '"{}"'.format(argument[
'name'])
1246 v = str(argument.get(
'default', argument[
'name']))
1247 for pattern, replacement
in CONSTANT_REPLACEMENTS:
1248 v = re.sub(pattern, replacement, v)
1251 elif argument[
'type'] ==
'argument':
1252 index = int(argument[
'name'])
1253 return get_argument(option[
'arguments'][index], option)
1255 return argument[
'name']
1257 def drop_argument(argument, option):
1260 if argument[
'name'] ==
'device':
1262 return 'CUDA' in backend_type_env[
'Backend']
and (
1263 option[
'mode'] ==
'TH' and argument[
'type'] ==
'THGenerator*')
1265 def get_arguments(arguments, option):
1267 return [get_argument(argument, option)
1268 for argument
in arguments
if not drop_argument(argument, option)]
1270 def is_actual_return_long(ret):
1272 if ret[
'type'] ==
'long':
1274 if ret[
'type'] ==
'real':
1275 return backend_type_env[
'ScalarName'] ==
'Long' 1276 if ret[
'type'] ==
'accreal':
1277 return backend_type_env[
'AccScalarName'] ==
'Long' 1280 def handle_zero_dim(env, option):
1282 zero_dim_dispatch = option.get(
'zero_dim_dispatch_when_scalar',
'')
1283 if not zero_dim_dispatch:
1285 broadcasts_arg = zero_dim_dispatch
in option.get(
'broadcast_actuals',
'')
1286 zero_dim_only = option.get(
'zero_dim_tensor_only',
False)
1288 assert not (broadcasts_arg
and zero_dim_only)
1293 zero_dim_actuals = [arg[
'name']
1294 if arg[
'name'] != zero_dim_dispatch
else "{}.item()".format(arg[
'name'])
1295 for arg
in option[
'formals_list']]
1296 return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)]
1298 def handle_only_zero_dim(env, option):
1300 if option.get(
'zero_dim_tensor_only',
False):
1301 check_name = option[
'zero_dim_dispatch_when_scalar']
1302 return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)]
1306 def handle_sparse(env, option):
1308 if 'when_sparse_dispatch' not in option
or 'Sparse' in backend_type_env[
'Backend']:
1310 check_name = option[
'when_sparse_dispatch']
1311 sparse_actuals = [arg[
'name']
1312 if arg[
'name'] != check_name
else "SparseTensorRef({})".format(arg[
'name'])
1313 for arg
in option[
'formals_list']]
1314 return [SPARSE_CHECK.substitute(env, check_name=check_name, sparse_actuals=sparse_actuals)]
1316 def allocate_arg(env, arg, output_count):
1321 state =
'globalContext().getTHCState()' 1322 allocation =
CodeTemplate(ALLOC_NOARGS_WRAP[arg[
'type']]).substitute(env)
1323 tensor_arg =
'{}_'.format(name)
1324 if arg.get(
'mask',
False):
1325 allocation =
'output_mask[{}] ? {} : nullptr'.format(output_count, allocation)
1326 tensor_arg = (
'{}_ == nullptr ? (TensorImpl*)UndefinedTensorImpl::singleton() : (TensorImpl*){}_' 1327 .format(name, name))
1328 intrusive_ptr_type =
'c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>' 1330 'auto {}_ = {};'.format(name, allocation),
1331 'auto {} = Tensor({}::reclaim({}));'.format(name, intrusive_ptr_type, tensor_arg),
1334 def resize_arg(arg):
1336 resize = arg[
'resize']
1337 if isinstance(resize, str):
1338 return "{}.resize_({}.sizes());".format(arg[
'name'], resize)
1340 resize_scalar = arg.get(
'resize_scalar',
False)
1342 dims = [
'{}.dim() == 0 ? 1 : {}.size({})'.format(name, name, dim)
for name, dim
in resize]
1344 dims = [
'{}.size({})'.format(name, dim)
for name, dim
in resize]
1345 return "{}.resize_({{ {} }});".format(arg[
'name'],
','.join(dims))
1347 def handle_call(env, option, cimpl):
1349 is_nn = option[
'mode'] ==
'NN' 1350 actuals = get_arguments(cimpl[
'arguments'], option)
1351 if is_cuda
or is_nn:
1352 actuals = [
'globalContext().getTHCState()'] + actuals
1354 cname = cimpl[
'cname']
1355 if option.get(
'sparse',
False):
1357 cname =
'THCS' + env[
'ScalarName'] +
"Tensor_" + cname
1359 cname = env[
'THTensor'].replace(
'TH',
'THS') +
'_' + cname
1361 cname =
'THNN_{}'.format(env[
'THType']) + cname
1363 cname = env[
'THTensor'] +
'_' + cname
1365 call = CALL_TEMPLATE.substitute(actuals=actuals, cname=cname)
1366 if cimpl.get(
'condition')
is not None:
1367 call =
'if ({}) {}'.format(cimpl[
'condition'], call)
1370 def emit_body(env, option):
1373 body += handle_sparse(env, option)
1374 body += handle_zero_dim(env, option)
1375 only_zero_dim_check = handle_only_zero_dim(env, option)
1376 if only_zero_dim_check
is not None:
1378 body += only_zero_dim_check
1384 seen_tensorlists = set()
1391 scalar_check_is_from_size =
False 1392 scalar_check_is_from_option =
False 1394 scalar_check_opt = option.get(
'scalar_check')
1395 if scalar_check_opt
is not None:
1396 if isinstance(scalar_check_opt, bool):
1397 scalar_check = str(scalar_check_opt).lower()
1399 scalar_check = scalar_check_opt
1400 scalar_check_is_from_option =
True 1402 for arg
in option[
'arguments']:
1403 if is_real_argument_to_wrapper(arg):
1405 if arg[
'type'] ==
'IntArrayRefSize' and not scalar_check_is_from_option:
1406 scalar_check_is_from_size =
True 1407 scalar_check =
'{}.size() == 0'.format(arg[
'name'])
1408 if arg[
'type'] ==
'TensorList':
1409 seen_tensorlists.add(arg[
'name'])
1411 wrap_dim_target = arg.get(
'wrap_dim',
None)
1412 if wrap_dim_target
is not None:
1416 if wrap_dim_target
not in seen_tensorlists:
1417 wrap_dim_target = wrap_dim_target +
"_" 1418 body.append(
"{} = maybe_wrap_dim({}, {});" 1419 .format(arg[
'name'], arg[
'name'], wrap_dim_target))
1422 if arg[
'name']
not in seen_names
and requires_checked_cast(arg):
1423 seen_names.add(arg[
'name'])
1426 if arg.get(
'allocate',
False):
1427 body += allocate_arg(env, arg, output_count)
1434 null_okay =
'true' if nullable_argument(arg)
else 'false' 1436 if 'default_init' in arg:
1437 default_init.append(arg[
'default_init'])
1439 check_cast = CHECKED_CAST[arg[
'type']].substitute(
1440 env, arg_name=arg[
'name'], arg_pos=count,
1441 null_okay=null_okay, default_init=default_init,
1442 size=arg.get(
'size'))
1443 body.append(
"auto {}_ = {};".format(
1444 arg[
'name'], check_cast))
1445 if drop_argument(arg, option)
or replace_with_null(arg):
1447 "(void) {}_; //silence unused warning".format(arg[
'name']))
1453 initializers.append(resize_arg(arg))
1456 if arg.get(
'zero',
False)
or (arg.get(
'cpu_zero',
False)
and not is_cuda):
1457 initializers.append(
"{}.zero_();".format(arg[
'name']))
1460 if nullable_argument(arg)
and len(initializers) > 0:
1461 body.append(CONDITIONAL_INITIALIZER.substitute({
1462 'name': arg[
'name'],
1463 'initializer': initializers
1466 body += initializers
1471 if (
not arg.get(
'output')
and 'Tensor' in arg[
'type']
and 1472 'TensorList' not in arg[
'type']
and 1473 'THS' not in arg[
'type']
and 1474 not scalar_check_is_from_size
and 1475 not scalar_check_is_from_option
and 1476 not option[
'inplace']):
1477 check =
'{}->dim() == 0'.format(arg[
'name'] +
'_')
1478 if nullable_argument(arg):
1479 check =
'(!{} || {})'.format(arg[
'name'] +
'_', check)
1480 scalar_check = (check
if scalar_check
is None 1481 else scalar_check +
' && ' + check)
1485 cimpls = option.get(
'cimpls', [option])
1486 calls = [handle_call(env, option, cimpl)
for cimpl
in cimpls]
1488 ret = option[
'return']
1490 if ret[
'kind'] ==
'arguments':
1491 if 'aten_custom_call' in option:
1495 option[
'aten_custom_call']).substitute(env))
1497 body.extend([call +
';' for call
in calls])
1498 arguments_indices = ret[
'arguments']
1499 arguments = [option[
'arguments'][argi]
1500 for argi
in arguments_indices]
1501 if scalar_check
is not None:
1502 if not isinstance(scalar_check, dict):
1503 if len(arguments) > 1:
1504 body.append(
"bool maybe_scalar = {};".format(scalar_check))
1505 scalar_check =
'maybe_scalar' 1506 for arg
in arguments:
1507 scalar_check_arg = (scalar_check
if not isinstance(scalar_check, dict)
1508 else scalar_check.get(arg[
'name']))
1509 if scalar_check_arg
is not None:
1510 stmt =
"{}_->maybe_zero_dim({});".format(arg[
'name'], scalar_check_arg)
1511 if nullable_argument(arg):
1512 stmt =
"if ({}_) {}".format(arg[
'name'], stmt)
1514 if len(arguments_indices) == 1:
1516 body.append(
"return {};".format(arg[
'name']))
1518 types = [to_return_type(arg, option)[
'type']
1519 for arg
in arguments]
1521 names = [arg[
'name']
for arg
in arguments]
1522 body.append(
CodeTemplate(
"return std::tuple<${types}>(${names});").substitute(
1523 types=types, names=names))
1524 elif ret[
'kind'] ==
'type':
1525 assert len(calls) == 1
1527 if 'aten_custom_call' in option:
1531 option[
'aten_custom_call']).substitute(env))
1533 if ret[
'type']
in ALLOC_WRAP.keys():
1534 maybe_scalar =
"->maybe_zero_dim({})".format(scalar_check) \
1535 if scalar_check
is not None \
1537 wrapped_tensor =
CodeTemplate(ALLOC_WRAP[ret[
'type']]).substitute(
1538 env, arguments=[call])
1541 "c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(" +
1542 "(${wrapped_tensor})${maybe_scalar}));")
1544 env, wrapped_tensor=wrapped_tensor, maybe_scalar=maybe_scalar))
1548 elif ret[
'type'] ==
'accreal' or ret[
'type'] ==
'real':
1549 return_scalar =
'return at::scalar_tensor(convert<${ScalarType}>(${call}), options());' 1550 body.append(
CodeTemplate(return_scalar).substitute(env, call=call))
1553 if is_actual_return_long(ret):
1554 call =
"static_cast<int64_t>({})".format(call)
1555 body.append(
"return {};".format(call))
1557 raise Exception(
"NYI - return handling")
1560 def process_option(option):
1562 pair = (backend_type_env[
'Backend'],
1563 backend_type_env[
'ScalarName'])
1564 if pair
in option[
'backend_type_pairs']:
1565 env = nested_dict(option, backend_type_env)
1566 body = emit_body(env, option)
1567 option[
'type_definition_body'] = body
1568 type_object_declarations.append(
1569 TYPE_DERIVED_DECLARATION.substitute(env))
1570 type_object_definitions.append(
1571 TYPE_DERIVED_DEFINITION.substitute(env))
1573 def process_native(option):
1575 dispatch = option[
'type_method_definition_dispatch']
1576 env = nested_dict(option, backend_type_env)
1578 if isinstance(dispatch, dict):
1579 pair = (backend_type_env[
'Backend'],
1580 backend_type_env[
'ScalarName'])
1581 if pair
in option[
'backend_type_pairs']:
1582 native_dispatch = dispatch.get(pair[0])
1583 type_object_declarations.append(
1584 TYPE_DERIVED_DECLARATION.substitute(env))
1585 if native_dispatch
is None:
1586 type_object_definitions.append(
1587 TYPE_DERIVED_DEFINITION_NATIVE_MISSING.substitute(env))
1589 option[
'native_type_method_dispatch'] = native_dispatch
1590 type_object_definitions.append(
1591 TYPE_DERIVED_DEFINITION_NATIVE.substitute(env))
1593 for declaration
in declarations:
1594 for option
in declaration[
'options']:
1595 if not option.get(
'skip',
False):
1597 if option[
'mode'] ==
'NN' and option.get(
'cimpls')
is None:
1599 if option[
'mode'] !=
'native':
1600 process_option(option)
1602 process_native(option)
1605 return type_object_declarations, type_object_definitions
1608 def create_extension_backend(backend_type_env, declarations):
1610 type_object_declarations = []
1611 type_object_definitions = []
1613 for declaration
in declarations:
1614 for option
in declaration[
'options']:
1615 if not option.get(
'skip',
False):
1617 option[
'formals_types'] = [f[
'type']
for f
in option[
'formals_list']]
1618 option[
'native_actuals'] = [f[
'name']
for f
in option[
'formals_list']]
1619 schema_args =
", ".join(
1620 [
"{} {}".format(f[
'dynamic_type'], f[
'name'])
for f
in option[
'formals_list']])
1621 return_type = NATIVE_DYNAMIC_TYPE.get(option[
'return_type'], option[
'return_type'])
1622 option[
'schema'] =
"{}({}) -> {}".format(option[
'api_name'], schema_args, return_type)
1623 env = nested_dict(option, backend_type_env)
1624 type_object_declarations.append(
1625 TYPE_DERIVED_DECLARATION.substitute(env))
1626 type_object_definitions.append(
1627 TYPE_DEFINITION_EXTENSION_BACKEND.substitute(env))
1630 return type_object_declarations, type_object_definitions
Module caffe2.python.layers.split.