25 from __future__
import print_function
28 from .utils
import CodeTemplate, nested_dict, write, uninplace_api_name
29 from .gen_autograd
import VIEW_FUNCTIONS
30 from .gen_autograd_functions
import uses_single_grad
33 MANUAL_IMPLEMENTATIONS = {
34 'resize_',
'resize_as_',
'detach',
'detach_',
's_copy_',
'_s_copy_from' 43 'convolution',
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
44 'conv_transpose2d',
'conv_transpose3d',
'lstm_cell',
'gru_cell',
45 'rnn_tanh_cell',
'rnn_relu_cell',
'linear',
58 (
'fill_',
'value'):
'fill_value' 64 'data_ptr',
'get_device',
'is_contiguous',
'is_cuda',
'is_distributed',
65 'is_same_size',
'is_set_to',
'is_signed',
'is_sparse',
'numel',
66 'size',
'storage_offset',
'stride',
72 DONT_REQUIRE_DERIVATIVE = {
74 'ones_like',
'zeros_like',
'rand_like',
'randn_like',
76 '__and__',
'__iand__',
'__ilshift__',
'__ior__',
'__irshift__',
'__ixor__',
77 '__lshift__',
'__or__',
'__rshift__',
'__xor__',
89 SAVE_TENSOR_STORAGE = CodeTemplate(
"""\ 90 c10::optional<Storage> ${tensor_name}_storage_saved = 91 ${tensor_name}.has_storage() ? c10::optional<Storage>(${tensor_name}.storage()) : c10::nullopt; 94 ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate(
"""\ 95 if (${tensor_name}_storage_saved.has_value()) 96 AT_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${tensor_name}.storage())); 99 SAVE_TENSORLIST_STORAGE = CodeTemplate(
"""\ 100 std::vector<c10::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size()); 101 for (Tensor tensor : ${tensorlist_name}) 102 ${tensorlist_name}_storage_saved.push_back( 103 tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt); 106 ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate(
"""\ 107 for (size_t i=0; i<${tensorlist_name}.size(); i++) { 108 if (${tensorlist_name}_storage_saved[i].has_value()) 109 AT_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(${tensorlist_name}[i].storage())); 113 SAVE_TENSOR_IMPL = CodeTemplate(
"""\ 114 c10::intrusive_ptr<TensorImpl> ${tensor_name}_impl_saved; 115 if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr(); 118 ENFORCE_SAME_TENSOR_IMPL = CodeTemplate(
"""\ 119 if (${tensor_name}_impl_saved) AT_ASSERT(${tensor_name}_impl_saved == ${tensor_name}.getIntrusivePtr()); 122 SAVE_TENSORLIST_IMPL = CodeTemplate(
"""\ 123 std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size()); 124 for (size_t i=0; i<${tensorlist_name}.size(); i++) 125 if (${tensorlist_name}[i].defined()) ${tensorlist_name}_impl_saved[i] = ${tensorlist_name}[i].getIntrusivePtr(); 128 ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate(
"""\ 129 for (size_t i=0; i<${tensorlist_name}.size(); i++) { 130 if (${tensorlist_name}_impl_saved[i]) 131 AT_ASSERT(${tensorlist_name}_impl_saved[i] == ${tensorlist_name}[i].getIntrusivePtr()); 136 DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
138 '_th_set_',
'_cudnn_rnn_flatten_weight',
142 METHOD_DECLARATION = CodeTemplate(
"""\ 143 ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override; 146 METHOD_DEFINITION = CodeTemplate(
"""\ 147 ${return_type} VariableType::${method_prefix_derived}${api_name}(${type_method_formals}) const { 148 ${type_definition_body} 152 UNPACK_TENSOR = CodeTemplate(
"""\ 153 auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""")
155 UNPACK_OPTIONS = CodeTemplate(
"""\ 156 auto ${arg_name}_ = TensorOptions(${arg_name}).is_variable(false);""")
158 DECLARE_GRAD_FN = CodeTemplate(
"""\ 159 std::shared_ptr<${op}> grad_fn; 162 SETUP_DERIVATIVE = CodeTemplate(
"""\ 163 if (compute_requires_grad( ${args_with_derivatives} )) { 168 ASSIGN_GRAD_FN = CodeTemplate(
"""\ 169 grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteFunction); 170 grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} )); 173 CALL_VIA_TYPE = CodeTemplate(
"""\ 174 TypeDefault::${method_prefix_derived}${api_name}(${type_method_args})""")
176 CALL_VIA_DERIVED = CodeTemplate(
"""\ 177 baseType->${method_prefix_derived}${base_name}(${unpacked_args})""")
182 DISPATCH_TO_NON_VAR_TYPE_WITH_RETURN_VALUES = CodeTemplate(
"""\ 184 at::AutoNonVariableTypeMode non_var_type_mode(true); 185 return ${base_type_call}; 187 ${return_values} = ${rhs_value}; 190 DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate(
"""\ 192 at::AutoNonVariableTypeMode non_var_type_mode(true); 197 SET_HISTORY = CodeTemplate(
"""\ 198 ${fn}_history(${differentiable_outputs}, grad_fn); 201 CONDITIONAL = CodeTemplate(
"""\ 207 RECORD_FUNCTION = CodeTemplate(
"""\ 208 profiler::RecordFunction profiler("${name}", Function::peek_at_next_sequence_nr());""")
210 SELECT = CodeTemplate(
"""\ 218 OP_NAME = CodeTemplate(
"""\ 219 op_name = jit::Symbol::fromQualString("aten::${trace_name}"); 222 PRE_RECORD_TRACE = CodeTemplate(
"""\ 223 torch::jit::Node* node = nullptr; 224 std::shared_ptr<jit::tracer::TracingState> tracer_state; 225 if (jit::tracer::isTracing()) { 226 tracer_state = jit::tracer::getTracingState(); 229 node = tracer_state->graph->create(op_name, /*num_outputs=*/0); 230 jit::tracer::recordSourceLocation(node); 232 tracer_state->graph->insertNode(node); 234 jit::tracer::setTracingState(nullptr); 238 INPLACE_GUARD = CodeTemplate(
"""\ 239 jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input}); 242 ADD_TRACE_INPUT = CodeTemplate(
"""jit::tracer::addInputs(node, "${name}", ${input});""")
244 POST_RECORD_TRACE = CodeTemplate(
"""\ 246 jit::tracer::setTracingState(std::move(tracer_state)); 251 RUN_ONLY_IN_DEBUG_MODE = CodeTemplate(
"""\ 258 FACTORY_FUNCTION_NAMES =
None 261 def find_factory_functions(declarations):
262 global FACTORY_FUNCTION_NAMES
263 FACTORY_FUNCTION_NAMES = set()
265 for declaration
in declarations:
266 if declaration[
'is_factory_method']:
267 FACTORY_FUNCTION_NAMES.add(declaration[
'api_name'])
270 def should_trace(declaration):
272 if any(arg[
'simple_type']
in {
'Storage',
'Type'}
for arg
in declaration[
'arguments']):
275 if 'Tensor' not in declaration[
'return_type']:
277 name = declaration[
'name']
278 base_name = name[:-1]
if declaration[
'inplace']
else name[:-4]
if name.endswith(
'_out')
else name
279 if base_name
in DONT_RECORD_TRACE
or name
in DONT_RECORD_TRACE:
284 def is_out_overload(declaration):
285 return declaration[
'api_name'].endswith(
'_out')
288 def format_postrecord_trace(declaration):
291 if is_out_overload(declaration):
292 output_names_outplace = [arg[
'name']
for arg
in declaration[
'arguments']
if arg.get(
'output',
False)]
293 output_names_inplace = [r[
'name']
for r
in declaration[
'returns']]
297 if output_names_outplace == output_names_inplace:
298 outputs = [
'jit::tracer::addOutput(node, {});'.format(n)
for n
in output_names_outplace]
299 return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
302 local[
'cond'] =
'force_outplace' 303 local[
'true'] = [
'jit::tracer::addOutput(node, {});'.format(n)
for n
in output_names_outplace]
304 local[
'false'] = [
'jit::tracer::addOutput(node, {});'.format(n)
for n
in output_names_inplace]
305 selection = SELECT.substitute(local)
306 return POST_RECORD_TRACE.substitute(add_trace_outputs=selection)
308 output_names = [r[
'name']
for r
in declaration[
'returns']]
309 outputs = [
'jit::tracer::addOutput(node, {});'.format(n)
for n
in output_names]
310 return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
313 def format_trace_op_name(declaration):
314 is_inplace = declaration[
'api_name'] != uninplace_api_name(declaration[
'api_name'])
316 if not is_inplace
or is_out_overload(declaration):
319 trace_name = uninplace_api_name(declaration[
'api_name'])
320 trace_name = RENAME_TRACE.get(trace_name, trace_name)
321 return OP_NAME.substitute(trace_name=trace_name)
325 outplace_trace_name = uninplace_api_name(declaration[
'api_name'])
326 inplace_trace_name = declaration[
'api_name']
327 outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name)
328 inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name)
331 select_params[
'cond'] =
'tracer_state->force_outplace' 332 select_params[
'true'] = OP_NAME.substitute(trace_name=outplace_trace_name)
333 select_params[
'false'] = OP_NAME.substitute(trace_name=inplace_trace_name)
335 return SELECT.substitute(select_params)
338 def format_trace_inputs(declaration):
339 def dispatch_trace_input(arg_spec):
340 name, value, simple_type, nullable = arg_spec
342 if simple_type ==
'TensorList' and nullable:
343 return '''jit::tracer::addInputs(node, "{}", {}, {});'''.format(name, value,
"true")
345 return ADD_TRACE_INPUT.substitute(name=name, input=value)
347 trace_inputs = declaration[
'arguments']
349 if is_out_overload(declaration):
352 out_input = trace_inputs[0]
353 trace_inputs = trace_inputs[1:]
355 trace_input_spec = [(i[
'name'], i[
'name'], i[
'simple_type'], i.get(
'is_nullable'))
for i
in trace_inputs]
358 '\n'.join(dispatch_trace_input(arg_spec)
for arg_spec
in trace_input_spec)
360 if is_out_overload(declaration):
363 inplace = ADD_TRACE_INPUT.substitute(name=out_input[
'name'], input=out_input[
'name'])
368 trace_name = uninplace_api_name(declaration[
'api_name'])
369 has_factory_name = trace_name
in FACTORY_FUNCTION_NAMES
371 outplace = ADD_TRACE_INPUT.substitute(name=
'out', input=
'out.options()')
376 trace_inputs += SELECT.substitute(
377 cond=
'tracer_state->force_outplace', true=outplace, false=inplace)
382 def format_prerecord_trace(declaration):
384 is_inplace = declaration[
'api_name'] != uninplace_api_name(declaration[
'api_name'])
386 local[
'set_op_name'] = format_trace_op_name(declaration)
387 local[
'add_trace_inputs'] = format_trace_inputs(declaration)
389 local[
'inplace_guard'] =
'' 391 local[
'inplace_guard'] = INPLACE_GUARD.substitute(
392 name=declaration[
'api_name'],
393 mutable_input=declaration[
'arguments'][0][
'name'])
395 return PRE_RECORD_TRACE.substitute(local)
398 def format_trace(declaration):
399 return (format_prerecord_trace(declaration), format_postrecord_trace(declaration))
402 def gen_variable_type(out, aten_declarations, template_path):
403 """VariableType.h and VariableType.cpp body 405 This is the at::Type subclass for differentiable tensors. The 406 implementation of each function dispatches to the base tensor type to 407 compute the output. The grad_fn is attached to differentiable functions. 411 find_factory_functions(aten_declarations)
413 aten_declarations = list(sorted(aten_declarations, key=
lambda decl: decl[
'name']))
415 gen_variable_type_shard(out, aten_declarations, template_path,
None,
True)
420 shards = [[]
for _
in range(num_shards)]
423 for decl
in aten_declarations:
424 x = sum(ord(c)
for c
in decl[
'name']) % num_shards
425 shards[x].append(decl)
427 for i, shard
in enumerate(shards):
428 gen_variable_type_shard(out, shard, template_path,
'_%d' % i,
False)
429 gen_variable_type_shard(out, aten_declarations, template_path,
'Everything',
False)
432 def gen_variable_type_shard(out, aten_declarations, template_path, suffix, header):
433 VARIABLE_TYPE_H = CodeTemplate.from_file(template_path +
'/VariableType.h')
434 VARIABLE_TYPE_CPP = CodeTemplate.from_file(template_path +
'/VariableType.cpp')
436 type_declarations = []
437 type_definitions = []
439 for declaration
in aten_declarations:
443 if declaration[
'is_factory_method']:
445 type_declarations.append(METHOD_DECLARATION.substitute(declaration))
446 if declaration[
'name']
not in MANUAL_IMPLEMENTATIONS:
447 type_definitions.append(emit_method_definition(declaration))
450 'type_derived_method_declarations': type_declarations,
451 'type_derived_method_definitions': type_definitions,
454 write(out,
'VariableType.h', VARIABLE_TYPE_H, env)
456 write(out,
'VariableType%s.cpp' % suffix, VARIABLE_TYPE_CPP, env)
459 def emit_method_definition(declaration):
460 body = emit_body(declaration)
461 return METHOD_DEFINITION.substitute(declaration, type_definition_body=body)
464 def emit_body(declaration):
465 strategy = dispatch_strategy(declaration)
467 arguments = declaration[
'arguments']
468 returns = declaration[
'returns']
469 func = declaration[
'derivative']
470 name = declaration[
'name']
471 inplace = declaration[
'inplace']
472 is_out_fn = name.endswith(
'_out')
473 modifies_arguments = inplace
or is_out_fn
474 returns_void = len(returns) == 1
and returns[0][
'type'] ==
'void' 476 base_name = name[:-1]
if inplace
else name[:-4]
if is_out_fn
else name
477 view_info = VIEW_FUNCTIONS.get(base_name,
None)
480 def is_differentiable(arg):
481 if 'TensorOptions' in arg[
'type']:
483 if 'Tensor' not in arg[
'type']:
485 if arg[
'dynamic_type']
in {
'IndexTensor',
'BoolTensor'}:
494 def find_args_with_derivatives(differentiable_inputs):
495 """Find arguments that have derivative definitions""" 497 return differentiable_inputs
498 names = set(name
for d
in func[
'derivatives']
for name
in d[
'var_names'])
499 differentiable = [arg
for arg
in differentiable_inputs
if arg[
'name']
in names]
500 if len(differentiable) != len(names):
501 missing = names - set(arg[
'name']
for arg
in differentiable)
502 raise RuntimeError(
'Missing arguments for derivatives: {} in {}'.format(missing, func[
'name']))
503 return differentiable
505 inputs = [arg
for arg
in arguments
if not arg.get(
'output',
False)]
506 differentiable_inputs = list(filter(is_differentiable, inputs))
507 args_with_derivatives = find_args_with_derivatives(differentiable_inputs)
508 not_differentiable_args_names = func[
'not_differentiable_args_names']
if func
else []
509 candidate_differentiable_outputs = list(filter(is_differentiable, returns))
511 if func
is not None and func.get(
'output_differentiability')
is not None:
512 differentiable_outputs = []
513 output_differentiability = func.get(
'output_differentiability')
514 for differentiable, output
in zip(output_differentiability, returns):
516 differentiable_outputs.append(output)
517 elif uses_single_grad(func):
518 differentiable_outputs = candidate_differentiable_outputs[:1]
520 differentiable_outputs = candidate_differentiable_outputs
522 requires_derivative = (
523 base_name
not in DONT_REQUIRE_DERIVATIVE
and name
not in DONT_REQUIRE_DERIVATIVE
and 524 len(differentiable_inputs) > 0
and len(differentiable_outputs) > 0
and 525 strategy ==
'use_derived')
527 if func
is not None and not requires_derivative:
528 print(
'WARNING: derivative ignored for {}'.format(name), file=sys.stderr)
530 def emit_save_inputs():
535 has_tensorlist_arg = any(arg[
'type'] ==
'TensorList' for arg
in func[
'args_with_derivatives'])
541 if has_tensorlist_arg:
550 if 'backward' in func[
'name']:
555 if len(func[
'args_with_derivatives']) <= 1:
559 if arg[
'type'] !=
'Tensor':
564 used_in = [d
for d
in func[
'derivatives']
if arg
in d[
'saved_inputs']]
565 assert len(used_in) > 0
566 if len(used_in) != 1:
568 derivative = used_in[0]
569 if len(derivative[
'var_names']) != 1:
571 derivative_var_name = derivative[
'var_names'][0]
574 for edge_off, arg
in enumerate(func[
'args_with_derivatives']):
575 if arg[
'name'] == derivative_var_name:
580 return 'grad_fn->should_compute_output({})'.format(edge_off)
582 setup.extend(save_variables(func[
'saved_inputs'],
False, guard_for))
583 for arg
in func[
'args_with_derivatives']:
584 if arg[
'type'] ==
'TensorList':
585 setup.append(
"grad_fn->{}_size_ = {}.size();".format(arg[
'name'], arg[
'name']))
589 def setup_derivative(differentiable_inputs):
592 env[
'args_with_derivatives'] = reference_args(args_with_derivatives)
593 env[
'op'] = func[
'op']
if func
is not None else 'NotImplemented' 594 env[
'op_ctor'] =
'' if func
is not None else '"{}"'.format(declaration[
'api_name'])
597 setup = [
'throw_error_out_requires_grad("{}");'.format(base_name)]
599 body.append(DECLARE_GRAD_FN.substitute(op=
'Function'))
600 body.append(SETUP_DERIVATIVE.substitute(
602 args_with_derivatives=reference_args(differentiable_inputs)))
603 body.append(SETUP_DERIVATIVE.substitute(
605 args_with_derivatives=reference_args(differentiable_outputs)))
609 setup.extend(ASSIGN_GRAD_FN.substitute(env).
split(
'\n'))
610 setup.extend(emit_save_inputs())
613 body.extend(emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives))
614 body.append(DECLARE_GRAD_FN.substitute(env))
615 body.append(SETUP_DERIVATIVE.substitute(env, setup=setup))
618 def emit_check_no_requires_grad(tensor_args, args_with_derivatives):
619 """Checks that arguments without derivatives don't require grad""" 621 for arg
in tensor_args:
622 if arg
in args_with_derivatives:
625 if name
in not_differentiable_args_names:
631 if arg[
'dynamic_type']
in {
'IndexTensor',
'BoolTensor'}:
633 body.append(
'check_no_requires_grad({}, "{}");'.format(name, name))
636 def save_variables(saved_variables, is_output, guard_for=lambda name: None):
639 for arg
in saved_variables:
641 expr = arg.get(
'expr', arg[
'name'])
642 if arg[
'type'] ==
'Tensor' or (is_output
and arg[
'type'] ==
'Scalar'):
645 if var ==
'self' and inplace:
648 if inplace
and is_output:
650 expr =
'SavedVariable({}, {})'.format(var, str(is_output).lower())
651 elif arg[
'type'] ==
'TensorList':
653 expr =
'make_saved_variable_list({})'.format(arg[
'name'])
654 elif arg[
'type'] ==
'IntArrayRef':
655 expr = expr +
".vec()" 656 guard = guard_for(arg)
658 stmts.append(
'grad_fn->{} = {};'.format(name, expr))
660 stmts.append(
'if ({}) {{'.format(guard))
661 stmts.append(
' grad_fn->{} = {};'.format(name, expr))
665 def reference_args(args):
668 if arg[
'type'] ==
'SparseTensorRef':
669 res.append(
'{}.tref'.format(arg[
'name']))
671 res.append(arg[
'name'])
674 def emit_record_trace(env):
675 if not should_trace(declaration):
677 return format_trace(declaration)
679 def declare_returned_variables():
680 if modifies_arguments:
682 if len(declaration[
'returns']) == 1:
685 names = [ret[
'type'] +
' ' + ret[
'name'] +
';' for ret
in declaration[
'returns']]
686 return '\n'.join(names)
688 def wrap_output(call):
693 if 'Tensor' not in declaration[
'return_type']:
695 elif view_info
is not None:
697 differentiable_output_vars = {r[
'name']
for r
in differentiable_outputs}
698 tensor_output_vars = {r[
'name']
for r
in returns
if 'Tensor' in r[
'type']}
699 if not isinstance(view_info, dict):
700 if len(differentiable_output_vars) == len(tensor_output_vars):
702 return 'as_view({}, {}, true)'.format(view_info, call), []
703 elif len(differentiable_output_vars) == 0:
705 return 'as_view({}, {}, false)'.format(view_info, call), []
709 base_name = view_info
711 for i, return_info
in enumerate(returns):
712 if 'Tensor' in return_info[
'type']:
713 view_info_dict[i] = base_name
715 view_info_dict = view_info
717 def wrap_view_single(output_var, base_var):
718 fmt =
'{output_var} = as_view({base_var}, {output_var}, {is_differentiable});' 719 if output_var
in differentiable_output_vars:
722 is_differentiable =
'true' 726 is_differentiable =
'false' 727 return fmt.format(output_var=output_var, base_var=base_var,
728 is_differentiable=is_differentiable)
730 extra_wrapping_stmts = []
731 for output_idx, return_info
in enumerate(returns):
732 if 'Tensor' not in return_info[
'type']:
733 assert output_idx
not in view_info_dict,
'Can not wrap non-Tensor output as a view' 735 output_var = return_info[
'name']
736 if output_idx
in view_info_dict:
737 stmt = wrap_view_single(output_var, view_info_dict[output_idx])
738 elif 'Tensor' in return_info[
'type']:
739 stmt =
'{output_var} = as_variable({output_var});'.format(output_var=output_var)
740 extra_wrapping_stmts.append(stmt)
741 return call, extra_wrapping_stmts
743 return 'as_variable({})'.format(call), []
745 def enforce_same_tensorimpl_and_storage(env, call):
747 enforce_same_ptrs_stmts = []
748 if declaration[
'name']
not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
749 for arg
in env.get(
'unpacked_args', []):
750 simple_type = env[
'unpacked_args_simple_type'][arg]
751 if simple_type ==
'TensorList':
752 save_ptrs_stmts += [SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
753 SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
754 enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
755 ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
756 elif simple_type ==
'Tensor':
757 save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
758 SAVE_TENSOR_IMPL.substitute(tensor_name=arg)]
759 enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSOR_STORAGE.substitute(tensor_name=arg),
760 ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)]
761 assert (save_ptrs_stmts
and enforce_same_ptrs_stmts)
or (
not save_ptrs_stmts
and not enforce_same_ptrs_stmts)
762 if save_ptrs_stmts
and enforce_same_ptrs_stmts:
763 call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=save_ptrs_stmts) + \
765 RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts)
769 combined = nested_dict(env, declaration)
770 extra_wrapping_stmts = []
771 if strategy ==
'use_derived':
777 base_type_call = CALL_VIA_DERIVED.substitute(combined)
778 if not modifies_arguments
and not returns_void:
779 rhs_value, extra_wrapping_stmts = wrap_output(
'tmp')
780 call = DISPATCH_TO_NON_VAR_TYPE_WITH_RETURN_VALUES.substitute(
781 base_type_call=base_type_call,
782 return_values=tie_return_values(),
785 call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
786 base_type_call=base_type_call)
788 call = CALL_VIA_TYPE.substitute(declaration)
789 if not modifies_arguments
and not returns_void:
790 call =
'{} = {}'.format(tie_return_values(), call)
792 for stmt
in extra_wrapping_stmts:
794 call = enforce_same_tensorimpl_and_storage(env, call)
797 def tie_return_values():
798 if len(declaration[
'returns']) == 1:
799 return 'auto {}'.format(declaration[
'returns'][0][
'name'])
800 names = [ret[
'name']
for ret
in declaration[
'returns']]
801 return 'std::tie({})'.format(
', '.join(names))
803 def get_return_value():
807 return_names = [arg[
'name']
for arg
in arguments
808 if arg.get(
'output',
False)]
809 if len(return_names) == 1:
810 return return_names[0]
811 return 'std::forward_as_tuple({})'.format(
', '.join(return_names))
813 returns = declaration[
'returns']
814 if len(returns) == 1:
815 return returns[0][
'name']
816 moved = [
'std::move({})'.format(r[
'name'])
for r
in returns]
817 return 'std::make_tuple({})'.format(
', '.join(moved))
820 fn =
'rebase' if modifies_arguments
and view_info
is None else 'set' 821 output_names = [r[
'name']
for r
in differentiable_outputs]
823 outs = CodeTemplate(
"flatten_tensor_args( ${outs} )").substitute(outs=output_names)
824 return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)
826 def emit_save_outputs():
830 func = declaration[
'derivative']
832 stmts = save_variables(func[
'saved_outputs'],
True)
835 return CONDITIONAL.substitute(cond=
'grad_fn', statements=stmts)
838 def emit_check_inplace():
841 return [
'check_inplace({});'.format(arg[
'name'])
for arg
in differentiable_outputs]
843 def emit_increment_version():
844 if not modifies_arguments:
846 return [
'increment_version({});'.format(arg[
'name'])
for arg
in differentiable_outputs]
849 combined = nested_dict(env, declaration)
852 if base_name
not in DONT_PROFILE:
853 body.append(RECORD_FUNCTION.substitute(combined))
854 if strategy !=
'use_type':
855 body.extend(unpack_args(env, declaration))
856 if requires_derivative:
857 body.extend(emit_check_inplace())
858 body.extend(setup_derivative(differentiable_inputs))
859 body.append(declare_returned_variables())
861 pre_record_trace, post_record_trace = emit_record_trace(env)
863 body.append(pre_record_trace)
864 body.append(emit_call(env))
865 if requires_derivative:
868 body.extend(emit_increment_version())
869 body.append(emit_history())
872 body.append(post_record_trace)
873 if requires_derivative:
874 body.append(emit_save_outputs())
876 body.append(
'return {};'.format(get_return_value()))
880 def unpack_args(env, declaration):
881 def requires_unpack(arg):
882 return 'Tensor' in arg[
'dynamic_type']
886 unpacked_args_simple_type = {}
887 for i, arg
in enumerate(declaration[
'arguments']):
888 if not requires_unpack(arg):
889 unpacked_args.append(arg[
'name'])
890 unpacked_args_simple_type[arg[
'name']] = arg[
'simple_type']
893 dynamic_type = arg[
'dynamic_type']
894 if 'TensorOptions' not in dynamic_type:
895 is_nullable = arg.get(
'is_nullable',
False)
896 ref = (
not is_nullable)
and dynamic_type
not in [
'TensorList',
'SparseTensorRef']
897 suffix =
'_opt' if is_nullable
and dynamic_type !=
'TensorList' else '' 899 body.append(UNPACK_TENSOR.substitute(
900 arg_name=arg[
'name'],
903 ref=
'&' if ref
else '',
909 body.append(UNPACK_OPTIONS.substitute(arg_name=arg[
'name']))
911 unpacked_args.append(arg[
'name'] +
'_')
912 unpacked_args_simple_type[arg[
'name'] +
'_'] = arg[
'simple_type']
914 env[
'unpacked_args'] = unpacked_args
915 env[
'unpacked_args_simple_type'] = unpacked_args_simple_type
919 def dispatch_strategy(declaration):
920 """How are we going to call the underlying implementation of a 921 declaration? There are two strategies: 923 - use_derived: we want to call the implementation on CPUDoubleType 924 (or a similar, derived Type instance). Because these derived 925 instances deal in Tensors, not Variables (it's a completely different 926 object, so it doesn't dispatch back to VariableType), code on 927 this dispatch path needs to wrap/unwrap tensors. If the 928 derived implementation takes and returns tensors, the 929 implementation is usually differentiable (although we also use 930 the derived dispatch path for non-differentiable functions 931 that we still want to dispatch on the derived Type instance; 934 - use_type: we want to call the implementation on Type, because 935 it is implemented concretely, and the functions it invokes will 936 get dispatched back to VariableType (which will ensure that they 939 if (declaration[
'abstract']
or declaration[
'requires_tensor']
or 940 declaration[
'derivative']
is not None):
Module caffe2.python.layers.split.