Caffe2 - Python API
A deep learning, cross platform ML framework
gen_variable_type.py
1 # Generates VariableType.h/cpp
2 #
3 # VariableType is a subclass of at::Type that provides the binding code
4 # necessary to provide a differentiable version of ATen operators. There are a
5 # number of different things we could mean:
6 #
7 # - Given a non-differentiable forward implementation, we might
8 # directly associate it with a backward implementation to make
9 # it differentiable. This is the common case.
10 #
11 # - Some functions don't need a backwards implementation, because
12 # backpropagation will never propagate beyond them. There are a
13 # number of different reasons why this may be the case:
14 #
15 # - The function has no differentiable inputs
16 # - The function's output is not differentiable
17 # - The function has no data dependency on its input
18 #
19 # - Some function don't need a backwards implementation because they
20 # are implemented as a composition of other (differentiable) ATen
21 # functions. These are dispatched directly to the Type superclass,
22 # which will in turn dispatch back to VariableType for its
23 # differentiable subcomponents.
24 #
25 from __future__ import print_function
26 import os
27 import sys
28 from .utils import CodeTemplate, nested_dict, write, uninplace_api_name
29 from .gen_autograd import VIEW_FUNCTIONS
30 from .gen_autograd_functions import uses_single_grad
31 
32 # These functions are written manually in templates/VariableType.cpp
33 MANUAL_IMPLEMENTATIONS = {
34  'resize_', 'resize_as_', 'detach', 'detach_', 's_copy_', '_s_copy_from'
35 }
36 
37 # These functions we don't want to record for tracing, because we always want
38 # to trace their constituent parts. This is a temporary hack in lieue
39 # of proper scopes, where subsequent compilation passes can ask for the unfolding
40 # on demand. Only concrete ATen methods can be disabled this way; it will have
41 # NO EFFECT otherwise.
42 DONT_RECORD_TRACE = {
43  'convolution', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d',
44  'conv_transpose2d', 'conv_transpose3d', 'lstm_cell', 'gru_cell',
45  'rnn_tanh_cell', 'rnn_relu_cell', 'linear',
46  # FIXME: figure out a better way when we support sparse tensors in jit
47  '_coalesced_',
48 }
49 
50 # These functions have their names recorded under trace renamed,
51 RENAME_TRACE = {
52  'zero': 'zeros_like',
53  'fill': 'full_like',
54 }
55 
56 # (declaration name, argument name) -> attribute name
57 RENAME_ATTRIBUTES = {
58  ('fill_', 'value'): 'fill_value'
59 }
60 
61 # These functions are not worth profiling because they are very cheap and may
62 # be called very often.
63 DONT_PROFILE = {
64  'data_ptr', 'get_device', 'is_contiguous', 'is_cuda', 'is_distributed',
65  'is_same_size', 'is_set_to', 'is_signed', 'is_sparse', 'numel',
66  'size', 'storage_offset', 'stride',
67 }
68 
69 # We don't set or modify grad_fn on these methods. Generally, they return
70 # tensors that have requires_grad=False. In-place functions listed here will
71 # not examine or modify requires_grad or grad_fn.
72 DONT_REQUIRE_DERIVATIVE = {
73  # These only depend on the input Tensor's shape and device, not the data
74  'ones_like', 'zeros_like', 'rand_like', 'randn_like',
75  # These are only implemented on integral types
76  '__and__', '__iand__', '__ilshift__', '__ior__', '__irshift__', '__ixor__',
77  '__lshift__', '__or__', '__rshift__', '__xor__',
78  # This is an unsafe method that is meant to be out of reach of autograd.
79  '_coalesced_',
80 }
81 
82 # NOTE [ Invariant: TensorImpl and Storage Pointer Equality ]
83 #
84 # When a function modifies its input tensors (via inplace or out-variants),
85 # it should never change the the input tensors' underlying c10::TensorImpl pointers
86 # or c10::Storage pointers.
87 #
88 # The following code templates implement the checks for this invariant:
89 SAVE_TENSOR_STORAGE = CodeTemplate("""\
90 c10::optional<Storage> ${tensor_name}_storage_saved =
91  ${tensor_name}.has_storage() ? c10::optional<Storage>(${tensor_name}.storage()) : c10::nullopt;
92 """)
93 
94 ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate("""\
95 if (${tensor_name}_storage_saved.has_value())
96  AT_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${tensor_name}.storage()));
97 """)
98 
99 SAVE_TENSORLIST_STORAGE = CodeTemplate("""\
100 std::vector<c10::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
101 for (Tensor tensor : ${tensorlist_name})
102  ${tensorlist_name}_storage_saved.push_back(
103  tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
104 """)
105 
106 ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate("""\
107 for (size_t i=0; i<${tensorlist_name}.size(); i++) {
108  if (${tensorlist_name}_storage_saved[i].has_value())
109  AT_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(${tensorlist_name}[i].storage()));
110 }
111 """)
112 
113 SAVE_TENSOR_IMPL = CodeTemplate("""\
114 c10::intrusive_ptr<TensorImpl> ${tensor_name}_impl_saved;
115 if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr();
116 """)
117 
118 ENFORCE_SAME_TENSOR_IMPL = CodeTemplate("""\
119 if (${tensor_name}_impl_saved) AT_ASSERT(${tensor_name}_impl_saved == ${tensor_name}.getIntrusivePtr());
120 """)
121 
122 SAVE_TENSORLIST_IMPL = CodeTemplate("""\
123 std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
124 for (size_t i=0; i<${tensorlist_name}.size(); i++)
125  if (${tensorlist_name}[i].defined()) ${tensorlist_name}_impl_saved[i] = ${tensorlist_name}[i].getIntrusivePtr();
126 """)
127 
128 ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate("""\
129 for (size_t i=0; i<${tensorlist_name}.size(); i++) {
130  if (${tensorlist_name}_impl_saved[i])
131  AT_ASSERT(${tensorlist_name}_impl_saved[i] == ${tensorlist_name}[i].getIntrusivePtr());
132 }
133 """)
134 
135 # The following list contains functions that we don't enforce the invariant on.
136 DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
137  # These functions are expected to change impl or storage of input tensors
138  '_th_set_', '_cudnn_rnn_flatten_weight',
139 }
140 # END CHECKS FOR [ Invariant: TensorImpl and Storage Pointer Equality ]
141 
142 METHOD_DECLARATION = CodeTemplate("""\
143 ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override;
144 """)
145 
146 METHOD_DEFINITION = CodeTemplate("""\
147 ${return_type} VariableType::${method_prefix_derived}${api_name}(${type_method_formals}) const {
148  ${type_definition_body}
149 }
150 """)
151 
152 UNPACK_TENSOR = CodeTemplate("""\
153 auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""")
154 
155 UNPACK_OPTIONS = CodeTemplate("""\
156 auto ${arg_name}_ = TensorOptions(${arg_name}).is_variable(false);""")
157 
158 DECLARE_GRAD_FN = CodeTemplate("""\
159 std::shared_ptr<${op}> grad_fn;
160 """)
161 
162 SETUP_DERIVATIVE = CodeTemplate("""\
163 if (compute_requires_grad( ${args_with_derivatives} )) {
164  ${setup}
165 }
166 """)
167 
168 ASSIGN_GRAD_FN = CodeTemplate("""\
169 grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteFunction);
170 grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} ));
171 """)
172 
173 CALL_VIA_TYPE = CodeTemplate("""\
174 TypeDefault::${method_prefix_derived}${api_name}(${type_method_args})""")
175 
176 CALL_VIA_DERIVED = CodeTemplate("""\
177 baseType->${method_prefix_derived}${base_name}(${unpacked_args})""")
178 
179 # If the `baseType` operation has return values, we use the `tmp` variable to hold the
180 # values temporarily and pass the values to the return variables outside of the
181 # `at::AutoNonVariableTypeMode` guard block.
182 DISPATCH_TO_NON_VAR_TYPE_WITH_RETURN_VALUES = CodeTemplate("""\
183 auto tmp = ([&]() {
184  at::AutoNonVariableTypeMode non_var_type_mode(true);
185  return ${base_type_call};
186 })();
187 ${return_values} = ${rhs_value};
188 """)
189 
190 DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate("""\
191 {
192  at::AutoNonVariableTypeMode non_var_type_mode(true);
193  ${base_type_call};
194 }
195 """)
196 
197 SET_HISTORY = CodeTemplate("""\
198 ${fn}_history(${differentiable_outputs}, grad_fn);
199 """)
200 
201 CONDITIONAL = CodeTemplate("""\
202 if (${cond}) {
203  ${statements}
204 }
205 """)
206 
207 RECORD_FUNCTION = CodeTemplate("""\
208 profiler::RecordFunction profiler("${name}", Function::peek_at_next_sequence_nr());""")
209 
210 SELECT = CodeTemplate("""\
211 if (${cond}) {
212  ${true}
213 } else {
214  ${false}
215 }
216 """)
217 
218 OP_NAME = CodeTemplate("""\
219 op_name = jit::Symbol::fromQualString("aten::${trace_name}");
220 """)
221 
222 PRE_RECORD_TRACE = CodeTemplate("""\
223 torch::jit::Node* node = nullptr;
224 std::shared_ptr<jit::tracer::TracingState> tracer_state;
225 if (jit::tracer::isTracing()) {
226  tracer_state = jit::tracer::getTracingState();
227  at::Symbol op_name;
228  ${set_op_name}
229  node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
230  jit::tracer::recordSourceLocation(node);
231  ${add_trace_inputs}
232  tracer_state->graph->insertNode(node);
233  ${inplace_guard}
234  jit::tracer::setTracingState(nullptr);
235 }
236 """)
237 
238 INPLACE_GUARD = CodeTemplate("""\
239 jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input});
240 """)
241 
242 ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""")
243 
244 POST_RECORD_TRACE = CodeTemplate("""\
245 if (tracer_state) {
246  jit::tracer::setTracingState(std::move(tracer_state));
247  ${add_trace_outputs}
248 }
249 """)
250 
251 RUN_ONLY_IN_DEBUG_MODE = CodeTemplate("""\
252 #ifndef NDEBUG
253 ${statements}
254 #endif
255 """)
256 
257 
258 FACTORY_FUNCTION_NAMES = None
259 
260 
261 def find_factory_functions(declarations):
262  global FACTORY_FUNCTION_NAMES
263  FACTORY_FUNCTION_NAMES = set()
264 
265  for declaration in declarations:
266  if declaration['is_factory_method']:
267  FACTORY_FUNCTION_NAMES.add(declaration['api_name'])
268 
269 
270 def should_trace(declaration):
271  # Operations involving Storage or Type are not traceable at the moment
272  if any(arg['simple_type'] in {'Storage', 'Type'} for arg in declaration['arguments']):
273  return False
274  # We can't trace functions which don't have any Tensor or TensorList returns
275  if 'Tensor' not in declaration['return_type']:
276  return False
277  name = declaration['name']
278  base_name = name[:-1] if declaration['inplace'] else name[:-4] if name.endswith('_out') else name
279  if base_name in DONT_RECORD_TRACE or name in DONT_RECORD_TRACE:
280  return False
281  return True
282 
283 
284 def is_out_overload(declaration):
285  return declaration['api_name'].endswith('_out')
286 
287 
288 def format_postrecord_trace(declaration):
289  # For outplacing ops, *_out overloads require special handling to move the
290  # output *argument* to a return value
291  if is_out_overload(declaration):
292  output_names_outplace = [arg['name'] for arg in declaration['arguments'] if arg.get('output', False)]
293  output_names_inplace = [r['name'] for r in declaration['returns']]
294 
295  # Code size optimization: the common case is that the return value is
296  # the same for both variants
297  if output_names_outplace == output_names_inplace:
298  outputs = ['jit::tracer::addOutput(node, {});'.format(n) for n in output_names_outplace]
299  return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
300 
301  local = {}
302  local['cond'] = 'force_outplace'
303  local['true'] = ['jit::tracer::addOutput(node, {});'.format(n) for n in output_names_outplace]
304  local['false'] = ['jit::tracer::addOutput(node, {});'.format(n) for n in output_names_inplace]
305  selection = SELECT.substitute(local)
306  return POST_RECORD_TRACE.substitute(add_trace_outputs=selection)
307 
308  output_names = [r['name'] for r in declaration['returns']]
309  outputs = ['jit::tracer::addOutput(node, {});'.format(n) for n in output_names]
310  return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
311 
312 
313 def format_trace_op_name(declaration):
314  is_inplace = declaration['api_name'] != uninplace_api_name(declaration['api_name'])
315 
316  if not is_inplace or is_out_overload(declaration):
317  # special case for *_out functions: the in-place and out-of-place ops
318  # are overloaded with the same name in the JIT
319  trace_name = uninplace_api_name(declaration['api_name'])
320  trace_name = RENAME_TRACE.get(trace_name, trace_name)
321  return OP_NAME.substitute(trace_name=trace_name)
322 
323  # otherwise, this is an in-place op and we need to emit both in- and
324  # out-of-place versions
325  outplace_trace_name = uninplace_api_name(declaration['api_name'])
326  inplace_trace_name = declaration['api_name']
327  outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name)
328  inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name)
329 
330  select_params = {}
331  select_params['cond'] = 'tracer_state->force_outplace'
332  select_params['true'] = OP_NAME.substitute(trace_name=outplace_trace_name)
333  select_params['false'] = OP_NAME.substitute(trace_name=inplace_trace_name)
334 
335  return SELECT.substitute(select_params)
336 
337 
338 def format_trace_inputs(declaration):
339  def dispatch_trace_input(arg_spec):
340  name, value, simple_type, nullable = arg_spec
341  # XXX: For arg that have type of Tensor?[], tracer will pass allow_undefined to addInputs
342  if simple_type == 'TensorList' and nullable:
343  return '''jit::tracer::addInputs(node, "{}", {}, {});'''.format(name, value, "true")
344  else:
345  return ADD_TRACE_INPUT.substitute(name=name, input=value)
346 
347  trace_inputs = declaration['arguments']
348 
349  if is_out_overload(declaration):
350  # *_out functions take the result as a first argument, but they are the
351  # last argument in the JIT schema.
352  out_input = trace_inputs[0]
353  trace_inputs = trace_inputs[1:]
354 
355  trace_input_spec = [(i['name'], i['name'], i['simple_type'], i.get('is_nullable')) for i in trace_inputs]
356 
357  trace_inputs = \
358  '\n'.join(dispatch_trace_input(arg_spec) for arg_spec in trace_input_spec)
359 
360  if is_out_overload(declaration):
361  # for *_out functions, handle the result argument differently for inplace/outplace.
362  # For inplace: just add the input to the end to confirm with the JIT schema
363  inplace = ADD_TRACE_INPUT.substitute(name=out_input['name'], input=out_input['name'])
364 
365  # for outplace: do nothing, except if the declaration is a factory.
366  # Factories are a bit special because their out-of-place overloads
367  # take an extra TensorOptions argument, which is missing in the _out function
368  trace_name = uninplace_api_name(declaration['api_name'])
369  has_factory_name = trace_name in FACTORY_FUNCTION_NAMES
370  if has_factory_name:
371  outplace = ADD_TRACE_INPUT.substitute(name='out', input='out.options()')
372  else:
373  outplace = ''
374 
375  trace_inputs += '\n'
376  trace_inputs += SELECT.substitute(
377  cond='tracer_state->force_outplace', true=outplace, false=inplace)
378 
379  return trace_inputs
380 
381 
382 def format_prerecord_trace(declaration):
383  local = {}
384  is_inplace = declaration['api_name'] != uninplace_api_name(declaration['api_name'])
385 
386  local['set_op_name'] = format_trace_op_name(declaration)
387  local['add_trace_inputs'] = format_trace_inputs(declaration)
388 
389  local['inplace_guard'] = ''
390  if is_inplace:
391  local['inplace_guard'] = INPLACE_GUARD.substitute(
392  name=declaration['api_name'],
393  mutable_input=declaration['arguments'][0]['name'])
394 
395  return PRE_RECORD_TRACE.substitute(local)
396 
397 
398 def format_trace(declaration):
399  return (format_prerecord_trace(declaration), format_postrecord_trace(declaration))
400 
401 
402 def gen_variable_type(out, aten_declarations, template_path):
403  """VariableType.h and VariableType.cpp body
404 
405  This is the at::Type subclass for differentiable tensors. The
406  implementation of each function dispatches to the base tensor type to
407  compute the output. The grad_fn is attached to differentiable functions.
408  """
409 
410  # WARNING: this function call modifies global mutable state
411  find_factory_functions(aten_declarations)
412 
413  aten_declarations = list(sorted(aten_declarations, key=lambda decl: decl['name']))
414 
415  gen_variable_type_shard(out, aten_declarations, template_path, None, True)
416 
417  # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
418  # template regarding sharding of the generated files.
419  num_shards = 5
420  shards = [[] for _ in range(num_shards)]
421 
422  # functions are assigned arbitrarily but stably to a file based on hash
423  for decl in aten_declarations:
424  x = sum(ord(c) for c in decl['name']) % num_shards
425  shards[x].append(decl)
426 
427  for i, shard in enumerate(shards):
428  gen_variable_type_shard(out, shard, template_path, '_%d' % i, False)
429  gen_variable_type_shard(out, aten_declarations, template_path, 'Everything', False)
430 
431 
432 def gen_variable_type_shard(out, aten_declarations, template_path, suffix, header):
433  VARIABLE_TYPE_H = CodeTemplate.from_file(template_path + '/VariableType.h')
434  VARIABLE_TYPE_CPP = CodeTemplate.from_file(template_path + '/VariableType.cpp')
435 
436  type_declarations = []
437  type_definitions = []
438 
439  for declaration in aten_declarations:
440  # Factory methods usually do not appear in `VariableType` at all, since they
441  # don't dispatch via `Type`; except in the case where the implementation is 'abstract'
442  # in which case they do!
443  if declaration['is_factory_method']:
444  continue
445  type_declarations.append(METHOD_DECLARATION.substitute(declaration))
446  if declaration['name'] not in MANUAL_IMPLEMENTATIONS:
447  type_definitions.append(emit_method_definition(declaration))
448 
449  env = {
450  'type_derived_method_declarations': type_declarations,
451  'type_derived_method_definitions': type_definitions,
452  }
453  if header:
454  write(out, 'VariableType.h', VARIABLE_TYPE_H, env)
455  else:
456  write(out, 'VariableType%s.cpp' % suffix, VARIABLE_TYPE_CPP, env)
457 
458 
459 def emit_method_definition(declaration):
460  body = emit_body(declaration)
461  return METHOD_DEFINITION.substitute(declaration, type_definition_body=body)
462 
463 
464 def emit_body(declaration):
465  strategy = dispatch_strategy(declaration)
466 
467  arguments = declaration['arguments']
468  returns = declaration['returns']
469  func = declaration['derivative']
470  name = declaration['name']
471  inplace = declaration['inplace']
472  is_out_fn = name.endswith('_out')
473  modifies_arguments = inplace or is_out_fn
474  returns_void = len(returns) == 1 and returns[0]['type'] == 'void'
475 
476  base_name = name[:-1] if inplace else name[:-4] if is_out_fn else name
477  view_info = VIEW_FUNCTIONS.get(base_name, None)
478 
479  # These exclude things like BoolTensor, int64_t, and Scalar
480  def is_differentiable(arg):
481  if 'TensorOptions' in arg['type']:
482  return False
483  if 'Tensor' not in arg['type']:
484  return False
485  if arg['dynamic_type'] in {'IndexTensor', 'BoolTensor'}:
486  # TODO: Enable this after native_functions.yaml schema unification.
487  # These are necessary for legacy code and should be
488  # used by legacy code only!
489  # assert name.startswith('_th_'), \
490  # "IndexTensor and BoolTensor are restricted to legacy _th_ functions only.
491  return False
492  return True
493 
494  def find_args_with_derivatives(differentiable_inputs):
495  """Find arguments that have derivative definitions"""
496  if func is None:
497  return differentiable_inputs
498  names = set(name for d in func['derivatives'] for name in d['var_names'])
499  differentiable = [arg for arg in differentiable_inputs if arg['name'] in names]
500  if len(differentiable) != len(names):
501  missing = names - set(arg['name'] for arg in differentiable)
502  raise RuntimeError('Missing arguments for derivatives: {} in {}'.format(missing, func['name']))
503  return differentiable
504 
505  inputs = [arg for arg in arguments if not arg.get('output', False)]
506  differentiable_inputs = list(filter(is_differentiable, inputs))
507  args_with_derivatives = find_args_with_derivatives(differentiable_inputs)
508  not_differentiable_args_names = func['not_differentiable_args_names'] if func else []
509  candidate_differentiable_outputs = list(filter(is_differentiable, returns))
510 
511  if func is not None and func.get('output_differentiability') is not None:
512  differentiable_outputs = []
513  output_differentiability = func.get('output_differentiability')
514  for differentiable, output in zip(output_differentiability, returns):
515  if differentiable:
516  differentiable_outputs.append(output)
517  elif uses_single_grad(func):
518  differentiable_outputs = candidate_differentiable_outputs[:1]
519  else:
520  differentiable_outputs = candidate_differentiable_outputs
521 
522  requires_derivative = (
523  base_name not in DONT_REQUIRE_DERIVATIVE and name not in DONT_REQUIRE_DERIVATIVE and
524  len(differentiable_inputs) > 0 and len(differentiable_outputs) > 0 and
525  strategy == 'use_derived')
526 
527  if func is not None and not requires_derivative:
528  print('WARNING: derivative ignored for {}'.format(name), file=sys.stderr)
529 
530  def emit_save_inputs():
531  setup = []
532  if func is None:
533  return setup
534 
535  has_tensorlist_arg = any(arg['type'] == 'TensorList' for arg in func['args_with_derivatives'])
536 
537  # We don't want to save tensors if we know that they will never be used
538  # when computing the derivative, so we add guards to those statements
539  def guard_for(arg):
540  # It's hard to determine the edge offset if we have TensorLists
541  if has_tensorlist_arg:
542  return None
543 
544  # Empirical evaluation of the cases where we insert those guards in
545  # backward show that they are somewhat useless. E.g. there's no need
546  # to guard on some values captured from forward, because they had to
547  # require_grad if the backward function even gets executed. I don't
548  # have any good ideas for detecting those cases, so I simply disabled the
549  # checks.
550  if 'backward' in func['name']:
551  return None
552 
553  # If there's a single derivative we could compute, we already have
554  # a requires_grad check that is sufficient
555  if len(func['args_with_derivatives']) <= 1:
556  return None
557 
558  # We really only care about trimming down the amount of tensors we save
559  if arg['type'] != 'Tensor':
560  return None
561 
562  # We want to emit simple guards, so we only allow that if checking one
563  # input is enough to determine whether we need that value
564  used_in = [d for d in func['derivatives'] if arg in d['saved_inputs']]
565  assert len(used_in) > 0
566  if len(used_in) != 1:
567  return None
568  derivative = used_in[0]
569  if len(derivative['var_names']) != 1:
570  return None
571  derivative_var_name = derivative['var_names'][0]
572 
573  # Figure out the offset of the edge that uses this variable
574  for edge_off, arg in enumerate(func['args_with_derivatives']):
575  if arg['name'] == derivative_var_name:
576  break
577  else:
578  assert False
579 
580  return 'grad_fn->should_compute_output({})'.format(edge_off)
581 
582  setup.extend(save_variables(func['saved_inputs'], False, guard_for))
583  for arg in func['args_with_derivatives']:
584  if arg['type'] == 'TensorList':
585  setup.append("grad_fn->{}_size_ = {}.size();".format(arg['name'], arg['name']))
586 
587  return setup
588 
589  def setup_derivative(differentiable_inputs):
590 
591  env = {}
592  env['args_with_derivatives'] = reference_args(args_with_derivatives)
593  env['op'] = func['op'] if func is not None else 'NotImplemented'
594  env['op_ctor'] = '' if func is not None else '"{}"'.format(declaration['api_name'])
595 
596  if is_out_fn:
597  setup = ['throw_error_out_requires_grad("{}");'.format(base_name)]
598  body = []
599  body.append(DECLARE_GRAD_FN.substitute(op='Function'))
600  body.append(SETUP_DERIVATIVE.substitute(
601  setup=setup,
602  args_with_derivatives=reference_args(differentiable_inputs)))
603  body.append(SETUP_DERIVATIVE.substitute(
604  setup=setup,
605  args_with_derivatives=reference_args(differentiable_outputs)))
606  return body
607 
608  setup = []
609  setup.extend(ASSIGN_GRAD_FN.substitute(env).split('\n'))
610  setup.extend(emit_save_inputs())
611 
612  body = []
613  body.extend(emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives))
614  body.append(DECLARE_GRAD_FN.substitute(env))
615  body.append(SETUP_DERIVATIVE.substitute(env, setup=setup))
616  return body
617 
618  def emit_check_no_requires_grad(tensor_args, args_with_derivatives):
619  """Checks that arguments without derivatives don't require grad"""
620  body = []
621  for arg in tensor_args:
622  if arg in args_with_derivatives:
623  continue
624  name = arg['name']
625  if name in not_differentiable_args_names:
626  continue
627  if name == 'output':
628  # Double-backwards definitions sometimes take in 'input' and
629  # 'output', but only define the derivative for input.
630  continue
631  if arg['dynamic_type'] in {'IndexTensor', 'BoolTensor'}:
632  continue
633  body.append('check_no_requires_grad({}, "{}");'.format(name, name))
634  return body
635 
636  def save_variables(saved_variables, is_output, guard_for=lambda name: None):
637  # assign the saved variables to the generated grad_fn
638  stmts = []
639  for arg in saved_variables:
640  name = arg['name']
641  expr = arg.get('expr', arg['name'])
642  if arg['type'] == 'Tensor' or (is_output and arg['type'] == 'Scalar'):
643  name += '_'
644  var = arg['name']
645  if var == 'self' and inplace:
646  var = 'self.clone()'
647  assert not is_output
648  if inplace and is_output:
649  var = 'self'
650  expr = 'SavedVariable({}, {})'.format(var, str(is_output).lower())
651  elif arg['type'] == 'TensorList':
652  name += '_'
653  expr = 'make_saved_variable_list({})'.format(arg['name'])
654  elif arg['type'] == 'IntArrayRef':
655  expr = expr + ".vec()"
656  guard = guard_for(arg)
657  if guard is None:
658  stmts.append('grad_fn->{} = {};'.format(name, expr))
659  else:
660  stmts.append('if ({}) {{'.format(guard))
661  stmts.append(' grad_fn->{} = {};'.format(name, expr))
662  stmts.append('}')
663  return stmts
664 
665  def reference_args(args):
666  res = []
667  for arg in args:
668  if arg['type'] == 'SparseTensorRef':
669  res.append('{}.tref'.format(arg['name']))
670  else:
671  res.append(arg['name'])
672  return res
673 
674  def emit_record_trace(env):
675  if not should_trace(declaration):
676  return ('', '')
677  return format_trace(declaration)
678 
679  def declare_returned_variables():
680  if modifies_arguments:
681  return ''
682  if len(declaration['returns']) == 1:
683  return ''
684  # TODO: this will be ugly
685  names = [ret['type'] + ' ' + ret['name'] + ';' for ret in declaration['returns']]
686  return '\n'.join(names)
687 
688  def wrap_output(call):
689  # Returns a 2-tuple `(wrapped_call, extra_wrapping_stmts)`, where
690  # `wrapped_call` is to drop-in replace `call`, and
691  # `extra_wrapping_stmts` is a list of extra statements to run after
692  # `call`.
693  if 'Tensor' not in declaration['return_type']:
694  return call, []
695  elif view_info is not None:
696  # See NOTE [ Autograd View Variables ] in variable.h for details.
697  differentiable_output_vars = {r['name'] for r in differentiable_outputs}
698  tensor_output_vars = {r['name'] for r in returns if 'Tensor' in r['type']}
699  if not isinstance(view_info, dict):
700  if len(differentiable_output_vars) == len(tensor_output_vars):
701  # all outputs are differentiable
702  return 'as_view({}, {}, true)'.format(view_info, call), []
703  elif len(differentiable_output_vars) == 0:
704  # no output is differentiable
705  return 'as_view({}, {}, false)'.format(view_info, call), []
706  else:
707  # some of the outputs are differentiable
708  # need to expand to dict mode, i.e., one entry per output
709  base_name = view_info
710  view_info_dict = {}
711  for i, return_info in enumerate(returns):
712  if 'Tensor' in return_info['type']:
713  view_info_dict[i] = base_name
714  else:
715  view_info_dict = view_info
716 
717  def wrap_view_single(output_var, base_var):
718  fmt = '{output_var} = as_view({base_var}, {output_var}, {is_differentiable});'
719  if output_var in differentiable_output_vars:
720  # If `GradMode::is_enabled()` is False, this is a
721  # non-differentiable view. Gradients should not flow through.
722  is_differentiable = 'true'
723  else:
724  # This output is non-differentiable, so it is a
725  # non-differentiable view. Gradients should not flow through.
726  is_differentiable = 'false'
727  return fmt.format(output_var=output_var, base_var=base_var,
728  is_differentiable=is_differentiable)
729 
730  extra_wrapping_stmts = []
731  for output_idx, return_info in enumerate(returns):
732  if 'Tensor' not in return_info['type']:
733  assert output_idx not in view_info_dict, 'Can not wrap non-Tensor output as a view'
734  continue
735  output_var = return_info['name']
736  if output_idx in view_info_dict:
737  stmt = wrap_view_single(output_var, view_info_dict[output_idx])
738  elif 'Tensor' in return_info['type']:
739  stmt = '{output_var} = as_variable({output_var});'.format(output_var=output_var)
740  extra_wrapping_stmts.append(stmt)
741  return call, extra_wrapping_stmts
742  else:
743  return 'as_variable({})'.format(call), []
744 
745  def enforce_same_tensorimpl_and_storage(env, call):
746  save_ptrs_stmts = []
747  enforce_same_ptrs_stmts = []
748  if declaration['name'] not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
749  for arg in env.get('unpacked_args', []):
750  simple_type = env['unpacked_args_simple_type'][arg]
751  if simple_type == 'TensorList':
752  save_ptrs_stmts += [SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
753  SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
754  enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
755  ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
756  elif simple_type == 'Tensor':
757  save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
758  SAVE_TENSOR_IMPL.substitute(tensor_name=arg)]
759  enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSOR_STORAGE.substitute(tensor_name=arg),
760  ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)]
761  assert (save_ptrs_stmts and enforce_same_ptrs_stmts) or (not save_ptrs_stmts and not enforce_same_ptrs_stmts)
762  if save_ptrs_stmts and enforce_same_ptrs_stmts:
763  call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=save_ptrs_stmts) + \
764  call + \
765  RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts)
766  return call
767 
768  def emit_call(env):
769  combined = nested_dict(env, declaration)
770  extra_wrapping_stmts = []
771  if strategy == 'use_derived':
772  # We only care about adding `at::AutoNonVariableTypeMode` guard for `baseType` dispatch
773  # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
774  # the baseType operations still dispatch to non-Variable type, even if the arguments passed
775  # in are now Variables.
776  # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details.
777  base_type_call = CALL_VIA_DERIVED.substitute(combined)
778  if not modifies_arguments and not returns_void:
779  rhs_value, extra_wrapping_stmts = wrap_output('tmp')
780  call = DISPATCH_TO_NON_VAR_TYPE_WITH_RETURN_VALUES.substitute(
781  base_type_call=base_type_call,
782  return_values=tie_return_values(),
783  rhs_value=rhs_value)
784  else:
785  call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
786  base_type_call=base_type_call)
787  else:
788  call = CALL_VIA_TYPE.substitute(declaration)
789  if not modifies_arguments and not returns_void:
790  call = '{} = {}'.format(tie_return_values(), call)
791  call = call + ';'
792  for stmt in extra_wrapping_stmts:
793  call += '\n' + stmt
794  call = enforce_same_tensorimpl_and_storage(env, call)
795  return call
796 
797  def tie_return_values():
798  if len(declaration['returns']) == 1:
799  return 'auto {}'.format(declaration['returns'][0]['name'])
800  names = [ret['name'] for ret in declaration['returns']]
801  return 'std::tie({})'.format(', '.join(names))
802 
803  def get_return_value():
804  if inplace:
805  return 'self'
806  if is_out_fn:
807  return_names = [arg['name'] for arg in arguments
808  if arg.get('output', False)]
809  if len(return_names) == 1:
810  return return_names[0]
811  return 'std::forward_as_tuple({})'.format(', '.join(return_names))
812 
813  returns = declaration['returns']
814  if len(returns) == 1:
815  return returns[0]['name']
816  moved = ['std::move({})'.format(r['name']) for r in returns]
817  return 'std::make_tuple({})'.format(', '.join(moved))
818 
819  def emit_history():
820  fn = 'rebase' if modifies_arguments and view_info is None else 'set'
821  output_names = [r['name'] for r in differentiable_outputs]
822  # TODO: flatten allocates a std::vector, which could be expensive
823  outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(outs=output_names)
824  return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)
825 
826  def emit_save_outputs():
827  if is_out_fn:
828  # out functions don't currently support differentiation
829  return ''
830  func = declaration['derivative']
831  if func is not None:
832  stmts = save_variables(func['saved_outputs'], True)
833  if len(stmts) == 0:
834  return ''
835  return CONDITIONAL.substitute(cond='grad_fn', statements=stmts)
836  return ''
837 
838  def emit_check_inplace():
839  if not inplace:
840  return []
841  return ['check_inplace({});'.format(arg['name']) for arg in differentiable_outputs]
842 
843  def emit_increment_version():
844  if not modifies_arguments:
845  return []
846  return ['increment_version({});'.format(arg['name']) for arg in differentiable_outputs]
847 
848  env = {}
849  combined = nested_dict(env, declaration)
850 
851  body = []
852  if base_name not in DONT_PROFILE:
853  body.append(RECORD_FUNCTION.substitute(combined))
854  if strategy != 'use_type':
855  body.extend(unpack_args(env, declaration))
856  if requires_derivative:
857  body.extend(emit_check_inplace())
858  body.extend(setup_derivative(differentiable_inputs))
859  body.append(declare_returned_variables())
860 
861  pre_record_trace, post_record_trace = emit_record_trace(env)
862 
863  body.append(pre_record_trace)
864  body.append(emit_call(env))
865  if requires_derivative:
866  # set_flags has to appear after version_counter, because rebase_history
867  # requires that the counter is incremented before it is called
868  body.extend(emit_increment_version())
869  body.append(emit_history())
870  # post_record_trace must appear before save_outputs so that saved outputs
871  # have their tracing state saved (that is setup by recordTrace)
872  body.append(post_record_trace)
873  if requires_derivative:
874  body.append(emit_save_outputs())
875  if not returns_void:
876  body.append('return {};'.format(get_return_value()))
877  return body
878 
879 
880 def unpack_args(env, declaration):
881  def requires_unpack(arg):
882  return 'Tensor' in arg['dynamic_type']
883 
884  body = []
885  unpacked_args = []
886  unpacked_args_simple_type = {}
887  for i, arg in enumerate(declaration['arguments']):
888  if not requires_unpack(arg):
889  unpacked_args.append(arg['name'])
890  unpacked_args_simple_type[arg['name']] = arg['simple_type']
891  continue
892 
893  dynamic_type = arg['dynamic_type']
894  if 'TensorOptions' not in dynamic_type:
895  is_nullable = arg.get('is_nullable', False)
896  ref = (not is_nullable) and dynamic_type not in ['TensorList', 'SparseTensorRef']
897  suffix = '_opt' if is_nullable and dynamic_type != 'TensorList' else ''
898 
899  body.append(UNPACK_TENSOR.substitute(
900  arg_name=arg['name'],
901  arg_pos=i,
902  suffix=suffix,
903  ref='&' if ref else '',
904  ))
905  else:
906  # Okay, we are abusing the definition of 'unpack' here a bit,
907  # although it's stll getting the non-variable from the variable
908  # (in this case via TensorOptions rather than Variable/Tensor).
909  body.append(UNPACK_OPTIONS.substitute(arg_name=arg['name']))
910 
911  unpacked_args.append(arg['name'] + '_')
912  unpacked_args_simple_type[arg['name'] + '_'] = arg['simple_type']
913 
914  env['unpacked_args'] = unpacked_args
915  env['unpacked_args_simple_type'] = unpacked_args_simple_type
916  return body
917 
918 
919 def dispatch_strategy(declaration):
920  """How are we going to call the underlying implementation of a
921  declaration? There are two strategies:
922 
923  - use_derived: we want to call the implementation on CPUDoubleType
924  (or a similar, derived Type instance). Because these derived
925  instances deal in Tensors, not Variables (it's a completely different
926  object, so it doesn't dispatch back to VariableType), code on
927  this dispatch path needs to wrap/unwrap tensors. If the
928  derived implementation takes and returns tensors, the
929  implementation is usually differentiable (although we also use
930  the derived dispatch path for non-differentiable functions
931  that we still want to dispatch on the derived Type instance;
932  e.g., size())
933 
934  - use_type: we want to call the implementation on Type, because
935  it is implemented concretely, and the functions it invokes will
936  get dispatched back to VariableType (which will ensure that they
937  are differentiable.)
938  """
939  if (declaration['abstract'] or declaration['requires_tensor'] or
940  declaration['derivative'] is not None):
941  # If the function is abstract (not implemented on at::Type), we must
942  # call the implementation on the derived type with unpacked tensors.
943 
944  # If the function has a derivative specified and is concrete, we could
945  # call either implementation. We prefer the calling the derived
946  # type's implementation with unpacked tensors because it is more
947  # performant in some cases: any internal calls to other ATen functions
948  # won't have the history tracked.
949 
950  # If the function has a type dispatched argument (i.e. is a factory),
951  # we prefer calling the derived type's implementation both because it is
952  # more performant and to ensure factory functions return tensors with _version
953  # of 0 (probably not strictly necessary, but nice to have to keeps versions simple
954  # to understand.
955  return 'use_derived'
956  else:
957  # If the function is concrete (we don't have to override it) and we
958  # didn't declare it in derivatives.yaml, we'll assume that it is
959  # actually implemented out of differentiable functions. (This
960  # assumption might not hold, but then you'll see gradcheck fail.)
961  return 'use_type'
Module caffe2.python.layers.split.