2 To run this file by hand from the root of the PyTorch 5 python -m tools.jit.gen_jit_dispatch \ 6 build/aten/src/ATen/Declarations.yaml \ 10 Where $OUTPUT_DIR is where you would like the files to be 11 generated. In the full build system, OUTPUT_DIR is 12 torch/csrc/jit/generated/ 19 from itertools
import count, combinations, groupby
20 from ..autograd.utils
import CodeTemplate, write, uninplace_api_name
21 from ..autograd.gen_autograd
import load_aten_declarations
22 from collections
import OrderedDict
23 from ..autograd.gen_autograd
import RETURNS_VIEWS_OF_INPUT
41 'std::array<bool,2>':
'bool[2]',
42 'std::array<bool,3>':
'bool[3]',
43 'std::array<bool,4>':
'bool[4]',
49 'TensorList':
'Tensor[]',
53 'std::vector<Tensor>':
'Tensor[]',
54 'IntArrayRef':
'int[]',
59 'ScalarType':
'ScalarType',
60 'ScalarType?':
'ScalarType?',
65 'Generator':
'Generator?',
69 def optional_type_of(arg, typ):
72 if arg.get(
'is_nullable')
and '?' not in typ:
73 if typ ==
'TensorList' or typ ==
'Tensor[]':
76 typ =
'{}?'.format(typ)
84 return arg[
'jit_type']
85 typ = TYPE_MAP[arg[
'simple_type']]
87 typ =
'int[{}]'.format(arg[
'size'])
89 typ = optional_type_of(arg, typ)
96 'Device':
'{}.toDevice()',
97 'Device?':
'{}.toOptional<c10::Device>()',
98 'IntArrayRef':
'{}.toIntList()->elements()',
99 'Layout':
'{}.toLayout()',
100 'Layout?':
'{}.toOptional<c10::Layout>()',
101 'Scalar':
'{}.toScalar()',
102 'Scalar?':
'{}.toOptional<Scalar>()',
103 'ScalarType':
'{}.toScalarType()',
104 'ScalarType?':
'{}.toOptional<ScalarType>()',
105 'Tensor':
'{}.toTensor()',
106 'Tensor?':
'toOptionalTensor({})',
107 'Tensor?[]':
'toListOfOptionalTensor({})',
108 'TensorList':
'{}.toTensorList()->elements()',
109 'bool':
'{}.toBool()',
110 'double':
'{}.toDouble()',
111 'int64_t':
'{}.toInt()',
112 'int64_t?':
'{}.toOptional<int64_t>()',
113 'std::string':
'{}.toString()->string()',
114 'Generator':
'nullptr',
115 'std::array<bool,2>':
'as_bool_array<2>({}.toBoolListRef())',
116 'std::array<bool,3>':
'as_bool_array<3>({}.toBoolListRef())',
117 'std::array<bool,4>':
'as_bool_array<4>({}.toBoolListRef())',
121 def from_ivalue(arg, value):
122 typ = optional_type_of(arg, arg[
'simple_type'])
123 return FROM_IVALUE[typ].format(value)
126 CALL_NAMESPACE = CodeTemplate(
"""\ 127 auto result_ = at::${name}( 131 CALL_METHOD = CodeTemplate(
"""\ 132 auto result_ = (${first}).${name}( 136 CALL_NAMESPACE_WITH_TENSOR_OPTIONS = CodeTemplate(
"""\ 137 const auto options = TensorOptions() 141 auto result_ = torch::${name}(${args_with_tensor_options}); 143 CALL_METHOD_WITH_TENSOR_OPTIONS = CodeTemplate(
"""\ 144 const auto options = TensorOptions() 148 auto result_ = (${first}).${name}(${args_with_tensor_options}); 151 CONSTRUCTOR = CodeTemplate(
"""\ 153 autograd::profiler::RecordFunction record("${name}"); 156 drop(stack, ${num_inputs}); 157 pack(stack, std::move(result_)); 162 OPERATOR = CodeTemplate(
"""\ 170 blacklisted_types = {
'SparseTensorRef',
'Storage',
'void*'}
171 default_only_types = {
'Generator'}
174 def is_jit_arg(i, arg):
175 simple_type = arg[
'simple_type']
176 if simple_type
in blacklisted_types:
178 if simple_type
in default_only_types
and 'default' not in arg:
180 if simple_type ==
'Type':
187 if all(r[
'type'] ==
'void' for r
in decl[
'returns']):
190 arguments = decl[
'arguments']
193 if is_out_variant(decl)
and sum([
not not arg.get(
'output')
for arg
in arguments]) > 1:
196 return ((
'namespace' in decl[
'method_of']
or 'Tensor' in decl[
'method_of'])
and 197 all(is_jit_arg(i, arg)
for i, arg
in enumerate(decl[
'arguments']))
and 198 all(is_jit_arg(i, arg)
for i, arg
in enumerate(decl[
'returns'])))
201 def is_tensor_arg(arg):
202 return arg[
'simple_type']
in {
'Tensor',
'TensorList'}
206 """Returns True for arguments declared as IntArrayRef[k], but False for IntArrayRef.""" 207 return (arg[
'simple_type'] ==
'IntArrayRef')
and (
'size' in arg)
212 return name[:-1]
if decl.get(
'inplace',
False)
else name[:-4]
if name.endswith(
'_out')
else name
216 return base_name(decl)
in RETURNS_VIEWS_OF_INPUT
219 def is_out_variant(decl):
220 return decl[
'name'].endswith(
'_out')
229 def argument_order(decl):
230 return decl.get(
'jit_argument_order')
or list(range(len(decl[
'arguments'])))
233 def gen_jit_dispatch(declarations, out, template_path):
234 REGISTER_ATEN_OPS_CPP = CodeTemplate.from_file(template_path +
'/register_aten_ops.cpp')
238 def get_invocation(decl, args, num_inputs):
241 def pack_arguments(args):
242 return ',\n'.join(args)
243 is_namespace_function =
'namespace' in decl[
'method_of']
244 tensor_options_arg_index = decl.get(
'tensor_options_arg_index',
None)
245 if tensor_options_arg_index
is not None:
246 dtype = args[tensor_options_arg_index]
247 layout = args[tensor_options_arg_index + 1]
248 device = args[tensor_options_arg_index + 2]
249 args_with_tensor_options = args[:tensor_options_arg_index] + \
250 [
'options'] + args[(tensor_options_arg_index + 3):]
251 if is_namespace_function:
252 return CALL_NAMESPACE_WITH_TENSOR_OPTIONS.substitute(
253 name=decl[
'name'], dtype=dtype, layout=layout, device=device,
254 args_with_tensor_options=pack_arguments(args_with_tensor_options))
256 return CALL_METHOD_WITH_TENSOR_OPTIONS.substitute(
257 name=decl[
'name'], dtype=dtype, layout=layout, device=device,
258 args_with_tensor_options=pack_arguments(args_with_tensor_options[1:]),
259 first=args_with_tensor_options[0], num_inputs=num_inputs)
261 if is_namespace_function:
262 return CALL_NAMESPACE.substitute(name=decl[
'name'],
263 args=pack_arguments(args),
264 num_inputs=num_inputs)
266 return CALL_METHOD.substitute(
267 name=decl[
'name'], first=args[0],
268 args=pack_arguments(args[1:]), num_inputs=num_inputs)
270 def requires_lvalue(arg):
271 return 'jit_type' in arg
and arg[
'jit_type']
in {
"Tensor!",
"Tensor(a!)"}
273 def emit_decl_variant(decl):
282 num_inputs = len(decl[
'arguments'])
284 order = argument_order(decl)
285 for i, arg
in enumerate(decl[
'arguments']):
286 value = from_ivalue(arg,
'(std::move(peek(stack, {}, {})))'.format(order[i], num_inputs))
287 if requires_lvalue(arg):
288 lvalues.append(
'auto {} = {};\n'.format(arg[
'name'], value))
290 arguments.append(value)
292 call = get_invocation(decl, arguments, num_inputs)
294 returns = decl[
'returns']
296 constructor = CONSTRUCTOR.substitute(name=decl[
'name'],
298 kw_assignments=kw_assignments,
299 num_inputs=num_inputs,
300 op_capture=op_capture,
309 def sort_decls(jit_decls):
316 args = decl[
'arguments']
318 for i
in range(len(args)):
319 result += (3 ** i) * (1
if args[i][
'simple_type'] ==
'Tensor' else 2)
323 sorted_decls = sorted(jit_decls, key=
lambda decl: decl[
'name'])
324 grouped_decls = [list(g)
for _, g
in 325 groupby(sorted_decls, key=
lambda decl: decl[
'name'])]
326 return [sorted(g, key=declkey)
for g
in grouped_decls]
329 tensor_impl_methods = [{
332 'method_of': [
'Tensor'],
333 'arguments': [{
'name':
'self',
'simple_type':
'Tensor'}],
334 'returns': [{
'name':
'result',
'type':
'int64_t',
'dynamic_type':
'int64_t',
'simple_type':
'int64_t'}],
335 }
for name
in [
'sizes',
'strides',
'dim']]
336 aten_decls = load_aten_declarations(declarations) + tensor_impl_methods
337 jit_decls = [d
for d
in aten_decls
if is_jit_op(d)]
340 def expand_options(decl, i, arg):
341 if arg[
'simple_type'] !=
'TensorOptions':
343 assert decl.get(
'tensor_options_arg_index') != i
344 decl[
'tensor_options_arg_index'] = i
345 tensor_options_expansion = [
351 {
'name':
'dtype',
'simple_type':
'ScalarType'},
353 {
'name':
'layout',
'simple_type':
'Layout'},
355 {
'name':
'device',
'simple_type':
'Device'},
359 tensor_options_expansion[0][
'simple_type'] +=
'?' 360 tensor_options_expansion[1][
'simple_type'] +=
'?' 361 tensor_options_expansion[2][
'simple_type'] +=
'?' 362 tensor_options_expansion[0][
'default'] =
'None' 363 tensor_options_expansion[1][
'default'] =
'None' 364 tensor_options_expansion[2][
'default'] =
'None' 365 if 'default' in arg
and arg[
'default'] ==
'at::kLong':
366 tensor_options_expansion[0][
'default'] =
'long' 367 if 'kwarg_only' in arg
and arg[
'kwarg_only']:
368 tensor_options_expansion[0][
'kwarg_only'] =
True 369 tensor_options_expansion[1][
'kwarg_only'] =
True 370 tensor_options_expansion[2][
'kwarg_only'] =
True 371 return tensor_options_expansion
373 additional_jit_decls = []
375 for decl
in jit_decls:
376 decl[
'arguments'] = [a
for i, arg
in enumerate(decl[
'arguments'])
for a
in expand_options(decl, i, arg)]
380 decl[
'should_match_schema'] =
True 382 decl_copy = copy.deepcopy(decl)
383 for arg
in decl_copy[
'arguments']:
384 if arg[
'simple_type'] ==
'TensorList' and arg.get(
'is_nullable'):
385 arg[
'is_nullable'] =
False 386 decl_copy[
'should_match_schema'] =
False 387 additional_jit_decls.append(decl_copy)
389 jit_decls.extend(additional_jit_decls)
393 jit_decl_groups = sort_decls(jit_decls)
402 shards = [[]
for _
in range(num_shards)]
405 for group
in jit_decl_groups:
406 x = sum(ord(c)
for c
in group[0][
'name']) % num_shards
408 shards[x].append(OPERATOR.substitute(signature=signature(decl, decl[
'should_match_schema']),
409 op=emit_decl_variant(decl)))
411 for i, shard
in enumerate(shards):
413 'constructors': shard,
415 write(out,
'register_aten_ops_%d.cpp' % i, REGISTER_ATEN_OPS_CPP, env)
418 default_map = {
'{}':
'None',
'nullptr':
'None',
'c10::nullopt':
'None'}
421 def annotate_op(decl):
423 if decl.get(
'inplace')
or is_out_variant(decl):
424 first_arg = decl[
'arguments'][0]
425 assert(jit_type_of(first_arg) ==
'Tensor')
426 first_arg[
'jit_type'] =
'Tensor(a!)' 427 first_ret = decl[
'returns'][0]
428 assert(jit_type_of(first_ret) ==
'Tensor')
429 first_ret[
'jit_type'] =
'Tensor(a!)' 430 if is_out_variant(decl):
431 assert(first_arg[
'output'])
435 nargs = len(decl[
'arguments'])
436 decl[
'jit_argument_order'] = [nargs - 1] + list(range(nargs - 1))
438 first_arg = decl[
'arguments'][0]
439 assert jit_type_of(first_arg) ==
'Tensor' 440 first_arg[
'jit_type'] =
'Tensor(a)' 441 first_ret = decl[
'returns'][0]
442 ret_type = jit_type_of(first_ret)
443 if ret_type ==
'Tensor[]':
444 first_ret[
'jit_type'] =
'Tensor(a)[]' 445 elif ret_type ==
'Tensor':
446 first_ret[
'jit_type'] =
'Tensor(a)' 449 def is_kwarg_only(a):
450 return a.get(
'kwarg_only')
or a.get(
'output')
453 def match_signature(decl, constructed_string, should_match_schema):
459 if decl.get(
'matches_jit_signature')
and should_match_schema:
460 assert(constructed_string == decl[
'schema_string']), \
461 decl[
'schema_string'] +
' is flagged as JIT signature compliant' + \
462 ', but does not match the signature ' + constructed_string
463 return decl[
'schema_string']
465 return constructed_string
468 def signature(decl, should_match_schema=True):
471 typ = jit_type_of(arg)
472 decl =
'{} {}'.format(typ, name)
475 default = arg[
'default']
478 default = str(default)
if not isinstance(default, float)
else repr(default)
480 .replace(
'{{',
'[') \
481 .replace(
'}}',
']') \
482 .replace(
'true',
'True') \
483 .replace(
'false',
'False') \
484 .replace(
'Reduction::Mean',
'Mean') \
485 .replace(
'{}',
'None' if is_tensor_arg(arg)
else '[]') \
489 default = default_map.get(default, default)
490 decl =
'{}={}'.format(decl, default)
496 ordered_arguments = sorted(zip(argument_order(decl), decl[
'arguments']))
497 for _, a
in ordered_arguments:
498 if not kwarg_only
and is_kwarg_only(a):
501 args.append(format_arg(a))
503 arg_list =
', '.join(args)
504 if len(decl[
'returns']) == 1:
505 ret_list = jit_type_of(decl[
'returns'][0])
507 if decl[
'returns'][0].get(
'field_name'):
508 ret_list +=
' ' + decl[
'returns'][0][
'field_name']
510 def type_maybe_field(r):
511 return '{} {}'.format(jit_type_of(r), r[
'field_name'])
if 'field_name' in r
else jit_type_of(r)
512 ret_list =
'({})'.format(
', '.join(type_maybe_field(r)
for r
in decl[
'returns']))
513 name = decl[
'name']
if not is_out_variant(decl)
else decl[
'name'][:-4]
514 constructed_string =
'aten::{}({}) -> {}'.format(name, arg_list, ret_list)
515 return match_signature(decl, constructed_string, should_match_schema)
519 parser = argparse.ArgumentParser(
520 description=
'Generate JIT op dispatch')
521 parser.add_argument(
'declarations', metavar=
'DECL',
522 help=
'path to Declarations.yaml')
523 parser.add_argument(
'out', metavar=
'OUT',
524 help=
'path to output directory')
525 parser.add_argument(
'template_path', metavar=
'TEMPLATE_PATH',
526 help=
'path to templates directory')
527 args = parser.parse_args()
528 gen_jit_dispatch(args.declarations, args.out, args.template_path)
531 if __name__ ==
'__main__':