6 from collections
import defaultdict
8 from .nested_dict
import nested_dict
9 from .gen_variable_type
import should_trace
10 from .utils
import write
13 from src.ATen.code_template
import CodeTemplate
16 CodeTemplate = import_module(
'code_template',
'aten/src/ATen/code_template.py').CodeTemplate
19 SKIP_PYTHON_BINDINGS = [
20 'alias',
'contiguous',
'is_cuda',
'is_sparse',
'size',
'stride',
21 '.*_backward',
'.*_backward_(out|input|weight|bias)',
'.*_forward',
22 '.*_forward_out',
'_unsafe_view',
'tensor',
'_?sparse_coo_tensor.*',
23 '_arange.*',
'_range.*',
'_linspace.*',
'_logspace.*',
24 '_sparse_add_out',
'_sparse_div.*',
'_sparse_mul.*',
'_sparse_sub.*',
26 '_indexCopy_',
'max_values',
'min_values',
27 '_cumsum.*',
'_cumprod.*',
'_sum.*',
'_prod.*',
29 'arange.*',
'range.*',
'_solve.*',
'_getri.*',
'_inverse.*',
30 '_cholesky.*',
'_btrifact.*',
31 'slice',
'randint(_out)?',
32 'item',
'_local_scalar_dense',
33 'max_pool1d',
'max_pool2d',
'max_pool3d',
'linear',
'to',
34 'copy_sparse_to_sparse_',
39 SKIP_PYTHON_BINDINGS_SIGNATURES = [
40 'add(Tensor, Scalar, Scalar)',
'add_(Tensor, Scalar, Scalar)',
41 'sub(Tensor, Scalar, Scalar)',
'sub_(Tensor, Scalar, Scalar)',
42 'mul(Tensor, Scalar)',
'mul_(Tensor, Scalar)',
43 'div(Tensor, Scalar)',
'div_(Tensor, Scalar)',
46 PY_VARIABLE_METHOD_VARARGS = CodeTemplate(
"""\ 47 static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) 50 static PythonArgParser parser({ 52 }, /*traceable=*/${traceable}); 54 ParsedArgs<${max_args}> parsed_args; 55 auto r = parser.parse(args, kwargs, parsed_args); 56 ${declare_namedtuple_return_types} 63 PY_VARIABLE_METHOD_NOARGS = CodeTemplate(
"""\ 64 static PyObject * ${pycname}(PyObject* self_, PyObject* args) 67 ${declare_namedtuple_return_types} 69 return wrap(${namedtuple_return_type}${dispatch_name}(${actuals})); 74 PY_VARIABLE_CASE = CodeTemplate(
"""\ 75 ${cond} (r.idx == ${i}) { 79 PY_VARIABLE_OUT = CodeTemplate(
"""\ 80 if (r.isNone(${out_idx})) { 87 PY_VARIABLE_OUT_CHECK_TYPE = CodeTemplate(
"""\ 88 if (r.isNone(${out_idx})) { 91 check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.isNone(${type_idx}), 92 r.layout(${layout_idx}), r.isNone(${layout_idx}), 93 r.device(${device_idx}), r.isNone(${device_idx})); 98 PY_VARIABLE_CALL_DISPATCH = CodeTemplate(
"""\ 99 ${dispatch_name}(${actuals})""")
101 PY_VARIABLE_SET_REQUIRES_GRAD = CodeTemplate(
"""\ 102 ${call_dispatch}.set_requires_grad(${requires_grad})""")
104 PY_VARIABLE_WRAP = CodeTemplate(
"""\ 105 return wrap(${namedtuple_return_type}${call_dispatch});""")
107 PY_VARIABLE_DISPATCH = CodeTemplate(
"""\ 108 inline ${simple_return_type} ${dispatch_name}(${formal_args}) { 111 return ${dispatch_call}(${dispatch_args}); 115 PY_VARIABLE_METHOD_DEF = CodeTemplate(
"""\ 116 {"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""")
118 PY_RETURN_NAMEDTUPLE_DEF = CodeTemplate(
"""\ 119 static PyStructSequence_Field fields${namedtuple_type_index}[] = { 120 ${namedtuple_fields} {nullptr} 122 static PyStructSequence_Desc desc${namedtuple_type_index} = { 123 "torch.return_types.${name}", nullptr, 124 fields${namedtuple_type_index}, ${namedtuple_size} 126 static PyTypeObject type${namedtuple_type_index}; 127 static bool namedtuple_type_initialized${namedtuple_type_index} = false; 128 if (!namedtuple_type_initialized${namedtuple_type_index}) { 129 PyStructSequence_InitType(&type${namedtuple_type_index}, &desc${namedtuple_type_index}); 130 type${namedtuple_type_index}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; 131 namedtuple_type_initialized${namedtuple_type_index} = true; 135 UNPACK_SELF =
"auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;" 137 PYTHON_FUNCTION_SIGNATURE = CodeTemplate(
"""\ 138 ${name}(${py_formal_args})""")
143 SUPPORTED_RETURN_TYPES = {
145 'std::tuple<Tensor,Tensor>',
146 'std::tuple<Tensor,Tensor,Tensor>',
147 'std::tuple<Tensor,Tensor,Tensor,Tensor>',
148 'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
149 'std::tuple<Tensor,Tensor,Tensor,int64_t>',
150 'std::tuple<Tensor,Tensor,double,int64_t>',
151 'std::vector<Tensor>',
152 'Scalar',
'bool',
'int64_t',
'void*',
'void' 155 TENSOR_OPTIONS = CodeTemplate(
"""\ 156 const auto options = TensorOptions() 159 .layout(${layout}.layout) 160 .requires_grad(${requires_grad}); 164 def should_generate_python_binding(declaration):
165 name = declaration[
'name']
166 for pattern
in SKIP_PYTHON_BINDINGS:
167 if re.match(
'^' + pattern +
'$', name):
170 simple_types = [arg[
'simple_type']
for arg
in declaration[
'arguments']]
171 signature =
'{}({})'.format(name,
', '.join(simple_types))
172 for pattern
in SKIP_PYTHON_BINDINGS_SIGNATURES:
173 if pattern == signature:
181 for arg
in declaration[
'arguments']:
182 if arg[
'type'] ==
'SparseTensorRef' and declaration[
'name'] !=
'sparse_mask':
188 def get_py_variable_methods(declarations):
190 Get declarations (grouped by name) which should be generated 191 as methods on Tensor. 193 def should_bind(declaration):
194 return (should_generate_python_binding(declaration)
and 195 declaration[
'mode'] !=
'NN' and 196 declaration.get(
'python_module') !=
'nn' and 197 'Tensor' in declaration[
'method_of'])
199 return group_declarations_by_name(declarations, should_bind)
202 def gen_py_variable_methods(out, declarations, template_path):
203 PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path +
'/python_variable_methods.cpp')
204 PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path +
'/python_variable_methods_dispatch.h')
206 py_variable_methods = get_py_variable_methods(declarations)
208 env = create_python_bindings(py_variable_methods,
True)
209 write(out,
'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
210 write(out,
'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env)
213 def get_py_nn_functions(declarations):
215 Get declarations (grouped by name) which should be generated 216 as functions in the "nn" module. 218 def should_bind(declaration):
219 return (should_generate_python_binding(declaration)
and 220 (declaration[
'mode'] ==
'NN' or declaration.get(
'python_module') ==
'nn'))
222 return group_declarations_by_name(declarations, should_bind)
225 def gen_py_nn_functions(out, declarations, template_path):
226 PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path +
'/python_nn_functions.cpp')
227 PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path +
'/python_nn_functions.h')
228 PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path +
'/python_nn_functions_dispatch.h')
230 py_nn_functions = get_py_nn_functions(declarations)
232 env = create_python_bindings(py_nn_functions, has_self=
False, is_module=
True)
233 write(out,
'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
234 write(out,
'python_nn_functions.h', PY_NN_FUNCTIONS_H, env)
235 write(out,
'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env)
238 def get_py_torch_functions(declarations):
240 Get declarations (grouped by name) which should be generated 241 as functions in the "torch" module. 243 def should_bind(declaration):
244 return (should_generate_python_binding(declaration)
and 245 declaration[
'mode'] !=
'NN' and 246 declaration.get(
'python_module') !=
'nn' and 247 'namespace' in declaration[
'method_of'])
249 return group_declarations_by_name(declarations, should_bind)
252 def gen_py_torch_functions(out, declarations, template_path):
253 PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path +
'/python_torch_functions.cpp')
254 PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path +
'/python_torch_functions_dispatch.h')
256 py_torch_functions = get_py_torch_functions(declarations)
258 env = create_python_bindings(py_torch_functions, has_self=
False)
259 write(out,
'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
260 write(out,
'python_torch_functions_dispatch.h', PY_TORCH_DISPATCH_H, env)
263 def group_declarations_by_name(declarations, should_bind_fn):
264 """Group declarations by name ignoring _out suffix""" 265 groups = defaultdict(list)
266 for declaration
in declarations:
267 name = declaration[
'name']
268 if should_bind_fn(declaration):
269 if name.endswith(
'_out'):
270 groups[name[:-4]].append(declaration)
272 groups[name].append(declaration)
276 def get_type_default(declaration):
277 if declaration[
'name'].startswith(
'randperm')
or \
278 declaration[
'name'] ==
'tril_indices' or \
279 declaration[
'name'] ==
'triu_indices':
285 def create_python_bindings(python_functions, has_self, is_module=False):
286 """Generates Python bindings to ATen functions""" 289 py_method_dispatch = []
292 'const Tensor &':
'tensor',
293 'SparseTensorRef':
'tensor',
294 'Tensor &':
'tensor',
295 'Generator *':
'generator',
296 'Storage &':
'storage',
297 'const Type &':
'scalartype',
298 'const THPLayout &':
'layout',
299 'const Device &':
'device',
300 'c10::optional<ScalarType>':
'scalartypeOptional',
301 'c10::optional<Scalar>':
'scalarOptional',
302 'c10::optional<int64_t>':
'toInt64Optional',
303 'IntArrayRef':
'intlist',
304 'int64_t':
'toInt64',
306 'double':
'toDouble',
307 'std::string':
'string',
310 unpack_with_default_methods = {
311 'IntArrayRef':
'setDefaultIntlist',
312 'Scalar':
'scalarWithDefault',
313 'int64_t':
'toInt64WithDefault',
314 'bool':
'setDefaultBool',
315 'double':
'setDefaultDouble',
316 'const Type &':
'scalartypeWithDefault',
317 'const THPLayout &':
'layoutWithDefault',
318 'const Device &':
'deviceWithDefault',
319 'ScalarType':
'scalartypeWithDefault',
322 def emit_single_dispatch(declaration, out_idx, base_env):
324 simple_return_type = declaration[
'return_type'].replace(
' &',
'')
325 assert simple_return_type
in SUPPORTED_RETURN_TYPES, \
326 declaration[
'name'] +
' returns unsupported type: ' + simple_return_type
334 return arg.get(
'output',
False)
336 inputs = [arg
for arg
in declaration[
'arguments']
if not is_output(arg)]
337 outputs = [arg
for arg
in declaration[
'arguments']
if is_output(arg)]
339 has_tensor_options = any(arg[
'simple_type'] ==
'TensorOptions' for arg
in declaration[
'arguments'])
341 def get_type_args(args):
342 return [arg
for arg
in args
if arg[
'simple_type'] ==
'Type']
343 type_actual_args = get_type_args(declaration[
'arguments'])
344 type_binding_args = get_type_args(declaration[
'python_binding_arguments'])
345 assert len(type_actual_args + type_binding_args) <= 1
346 if type_binding_args
and len(outputs) == 0:
348 type_args = type_binding_args
350 type_args = type_actual_args
352 if type_args
and len(outputs) > 1:
353 raise RuntimeError(
"Not supported: type dispatched parameter with multiple outputs")
355 def parse_arg(arg, arg_index, unpack_args=False):
357 typename = arg[
'type']
358 if typename.startswith(
'IntArrayRef['):
359 typename =
'IntArrayRef' 360 if typename.startswith(
'LongTensor'):
363 if arg.get(
'python_default_init'):
364 assert typename
in unpack_with_default_methods, \
365 '`{}` type is not supported in python_default_init'.format(typename)
366 unpack_with_default = unpack_with_default_methods.get(typename)
367 default_expr = arg.get(
'python_default_init')
368 expr =
'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
370 unpack = unpack_methods.get(typename, typename.lower())
371 expr =
'r.{}({})'.format(unpack, arg_index)
374 body.append(
'auto {} = {};'.format(name, expr))
377 if typename ==
'SparseTensorRef':
378 expr =
'SparseTensorRef({})'.format(expr)
380 dispatch_type = typename
381 if dispatch_type ==
'Tensor':
382 dispatch_type =
'const Tensor &' 383 elif dispatch_type ==
'Tensor &':
384 dispatch_type =
'Tensor' 385 elif dispatch_type ==
'const Device &':
386 dispatch_type =
'c10::optional<int32_t>' 387 formal =
'{} {}'.format(dispatch_type, name)
390 def append_actuals_formals(actual, formal):
391 actuals.append(actual)
392 formal_args.append(formal)
395 unpack = has_tensor_options
397 if arg[
'simple_type']
in [
'Type',
'TensorOptions']:
399 if has_self
and arg[
'name'] ==
'self':
400 formal_args.append(
'Tensor & self')
401 actuals.append(
'self')
403 append_actuals_formals(*parse_arg(arg, arg_idx, unpack))
406 if len(outputs) == 1:
407 append_actuals_formals(*parse_arg(outputs[0], arg_idx))
408 elif len(outputs) > 1:
410 body.append(
'auto results = r.tensorlist_n<{}>({});'.format(N, arg_idx))
411 for i, arg
in enumerate(outputs):
412 formal_args.append(
'Tensor & {}'.format(arg[
'name']))
413 actuals.append(
'results[{}]'.format(i))
416 parsed_type_args =
None 418 arg_idx = arg_idx
if out_idx
is None else out_idx + 1
419 for arg
in type_args:
420 parsed_type_args = parse_arg(arg, arg_idx, unpack)
424 has_device_bind =
False 426 python_binding_arguments = declaration.get(
'python_binding_arguments', [])
427 if 'dtype' in (a[
'name']
for a
in python_binding_arguments):
428 if not has_tensor_options:
431 if 'layout' in (a[
'name']
for a
in python_binding_arguments):
432 layout_idx, device_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2)
434 device_idx, requires_grad_idx = (arg_idx, arg_idx + 1)
437 for arg
in python_binding_arguments:
438 if arg[
'name'] ==
'dtype' and arg[
'simple_type'] ==
'Type':
440 elif arg[
'name'] ==
'layout' and arg[
'simple_type'] ==
'Layout':
442 if len(outputs) == 0:
443 layout = parse_arg(arg, layout_idx)[0]
444 elif arg[
'name'] ==
'device' and arg[
'simple_type'] ==
'Device':
445 if len(outputs) == 0:
446 assert parsed_type_args
448 device, device_type = parse_arg(arg, device_idx,
True)
450 if not has_tensor_options:
454 formal_args.append(parsed_type_args[1])
455 formal_args.append(device_type)
456 actuals.append(
"torch::getVariableType({}, {}, {})".format(parsed_type_args[0], layout, device))
457 actuals.append(
'{}.index()'.format(device))
459 has_device_bind =
True 460 elif arg[
'name'] ==
'requires_grad' and arg[
'simple_type'] ==
'bool':
461 requires_grad = parse_arg(arg, requires_grad_idx)[0]
463 raise RuntimeError((
"found {} in python_binding_arguments but only " 464 "\"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", " 465 "\"Device device\" are supported".format(arg)))
467 dtype = parsed_type_args[0]
if parsed_type_args
else None 468 if has_tensor_options
and all([dtype, device, layout, requires_grad]):
469 body.append(TENSOR_OPTIONS.substitute({
473 'requires_grad': requires_grad
475 formal_args.append(
'const TensorOptions & options')
476 actuals.append(
'options')
478 env[
'unpack_args'] = []
479 env[
'formal_args'] = formal_args
480 env[
'actuals'] = actuals
482 if has_tensor_options:
483 env[
'initialize_cuda'] =
'maybe_initialize_cuda(options);' 485 env[
'initialize_cuda'] =
'' 487 if 'call_args' in declaration:
488 env[
'dispatch_args'] = declaration[
'call_args']
490 env[
'dispatch_args'] = [arg[
'name']
for arg
in declaration[
'arguments']]
492 if 'Tensor' in declaration[
'method_of']:
493 env[
'dispatch_args'] = [arg
for arg
in env[
'dispatch_args']
if arg !=
'self']
494 env[
'dispatch_call'] =
'self.{}'.format(declaration[
'name'])
495 elif 'namespace' in declaration[
'method_of']:
496 namespace =
'torch' if (has_tensor_options
or declaration[
'name'].endswith(
'_like'))
else 'at' 497 env[
'dispatch_call'] =
'{}::{}'.format(namespace, declaration[
'name'])
499 raise RuntimeError(
'could not dispatch, neither namespace function nor Tensor method')
501 env[
'AutoNoGIL'] =
'AutoNoGIL no_gil;' if not declaration[
'with_gil']
else '' 521 env[
'simple_return_type'] = simple_return_type
523 env = nested_dict(env, nested_dict(base_env, declaration))
524 call_dispatch = PY_VARIABLE_CALL_DISPATCH.substitute(env)
525 if requires_grad
and not has_tensor_options:
526 call_dispatch = PY_VARIABLE_SET_REQUIRES_GRAD.substitute(env, call_dispatch=call_dispatch,
527 requires_grad=requires_grad)
528 if simple_return_type ==
'void':
529 body.append(
'{call_dispatch};'.format(call_dispatch=call_dispatch))
530 body.append(
'Py_RETURN_NONE;')
532 body.append(PY_VARIABLE_WRAP.substitute(env, call_dispatch=call_dispatch))
533 py_method_dispatch.append(PY_VARIABLE_DISPATCH.substitute(env))
536 def emit_dispatch(i, dictionary, base_env):
537 if 'out' in dictionary:
538 out_idx = len([arg
for arg
in dictionary[
'out'][
'arguments']
539 if not arg.get(
'output',
False)])
541 env[
'call_dispatch_out'] = emit_single_dispatch(dictionary[
'out'], out_idx, base_env)
542 env[
'call_dispatch'] = emit_single_dispatch(dictionary[
'base'], out_idx, base_env)
544 has_dtype_bind =
'dtype' in [d[
'name']
for d
in dictionary[
'out'].get(
'python_binding_arguments', [])]
546 body = PY_VARIABLE_OUT_CHECK_TYPE.substitute(env, out_idx=out_idx, type_idx=out_idx + 1,
547 layout_idx=out_idx + 2, device_idx=out_idx + 3).
split(
'\n')
549 body = PY_VARIABLE_OUT.substitute(env, out_idx=out_idx).
split(
'\n')
551 body = emit_single_dispatch(dictionary[
'base'],
None, base_env)
553 cond =
'if' if i == 0
else '} else if' 554 return PY_VARIABLE_CASE.substitute(i=i, cond=cond, call_dispatch=body)
556 def get_python_binding_arguments(declaration):
557 python_binding_arguments = []
558 has_tensor_input_arg =
False 559 has_type_input_arg =
False 560 has_options_arg =
False 561 for arg
in declaration[
'arguments']:
562 if arg.get(
'output',
False):
564 typename = arg[
'simple_type']
565 if typename
in [
'Tensor',
'TensorList']:
566 has_tensor_input_arg =
True 567 if arg[
'simple_type'] ==
'Type':
568 has_type_input_arg =
True 569 elif arg[
'simple_type'] ==
'TensorOptions':
570 has_options_arg =
True 571 if arg[
'name'] ==
'requires_grad':
572 raise ValueError(
"argument named requires_grad not supported")
574 has_tensor_return =
False 575 for ret
in declaration[
'returns']:
576 if ret[
'dynamic_type']
in [
'Tensor',
'TensorList']:
579 has_tensor_return =
True 581 is_like_function = name.endswith(
'_like')
582 is_like_function_with_options = is_like_function
and has_options_arg
583 is_factory_function = has_tensor_return
and not has_tensor_input_arg
584 is_factory_or_like_function = has_tensor_return
and (
not has_tensor_input_arg
or is_like_function)
586 if (is_factory_function
and not has_type_input_arg)
or has_options_arg:
587 default_type = get_type_default(declaration)
588 py_default_dtype =
'self.scalar_type()' if is_like_function_with_options
else None 590 'default': default_type,
591 'dynamic_type':
'Type',
594 'type':
'const Type &',
595 'simple_type':
'Type',
596 'python_default_init': py_default_dtype,
598 python_binding_arguments.append(dtype_arg)
599 if is_factory_function
or is_like_function_with_options:
600 py_default_layout =
'*torch::getLayout(self.type().backend())' if is_like_function_with_options
else None 602 'default':
'torch.strided',
603 'dynamic_type':
'Layout',
606 'type':
'const THPLayout &',
607 'simple_type':
'Layout',
608 'python_default_init': py_default_layout,
610 python_binding_arguments.append(layout_arg)
611 py_default_device =
'self.device()' if is_like_function_with_options
else None 614 'default_init':
'None',
615 'dynamic_type':
'Device',
618 'type':
'const Device &',
619 'simple_type':
'Device',
620 'python_default_init': py_default_device
622 python_binding_arguments.append(device_arg)
623 if is_factory_or_like_function:
624 requires_grad_arg = {
626 'dynamic_type':
'bool',
628 'name':
'requires_grad',
630 'simple_type':
'bool',
632 python_binding_arguments.append(requires_grad_arg)
633 return python_binding_arguments
635 def emit_namedtuple_return_type_def(declaration, next_index):
636 returns = declaration[
'returns']
637 if len(returns) <= 1
or all([
'field_name' not in x
for x
in returns]):
638 declaration[
'namedtuple_return_type'] =
'' 639 return '', next_index
640 declaration[
'namedtuple_type_index'] = next_index
641 declaration[
'namedtuple_fields'] =
'' 644 if 'field_name' not in x:
654 raise ValueError(
"Unnamed field is not supported by codegen")
656 declaration[
'namedtuple_fields'] +=
'{"' + x[
'field_name'] +
'", ""}, ' 657 declaration[
'namedtuple_size'] = len(returns)
658 declaration[
'namedtuple_return_type'] =
'&type{}, '.format(next_index)
659 return PY_RETURN_NAMEDTUPLE_DEF.substitute(declaration), next_index + 1
661 def process_function(name, declarations):
662 for declaration
in declarations:
663 declaration[
'python_binding_arguments'] = get_python_binding_arguments(declaration)
667 'dispatch_name':
'dispatch_{}'.format(name),
668 'pycname':
'THPVariable_{}'.format(name),
670 'max_args': max(len(o[
'arguments']) + len(o[
'python_binding_arguments'])
for o
in declarations),
673 'declare_namedtuple_return_types':
'',
677 env[
'unpack_self'] = [UNPACK_SELF]
681 for declaration
in declarations:
682 typedef, next_index = emit_namedtuple_return_type_def(declaration, next_index)
683 env[
'declare_namedtuple_return_types'] += typedef
686 grouped = group_declarations(declarations)
687 for i, dictionary
in enumerate(grouped):
688 signature = dictionary[
'signature']
690 signature = signature.replace(
'Tensor self, ',
'')
691 signature = signature.replace(
'Tensor self',
'')
694 signature = signature.replace(
'Tensor self',
'Tensor input')
695 signature = signature.replace(
'SparseTensorRef',
'Tensor')
696 if dictionary[
'base'].get(
'deprecated',
False):
697 signature +=
'|deprecated' 698 env[
'signatures'].append(
'"{}",'.format(signature))
699 env[
'dispatch'].append(emit_dispatch(i, dictionary, env))
701 env[
'dispatch'].append(
'}')
703 env[
'traceable'] =
'true' if all(should_trace(d)
for d
in declarations)
else 'false' 705 if len(declarations) == 1
and len(declarations[0][
'args']) == 1
and has_self:
706 tmpl = PY_VARIABLE_METHOD_NOARGS
707 env[
'actuals'] = [
'self']
708 env[
'flags'] =
'METH_NOARGS' 709 env[
'namedtuple_return_type'] = declarations[0][
'namedtuple_return_type']
711 tmpl = PY_VARIABLE_METHOD_VARARGS
712 env[
'flags'] =
'METH_VARARGS | METH_KEYWORDS' 714 if not is_module
and not has_self:
715 env[
'flags'] +=
' | METH_STATIC' 717 py_methods.append(tmpl.substitute(env))
718 py_method_defs.append(PY_VARIABLE_METHOD_DEF.substitute(env))
720 for name
in sorted(python_functions.keys()):
721 process_function(name, python_functions[name])
724 'py_methods': py_methods,
725 'py_method_defs': py_method_defs,
726 'py_method_dispatch': py_method_dispatch,
730 def group_declarations(declarations):
731 """Returns a list of dictionaries containing the optional keys: 733 "base": the regular ATen declaration (e.g. conv2d) 734 "out": the out variant (e.g. conv2d_out) 735 "signature": the signature used for Python argument parsing 737 grouped = defaultdict(dict)
740 for declaration
in declarations:
741 signature = get_python_signature(declaration,
False)
742 v = grouped[signature]
743 if declaration[
'name'].endswith(
'_out'):
744 v[
'out'] = declaration
746 v[
'signature'] = get_python_signature(declaration,
True)
748 v[
'base'] = declaration
749 if 'signature' not in v:
750 v[
'signature'] = signature
753 for _, dictionary
in sorted(grouped.items()):
754 if 'base' not in dictionary:
755 raise RuntimeError(
"'base' not in dictionary", dictionary)
756 result.append(dictionary)
757 return sort_declarations(result)
765 def sort_declarations(grouped_decls):
784 def normalized_dynamic_type(arg):
785 if arg[
'dynamic_type'] ==
'real':
787 return arg[
'dynamic_type']
789 def is_coord_smaller(arg1, arg2):
790 return normalized_dynamic_type(arg1) ==
'Scalar' and arg2[
'dynamic_type'] ==
'Tensor' 792 def is_smaller(d1, d2):
793 """Returns True if d1 < d2 in the partial order.""" 794 args1, args2 = d1[
'base'][
'arguments'], d2[
'base'][
'arguments']
795 if len(args1) != len(args2):
797 any_smaller = any(is_coord_smaller(arg1, arg2)
for arg1, arg2
in zip(args1, args2))
798 all_smaller_or_equal = all(normalized_dynamic_type(arg1) == normalized_dynamic_type(arg2)
or 799 is_coord_smaller(arg1, arg2)
800 for arg1, arg2
in zip(args1, args2))
801 return any_smaller
and all_smaller_or_equal
804 larger_than = defaultdict(set)
805 for i1, decl1
in enumerate(grouped_decls):
806 for i2, decl2
in enumerate(grouped_decls):
807 if is_smaller(decl1, decl2):
808 larger_than[i1].add(i2)
814 sorted_deps = [(i, decl)
for i, decl
in enumerate(grouped_decls)
815 if i
not in larger_than]
816 for i, decl
in sorted_deps:
817 for i2
in sorted(larger_than.keys()):
818 larger = larger_than[i2]
822 sorted_deps.append((i2, grouped_decls[i2]))
824 return [decl
for i, decl
in sorted_deps]
827 def get_python_signature(declaration, include_out):
843 def get_py_formal_arg(arg):
844 typename = arg[
'simple_type']
845 typename = typename
if typename !=
'Type' else 'ScalarType' 849 if arg.get(
'is_nullable')
and '?' not in typename:
850 typename =
'{}?'.format(typename)
852 if arg.get(
'size')
is not None:
853 typename =
'{}[{}]'.format(typename, arg[
'size'])
854 param = typename +
' ' + arg[
'name']
856 if arg.get(
'default')
is not None:
857 default = arg[
'default']
858 if default ==
'nullptr' or default ==
'nullopt' or default ==
'{}':
860 if default
is not None:
861 param +=
'=' + str(default)
864 for arg
in declaration[
'arguments']:
865 if arg.get(
'output',
False):
866 output_args.append(arg)
868 if arg[
'simple_type'] ==
'Type':
869 type_args.append(arg)
872 if arg[
'simple_type'] ==
'TensorOptions':
874 if arg.get(
'kwarg_only',
False)
and positional:
875 py_formal_args.append(
'*')
877 param = get_py_formal_arg(arg)
878 py_formal_args.append(param)
881 name = declaration[
'name']
882 if name.endswith(
'_out'):
885 if len(output_args) > 0
and include_out:
886 assert declaration[
'name'].endswith(
'_out')
888 py_formal_args.append(
'*')
890 typenames = [arg[
'simple_type']
for arg
in output_args]
891 if len(typenames) > 1:
892 typename =
'TensorList[{}]'.format(len(typenames))
894 typename = typenames[0]
895 if len(output_args) == 1:
898 py_formal_args.append(typename +
' ' + output_args[0][
'name'] +
'=None')
902 py_formal_args.append(typename +
' out=None')
908 assert len(type_args) <= 1
909 for arg
in type_args:
911 py_formal_args.append(
'*')
913 py_formal_args.append(get_py_formal_arg(arg))
915 if len(declaration[
'python_binding_arguments']) > 0:
916 for arg
in declaration[
'python_binding_arguments']:
917 if arg.get(
'kwarg_only',
False)
and positional:
918 py_formal_args.append(
'*')
920 py_formal_args.append(get_py_formal_arg(arg))
926 return PYTHON_FUNCTION_SIGNATURE.substitute(name=name, py_formal_args=py_formal_args)
Module caffe2.python.layers.split.