5 from code_template 
import CodeTemplate
    11         'Missing build dependency: Unable to import the `typing` module. '    12         'Please install it via `conda install typing` or `pip install typing`')
    15 from typing 
import Union, Set  
    16 from typing 
import Any, Dict, List, Optional, Tuple, NamedTuple
    19     from mypy_extensions 
import TypedDict
    23     def TypedDict(name, attrs, total=True):  
    27 if sys.version_info[0] == 3:
    30     string_type = basestring
    41 TYPE_METHOD_DECLARATION_BROADCAST = CodeTemplate(
"""\    42 ${return_type} ${api_name}(${type_method_formals}) const override;    45 TYPE_METHOD_DEFINITION_BROADCAST = CodeTemplate(
"""\    46 ${return_type} TypeDefault::${api_name}(${type_method_formals}) const {    47     ${device_guard_declaration}    48     Tensor ${broadcast_returns};    49     std::tie(${broadcast_returns}) = ${broadcast_function}(${broadcast_actuals}, "${api_name}");    50     return ${method_prefix_derived}${api_name}(${broadcast_modified_actuals});    62 PURE_VIRTUAL_TYPE_METHOD_DECLARATION = CodeTemplate(
"""\    63 virtual ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const = 0;    65 DEPRECATED_PURE_VIRTUAL_TYPE_METHOD_DECLARATION = CodeTemplate(
"""\    66 C10_DEPRECATED virtual ${return_type} \    67 ${method_prefix_derived}${api_name}(${type_method_formals}) const = 0;    69 PURE_VIRTUAL_TYPE_METHOD_DECLARATION_BROADCAST = CodeTemplate(
"""\    70 virtual ${return_type} ${api_name}(${type_method_formals}) const = 0;    73 TYPE_METHOD_DECLARATION_ABSTRACT = CodeTemplate(
"""\    74 ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override;    76 TYPE_METHOD_DEFINITION_ABSTRACT = CodeTemplate(
"""\    77 ${return_type} TypeDefault::${method_prefix_derived}${api_name}(${type_method_formals}) const {    78     AT_ERROR("${method_prefix_derived}${api_name} is not implemented for type ", toString());    81 TYPE_METHOD_DECLARATION_CONCRETE = CodeTemplate(
"""\    82 ${return_type} ${api_name}(${type_method_formals}) const override;    84 TYPE_METHOD_DEFINITION_CONCRETE = CodeTemplate(
"""\    85 ${return_type} TypeDefault::${api_name}(${type_method_formals}) const {    86     ${device_guard_declaration}    87     ${type_definition_body}    91 TYPE_DERIVED_DECLARATION = CodeTemplate(
"""\    92 ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override;    95 TYPE_DERIVED_DEFINITION = CodeTemplate(
"""\    96 ${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const {    97     ${device_guard_declaration}    98     ${type_definition_body}   104 TYPE_DERIVED_DEFINITION_NATIVE = CodeTemplate(
"""\   105 ${return_type} ${Type}::${api_name}(${type_method_formals}) const {   106     ${device_guard_declaration}   107     ${return_call} at::native::${native_type_method_dispatch}(/* actuals */ ${actuals});   110 TYPE_DERIVED_DEFINITION_NATIVE_MISSING = CodeTemplate(
"""\   111 ${return_type} ${Type}::${api_name}(${type_method_formals}) const {   112     AT_ERROR("${api_name} not supported on ${Type}");   115 TYPE_DEFINITION_BODY_NATIVE = CodeTemplate(
"""\   116 ${return_call} at::native::${native_type_method_dispatch}(/* native_actuals */ ${native_actuals});   120 TYPE_DEFINITION_EXTENSION_BACKEND = CodeTemplate(
"""\   121 ${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const {   122     return ${Type}Dispatch::get_function<${return_type} (*)(${formals_types})>("${schema}")(${native_actuals});   127 TENSOR_METHOD_DECLARATION = CodeTemplate(
"""\   128 ${return_type} ${api_name}(${method_formals_with_defaults})${const_mark};   131 TENSOR_METHOD_DEFINITION = CodeTemplate(
"""\   132 inline ${return_type} Tensor::${api_name}(${method_formals})${const_mark} {   133     return type().${api_name}(${method_actuals});   137 FUNCTION_DECLARATION = CodeTemplate(
"""\   138 static inline ${return_type} ${api_name}(${formals_with_defaults});   141 DEPRECATED_FUNCTION_DECLARATION = CodeTemplate(
"""\   142 C10_DEPRECATED static inline ${return_type} ${api_name}(${formals_with_defaults});   145 FUNCTION_DEFINITION = CodeTemplate(
"""\   146 static inline ${return_type} ${api_name}(${formals}) {   147     return ${inferred_type}.${api_name}(${type_method_actuals});   151 NATIVE_DECLARATION = CodeTemplate(
"""\   152 CAFFE2_API ${return_type} ${native_type_method_dispatch}(${formals_with_defaults});   156 FACTORY_DEFINITION = CodeTemplate(
"""\   157 static inline ${return_type} ${api_name}(${formals}) {   158     const DeviceGuard guard(options.device());   159     return at::native::${api_name}(${type_method_actuals});   166 ZERO_DIM_CHECK = CodeTemplate(
"""\   167 if (${check_name}.dim() == 0) {   168     return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${zero_dim_actuals});   171 ZERO_DIM_ONLY = CodeTemplate(
"""\   172 AT_ERROR("${api_name} only supports a 0-dimensional ${check_name} tensor, but got tensor "   173     "with ", ${check_name}.dim(), " dimension(s).");   176 SPARSE_CHECK = CodeTemplate(
"""\   177 if(${check_name}.is_sparse()) {   178     return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${sparse_actuals});   181 BUFFER_DEFINITION = CodeTemplate(
"""\   182 auto ${name}_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(   183     ${Backend}TensorId(), caffe2::TypeMeta::Make<${ScalarType}>(), ${THTensor}_new(), false).release();   184 auto ${name} = Tensor(${name}_, false);""")
   186 CONDITIONAL_INITIALIZER = CodeTemplate(
"""\   187 if (${name}.defined()) {   191 CALL_TEMPLATE = CodeTemplate(
"${cname}(${actuals})")
   195     """Indicates we don't support this declaration yet"""   197     def __init__(self, reason):
   201 TYPE_FORMAL_GENERIC = {
   202     'THTensor*': 
'Tensor &',
   203     'THSTensor*': 
'SparseTensorRef',
   204     'THBoolTensor*': 
'Tensor &',
   205     'THIndexTensor*': 
'Tensor &',
   206     'THIntegerTensor*': 
'Tensor &',
   207     'THDenseTensor*': 
'Tensor &',
   208     'THDenseIndexTensor*': 
'Tensor &',
   209     'THStorage*': 
'Storage',
   210     'THGenerator*': 
'Generator *',
   211     'IntArrayRefSize': 
'IntArrayRef',
   218     'THTensor*': 
'Tensor',
   219     'THSTensor*': 
'SparseTensorRef',
   220     'THBoolTensor*': 
'BoolTensor',
   221     'THIndexTensor*': 
'IndexTensor',
   222     'THIntegerTensor*': 
'IntegerTensor',
   223     'THDenseTensor*': 
'Tensor',
   224     'THDenseIndexTensor*': 
'IndexTensor',
   225     'THStorage*': 
'Storage',
   226     'THGenerator*': 
'Generator*',
   227     'IntArrayRefSize': 
'IntArrayRef',
   228     'accreal': 
'accreal',
   233 NATIVE_DYNAMIC_TYPE = {
   234     'Tensor &': 
'Tensor',
   235     'const Tensor &': 
'Tensor',
   239     'THTensor*': 
'Tensor',
   240     'THIndexTensor*': 
'Tensor',
   241     'THBoolTensor*': 
'Tensor',
   242     'THIntegerTensor*': 
'Tensor',
   243     'THSTensor*': 
'Tensor',
   244     'THDenseTensor*': 
'Tensor',
   245     'THDenseIndexTensor*': 
'Tensor',
   254             'checked_tensor_unwrap('   255             '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '   256             'Backend::${Backend}, ScalarType::${ScalarName})'),
   259             'checked_tensor_unwrap('   260             '${arg_name}.tref,"${arg_name}",${arg_pos},false, '   261             'Backend::${Backend}, ScalarType::${ScalarName})'),
   264             'checked_tensor_unwrap('   265             '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '   266             'Backend::${Backend}, ScalarType::Byte)'),
   269             'checked_tensor_unwrap('   270             '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '   271             'Backend::${Backend}, ScalarType::Long)'),
   274             'checked_tensor_unwrap('   275             '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '   276             'Backend::${Backend}, ScalarType::Int)'),
   279             'checked_tensor_unwrap('   280             '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '   281             'Backend::${DenseBackend}, ScalarType::${ScalarName})'),
   282     'THDenseIndexTensor*':
   284             'checked_tensor_unwrap('   285             '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '   286             'Backend::${DenseBackend}, ScalarType::Long)'),
   290             '${arg_name},"${arg_name}",${arg_pos}, '   293             'DeviceType::${Backend}, at::scalarTypeToTypeMeta(ScalarType::${ScalarName}))'),
   296             'check_generator<${Backend}Generator>(${arg_name}, &globalContext().defaultGenerator(device_type()))'),
   298     'IntArrayRefStride': 
CodeTemplate(
'at::IntArrayRef ${result_name} = get_intlist_stride_th(${arg_name});'),
   300     'accreal': 
CodeTemplate(
'${arg_name}.to${AccScalarName}()'),
   302             'checked_tensor_list_unwrap(${arg_name},"${arg_name}",${arg_pos}, '   303             'Backend::${Backend}, ScalarType::${ScalarName})'),
   304     'IntArrayRef': 
CodeTemplate(
'check_intlist<${size}>(${arg_name}, "${arg_name}", ${arg_pos}${,default_init})')
   310     'THIndexTensor*': 
'{}_',
   311     'THBoolTensor*': 
'{}_',
   312     'THIntegerTensor*': 
'{}_',
   313     'THDenseTensor*': 
'{}_',
   314     'THDenseIndexTensor*': 
'{}_',
   315     'THStorage*': 
'{}_.unsafeGetStorageImpl()',
   316     'THGenerator*': 
'{}_->generator',
   317     'TensorList': 
"{0}_.data(), {0}_.size()",
   320 CHECKED_USE_NULLABLE = 
CodeTemplate(
'${arg_name}_ ? ${usage} : NULL')
   322 ALLOC_NOARGS_WRAP = {
   323     'THTensor*': 
'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'   324                  '(${Backend}TensorId(), caffe2::TypeMeta::Make<${ScalarType}>(), allocator(), false).release()',
   325     'THBoolTensor*': 
'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'   326                      '(${Backend}TensorId(), scalarTypeToTypeMeta(ScalarType::Byte), allocator(), false).release()',
   327     'THIndexTensor*': 
'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'   328                       '(${Backend}TensorId(), scalarTypeToTypeMeta(ScalarType::Long), allocator(), false).release()',
   329     'THIntegerTensor*': 
'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'   330                         '(${Backend}TensorId(), scalarTypeToTypeMeta(ScalarType::Int), allocator(), false).release()',
   331     'THDenseTensor*': 
'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'   332                       '(${Backend}TensorId(), caffe2::TypeMeta::Make<${ScalarType}>(), allocator(), false).release()',
   333     'THDenseIndexTensor*': 
'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'   334                            '(${Backend}TensorId(), scalarTypeToTypeMeta(ScalarType::Long), '   335                            'allocator(), false).release()'   339     'THTensor*': 
'${arguments}',
   340     'THBoolTensor*': 
'${arguments}',
   341     'THIndexTensor*': 
'${arguments}',
   342     'THIntegerTensor*': 
'${arguments}',
   343     'THDenseTensor*': 
'${arguments}',
   344     'THDenseIndexTensor*': 
'${arguments}',
   348 CONSTANT_REPLACEMENTS = [
   349     (
'AS_REAL', 
'${AS_REAL}'),
   350     (
'__last_dim', 
'self.ndimension()-1'),
   354 HEADER_CONSTANT_REPLACEMENTS = [
   355     (
r'AS_REAL\((.*)\)', 
r'\1'),
   356     (
'__last_dim', 
'-1'),
   360 class nested_dict(object):
   361     def __init__(self, base, parent):
   362         self.base, self.parent = base, parent
   364     def __getitem__(self, x):
   368         return self.parent[x]
   371 Environment = TypedDict(
'Environment', {
   377     'AccScalarName': str,
   380 TopEnvironment = TypedDict(
'TopEnvironment', {
   381     'type_registrations': List[str],
   382     'type_headers': List[str],
   383     'pure_virtual_type_method_declarations': List[str],
   384     'pure_virtual_extended_type_method_declarations': List[str],
   385     'type_method_declarations': List[str],
   386     'type_method_definitions': List[str],
   387     'tensor_method_declarations': List[str],
   388     'tensor_method_definitions': List[str],
   389     'function_declarations': List[str],
   390     'function_definitions': List[str],
   391     'type_ids': List[str],
   392     'native_function_declarations': List[str],
   397 THFormal = TypedDict(
'THFormal', {
   407     'declared_type': str,
   408     'ignore_check': bool,
   423 AtFormal = TypedDict(
'AtFormal', {
   452 ReturnType = TypedDict(
'ReturnType', {
   460 ReturnDecl = TypedDict(
'ReturnDecl', {
   463     'arguments': List[int],
   467 NNBuffer = TypedDict(
'NNBuffer', {
   471 FunctionOption = TypedDict(
'FunctionOption', {
   472     'actuals': List[str],
   474     'arguments': List[THFormal],
   475     'aten_custom_call': str,
   476     'aten_dense_sparse': bool,
   477     'backend_type_pairs': List[Tuple[str, str]],
   478     'backends': List[str],
   479     'broadcast_actuals': List[str],
   480     'broadcast_function': str,
   481     'broadcast_modified_actuals': List[str],
   482     'broadcast_returns': List[str],
   483     'buffers': List[NNBuffer],
   489     'device_guard': bool,
   490     'device_guard_declaration': str,
   497     'formals_list': List[AtFormal],
   498     'formals_with_defaults': List[str],
   499     'formals': List[str],
   500     'formals_types': List[str],
   501     'inferred_type': str,
   503     'matches_jit_signature': bool,
   506     'extended_method': bool,
   507     'method_actuals': List[str],
   508     'method_formals_with_defaults': List[str],
   509     'method_formals': List[str],
   510     'method_prefix_derived': str,
   512     'python_module': str,
   514     'native_actuals': List[str],
   515     'native_type_method_dispatch': str,
   518     'schema_string': str,
   519     'requires_tensor': bool,
   522     'return': ReturnDecl,
   523     'returns': List[ReturnType],
   528     'type_definition_body': List[str],
   529     'type_method_actuals': List[str],
   530     'type_method_definition_dispatch': str,
   531     'type_method_formals': List[str],
   533     'when_spares_dispatch': str,
   534     'when_sparse_dispatch': str,
   536     'zero_dim_dispatch_when_scalar': str,
   537     'zero_dim_tensor_only': bool,
   540 OutputDeclaration = NamedTuple(
'OutputDeclaration', [
   542     (
'matches_jit_signature', bool),
   543     (
'schema_string', str),
   544     (
'method_prefix_derived', str),
   545     (
'arguments', List[AtFormal]),
   546     (
'method_of', List[str]),
   548     (
'python_module', str),
   549     (
'buffers', Optional[List[str]]),
   550     (
'returns', List[ReturnType]),
   552     (
'is_factory_method', bool),
   554     (
'requires_tensor', bool),
   555     (
'device_guard', bool),
   557     (
'deprecated', bool),
   561 def device_guard(option, formals, dispatch_options, dispatch_tensor):
   563     if option.get(
'device_guard', 
True):
   565             return 'const DeviceGuard device_guard({}.device());'.format(dispatch_options[
'name'])
   567             return 'const OptionalDeviceGuard device_guard(device_of({}));'.format(dispatch_tensor)
   568     return '// DeviceGuard omitted'   571 def is_real_argument_to_wrapper(argument):
   573     return not argument.get(
'output', 
False) 
and\
   574         argument[
'type'] != 
'CONSTANT' and\
   575         argument[
'type'] != 
'argument'   578 def is_mutable_formal_argument(argument, option):
   580     return argument.get(
'output') 
or option[
'inplace'] 
and argument[
'name'] == 
'self'   583 def check_methods_do_not_start_with_underscore(name, is_method):
   584     if name 
in {
'_values', 
'_indices', 
'_nnz', 
'_dimI', 
'_dimV', 
'_coalesced_'}:
   586     if is_method 
and name.startswith(
'_') 
and not name.startswith(
'__') 
and not name.startswith(
'_th_'):
   587         message = 
"Function '{}' starts with a single underscore and is ".format(name)
   588         message += 
"configured to have a method on Tensor. Functions that start with "   589         message += 
" a single underscore should only be functions in the at:: "   590         message += 
"namespace and not methods on Tensor!"   591         raise RuntimeError(message)
   594 def to_return_type(arg, option):
   597     rt = TYPE_RETURN.get(t, t)
   598     if rt == 
'Tensor' and not arg.get(
'allocate'):
   600         if not is_mutable_formal_argument(arg, option):
   605         'dynamic_type': DYNAMIC_TYPE.get(arg[
'type'], arg[
'type']),
   609 def create_generic(top_env, declarations):
   612     def translate_default(argument, type_str, default):
   617         if 'if_true' in argument:
   618             return argument[
'default'] == argument[
'if_true']
   619         for pattern, replacement 
in HEADER_CONSTANT_REPLACEMENTS:
   620             default = re.sub(pattern, replacement, str(default))
   621         if type_str 
in {
'Scalar', 
'int64_t', 
'double'}:
   626                     return float(default)
   629         elif type_str == 
'bool':
   630             assert default.lower() 
in [
'true', 
'false']
   631             return default.lower() == 
'true'   637     def translate_formal(argument, option):
   639         type_str = TYPE_FORMAL_GENERIC.get(argument[
'type'], argument[
'type'])
   640         if type_str == 
'Tensor &' and not is_mutable_formal_argument(argument, option):
   641             type_str = 
'const ' + type_str
   643             'name': argument[
'name'],
   645             'dynamic_type': DYNAMIC_TYPE.get(argument[
'type'], argument[
'type']),
   647         if 'kwarg_only' in argument:
   648             translated[
'kwarg_only'] = argument[
'kwarg_only']
   649         if 'default' in argument:
   650             default = translate_default(argument, type_str, argument[
'default'])
   651             translated[
'default'] = default
   652             translated[
'default_init'] = argument.get(
'default_init', default)
   653         if argument.get(
'output'):
   654             translated[
'output'] = 
True   655         if argument.get(
'size'):
   656             translated[
'size'] = argument[
'size']
   657         if argument.get(
'is_nullable') 
is not None:
   658             translated[
'is_nullable'] = argument[
'is_nullable']
   661     def get_formals(option, include_constants=False):
   667         def insert(argument):
   669             if argument[
'name'] 
not in seen:
   670                 seen.add(argument[
'name'])
   671                 if argument.get(
'kwarg_only', 
False):
   672                     kwd_args.append(argument)
   674                     pos_args.append(argument)
   676         def has_output_mask(argument):
   678             return argument.get(
'allocate', 
False) 
and argument.get(
'mask', 
False)
   680         for argument 
in option[
'arguments']:
   681             if argument.get(
'output') 
and not argument.get(
'allocate', 
False):
   683         for argument 
in option[
'arguments']:
   684             if argument[
'type'] == 
'THSTensor*':
   686                 if not (option.get(
'aten_dense_sparse', 
False)):
   689             if include_constants 
and argument[
'type'] == 
'CONSTANT':
   691             elif is_real_argument_to_wrapper(argument):
   693         if any(has_output_mask(arg) 
for arg 
in option[
'arguments']):
   694             mask_size = sum(has_output_mask(arg) 
for arg 
in option[
'arguments'])
   696                 'name': 
'output_mask',
   699                 'type': 
'std::array<bool,{}>'.format(mask_size),
   700                 'default': 
'{{' + 
', '.join([
'true'] * mask_size) + 
'}}',
   703         result = pos_args + kwd_args
   704         return [translate_formal(argument, option) 
for argument 
in result]
   706     def get_return_types(option):
   708         ret = option[
'return']
   709         if ret[
'kind'] == 
'arguments':
   710             argument_indices = ret[
'arguments']
   711             if len(argument_indices) == 1:
   712                 the_arg = option[
'arguments'][argument_indices[0]]
   713                 return [to_return_type(the_arg, option)]
   715                 return [to_return_type(option[
'arguments'][idx], option)
   716                         for idx 
in argument_indices]
   717         elif ret[
'kind'] == 
'type':
   719                 'type': TYPE_RETURN.get(ret[
'type'], ret[
'type']),
   720                 'dynamic_type': DYNAMIC_TYPE.get(ret[
'type'], ret[
'type']),
   723             raise Exception(
"format_return_type")
   725     def format_return_type(return_types):
   727         if len(return_types) == 1:
   728             return return_types[0][
'type']
   729         return "std::tuple<{}>".format(
','.join(r[
'type'] 
for r 
in return_types))
   731     def find_dispatch_tensor(formals):
   734         def is_any_tensor_type(formal):
   735             return (formal[
'dynamic_type'] == 
'Tensor' or formal[
'dynamic_type'] == 
'BoolTensor'   736                     or formal[
'dynamic_type'] == 
'IndexTensor')
   738         for formal 
in formals:
   739             if formal[
'name'] == 
'self' and is_any_tensor_type(formal) 
and not formal.get(
'is_nullable', 
False):
   740                 return formal[
'name']
   742         for formal 
in formals:
   743             if 'TensorList' == formal[
'dynamic_type'] 
or is_any_tensor_type(formal) 
and \
   744                not formal.get(
'is_nullable', 
False):
   745                 return formal[
'name']
   749     def format_formal(f):
   751         return '{} {}'.format(f[
'type'], f[
'name'])
   753     def formal_with_default(f):
   759         if isinstance(v, bool):
   761         return '{}={}'.format(s, v)
   763     def get_broadcast_argument(option):
   765         for argument 
in option[
'arguments']:
   766             if argument.get(
'broadcast'):
   770     def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims):
   779         if not broadcast_dims:
   780             broadcast_actuals = [broadcast_arg[
'name']] + broadcast_arg[
'broadcast'].
split()[0].
split(
",")
   782             broadcast_dims_spec = broadcast_arg[
'broadcast'].
split()[1].
split(
':')[1].
split(
',')
   784             broadcast_dims = ([x.split(
'.')[0] + 
'.size(' + x.split(
'.')[1].replace(
'dim', 
'') + 
')'     785                               for x 
in broadcast_dims_spec])
   786             broadcast_dims_init_list = 
'{' + 
','.join(broadcast_dims) + 
'}'     787             broadcast_actuals = [broadcast_arg[
'name'], broadcast_dims_init_list]
   789         return broadcast_actuals
   791     def emit_nn_body(option):
   795         actuals = option[
'actuals']
   796         base_name = option[
'name'][:-1] 
if option[
'inplace'] 
else option[
'name']
   797         fwd_name = option[
'api_name'].replace(base_name, base_name + 
'_forward')
   799         if len(option[
'buffers']) == 0:
   800             return 'return {}({});'.format(fwd_name, 
', '.join(actuals))
   803         if option[
'api_name'].endswith(
'_out'):
   806             for buffer 
in option[
'buffers']:
   807                 body.append(
'Tensor {} = at::empty({{0}}, this->options());'.format(buffer[
'name']))
   808             actuals = [arg[
'name'] 
for arg 
in option[
'arguments'] 
if arg.get(
'output')]
   809             actuals += [buffer[
'name'] 
for buffer 
in option[
'buffers']]
   810             actuals += [arg[
'name'] 
for arg 
in option[
'arguments'] 
if not arg.get(
'output')]
   812         body.append(
'return std::get<0>({}({}));'.format(fwd_name, 
', '.join(actuals)))
   815     def process_option(option, output_options):
   817         option[
'inplace'] = re.search(
   818             '(^__i|[^_]_$)', option[
'api_name']) 
is not None   821         formals = get_formals(option)
   822         option[
'formals_list'] = formals
   823         option[
'formals'] = [format_formal(f) 
for f 
in formals]
   824         option[
'formals_with_defaults'] = [formal_with_default(f) 
for f 
in formals]
   825         option[
'returns'] = get_return_types(option)
   826         option[
'return_type'] = format_return_type(option[
'returns'])
   827         option[
'return_call'] = 
'return ' if option[
'return_type'] != 
'void' else ''   828         option[
'actuals'] = [f[
'name'] 
for f 
in formals]
   830         option[
'method_formals'] = [format_formal(f) 
for f 
in formals
   831                                     if f[
'name'] != 
'self']
   832         option[
'method_formals_with_defaults'] = (
   833             [formal_with_default(f) 
for f 
in formals 
if f[
'name'] != 
'self'])
   834         option[
'method_actuals'] = [
   835             f[
'name'] 
if f[
'name'] != 
'self' else '*this' for f 
in formals]
   838         option[
'type_method_formals'] = option[
'formals']
   839         option[
'type_method_actuals'] = option[
'actuals']
   841         option[
'const_mark'] = 
'' if option[
'inplace'] 
else ' const'   843         assert 'method' not in option[
'variants'], 
'TH functions cannot be methods'   844         is_function = 
'function' in option[
'variants']
   845         dispatch_tensor = find_dispatch_tensor(formals)
   846         is_namespace_function = is_function 
and dispatch_tensor 
is not None   848         broadcast_arg = get_broadcast_argument(option)
   850         option[
'method_prefix_derived'] = 
'' if broadcast_arg 
is None else 's_'   851         if option[
'mode'] == 
'TH':
   852             option[
'device_guard'] = 
False   853         option[
'device_guard_declaration'] = device_guard(option, formals, 
False, dispatch_tensor)
   855         env = nested_dict(option, top_env)
   857         mode = option[
'mode']
   859         assert option[
'extended_method'], 
'Expected legacy operator to be an extended method'   861         if mode == 
'NN' and option.get(
'cimpls') 
is None:
   865             top_env[
'pure_virtual_extended_type_method_declarations'].append(
   866                 PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
   867             top_env[
'type_method_declarations'].append(
   868                 TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
   869             body = emit_nn_body(option)
   870             top_env[
'type_method_definitions'].append(
   871                 TYPE_METHOD_DEFINITION_CONCRETE.substitute(
   872                     env, type_definition_body=body))
   873         elif broadcast_arg 
is None:
   874             top_env[
'pure_virtual_extended_type_method_declarations'].append(
   875                 PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
   876             top_env[
'type_method_declarations'].append(
   877                 TYPE_METHOD_DECLARATION_ABSTRACT.substitute(env))
   878             top_env[
'type_method_definitions'].append(
   879                 TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
   881             top_env[
'pure_virtual_extended_type_method_declarations'].append(
   882                 PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
   883             top_env[
'pure_virtual_extended_type_method_declarations'].append(
   884                 PURE_VIRTUAL_TYPE_METHOD_DECLARATION_BROADCAST.substitute(env))
   885             top_env[
'type_method_declarations'].append(
   886                 TYPE_METHOD_DECLARATION_BROADCAST.substitute(env))
   887             top_env[
'type_method_declarations'].append(
   888                 TYPE_METHOD_DECLARATION_ABSTRACT.substitute(env))
   889             top_env[
'type_method_definitions'].append(
   890                 TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
   892             broadcast_inplace = 
'inplace' in broadcast_arg[
'broadcast']
   893             broadcast_dims = 
'dims:' in broadcast_arg[
'broadcast']
   894             option[
'broadcast_actuals'] = get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims)
   895             if not broadcast_dims:
   896                 option[
'broadcast_returns'] = ([
"b_" + x 
for x 
in option[
'broadcast_actuals']
   897                                                if x != broadcast_arg[
'name'] 
or not broadcast_inplace])
   899                 option[
'broadcast_returns'] = [
"b_" + broadcast_arg[
'name']]
   901             option[
'broadcast_function'] = 
'expand_' + (
'inplace' if broadcast_inplace
   902                                                         else 'size' if broadcast_dims 
else 'outplace')
   903             option[
'broadcast_modified_actuals'] = [
'b_' + y 
if 'b_' + y 
in option[
'broadcast_returns'] 
else y
   904                                                     for y 
in option[
'actuals']]
   905             top_env[
'type_method_definitions'].append(
   906                 TYPE_METHOD_DEFINITION_BROADCAST.substitute(env))
   909         if is_namespace_function:
   910             option[
'inferred_type'] = 
'detail::infer_type({})'.format(dispatch_tensor)
   911             top_env[
'function_declarations'].append(
   912                 FUNCTION_DECLARATION.substitute(env))
   913             top_env[
'function_definitions'].append(
   914                 FUNCTION_DEFINITION.substitute(env))
   915             method_of.append(
'namespace')
   917         buffer_names = [buffer[
'name'] 
for buffer 
in option.get(
'buffers', [])]
   919         output_options.append(OutputDeclaration(
   920             name=option[
'api_name'],
   921             matches_jit_signature=option[
'matches_jit_signature'],
   922             schema_string=option[
'schema_string'],
   923             method_prefix_derived=option[
'method_prefix_derived'],
   927             python_module=option.get(
'python_module', 
''),
   928             buffers=buffer_names,
   929             returns=option[
'returns'],
   930             inplace=option[
'inplace'],
   931             is_factory_method=
False,
   934             requires_tensor=option.get(
'requires_tensor', 
False),
   935             device_guard=option.get(
'device_guard', 
True),
   936             with_gil=option.get(
'with_gil', 
False),
   937             deprecated=option.get(
'deprecated', 
False)
   940     def native_get_formals(option, include_constants=False):
   946         def insert(argument):
   948             if argument[
'name'] 
not in seen:
   949                 seen.add(argument[
'name'])
   950                 if argument.get(
'kwarg_only', 
False):
   951                     kwd_args.append(argument)
   953                     pos_args.append(argument)
   955         for argument 
in option[
'arguments']:
   960         def add_dynamic_type(argument, option):
   962             argument[
'dynamic_type'] = NATIVE_DYNAMIC_TYPE.get(argument[
'type'], argument[
'type'])
   965         result = pos_args + kwd_args
   966         result = [add_dynamic_type(argument, option) 
for argument 
in result]
   969         def native_translate_formals(argument, option):
   971             def translate_map(const):
   974                     'Tensor': 
'const Tensor &' if const 
else 'Tensor &',
   975                     'BoolTensor': 
'const Tensor &' if const 
else 'Tensor &',
   976                     'IndexTensor': 
'const Tensor &' if const 
else 'Tensor &',
   977                     'Type': 
'const Type &' if const 
else 'Type &',
   978                     'TensorOptions': 
'const TensorOptions &' if const 
else 'TensorOptions &',
   979                     'TensorList': 
'TensorList',
   982             if argument.get(
'is_nullable') 
and argument[
'type'] 
not in translate_map(
False).keys():
   983                 argument[
'type'] = 
"c10::optional<{}>".format(argument[
'type'])
   985             if (option[
'inplace'] 
and argument[
'name'] == 
'self') 
or argument.get(
'output', 
False):
   986                 argument[
'type'] = translate_map(
False).get(argument[
'type'], argument[
'type'])
   988                 argument[
'type'] = translate_map(
True).get(argument[
'type'], argument[
'type'])
   992         result = [native_translate_formals(argument, option) 
for argument 
in result]
   996     def native_get_return_types(option):
   998         ret = option[
'return']
  1004             if isinstance(t_raw, string_type):
  1012                 name = t_raw[
'name']
  1013                 if 'field_name' in t_raw:
  1014                     field_name = t_raw[
'field_name']
  1017             actual_return_type = {
'TensorList': 
'std::vector<Tensor>'}.get(t, t)
  1019             if actual_return_type == 
'Tensor' and (option[
'inplace'] 
or option[
'api_name'].endswith(
'_out')):
  1021                 actual_return_type = 
'Tensor &'  1024                 'type': actual_return_type,
  1025                 'dynamic_type': NATIVE_DYNAMIC_TYPE.get(t, t),
  1027             if name 
is not None:
  1028                 rtype[
'name'] = name
  1029             if field_name 
is not None:
  1030                 rtype[
'field_name'] = field_name
  1031             return_types.append(rtype)
  1035     def process_native(option, output_options):
  1037         assert option[
'python_module'] == 
'' or option[
'python_module'] == 
'nn', \
  1038             "Found python_module of {} for decl {}, but only \'\' string or \'nn\' are supported".format(
  1039                 option[
'python_module'], option[
'name'])
  1041         formals = native_get_formals(option)
  1042         option[
'formals_list'] = formals
  1043         option[
'formals'] = [format_formal(f) 
for f 
in formals]
  1044         option[
'formals_with_defaults'] = [formal_with_default(f) 
for f 
in formals]
  1045         option[
'returns'] = native_get_return_types(option)
  1046         option[
'return_type'] = format_return_type(option[
'returns'])
  1047         option[
'return_call'] = 
'return ' if option[
'return_type'] != 
'void' else ''  1048         option[
'actuals'] = [f[
'name'] 
for f 
in formals]
  1050         option[
'method_formals'] = [format_formal(f) 
for f 
in formals
  1051                                     if f[
'name'] != 
'self']
  1052         option[
'method_formals_with_defaults'] = (
  1053             [formal_with_default(f) 
for f 
in formals 
if f[
'name'] != 
'self'])
  1054         option[
'method_actuals'] = [
  1055             f[
'name'] 
if f[
'name'] != 
'self' else '*this' for f 
in formals]
  1057         def find_formal(formal_name, formals):
  1058             for formal 
in formals:
  1059                 if formal_name == formal[
'dynamic_type']:
  1063         assert find_formal(
'Type', formals) 
is None, \
  1064             "Found Type argument in {}({}). Use TensorOptions instead.".format(
  1065                 option[
'name'], 
", ".join(option[
'method_formals_with_defaults']))
  1067         type_method_dispatch = option[
'type_method_definition_dispatch']
  1069         dispatch_options = find_formal(
'TensorOptions', formals)
  1071         dispatch_tensor = 
None if dispatch_options 
else find_dispatch_tensor(formals)
  1073         option[
'type_method_formals'] = [format_formal(f) 
for f 
in formals]
  1074         option[
'type_method_actuals'] = [f[
'name'] 
for f 
in formals]
  1075         option[
'native_actuals'] = [f[
'name'] 
for f 
in formals]
  1077         option[
'const_mark'] = 
'' if option[
'inplace'] 
else ' const'  1079         is_method = 
'method' in option[
'variants']
  1080         is_namespace_function = 
'function' in option[
'variants']
  1081         is_factory_method = find_formal(
'TensorOptions', formals) 
and \
  1082             not dispatch_options 
and 'method' not in option[
'variants']
  1084         check_methods_do_not_start_with_underscore(option[
'name'], is_method)
  1086         option[
'method_prefix_derived'] = 
''  1087         option[
'device_guard_declaration'] = device_guard(option, formals, dispatch_options, dispatch_tensor)
  1089         env = nested_dict(option, top_env)
  1091         broadcast_arg = get_broadcast_argument(option)
  1092         if broadcast_arg 
is not None:
  1093             raise Exception(
"broadcasting is not yet supported for native functions, "  1094                             "but specified for function {}", option[
'name'])
  1096         if option[
'extended_method']:
  1097             top_env[
'pure_virtual_extended_type_method_declarations'].append(
  1098                 PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
  1100             top_env[
'pure_virtual_type_method_declarations'].append(
  1101                 PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
  1102         top_env[
'type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
  1103         option[
'native_type_method_dispatch'] = type_method_dispatch
  1114         if isinstance(type_method_dispatch, dict):
  1116             top_env[
'type_method_definitions'].append(
  1117                 TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
  1119             body = TYPE_DEFINITION_BODY_NATIVE.substitute(env)
  1120             top_env[
'type_method_definitions'].append(
  1121                 TYPE_METHOD_DEFINITION_CONCRETE.substitute(
  1122                     env, type_definition_body=body))
  1125         if isinstance(type_method_dispatch, dict):
  1126             generated_native_functions = []  
  1127             for key 
in sorted(type_method_dispatch.keys()):
  1128                 value = type_method_dispatch[key]
  1129                 if value 
not in generated_native_functions:
  1130                     option[
'native_type_method_dispatch'] = value
  1131                     top_env[
'native_function_declarations'].append(
  1132                         NATIVE_DECLARATION.substitute(env))
  1133                     generated_native_functions.append(value)
  1135             top_env[
'native_function_declarations'].append(
  1136                 NATIVE_DECLARATION.substitute(env))
  1138         method_of = [
'Type']
  1140             top_env[
'tensor_method_declarations'].append(
  1141                 TENSOR_METHOD_DECLARATION.substitute(env))
  1142             top_env[
'tensor_method_definitions'].append(
  1143                 TENSOR_METHOD_DEFINITION.substitute(env))
  1144             method_of.append(
'Tensor')
  1146         if is_namespace_function:
  1148                 option[
'inferred_type'] = 
'detail::infer_type({})'.format(dispatch_tensor)
  1149             elif dispatch_options:
  1150                 option[
'inferred_type'] = 
'at::getType({})'.format(dispatch_options[
'name'])
  1153                 option[
'inferred_type'] = 
'at::getNonVariableType(at::Backend::Undefined, at::ScalarType::Float)'  1154             declaration = DEPRECATED_FUNCTION_DECLARATION 
if option[
'deprecated'] 
else FUNCTION_DECLARATION
  1155             top_env[
'function_declarations'].append(declaration.substitute(env))
  1156             top_env[
'function_definitions'].append(FUNCTION_DEFINITION.substitute(env))
  1157             method_of.append(
'namespace')
  1159         output_options.append(OutputDeclaration(
  1160             name=option[
'api_name'],
  1161             matches_jit_signature=option[
"matches_jit_signature"],
  1162             schema_string=option[
"schema_string"],
  1163             method_prefix_derived=option[
'method_prefix_derived'],
  1165             method_of=method_of,
  1166             mode=option[
'mode'],
  1167             python_module=option[
'python_module'],
  1169             returns=option[
'returns'],
  1170             inplace=option[
'inplace'],
  1171             is_factory_method=is_factory_method,
  1174             requires_tensor=option.get(
'requires_tensor', 
False),
  1175             device_guard=option.get(
'device_guard', 
True),
  1176             with_gil=option.get(
'with_gil', 
False),
  1177             deprecated=option[
'deprecated'],
  1180     output_declarations = []  
  1181     for declaration 
in declarations:
  1183         for option 
in declaration[
'options']:
  1184             option[
"matches_jit_signature"] = declaration[
"matches_jit_signature"]
  1185             option[
"schema_string"] = declaration[
"schema_string"]
  1187                 if option[
'mode'] != 
'native':
  1188                     process_option(option, output_options)
  1190                     process_native(option, output_options)
  1192                 option[
'skip'] = 
True  1193         output_declarations.extend(output_options)
  1195     return output_declarations
  1198 def create_derived(backend_type_env, declarations):
  1200     type_object_declarations = []
  1201     type_object_definitions = []
  1203     is_cuda = 
'CUDA' in backend_type_env[
'Backend']
  1205     def replace_with_null(argument):
  1207         return (argument[
'type'] == 
'THGenerator*' and  1208                 backend_type_env[
'Backend'] == 
'CUDA')
  1210     def requires_checked_cast(argument):
  1212         if argument[
'type'] == 
'IntArrayRef':
  1213             return 'size' in argument
  1214         return argument[
'type'] 
in CHECKED_CAST
  1216     def nullable_argument(argument):
  1218         return argument.get(
'is_nullable', 
False)
  1220     def bool_option_is_string(argument):
  1222         return 'if_true' in argument 
and isinstance(argument[
'if_true'], string_type)
  1224     def get_argument(argument, option):
  1226         if replace_with_null(argument):
  1228         elif requires_checked_cast(argument):
  1229             checked_use = CHECKED_USE.get(
  1230                 argument[
'type'], 
'{}_').format(argument[
'name'])
  1231             if nullable_argument(argument):
  1232                 checked_use = CHECKED_USE_NULLABLE.substitute(
  1233                     env={}, arg_name=argument[
'name'], usage=checked_use)
  1235         elif argument[
'type'] == 
'bool' and 'if_true' in argument:
  1236             if bool_option_is_string(argument):
  1237                 tpl = 
'({}) ? "{}" : "{}"'  1239                 tpl = 
'({}) ? {} : {}'  1240             return tpl.format(argument[
'name'],
  1241                               argument[
'if_true'], argument[
'if_false'])
  1242         elif argument[
'type'] == 
'CONSTANT':
  1244             if bool_option_is_string(argument):
  1245                 return '"{}"'.format(argument[
'name'])
  1246             v = str(argument.get(
'default', argument[
'name']))
  1247             for pattern, replacement 
in CONSTANT_REPLACEMENTS:
  1248                 v = re.sub(pattern, replacement, v)
  1251         elif argument[
'type'] == 
'argument':
  1252             index = int(argument[
'name'])
  1253             return get_argument(option[
'arguments'][index], option)
  1255             return argument[
'name']
  1257     def drop_argument(argument, option):
  1260         if argument[
'name'] == 
'device':
  1262         return 'CUDA' in backend_type_env[
'Backend'] 
and (
  1263             option[
'mode'] == 
'TH' and argument[
'type'] == 
'THGenerator*')
  1265     def get_arguments(arguments, option):
  1267         return [get_argument(argument, option)
  1268                 for argument 
in arguments 
if not drop_argument(argument, option)]
  1270     def is_actual_return_long(ret):
  1272         if ret[
'type'] == 
'long':
  1274         if ret[
'type'] == 
'real':
  1275             return backend_type_env[
'ScalarName'] == 
'Long'  1276         if ret[
'type'] == 
'accreal':
  1277             return backend_type_env[
'AccScalarName'] == 
'Long'  1280     def handle_zero_dim(env, option):
  1282         zero_dim_dispatch = option.get(
'zero_dim_dispatch_when_scalar', 
'')
  1283         if not zero_dim_dispatch:
  1285         broadcasts_arg = zero_dim_dispatch 
in option.get(
'broadcast_actuals', 
'')
  1286         zero_dim_only = option.get(
'zero_dim_tensor_only', 
False)
  1288         assert not (broadcasts_arg 
and zero_dim_only)
  1293         zero_dim_actuals = [arg[
'name']
  1294                             if arg[
'name'] != zero_dim_dispatch 
else "{}.item()".format(arg[
'name'])
  1295                             for arg 
in option[
'formals_list']]
  1296         return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)]
  1298     def handle_only_zero_dim(env, option):
  1300         if option.get(
'zero_dim_tensor_only', 
False):
  1301             check_name = option[
'zero_dim_dispatch_when_scalar']
  1302             return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)]
  1306     def handle_sparse(env, option):
  1308         if 'when_sparse_dispatch' not in option 
or 'Sparse' in backend_type_env[
'Backend']:
  1310         check_name = option[
'when_sparse_dispatch']
  1311         sparse_actuals = [arg[
'name']
  1312                           if arg[
'name'] != check_name 
else "SparseTensorRef({})".format(arg[
'name'])
  1313                           for arg 
in option[
'formals_list']]
  1314         return [SPARSE_CHECK.substitute(env, check_name=check_name, sparse_actuals=sparse_actuals)]
  1316     def allocate_arg(env, arg, output_count):
  1321             state = 
'globalContext().getTHCState()'  1322         allocation = 
CodeTemplate(ALLOC_NOARGS_WRAP[arg[
'type']]).substitute(env)
  1323         tensor_arg = 
'{}_'.format(name)
  1324         if arg.get(
'mask', 
False):
  1325             allocation = 
'output_mask[{}] ? {} : nullptr'.format(output_count, allocation)
  1326             tensor_arg = (
'{}_ == nullptr ? (TensorImpl*)UndefinedTensorImpl::singleton() : (TensorImpl*){}_'  1327                           .format(name, name))
  1328         intrusive_ptr_type = 
'c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>'  1330             'auto {}_ = {};'.format(name, allocation),
  1331             'auto {} = Tensor({}::reclaim({}));'.format(name, intrusive_ptr_type, tensor_arg),
  1334     def resize_arg(arg):
  1336         resize = arg[
'resize']
  1337         if isinstance(resize, str):
  1338             return "{}.resize_({}.sizes());".format(arg[
'name'], resize)
  1340             resize_scalar = arg.get(
'resize_scalar', 
False)
  1342                 dims = [
'{}.dim() == 0 ? 1 : {}.size({})'.format(name, name, dim) 
for name, dim 
in resize]
  1344                 dims = [
'{}.size({})'.format(name, dim) 
for name, dim 
in resize]
  1345             return "{}.resize_({{ {} }});".format(arg[
'name'], 
','.join(dims))
  1347     def handle_call(env, option, cimpl):
  1349         is_nn = option[
'mode'] == 
'NN'  1350         actuals = get_arguments(cimpl[
'arguments'], option)
  1351         if is_cuda 
or is_nn:
  1352             actuals = [
'globalContext().getTHCState()'] + actuals
  1354         cname = cimpl[
'cname']
  1355         if option.get(
'sparse', 
False):
  1357                 cname = 
'THCS' + env[
'ScalarName'] + 
"Tensor_" + cname
  1359                 cname = env[
'THTensor'].replace(
'TH', 
'THS') + 
'_' + cname
  1361             cname = 
'THNN_{}'.format(env[
'THType']) + cname
  1363             cname = env[
'THTensor'] + 
'_' + cname
  1365         call = CALL_TEMPLATE.substitute(actuals=actuals, cname=cname)
  1366         if cimpl.get(
'condition') 
is not None:
  1367             call = 
'if ({}) {}'.format(cimpl[
'condition'], call)
  1370     def emit_body(env, option):
  1373         body += handle_sparse(env, option)
  1374         body += handle_zero_dim(env, option)
  1375         only_zero_dim_check = handle_only_zero_dim(env, option)
  1376         if only_zero_dim_check 
is not None:
  1378             body += only_zero_dim_check
  1384         seen_tensorlists = set()  
  1391         scalar_check_is_from_size = 
False  1392         scalar_check_is_from_option = 
False  1394         scalar_check_opt = option.get(
'scalar_check')
  1395         if scalar_check_opt 
is not None:
  1396             if isinstance(scalar_check_opt, bool):
  1397                 scalar_check = str(scalar_check_opt).lower()
  1399                 scalar_check = scalar_check_opt
  1400             scalar_check_is_from_option = 
True  1402         for arg 
in option[
'arguments']:
  1403             if is_real_argument_to_wrapper(arg):
  1405             if arg[
'type'] == 
'IntArrayRefSize' and not scalar_check_is_from_option:
  1406                 scalar_check_is_from_size = 
True  1407                 scalar_check = 
'{}.size() == 0'.format(arg[
'name'])
  1408             if arg[
'type'] == 
'TensorList':
  1409                 seen_tensorlists.add(arg[
'name'])
  1411             wrap_dim_target = arg.get(
'wrap_dim', 
None)
  1412             if wrap_dim_target 
is not None:
  1416                 if wrap_dim_target 
not in seen_tensorlists:
  1417                     wrap_dim_target = wrap_dim_target + 
"_"  1418                 body.append(
"{} = maybe_wrap_dim({}, {});"  1419                             .format(arg[
'name'], arg[
'name'], wrap_dim_target))
  1422             if arg[
'name'] 
not in seen_names 
and requires_checked_cast(arg):
  1423                 seen_names.add(arg[
'name'])
  1426                 if arg.get(
'allocate', 
False):
  1427                     body += allocate_arg(env, arg, output_count)
  1434                     null_okay = 
'true' if nullable_argument(arg) 
else 'false'  1436                     if 'default_init' in arg:
  1437                         default_init.append(arg[
'default_init'])
  1439                     check_cast = CHECKED_CAST[arg[
'type']].substitute(
  1440                         env, arg_name=arg[
'name'], arg_pos=count,
  1441                         null_okay=null_okay, default_init=default_init,
  1442                         size=arg.get(
'size'))
  1443                     body.append(
"auto {}_ = {};".format(
  1444                         arg[
'name'], check_cast))
  1445                 if drop_argument(arg, option) 
or replace_with_null(arg):
  1447                         "(void) {}_; //silence unused warning".format(arg[
'name']))
  1453                     initializers.append(resize_arg(arg))
  1456                 if arg.get(
'zero', 
False) 
or (arg.get(
'cpu_zero', 
False) 
and not is_cuda):
  1457                     initializers.append(
"{}.zero_();".format(arg[
'name']))
  1460                 if nullable_argument(arg) 
and len(initializers) > 0:
  1461                     body.append(CONDITIONAL_INITIALIZER.substitute({
  1462                         'name': arg[
'name'],
  1463                         'initializer': initializers
  1466                     body += initializers
  1471                 if (
not arg.get(
'output') 
and 'Tensor' in arg[
'type'] 
and  1472                         'TensorList' not in arg[
'type'] 
and  1473                         'THS' not in arg[
'type'] 
and  1474                         not scalar_check_is_from_size 
and  1475                         not scalar_check_is_from_option 
and  1476                         not option[
'inplace']):
  1477                     check = 
'{}->dim() == 0'.format(arg[
'name'] + 
'_')
  1478                     if nullable_argument(arg):
  1479                         check = 
'(!{} || {})'.format(arg[
'name'] + 
'_', check)
  1480                     scalar_check = (check 
if scalar_check 
is None  1481                                     else scalar_check + 
' && ' + check)
  1485         cimpls = option.get(
'cimpls', [option])
  1486         calls = [handle_call(env, option, cimpl) 
for cimpl 
in cimpls]
  1488         ret = option[
'return']
  1490         if ret[
'kind'] == 
'arguments':
  1491             if 'aten_custom_call' in option:
  1495                     option[
'aten_custom_call']).substitute(env))
  1497                 body.extend([call + 
';' for call 
in calls])
  1498             arguments_indices = ret[
'arguments']
  1499             arguments = [option[
'arguments'][argi]
  1500                          for argi 
in arguments_indices]
  1501             if scalar_check 
is not None:
  1502                 if not isinstance(scalar_check, dict):
  1503                     if len(arguments) > 1:
  1504                         body.append(
"bool maybe_scalar = {};".format(scalar_check))
  1505                         scalar_check = 
'maybe_scalar'  1506                 for arg 
in arguments:
  1507                     scalar_check_arg = (scalar_check 
if not isinstance(scalar_check, dict)
  1508                                         else scalar_check.get(arg[
'name']))  
  1509                     if scalar_check_arg 
is not None:
  1510                         stmt = 
"{}_->maybe_zero_dim({});".format(arg[
'name'], scalar_check_arg)
  1511                         if nullable_argument(arg):
  1512                             stmt = 
"if ({}_) {}".format(arg[
'name'], stmt)
  1514             if len(arguments_indices) == 1:
  1516                 body.append(
"return {};".format(arg[
'name']))
  1518                 types = [to_return_type(arg, option)[
'type']
  1519                          for arg 
in arguments]
  1521                 names = [arg[
'name'] 
for arg 
in arguments]
  1522                 body.append(
CodeTemplate(
"return std::tuple<${types}>(${names});").substitute(
  1523                     types=types, names=names))
  1524         elif ret[
'kind'] == 
'type':
  1525             assert len(calls) == 1
  1527             if 'aten_custom_call' in option:
  1531                     option[
'aten_custom_call']).substitute(env))
  1533             if ret[
'type'] 
in ALLOC_WRAP.keys():
  1534                 maybe_scalar = 
"->maybe_zero_dim({})".format(scalar_check) \
  1535                                if scalar_check 
is not None \
  1537                 wrapped_tensor = 
CodeTemplate(ALLOC_WRAP[ret[
'type']]).substitute(
  1538                     env, arguments=[call])
  1541                     "c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(" +
  1542                     "(${wrapped_tensor})${maybe_scalar}));")
  1544                     env, wrapped_tensor=wrapped_tensor, maybe_scalar=maybe_scalar))
  1548             elif ret[
'type'] == 
'accreal' or ret[
'type'] == 
'real':
  1549                 return_scalar = 
'return at::scalar_tensor(convert<${ScalarType}>(${call}), options());'  1550                 body.append(
CodeTemplate(return_scalar).substitute(env, call=call))
  1553                 if is_actual_return_long(ret):
  1554                     call = 
"static_cast<int64_t>({})".format(call)
  1555                 body.append(
"return {};".format(call))
  1557             raise Exception(
"NYI - return handling")
  1560     def process_option(option):
  1562         pair = (backend_type_env[
'Backend'],
  1563                 backend_type_env[
'ScalarName'])
  1564         if pair 
in option[
'backend_type_pairs']:
  1565             env = nested_dict(option, backend_type_env)
  1566             body = emit_body(env, option)  
  1567             option[
'type_definition_body'] = body
  1568             type_object_declarations.append(
  1569                 TYPE_DERIVED_DECLARATION.substitute(env))
  1570             type_object_definitions.append(
  1571                 TYPE_DERIVED_DEFINITION.substitute(env))
  1573     def process_native(option):
  1575         dispatch = option[
'type_method_definition_dispatch']
  1576         env = nested_dict(option, backend_type_env)
  1578         if isinstance(dispatch, dict):
  1579             pair = (backend_type_env[
'Backend'],
  1580                     backend_type_env[
'ScalarName'])
  1581             if pair 
in option[
'backend_type_pairs']:
  1582                 native_dispatch = dispatch.get(pair[0])
  1583                 type_object_declarations.append(
  1584                     TYPE_DERIVED_DECLARATION.substitute(env))
  1585                 if native_dispatch 
is None:
  1586                     type_object_definitions.append(
  1587                         TYPE_DERIVED_DEFINITION_NATIVE_MISSING.substitute(env))
  1589                     option[
'native_type_method_dispatch'] = native_dispatch
  1590                     type_object_definitions.append(
  1591                         TYPE_DERIVED_DEFINITION_NATIVE.substitute(env))
  1593     for declaration 
in declarations:
  1594         for option 
in declaration[
'options']:
  1595             if not option.get(
'skip', 
False):
  1597                     if option[
'mode'] == 
'NN' and option.get(
'cimpls') 
is None:
  1599                     if option[
'mode'] != 
'native':
  1600                         process_option(option)
  1602                         process_native(option)
  1605     return type_object_declarations, type_object_definitions
  1608 def create_extension_backend(backend_type_env, declarations):
  1610     type_object_declarations = []
  1611     type_object_definitions = []
  1613     for declaration 
in declarations:
  1614         for option 
in declaration[
'options']:
  1615             if not option.get(
'skip', 
False):
  1617                     option[
'formals_types'] = [f[
'type'] 
for f 
in option[
'formals_list']]
  1618                     option[
'native_actuals'] = [f[
'name'] 
for f 
in option[
'formals_list']]
  1619                     schema_args = 
", ".join(
  1620                         [
"{} {}".format(f[
'dynamic_type'], f[
'name']) 
for f 
in option[
'formals_list']])
  1621                     return_type = NATIVE_DYNAMIC_TYPE.get(option[
'return_type'], option[
'return_type'])
  1622                     option[
'schema'] = 
"{}({}) -> {}".format(option[
'api_name'], schema_args, return_type)
  1623                     env = nested_dict(option, backend_type_env)
  1624                     type_object_declarations.append(
  1625                         TYPE_DERIVED_DECLARATION.substitute(env))
  1626                     type_object_definitions.append(
  1627                         TYPE_DEFINITION_EXTENSION_BACKEND.substitute(env))
  1630     return type_object_declarations, type_object_definitions
 Module caffe2.python.layers.split.