Caffe2 - Python API
A deep learning, cross platform ML framework
gen_jit_dispatch.py
1 """
2 To run this file by hand from the root of the PyTorch
3 repository, run:
4 
5 python -m tools.jit.gen_jit_dispatch \
6  build/aten/src/ATen/Declarations.yaml \
7  $OUTPUT_DIR \
8  tools/jit/templates
9 
10 Where $OUTPUT_DIR is where you would like the files to be
11 generated. In the full build system, OUTPUT_DIR is
12 torch/csrc/jit/generated/
13 """
14 
15 import os
16 import argparse
17 import re
18 import copy
19 from itertools import count, combinations, groupby
20 from ..autograd.utils import CodeTemplate, write, uninplace_api_name
21 from ..autograd.gen_autograd import load_aten_declarations
22 from collections import OrderedDict
23 from ..autograd.gen_autograd import RETURNS_VIEWS_OF_INPUT
24 
25 # JIT has a type system of
26 # Scalar = int | float | bool # int is the largest int (int64_t),
27 # float is the largest float (double) we don't have the others because they are never held in tensors
28 # Type = Scalar # primitive numbers
29 # | Tensor # any tensor, as defined by at::Tensor
30 # | Type[] # a dynamically sized list[ of a type
31 # | Scalar[N] # a homogenous fixed size scalar list, single scalars can expand to this list
32 # | (Type1, Type2, ...) # a heterogenous tuple
33 # | Layout | ScalarType | Device | Generator # special singleton types for built-in concepts in tensor lib
34 
35 # clean up the variety of C++ types in the ATen declarations
36 # to be in the restricted set of types that the IR represents
37 # note: no default values for this map, to make it clear what types
38 # can be passedthrough
39 
40 TYPE_MAP = {
41  'std::array<bool,2>': 'bool[2]',
42  'std::array<bool,3>': 'bool[3]',
43  'std::array<bool,4>': 'bool[4]',
44  'std::string': 'str',
45  'Scalar': 'Scalar',
46  'Scalar?': 'Scalar?',
47  'Tensor': 'Tensor',
48  'Tensor?': 'Tensor?',
49  'TensorList': 'Tensor[]',
50  # this appears in return values instead of TensorList
51  # since TensorList is a ArrayRef in arguments but a vector
52  # in returns
53  'std::vector<Tensor>': 'Tensor[]',
54  'IntArrayRef': 'int[]',
55  'Layout': 'Layout',
56  'Layout?': 'Layout?',
57  'Device': 'Device',
58  'Device?': 'Device?',
59  'ScalarType': 'ScalarType',
60  'ScalarType?': 'ScalarType?',
61  'int64_t': 'int',
62  'int64_t?': 'int?',
63  'double': 'float',
64  'bool': 'bool',
65  'Generator': 'Generator?',
66 }
67 
68 
69 def optional_type_of(arg, typ):
70  # optional type special handling for Tensor?[] and Tensor
71  # types that is missing a optional annotation
72  if arg.get('is_nullable') and '?' not in typ:
73  if typ == 'TensorList' or typ == 'Tensor[]':
74  typ = 'Tensor?[]'
75  else:
76  typ = '{}?'.format(typ)
77  return typ
78 
79 
80 def jit_type_of(arg):
81  # override for when viewing ops have already set
82  # annotated jit types
83  if 'jit_type' in arg:
84  return arg['jit_type']
85  typ = TYPE_MAP[arg['simple_type']]
86  if is_sized_intlist_arg(arg):
87  typ = 'int[{}]'.format(arg['size'])
88 
89  typ = optional_type_of(arg, typ)
90  return typ
91 
92 
93 # map from aten 'simple_type' to the function that will turn a tensor into
94 # that type
95 FROM_IVALUE = {
96  'Device': '{}.toDevice()',
97  'Device?': '{}.toOptional<c10::Device>()',
98  'IntArrayRef': '{}.toIntList()->elements()',
99  'Layout': '{}.toLayout()',
100  'Layout?': '{}.toOptional<c10::Layout>()',
101  'Scalar': '{}.toScalar()',
102  'Scalar?': '{}.toOptional<Scalar>()',
103  'ScalarType': '{}.toScalarType()',
104  'ScalarType?': '{}.toOptional<ScalarType>()',
105  'Tensor': '{}.toTensor()',
106  'Tensor?': 'toOptionalTensor({})',
107  'Tensor?[]': 'toListOfOptionalTensor({})',
108  'TensorList': '{}.toTensorList()->elements()',
109  'bool': '{}.toBool()',
110  'double': '{}.toDouble()',
111  'int64_t': '{}.toInt()',
112  'int64_t?': '{}.toOptional<int64_t>()',
113  'std::string': '{}.toString()->string()',
114  'Generator': 'nullptr',
115  'std::array<bool,2>': 'as_bool_array<2>({}.toBoolListRef())',
116  'std::array<bool,3>': 'as_bool_array<3>({}.toBoolListRef())',
117  'std::array<bool,4>': 'as_bool_array<4>({}.toBoolListRef())',
118 }
119 
120 
121 def from_ivalue(arg, value):
122  typ = optional_type_of(arg, arg['simple_type'])
123  return FROM_IVALUE[typ].format(value)
124 
125 
126 CALL_NAMESPACE = CodeTemplate("""\
127 auto result_ = at::${name}(
128  ${args}
129 );
130 """)
131 CALL_METHOD = CodeTemplate("""\
132 auto result_ = (${first}).${name}(
133  ${args}
134 );
135 """)
136 CALL_NAMESPACE_WITH_TENSOR_OPTIONS = CodeTemplate("""\
137 const auto options = TensorOptions()
138  .dtype(${dtype})
139  .layout(${layout})
140  .device(${device});
141 auto result_ = torch::${name}(${args_with_tensor_options});
142 """)
143 CALL_METHOD_WITH_TENSOR_OPTIONS = CodeTemplate("""\
144 const auto options = TensorOptions()
145  .dtype(${dtype})
146  .layout(${layout})
147  .device(${device});
148 auto result_ = (${first}).${name}(${args_with_tensor_options});
149 """)
150 
151 CONSTRUCTOR = CodeTemplate("""\
152 [](Stack & stack) {
153  autograd::profiler::RecordFunction record("${name}");
154  ${lvalues}
155  ${call}
156  drop(stack, ${num_inputs});
157  pack(stack, std::move(result_));
158  return 0;
159 }
160 """)
161 
162 OPERATOR = CodeTemplate("""\
163 Operator(
164  "${signature}",
165  ${op}
166 ),
167 """)
168 
169 
170 blacklisted_types = {'SparseTensorRef', 'Storage', 'void*'}
171 default_only_types = {'Generator'}
172 
173 
174 def is_jit_arg(i, arg):
175  simple_type = arg['simple_type']
176  if simple_type in blacklisted_types:
177  return False
178  if simple_type in default_only_types and 'default' not in arg:
179  return False
180  if simple_type == 'Type':
181  return False
182  return True
183 
184 
185 def is_jit_op(decl):
186  # We currently don't support functions that return nothing
187  if all(r['type'] == 'void' for r in decl['returns']):
188  return False
189 
190  arguments = decl['arguments']
191 
192  # there must be a single out variant
193  if is_out_variant(decl) and sum([not not arg.get('output') for arg in arguments]) > 1:
194  return False
195 
196  return (('namespace' in decl['method_of'] or 'Tensor' in decl['method_of']) and
197  all(is_jit_arg(i, arg) for i, arg in enumerate(decl['arguments'])) and
198  all(is_jit_arg(i, arg) for i, arg in enumerate(decl['returns'])))
199 
200 
201 def is_tensor_arg(arg):
202  return arg['simple_type'] in {'Tensor', 'TensorList'}
203 
204 
206  """Returns True for arguments declared as IntArrayRef[k], but False for IntArrayRef."""
207  return (arg['simple_type'] == 'IntArrayRef') and ('size' in arg)
208 
209 
210 def base_name(decl):
211  name = decl['name']
212  return name[:-1] if decl.get('inplace', False) else name[:-4] if name.endswith('_out') else name
213 
214 
215 def is_view(decl):
216  return base_name(decl) in RETURNS_VIEWS_OF_INPUT
217 
218 
219 def is_out_variant(decl):
220  return decl['name'].endswith('_out')
221 
222 
223 # for each argument in decl, the location it should appear in the
224 # jit schema declaration. e.g.
225 # arguments = [x, y, z] # the order in aten
226 # jit_argument_order = [2, 0, 1]
227 # aten::my_arg(Tensor y, Tensor z, Tensor x) # the order in schema
228 # used to move 'out' arguments to the end of the list
229 def argument_order(decl):
230  return decl.get('jit_argument_order') or list(range(len(decl['arguments'])))
231 
232 
233 def gen_jit_dispatch(declarations, out, template_path):
234  REGISTER_ATEN_OPS_CPP = CodeTemplate.from_file(template_path + '/register_aten_ops.cpp')
235 
236  ops = []
237 
238  def get_invocation(decl, args, num_inputs):
239 
240  # because the arg list can get lengthy we put them on a separate line
241  def pack_arguments(args):
242  return ',\n'.join(args)
243  is_namespace_function = 'namespace' in decl['method_of']
244  tensor_options_arg_index = decl.get('tensor_options_arg_index', None)
245  if tensor_options_arg_index is not None:
246  dtype = args[tensor_options_arg_index]
247  layout = args[tensor_options_arg_index + 1]
248  device = args[tensor_options_arg_index + 2]
249  args_with_tensor_options = args[:tensor_options_arg_index] + \
250  ['options'] + args[(tensor_options_arg_index + 3):]
251  if is_namespace_function:
252  return CALL_NAMESPACE_WITH_TENSOR_OPTIONS.substitute(
253  name=decl['name'], dtype=dtype, layout=layout, device=device,
254  args_with_tensor_options=pack_arguments(args_with_tensor_options))
255  else:
256  return CALL_METHOD_WITH_TENSOR_OPTIONS.substitute(
257  name=decl['name'], dtype=dtype, layout=layout, device=device,
258  args_with_tensor_options=pack_arguments(args_with_tensor_options[1:]),
259  first=args_with_tensor_options[0], num_inputs=num_inputs)
260  else:
261  if is_namespace_function:
262  return CALL_NAMESPACE.substitute(name=decl['name'],
263  args=pack_arguments(args),
264  num_inputs=num_inputs)
265  else:
266  return CALL_METHOD.substitute(
267  name=decl['name'], first=args[0],
268  args=pack_arguments(args[1:]), num_inputs=num_inputs)
269 
270  def requires_lvalue(arg):
271  return 'jit_type' in arg and arg['jit_type'] in {"Tensor!", "Tensor(a!)"}
272 
273  def emit_decl_variant(decl):
274  kw_assignments = []
275 
276  # mutable arguments in aten are passed as non const references
277  # these must be lvalues, so we have to put them in variables
278  # before calling the function
279  lvalues = []
280 
281  arguments = []
282  num_inputs = len(decl['arguments'])
283  op_capture = ''
284  order = argument_order(decl)
285  for i, arg in enumerate(decl['arguments']):
286  value = from_ivalue(arg, '(std::move(peek(stack, {}, {})))'.format(order[i], num_inputs))
287  if requires_lvalue(arg):
288  lvalues.append('auto {} = {};\n'.format(arg['name'], value))
289  value = arg['name']
290  arguments.append(value)
291 
292  call = get_invocation(decl, arguments, num_inputs)
293 
294  returns = decl['returns']
295 
296  constructor = CONSTRUCTOR.substitute(name=decl['name'],
297  call=call,
298  kw_assignments=kw_assignments,
299  num_inputs=num_inputs,
300  op_capture=op_capture,
301  lvalues=lvalues)
302  return constructor
303 
304  # This function declares an order on declarations. This is necessary because
305  # there is some ambiguity in the choice of overload: if an argument is overloaded
306  # to accept both Scalar and Tensor, the schema with the Tensor should come first
307  # TODO: this can (probably) be removed when we remove the implicit conversion
308  # from Tensor -> Number.
309  def sort_decls(jit_decls):
310  def declkey(decl):
311  # key = sum_{i < len(args)} {1 if arg is tensor else 2} * (3 ** i)
312  # This is a ternary encoding where
313  # 0: No argument at this position
314  # 1: Tensor argument at this position
315  # 2: Some other argument at this position.
316  args = decl['arguments']
317  result = 0
318  for i in range(len(args)):
319  result += (3 ** i) * (1 if args[i]['simple_type'] == 'Tensor' else 2)
320  return result
321 
322  # NB: itertools.groupby requires the list be sorted.
323  sorted_decls = sorted(jit_decls, key=lambda decl: decl['name'])
324  grouped_decls = [list(g) for _, g in
325  groupby(sorted_decls, key=lambda decl: decl['name'])]
326  return [sorted(g, key=declkey) for g in grouped_decls]
327 
328  # We need to add methods implemented manually in TensorImpl
329  tensor_impl_methods = [{
330  'name': name,
331  'api_name': name,
332  'method_of': ['Tensor'],
333  'arguments': [{'name': 'self', 'simple_type': 'Tensor'}],
334  'returns': [{'name': 'result', 'type': 'int64_t', 'dynamic_type': 'int64_t', 'simple_type': 'int64_t'}],
335  } for name in ['sizes', 'strides', 'dim']]
336  aten_decls = load_aten_declarations(declarations) + tensor_impl_methods
337  jit_decls = [d for d in aten_decls if is_jit_op(d)]
338 
339  # add arguments dtype and device for functions like zeros
340  def expand_options(decl, i, arg):
341  if arg['simple_type'] != 'TensorOptions':
342  return [arg]
343  assert decl.get('tensor_options_arg_index') != i
344  decl['tensor_options_arg_index'] = i
345  tensor_options_expansion = [
346  # XXX - until we actually have first-class interpreter types for these
347  # concepts, the default values to be encoded in Tensors
348  # If you change this, you also need to update [TensorOptions in script]
349  # in the tracer code.
350  # dtype is specified as an int64_t of at::ScalarType
351  {'name': 'dtype', 'simple_type': 'ScalarType'},
352  # layout is specified as an int64_t of at::Layout
353  {'name': 'layout', 'simple_type': 'Layout'},
354  # device is specified as an IntArrayRef of { at::Device::Type, device_id }
355  {'name': 'device', 'simple_type': 'Device'},
356  ]
357  # TODO: Don't repack this into TensorOptions. Needs various changes in downstream code.
358  if 'default' in arg:
359  tensor_options_expansion[0]['simple_type'] += '?'
360  tensor_options_expansion[1]['simple_type'] += '?'
361  tensor_options_expansion[2]['simple_type'] += '?'
362  tensor_options_expansion[0]['default'] = 'None'
363  tensor_options_expansion[1]['default'] = 'None'
364  tensor_options_expansion[2]['default'] = 'None'
365  if 'default' in arg and arg['default'] == 'at::kLong':
366  tensor_options_expansion[0]['default'] = 'long'
367  if 'kwarg_only' in arg and arg['kwarg_only']:
368  tensor_options_expansion[0]['kwarg_only'] = True
369  tensor_options_expansion[1]['kwarg_only'] = True
370  tensor_options_expansion[2]['kwarg_only'] = True
371  return tensor_options_expansion
372 
373  additional_jit_decls = []
374 
375  for decl in jit_decls:
376  decl['arguments'] = [a for i, arg in enumerate(decl['arguments']) for a in expand_options(decl, i, arg)]
377  # add annotations about alias an mutability of arguments
378  annotate_op(decl)
379 
380  decl['should_match_schema'] = True
381 
382  decl_copy = copy.deepcopy(decl)
383  for arg in decl_copy['arguments']:
384  if arg['simple_type'] == 'TensorList' and arg.get('is_nullable'):
385  arg['is_nullable'] = False
386  decl_copy['should_match_schema'] = False
387  additional_jit_decls.append(decl_copy)
388 
389  jit_decls.extend(additional_jit_decls)
390 
391  # Group and sort the generated snippets to ensure that the
392  # generation is deterministic
393  jit_decl_groups = sort_decls(jit_decls)
394 
395  # NOTE: see Note [Sharded File] at the top of the register_aten_ops.cpp
396  # template regarding sharding of the generated files.
397  #
398  # If you edit the number of shards here, you will also have to
399  # modify generate_code.py, torch/CMakeLists.txt, and the TARGETS
400  # files.
401  num_shards = 3
402  shards = [[] for _ in range(num_shards)]
403 
404  # ops are assigned arbitrarily but stably to a file based on hash
405  for group in jit_decl_groups:
406  x = sum(ord(c) for c in group[0]['name']) % num_shards
407  for decl in group:
408  shards[x].append(OPERATOR.substitute(signature=signature(decl, decl['should_match_schema']),
409  op=emit_decl_variant(decl)))
410 
411  for i, shard in enumerate(shards):
412  env = {
413  'constructors': shard,
414  }
415  write(out, 'register_aten_ops_%d.cpp' % i, REGISTER_ATEN_OPS_CPP, env)
416 
417 
418 default_map = {'{}': 'None', 'nullptr': 'None', 'c10::nullopt': 'None'}
419 
420 
421 def annotate_op(decl):
422  # insert alias annotations into viewing operators
423  if decl.get('inplace') or is_out_variant(decl):
424  first_arg = decl['arguments'][0]
425  assert(jit_type_of(first_arg) == 'Tensor')
426  first_arg['jit_type'] = 'Tensor(a!)'
427  first_ret = decl['returns'][0]
428  assert(jit_type_of(first_ret) == 'Tensor')
429  first_ret['jit_type'] = 'Tensor(a!)'
430  if is_out_variant(decl):
431  assert(first_arg['output'])
432  # the output variant must go at the end
433  # note: this is an annoying side effect of using a single '*'
434  # to denote kwarg_only
435  nargs = len(decl['arguments'])
436  decl['jit_argument_order'] = [nargs - 1] + list(range(nargs - 1))
437  elif is_view(decl):
438  first_arg = decl['arguments'][0]
439  assert jit_type_of(first_arg) == 'Tensor'
440  first_arg['jit_type'] = 'Tensor(a)'
441  first_ret = decl['returns'][0]
442  ret_type = jit_type_of(first_ret)
443  if ret_type == 'Tensor[]':
444  first_ret['jit_type'] = 'Tensor(a)[]'
445  elif ret_type == 'Tensor':
446  first_ret['jit_type'] = 'Tensor(a)'
447 
448 
449 def is_kwarg_only(a):
450  return a.get('kwarg_only') or a.get('output')
451 
452 
453 def match_signature(decl, constructed_string, should_match_schema):
454  # If matches_jit_signature has been specified the signature constructed from the
455  # declared attributes should match the raw string passed through. In the
456  # case of native_functions.yaml, func should match the generated signature,
457  # if matches_jit_signature is true. This is used to track and verify the alignment
458  # of native_function.yaml's function schema with that used in this parse.
459  if decl.get('matches_jit_signature') and should_match_schema:
460  assert(constructed_string == decl['schema_string']), \
461  decl['schema_string'] + ' is flagged as JIT signature compliant' + \
462  ', but does not match the signature ' + constructed_string
463  return decl['schema_string']
464 
465  return constructed_string
466 
467 
468 def signature(decl, should_match_schema=True):
469  def format_arg(arg):
470  name = arg['name']
471  typ = jit_type_of(arg)
472  decl = '{} {}'.format(typ, name)
473  if 'default' in arg:
474  # clean up initializer lists {{true, true}} -> [true, true]
475  default = arg['default']
476  # NOTE: str(float) in python2 truncates, which makes JIT signatures not match native_functions
477  # signatures. repr(float) doesn't seem to truncate in these cases.
478  default = str(default) if not isinstance(default, float) else repr(default)
479  default = default \
480  .replace('{{', '[') \
481  .replace('}}', ']') \
482  .replace('true', 'True') \
483  .replace('false', 'False') \
484  .replace('Reduction::Mean', 'Mean') \
485  .replace('{}', 'None' if is_tensor_arg(arg) else '[]') \
486  .replace('{', '[') \
487  .replace('}', ']')
488 
489  default = default_map.get(default, default)
490  decl = '{}={}'.format(decl, default)
491  return decl
492 
493  args = []
494  kwarg_only = False
495 
496  ordered_arguments = sorted(zip(argument_order(decl), decl['arguments']))
497  for _, a in ordered_arguments:
498  if not kwarg_only and is_kwarg_only(a):
499  args.append('*')
500  kwarg_only = True
501  args.append(format_arg(a))
502 
503  arg_list = ', '.join(args)
504  if len(decl['returns']) == 1:
505  ret_list = jit_type_of(decl['returns'][0])
506  # Adding output name if it exists
507  if decl['returns'][0].get('field_name'):
508  ret_list += ' ' + decl['returns'][0]['field_name']
509  else:
510  def type_maybe_field(r):
511  return '{} {}'.format(jit_type_of(r), r['field_name']) if 'field_name' in r else jit_type_of(r)
512  ret_list = '({})'.format(', '.join(type_maybe_field(r) for r in decl['returns']))
513  name = decl['name'] if not is_out_variant(decl) else decl['name'][:-4]
514  constructed_string = 'aten::{}({}) -> {}'.format(name, arg_list, ret_list)
515  return match_signature(decl, constructed_string, should_match_schema)
516 
517 
518 def main():
519  parser = argparse.ArgumentParser(
520  description='Generate JIT op dispatch')
521  parser.add_argument('declarations', metavar='DECL',
522  help='path to Declarations.yaml')
523  parser.add_argument('out', metavar='OUT',
524  help='path to output directory')
525  parser.add_argument('template_path', metavar='TEMPLATE_PATH',
526  help='path to templates directory')
527  args = parser.parse_args()
528  gen_jit_dispatch(args.declarations, args.out, args.template_path)
529 
530 
531 if __name__ == '__main__':
532  main()