Caffe2 - Python API
A deep learning, cross platform ML framework
gen_python_functions.py
1 # Generates Python bindings for ATen functions
2 #
3 # The bindings are generated as methods on python_variable or functions on the
4 # torch._C._nn object.
5 #
6 from collections import defaultdict
7 import re
8 from .nested_dict import nested_dict
9 from .gen_variable_type import should_trace
10 from .utils import write
11 
12 try:
13  from src.ATen.code_template import CodeTemplate
14 except ImportError:
15  from tools.shared.module_loader import import_module
16  CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
17 
18 # These functions require manual Python bindings or are not exposed to Python
19 SKIP_PYTHON_BINDINGS = [
20  'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride',
21  '.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
22  '.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
23  '_arange.*', '_range.*', '_linspace.*', '_logspace.*',
24  '_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*',
25  'index',
26  '_indexCopy_', 'max_values', 'min_values',
27  '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
28  '_th_.*', '_thnn_.*',
29  'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*',
30  '_cholesky.*', '_btrifact.*',
31  'slice', 'randint(_out)?',
32  'item', '_local_scalar_dense',
33  'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
34  'copy_sparse_to_sparse_',
35 ]
36 
37 # These function signatures are not exposed to Python. Note that this signature
38 # list does not support regex.
39 SKIP_PYTHON_BINDINGS_SIGNATURES = [
40  'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)',
41  'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)',
42  'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)',
43  'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
44 ]
45 
46 PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
47 static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
48 {
49  HANDLE_TH_ERRORS
50  static PythonArgParser parser({
51  ${signatures}
52  }, /*traceable=*/${traceable});
53  ${unpack_self}
54  ParsedArgs<${max_args}> parsed_args;
55  auto r = parser.parse(args, kwargs, parsed_args);
56  ${declare_namedtuple_return_types}
57  ${dispatch}
58  Py_RETURN_NONE;
59  END_HANDLE_TH_ERRORS
60 }
61 """)
62 
63 PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
64 static PyObject * ${pycname}(PyObject* self_, PyObject* args)
65 {
66  HANDLE_TH_ERRORS
67  ${declare_namedtuple_return_types}
68  ${unpack_self}
69  return wrap(${namedtuple_return_type}${dispatch_name}(${actuals}));
70  END_HANDLE_TH_ERRORS
71 }
72 """)
73 
74 PY_VARIABLE_CASE = CodeTemplate("""\
75 ${cond} (r.idx == ${i}) {
76  ${call_dispatch}
77 """)
78 
79 PY_VARIABLE_OUT = CodeTemplate("""\
80 if (r.isNone(${out_idx})) {
81  ${call_dispatch}
82 } else {
83  ${call_dispatch_out}
84 }
85 """)
86 
87 PY_VARIABLE_OUT_CHECK_TYPE = CodeTemplate("""\
88 if (r.isNone(${out_idx})) {
89  ${call_dispatch}
90 } else {
91  check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.isNone(${type_idx}),
92  r.layout(${layout_idx}), r.isNone(${layout_idx}),
93  r.device(${device_idx}), r.isNone(${device_idx}));
94  ${call_dispatch_out}
95 }
96 """)
97 
98 PY_VARIABLE_CALL_DISPATCH = CodeTemplate("""\
99 ${dispatch_name}(${actuals})""")
100 
101 PY_VARIABLE_SET_REQUIRES_GRAD = CodeTemplate("""\
102 ${call_dispatch}.set_requires_grad(${requires_grad})""")
103 
104 PY_VARIABLE_WRAP = CodeTemplate("""\
105 return wrap(${namedtuple_return_type}${call_dispatch});""")
106 
107 PY_VARIABLE_DISPATCH = CodeTemplate("""\
108 inline ${simple_return_type} ${dispatch_name}(${formal_args}) {
109  ${initialize_cuda}
110  ${AutoNoGIL}
111  return ${dispatch_call}(${dispatch_args});
112 }
113 """)
114 
115 PY_VARIABLE_METHOD_DEF = CodeTemplate("""\
116 {"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""")
117 
118 PY_RETURN_NAMEDTUPLE_DEF = CodeTemplate("""\
119 static PyStructSequence_Field fields${namedtuple_type_index}[] = {
120  ${namedtuple_fields} {nullptr}
121 };
122 static PyStructSequence_Desc desc${namedtuple_type_index} = {
123  "torch.return_types.${name}", nullptr,
124  fields${namedtuple_type_index}, ${namedtuple_size}
125 };
126 static PyTypeObject type${namedtuple_type_index};
127 static bool namedtuple_type_initialized${namedtuple_type_index} = false;
128 if (!namedtuple_type_initialized${namedtuple_type_index}) {
129  PyStructSequence_InitType(&type${namedtuple_type_index}, &desc${namedtuple_type_index});
130  type${namedtuple_type_index}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
131  namedtuple_type_initialized${namedtuple_type_index} = true;
132 }
133 """)
134 
135 UNPACK_SELF = "auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;"
136 
137 PYTHON_FUNCTION_SIGNATURE = CodeTemplate("""\
138 ${name}(${py_formal_args})""")
139 
140 # XXX: if you got here because of an assertion failure, it doesn't mean
141 # it's enough to just extend the list here. Before you do this, make sure
142 # to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
143 SUPPORTED_RETURN_TYPES = {
144  'Tensor',
145  'std::tuple<Tensor,Tensor>',
146  'std::tuple<Tensor,Tensor,Tensor>',
147  'std::tuple<Tensor,Tensor,Tensor,Tensor>',
148  'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
149  'std::tuple<Tensor,Tensor,Tensor,int64_t>',
150  'std::tuple<Tensor,Tensor,double,int64_t>',
151  'std::vector<Tensor>',
152  'Scalar', 'bool', 'int64_t', 'void*', 'void'
153 }
154 
155 TENSOR_OPTIONS = CodeTemplate("""\
156 const auto options = TensorOptions()
157  .dtype(${dtype})
158  .device(${device})
159  .layout(${layout}.layout)
160  .requires_grad(${requires_grad});
161 """)
162 
163 
164 def should_generate_python_binding(declaration):
165  name = declaration['name']
166  for pattern in SKIP_PYTHON_BINDINGS:
167  if re.match('^' + pattern + '$', name):
168  return False
169 
170  simple_types = [arg['simple_type'] for arg in declaration['arguments']]
171  signature = '{}({})'.format(name, ', '.join(simple_types))
172  for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
173  if pattern == signature:
174  return False
175 
176  # TODO: fix handling of SparseTensor. We don't want to generate Python
177  # bindings to SparseTensor overloads, such as add(Tensor, SparseTensorRef),
178  # since the Tensor-based signature already dynamically dispatches correctly.
179  # However, sparse_mask only has a SparseTensor signature so we need to bind
180  # that function.
181  for arg in declaration['arguments']:
182  if arg['type'] == 'SparseTensorRef' and declaration['name'] != 'sparse_mask':
183  return False
184 
185  return True
186 
187 
188 def get_py_variable_methods(declarations):
189  """
190  Get declarations (grouped by name) which should be generated
191  as methods on Tensor.
192  """
193  def should_bind(declaration):
194  return (should_generate_python_binding(declaration) and
195  declaration['mode'] != 'NN' and
196  declaration.get('python_module') != 'nn' and
197  'Tensor' in declaration['method_of'])
198 
199  return group_declarations_by_name(declarations, should_bind)
200 
201 
202 def gen_py_variable_methods(out, declarations, template_path):
203  PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
204  PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
205 
206  py_variable_methods = get_py_variable_methods(declarations)
207 
208  env = create_python_bindings(py_variable_methods, True)
209  write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
210  write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env)
211 
212 
213 def get_py_nn_functions(declarations):
214  """
215  Get declarations (grouped by name) which should be generated
216  as functions in the "nn" module.
217  """
218  def should_bind(declaration):
219  return (should_generate_python_binding(declaration) and
220  (declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn'))
221 
222  return group_declarations_by_name(declarations, should_bind)
223 
224 
225 def gen_py_nn_functions(out, declarations, template_path):
226  PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
227  PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h')
228  PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h')
229 
230  py_nn_functions = get_py_nn_functions(declarations)
231 
232  env = create_python_bindings(py_nn_functions, has_self=False, is_module=True)
233  write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
234  write(out, 'python_nn_functions.h', PY_NN_FUNCTIONS_H, env)
235  write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env)
236 
237 
238 def get_py_torch_functions(declarations):
239  """
240  Get declarations (grouped by name) which should be generated
241  as functions in the "torch" module.
242  """
243  def should_bind(declaration):
244  return (should_generate_python_binding(declaration) and
245  declaration['mode'] != 'NN' and
246  declaration.get('python_module') != 'nn' and
247  'namespace' in declaration['method_of'])
248 
249  return group_declarations_by_name(declarations, should_bind)
250 
251 
252 def gen_py_torch_functions(out, declarations, template_path):
253  PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
254  PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
255 
256  py_torch_functions = get_py_torch_functions(declarations)
257 
258  env = create_python_bindings(py_torch_functions, has_self=False)
259  write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
260  write(out, 'python_torch_functions_dispatch.h', PY_TORCH_DISPATCH_H, env)
261 
262 
263 def group_declarations_by_name(declarations, should_bind_fn):
264  """Group declarations by name ignoring _out suffix"""
265  groups = defaultdict(list)
266  for declaration in declarations:
267  name = declaration['name']
268  if should_bind_fn(declaration):
269  if name.endswith('_out'):
270  groups[name[:-4]].append(declaration)
271  else:
272  groups[name].append(declaration)
273  return groups
274 
275 
276 def get_type_default(declaration):
277  if declaration['name'].startswith('randperm') or \
278  declaration['name'] == 'tril_indices' or \
279  declaration['name'] == 'triu_indices':
280  return 'torch.int64'
281  else:
282  return 'None'
283 
284 
285 def create_python_bindings(python_functions, has_self, is_module=False):
286  """Generates Python bindings to ATen functions"""
287  py_methods = []
288  py_method_defs = []
289  py_method_dispatch = []
290 
291  unpack_methods = {
292  'const Tensor &': 'tensor',
293  'SparseTensorRef': 'tensor',
294  'Tensor &': 'tensor',
295  'Generator *': 'generator',
296  'Storage &': 'storage',
297  'const Type &': 'scalartype',
298  'const THPLayout &': 'layout',
299  'const Device &': 'device',
300  'c10::optional<ScalarType>': 'scalartypeOptional',
301  'c10::optional<Scalar>': 'scalarOptional',
302  'c10::optional<int64_t>': 'toInt64Optional',
303  'IntArrayRef': 'intlist',
304  'int64_t': 'toInt64',
305  'bool': 'toBool',
306  'double': 'toDouble',
307  'std::string': 'string',
308  }
309 
310  unpack_with_default_methods = {
311  'IntArrayRef': 'setDefaultIntlist',
312  'Scalar': 'scalarWithDefault',
313  'int64_t': 'toInt64WithDefault',
314  'bool': 'setDefaultBool',
315  'double': 'setDefaultDouble',
316  'const Type &': 'scalartypeWithDefault',
317  'const THPLayout &': 'layoutWithDefault',
318  'const Device &': 'deviceWithDefault',
319  'ScalarType': 'scalartypeWithDefault',
320  }
321 
322  def emit_single_dispatch(declaration, out_idx, base_env):
323  env = {}
324  simple_return_type = declaration['return_type'].replace(' &', '')
325  assert simple_return_type in SUPPORTED_RETURN_TYPES, \
326  declaration['name'] + ' returns unsupported type: ' + simple_return_type
327 
328  body = []
329  actuals = []
330  formal_args = []
331  arg_idx = 0
332 
333  def is_output(arg):
334  return arg.get('output', False)
335 
336  inputs = [arg for arg in declaration['arguments'] if not is_output(arg)]
337  outputs = [arg for arg in declaration['arguments'] if is_output(arg)]
338 
339  has_tensor_options = any(arg['simple_type'] == 'TensorOptions' for arg in declaration['arguments'])
340 
341  def get_type_args(args):
342  return [arg for arg in args if arg['simple_type'] == 'Type']
343  type_actual_args = get_type_args(declaration['arguments'])
344  type_binding_args = get_type_args(declaration['python_binding_arguments'])
345  assert len(type_actual_args + type_binding_args) <= 1
346  if type_binding_args and len(outputs) == 0:
347  # out(s) determines the dtype if it is present, so only use this if there are no outputs.
348  type_args = type_binding_args
349  else:
350  type_args = type_actual_args
351 
352  if type_args and len(outputs) > 1:
353  raise RuntimeError("Not supported: type dispatched parameter with multiple outputs")
354 
355  def parse_arg(arg, arg_index, unpack_args=False):
356  name = arg['name']
357  typename = arg['type']
358  if typename.startswith('IntArrayRef['):
359  typename = 'IntArrayRef'
360  if typename.startswith('LongTensor'):
361  typename = 'Tensor'
362 
363  if arg.get('python_default_init'):
364  assert typename in unpack_with_default_methods, \
365  '`{}` type is not supported in python_default_init'.format(typename)
366  unpack_with_default = unpack_with_default_methods.get(typename)
367  default_expr = arg.get('python_default_init')
368  expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
369  else:
370  unpack = unpack_methods.get(typename, typename.lower())
371  expr = 'r.{}({})'.format(unpack, arg_index)
372 
373  if unpack_args:
374  body.append('auto {} = {};'.format(name, expr))
375  expr = name
376 
377  if typename == 'SparseTensorRef':
378  expr = 'SparseTensorRef({})'.format(expr)
379 
380  dispatch_type = typename
381  if dispatch_type == 'Tensor':
382  dispatch_type = 'const Tensor &'
383  elif dispatch_type == 'Tensor &':
384  dispatch_type = 'Tensor'
385  elif dispatch_type == 'const Device &':
386  dispatch_type = 'c10::optional<int32_t>'
387  formal = '{} {}'.format(dispatch_type, name)
388  return expr, formal
389 
390  def append_actuals_formals(actual, formal):
391  actuals.append(actual)
392  formal_args.append(formal)
393 
394  # We always want to unpack when we have TensorOptions.
395  unpack = has_tensor_options
396  for arg in inputs:
397  if arg['simple_type'] in ['Type', 'TensorOptions']:
398  continue
399  if has_self and arg['name'] == 'self':
400  formal_args.append('Tensor & self')
401  actuals.append('self')
402  continue
403  append_actuals_formals(*parse_arg(arg, arg_idx, unpack))
404  arg_idx += 1
405 
406  if len(outputs) == 1:
407  append_actuals_formals(*parse_arg(outputs[0], arg_idx))
408  elif len(outputs) > 1:
409  N = len(outputs)
410  body.append('auto results = r.tensorlist_n<{}>({});'.format(N, arg_idx))
411  for i, arg in enumerate(outputs):
412  formal_args.append('Tensor & {}'.format(arg['name']))
413  actuals.append('results[{}]'.format(i))
414 
415  layout = None
416  parsed_type_args = None
417  # type args go after the outputs to match the signature generation.
418  arg_idx = arg_idx if out_idx is None else out_idx + 1
419  for arg in type_args:
420  parsed_type_args = parse_arg(arg, arg_idx, unpack)
421  arg_idx += 1
422 
423  # check python_binding_arguments
424  has_device_bind = False
425  requires_grad = None
426  python_binding_arguments = declaration.get('python_binding_arguments', [])
427  if 'dtype' in (a['name'] for a in python_binding_arguments):
428  if not has_tensor_options:
429  arg_idx += 1
430 
431  if 'layout' in (a['name'] for a in python_binding_arguments):
432  layout_idx, device_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2)
433  else:
434  device_idx, requires_grad_idx = (arg_idx, arg_idx + 1)
435 
436  device = None
437  for arg in python_binding_arguments:
438  if arg['name'] == 'dtype' and arg['simple_type'] == 'Type':
439  pass # already handled by type_dispatched_args
440  elif arg['name'] == 'layout' and arg['simple_type'] == 'Layout':
441  # out(s) determines the type and layout if it is present, so only use this if there are no outputs.
442  if len(outputs) == 0:
443  layout = parse_arg(arg, layout_idx)[0]
444  elif arg['name'] == 'device' and arg['simple_type'] == 'Device':
445  if len(outputs) == 0:
446  assert parsed_type_args
447  assert layout
448  device, device_type = parse_arg(arg, device_idx, True)
449 
450  if not has_tensor_options:
451  # add type, device formals and corresponding actuals.
452  # The type actual is the ATen type mapped from (ScalarType, Layout, Device)
453  # The device actual is the corresponding AutoGPU index for the Device.
454  formal_args.append(parsed_type_args[1])
455  formal_args.append(device_type)
456  actuals.append("torch::getVariableType({}, {}, {})".format(parsed_type_args[0], layout, device))
457  actuals.append('{}.index()'.format(device))
458 
459  has_device_bind = True
460  elif arg['name'] == 'requires_grad' and arg['simple_type'] == 'bool':
461  requires_grad = parse_arg(arg, requires_grad_idx)[0]
462  else:
463  raise RuntimeError(("found {} in python_binding_arguments but only "
464  "\"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", "
465  "\"Device device\" are supported".format(arg)))
466 
467  dtype = parsed_type_args[0] if parsed_type_args else None
468  if has_tensor_options and all([dtype, device, layout, requires_grad]):
469  body.append(TENSOR_OPTIONS.substitute({
470  'dtype': dtype,
471  'layout': layout,
472  'device': device,
473  'requires_grad': requires_grad
474  }))
475  formal_args.append('const TensorOptions & options')
476  actuals.append('options')
477 
478  env['unpack_args'] = []
479  env['formal_args'] = formal_args
480  env['actuals'] = actuals
481 
482  if has_tensor_options:
483  env['initialize_cuda'] = 'maybe_initialize_cuda(options);'
484  else:
485  env['initialize_cuda'] = ''
486 
487  if 'call_args' in declaration:
488  env['dispatch_args'] = declaration['call_args']
489  else:
490  env['dispatch_args'] = [arg['name'] for arg in declaration['arguments']]
491 
492  if 'Tensor' in declaration['method_of']:
493  env['dispatch_args'] = [arg for arg in env['dispatch_args'] if arg != 'self']
494  env['dispatch_call'] = 'self.{}'.format(declaration['name'])
495  elif 'namespace' in declaration['method_of']:
496  namespace = 'torch' if (has_tensor_options or declaration['name'].endswith('_like')) else 'at'
497  env['dispatch_call'] = '{}::{}'.format(namespace, declaration['name'])
498  else:
499  raise RuntimeError('could not dispatch, neither namespace function nor Tensor method')
500 
501  env['AutoNoGIL'] = 'AutoNoGIL no_gil;' if not declaration['with_gil'] else ''
502 
503  # Use the simple_return_type (Tensor) rather than the fancy return type
504  # (Tensor &). This is important because the dispatch functions take
505  # mutable arguments *by value*, not by reference. If you then return
506  # a a reference to such an argument, you will now have a pointer to a
507  # dangling stack entry. Not good.
508  #
509  # You want:
510  #
511  # Tensor dispatch_selu_(Tensor self) { return at::selu_(self); }
512  #
513  # *not*
514  #
515  # Tensor& dispatch_selu_(Tensor self) { return at::selu_(self); }
516  #
517  # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
518  # codegen looks like dispatch_selu_(wrap(tensor)), and you can't take a
519  # mutable reference to temporary. Maybe we could assign it to a
520  # variable itself.)
521  env['simple_return_type'] = simple_return_type
522 
523  env = nested_dict(env, nested_dict(base_env, declaration))
524  call_dispatch = PY_VARIABLE_CALL_DISPATCH.substitute(env)
525  if requires_grad and not has_tensor_options:
526  call_dispatch = PY_VARIABLE_SET_REQUIRES_GRAD.substitute(env, call_dispatch=call_dispatch,
527  requires_grad=requires_grad)
528  if simple_return_type == 'void':
529  body.append('{call_dispatch};'.format(call_dispatch=call_dispatch))
530  body.append('Py_RETURN_NONE;')
531  else:
532  body.append(PY_VARIABLE_WRAP.substitute(env, call_dispatch=call_dispatch))
533  py_method_dispatch.append(PY_VARIABLE_DISPATCH.substitute(env))
534  return body
535 
536  def emit_dispatch(i, dictionary, base_env):
537  if 'out' in dictionary:
538  out_idx = len([arg for arg in dictionary['out']['arguments']
539  if not arg.get('output', False)])
540  env = {}
541  env['call_dispatch_out'] = emit_single_dispatch(dictionary['out'], out_idx, base_env)
542  env['call_dispatch'] = emit_single_dispatch(dictionary['base'], out_idx, base_env)
543 
544  has_dtype_bind = 'dtype' in [d['name'] for d in dictionary['out'].get('python_binding_arguments', [])]
545  if has_dtype_bind:
546  body = PY_VARIABLE_OUT_CHECK_TYPE.substitute(env, out_idx=out_idx, type_idx=out_idx + 1,
547  layout_idx=out_idx + 2, device_idx=out_idx + 3).split('\n')
548  else:
549  body = PY_VARIABLE_OUT.substitute(env, out_idx=out_idx).split('\n')
550  else:
551  body = emit_single_dispatch(dictionary['base'], None, base_env)
552 
553  cond = 'if' if i == 0 else '} else if'
554  return PY_VARIABLE_CASE.substitute(i=i, cond=cond, call_dispatch=body)
555 
556  def get_python_binding_arguments(declaration):
557  python_binding_arguments = []
558  has_tensor_input_arg = False
559  has_type_input_arg = False
560  has_options_arg = False
561  for arg in declaration['arguments']:
562  if arg.get('output', False):
563  continue
564  typename = arg['simple_type']
565  if typename in ['Tensor', 'TensorList']:
566  has_tensor_input_arg = True
567  if arg['simple_type'] == 'Type':
568  has_type_input_arg = True
569  elif arg['simple_type'] == 'TensorOptions':
570  has_options_arg = True
571  if arg['name'] == 'requires_grad':
572  raise ValueError("argument named requires_grad not supported")
573 
574  has_tensor_return = False
575  for ret in declaration['returns']:
576  if ret['dynamic_type'] in ['Tensor', 'TensorList']:
577  # this probably won't work if one of the returns is not a tensor, but it will
578  # produce a compile-time error that is obvious
579  has_tensor_return = True
580 
581  is_like_function = name.endswith('_like')
582  is_like_function_with_options = is_like_function and has_options_arg
583  is_factory_function = has_tensor_return and not has_tensor_input_arg
584  is_factory_or_like_function = has_tensor_return and (not has_tensor_input_arg or is_like_function)
585 
586  if (is_factory_function and not has_type_input_arg) or has_options_arg:
587  default_type = get_type_default(declaration)
588  py_default_dtype = 'self.scalar_type()' if is_like_function_with_options else None
589  dtype_arg = {
590  'default': default_type,
591  'dynamic_type': 'Type',
592  'kwarg_only': True,
593  'name': 'dtype',
594  'type': 'const Type &',
595  'simple_type': 'Type',
596  'python_default_init': py_default_dtype,
597  }
598  python_binding_arguments.append(dtype_arg)
599  if is_factory_function or is_like_function_with_options:
600  py_default_layout = '*torch::getLayout(self.type().backend())' if is_like_function_with_options else None
601  layout_arg = {
602  'default': 'torch.strided',
603  'dynamic_type': 'Layout',
604  'kwarg_only': True,
605  'name': 'layout',
606  'type': 'const THPLayout &',
607  'simple_type': 'Layout',
608  'python_default_init': py_default_layout,
609  }
610  python_binding_arguments.append(layout_arg)
611  py_default_device = 'self.device()' if is_like_function_with_options else None
612  device_arg = {
613  'default': 'None',
614  'default_init': 'None',
615  'dynamic_type': 'Device',
616  'kwarg_only': True,
617  'name': 'device',
618  'type': 'const Device &',
619  'simple_type': 'Device',
620  'python_default_init': py_default_device
621  }
622  python_binding_arguments.append(device_arg)
623  if is_factory_or_like_function:
624  requires_grad_arg = {
625  'default': False,
626  'dynamic_type': 'bool',
627  'kwarg_only': True,
628  'name': 'requires_grad',
629  'type': 'bool',
630  'simple_type': 'bool',
631  }
632  python_binding_arguments.append(requires_grad_arg)
633  return python_binding_arguments
634 
635  def emit_namedtuple_return_type_def(declaration, next_index):
636  returns = declaration['returns']
637  if len(returns) <= 1 or all(['field_name' not in x for x in returns]):
638  declaration['namedtuple_return_type'] = ''
639  return '', next_index
640  declaration['namedtuple_type_index'] = next_index
641  declaration['namedtuple_fields'] = ''
642  for x in returns:
643  # See Note [field_name versus name]
644  if 'field_name' not in x:
645  # When building on Windows, `PyStructSequence_UnnamedField` could not be
646  # resolved by the linker for some reason, which cause error in building:
647  #
648  # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
649  # PyStructSequence_UnnamedField
650  #
651  # Thus, at this point in time, we do not support unnamed
652  # fields in namedtuple; you must either name all fields,
653  # or none of them.
654  raise ValueError("Unnamed field is not supported by codegen")
655  else:
656  declaration['namedtuple_fields'] += '{"' + x['field_name'] + '", ""}, '
657  declaration['namedtuple_size'] = len(returns)
658  declaration['namedtuple_return_type'] = '&type{}, '.format(next_index)
659  return PY_RETURN_NAMEDTUPLE_DEF.substitute(declaration), next_index + 1
660 
661  def process_function(name, declarations):
662  for declaration in declarations:
663  declaration['python_binding_arguments'] = get_python_binding_arguments(declaration)
664 
665  env = {
666  'name': name,
667  'dispatch_name': 'dispatch_{}'.format(name),
668  'pycname': 'THPVariable_{}'.format(name),
669  'signatures': [],
670  'max_args': max(len(o['arguments']) + len(o['python_binding_arguments']) for o in declarations),
671  'unpack_self': [],
672  'dispatch': [],
673  'declare_namedtuple_return_types': '',
674  }
675 
676  if has_self:
677  env['unpack_self'] = [UNPACK_SELF]
678 
679  # generate namedtuple type declare
680  next_index = 0
681  for declaration in declarations:
682  typedef, next_index = emit_namedtuple_return_type_def(declaration, next_index)
683  env['declare_namedtuple_return_types'] += typedef
684 
685  # emit dispatch
686  grouped = group_declarations(declarations)
687  for i, dictionary in enumerate(grouped):
688  signature = dictionary['signature']
689  if has_self:
690  signature = signature.replace('Tensor self, ', '')
691  signature = signature.replace('Tensor self', '')
692  if not has_self:
693  # Use 'input' instead of 'self' for NN functions
694  signature = signature.replace('Tensor self', 'Tensor input')
695  signature = signature.replace('SparseTensorRef', 'Tensor')
696  if dictionary['base'].get('deprecated', False):
697  signature += '|deprecated'
698  env['signatures'].append('"{}",'.format(signature))
699  env['dispatch'].append(emit_dispatch(i, dictionary, env))
700 
701  env['dispatch'].append('}')
702 
703  env['traceable'] = 'true' if all(should_trace(d) for d in declarations) else 'false'
704 
705  if len(declarations) == 1 and len(declarations[0]['args']) == 1 and has_self:
706  tmpl = PY_VARIABLE_METHOD_NOARGS
707  env['actuals'] = ['self']
708  env['flags'] = 'METH_NOARGS'
709  env['namedtuple_return_type'] = declarations[0]['namedtuple_return_type']
710  else:
711  tmpl = PY_VARIABLE_METHOD_VARARGS
712  env['flags'] = 'METH_VARARGS | METH_KEYWORDS'
713 
714  if not is_module and not has_self:
715  env['flags'] += ' | METH_STATIC'
716 
717  py_methods.append(tmpl.substitute(env))
718  py_method_defs.append(PY_VARIABLE_METHOD_DEF.substitute(env))
719 
720  for name in sorted(python_functions.keys()):
721  process_function(name, python_functions[name])
722 
723  return {
724  'py_methods': py_methods,
725  'py_method_defs': py_method_defs,
726  'py_method_dispatch': py_method_dispatch,
727  }
728 
729 
730 def group_declarations(declarations):
731  """Returns a list of dictionaries containing the optional keys:
732 
733  "base": the regular ATen declaration (e.g. conv2d)
734  "out": the out variant (e.g. conv2d_out)
735  "signature": the signature used for Python argument parsing
736  """
737  grouped = defaultdict(dict)
738 
739  # first group by signature ignoring out arguments
740  for declaration in declarations:
741  signature = get_python_signature(declaration, False)
742  v = grouped[signature]
743  if declaration['name'].endswith('_out'):
744  v['out'] = declaration
745  # prefer the signature with optional out=... arguments
746  v['signature'] = get_python_signature(declaration, True)
747  else:
748  v['base'] = declaration
749  if 'signature' not in v:
750  v['signature'] = signature
751 
752  result = []
753  for _, dictionary in sorted(grouped.items()):
754  if 'base' not in dictionary:
755  raise RuntimeError("'base' not in dictionary", dictionary)
756  result.append(dictionary)
757  return sort_declarations(result)
758 
759 
760 # This function declares a partial order on declarations, and sorts them according
761 # to its linear extension. This is necessary, because there's some ambiguity in the
762 # choice of overload, and we want a different order.
763 #
764 # See Note[Order of overloads matters]
765 def sort_declarations(grouped_decls):
766 
767  # TODO: This is a hack!
768  #
769  # For some reason, when you specify a Scalar argument in a native
770  # function, you get a Declarations.yaml entry that looks like this:
771  #
772  # - default: 1
773  # dynamic_type: Scalar
774  # is_nullable: false
775  # kwarg_only: true
776  # name: alpha
777  # type: Scalar
778  #
779  # This is contrast to when there is a 'real' argument in TH
780  # Declarations.cwrap; this gets (correctly?) translated into
781  # dynamic_type: real, and type: Scalar. I would like to fix this
782  # at the source but I have never understood what dynamic_type is
783  # supposed to be.
784  def normalized_dynamic_type(arg):
785  if arg['dynamic_type'] == 'real':
786  return 'Scalar'
787  return arg['dynamic_type']
788 
789  def is_coord_smaller(arg1, arg2):
790  return normalized_dynamic_type(arg1) == 'Scalar' and arg2['dynamic_type'] == 'Tensor'
791 
792  def is_smaller(d1, d2):
793  """Returns True if d1 < d2 in the partial order."""
794  args1, args2 = d1['base']['arguments'], d2['base']['arguments']
795  if len(args1) != len(args2):
796  return False
797  any_smaller = any(is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2))
798  all_smaller_or_equal = all(normalized_dynamic_type(arg1) == normalized_dynamic_type(arg2) or
799  is_coord_smaller(arg1, arg2)
800  for arg1, arg2 in zip(args1, args2))
801  return any_smaller and all_smaller_or_equal
802 
803  # Construct the relation graph
804  larger_than = defaultdict(set)
805  for i1, decl1 in enumerate(grouped_decls):
806  for i2, decl2 in enumerate(grouped_decls):
807  if is_smaller(decl1, decl2):
808  larger_than[i1].add(i2)
809 
810  if not larger_than:
811  return grouped_decls
812 
813  # Use a topological sort to sort decls according to the partial order.
814  sorted_deps = [(i, decl) for i, decl in enumerate(grouped_decls)
815  if i not in larger_than]
816  for i, decl in sorted_deps:
817  for i2 in sorted(larger_than.keys()):
818  larger = larger_than[i2]
819  larger.discard(i)
820  if not larger:
821  del larger_than[i2]
822  sorted_deps.append((i2, grouped_decls[i2]))
823 
824  return [decl for i, decl in sorted_deps]
825 
826 
827 def get_python_signature(declaration, include_out):
828  # Compute the Python function signature for argument parsing,
829  # as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
830  # this is NOT the same type signature as specified by PEP 484
831  # as understood by mypy; our format was independently developed
832  # and has some quirks to make it more suitable specifically
833  # for error parsing.
834  #
835  # For a translation to mypy-valid type signatures, see
836  # tools/gen_pyi.py. If you change any logic here, please
837  # check that file too.
838  py_formal_args = []
839  output_args = []
840  type_args = []
841  positional = True
842 
843  def get_py_formal_arg(arg):
844  typename = arg['simple_type']
845  typename = typename if typename != 'Type' else 'ScalarType'
846 
847  # TODO: remove this and make optional types in simple_type to be consistent across
848  # tensor and other types after make Tensor? be optional instead of undefined
849  if arg.get('is_nullable') and '?' not in typename:
850  typename = '{}?'.format(typename)
851 
852  if arg.get('size') is not None:
853  typename = '{}[{}]'.format(typename, arg['size'])
854  param = typename + ' ' + arg['name']
855  default = None
856  if arg.get('default') is not None:
857  default = arg['default']
858  if default == 'nullptr' or default == 'nullopt' or default == '{}':
859  default = 'None'
860  if default is not None:
861  param += '=' + str(default)
862  return param
863 
864  for arg in declaration['arguments']:
865  if arg.get('output', False):
866  output_args.append(arg)
867  continue
868  if arg['simple_type'] == 'Type':
869  type_args.append(arg)
870  continue
871  # Skip `TensorOptions` in Python, as it is only used on the C++ side.
872  if arg['simple_type'] == 'TensorOptions':
873  continue
874  if arg.get('kwarg_only', False) and positional:
875  py_formal_args.append('*')
876  positional = False
877  param = get_py_formal_arg(arg)
878  py_formal_args.append(param)
879 
880  # add output arguments
881  name = declaration['name']
882  if name.endswith('_out'):
883  name = name[:-4]
884 
885  if len(output_args) > 0 and include_out:
886  assert declaration['name'].endswith('_out')
887  if positional:
888  py_formal_args.append('*')
889  positional = False
890  typenames = [arg['simple_type'] for arg in output_args]
891  if len(typenames) > 1:
892  typename = 'TensorList[{}]'.format(len(typenames))
893  else:
894  typename = typenames[0]
895  if len(output_args) == 1:
896  # The nn module bindings are often not exposed to the user directly
897  # but via torch.nn modules and functionals.
898  py_formal_args.append(typename + ' ' + output_args[0]['name'] + '=None')
899  else:
900  # NB: For more than 1 output args the type name is a TensorList
901  # and as such we don't (yet) need to consider the naming.
902  py_formal_args.append(typename + ' out=None')
903 
904  # we could put this in the loop above but we want to ensure both type dispatched args
905  # and python binding arguments are after the out argument; this matches the case
906  # where there is a python binding argument dtype, which is necessary to match
907  # the function signatures between the out and non-out variant.
908  assert len(type_args) <= 1
909  for arg in type_args:
910  if positional: # assume type_args should be kwarg_only.
911  py_formal_args.append('*')
912  positional = False
913  py_formal_args.append(get_py_formal_arg(arg))
914 
915  if len(declaration['python_binding_arguments']) > 0:
916  for arg in declaration['python_binding_arguments']:
917  if arg.get('kwarg_only', False) and positional:
918  py_formal_args.append('*')
919  positional = False
920  py_formal_args.append(get_py_formal_arg(arg))
921 
922  # Python function signature.
923  # This is the string that we give to FunctionParameter, which is
924  # then parsed into the actual structure which we do parsing
925  # with.
926  return PYTHON_FUNCTION_SIGNATURE.substitute(name=name, py_formal_args=py_formal_args)
Module caffe2.python.layers.split.