Caffe2 - C++ API
A deep learning, cross platform ML framework
python_function.cpp
1 #include <torch/csrc/autograd/python_function.h>
2 
3 #include <torch/csrc/python_headers.h>
4 #include <structmember.h>
5 #include <unordered_map>
6 #include <unordered_set>
7 #include <exception>
8 #include <ATen/ATen.h>
9 
10 #include <torch/csrc/THP.h>
11 #include <torch/csrc/autograd/grad_mode.h>
12 #include <torch/csrc/autograd/functions/accumulate_grad.h>
13 #include <torch/csrc/autograd/functions/basic_ops.h>
14 #include <torch/csrc/autograd/functions/utils.h>
15 #include <torch/csrc/autograd/python_cpp_function.h>
16 #include <torch/csrc/autograd/python_hook.h>
17 #include <torch/csrc/autograd/saved_variable.h>
18 #include <torch/csrc/autograd/python_anomaly_mode.h>
19 #include <torch/csrc/jit/tracer.h>
20 #include <torch/csrc/jit/python_tracer.h>
21 #include <torch/csrc/DynamicTypes.h>
22 #include <torch/csrc/utils/auto_gil.h>
23 #include <torch/csrc/Exceptions.h>
24 
25 #include <exception>
26 #include <functional>
27 #include <memory>
28 #include <stdexcept>
29 #include <string>
30 #include <tuple>
31 #include <utility>
32 #include <vector>
33 
34 using namespace torch;
35 using namespace torch::autograd;
36 using namespace torch::jit;
37 using at::Tensor;
38 
39 PyObject *THPFunctionClass = nullptr;
40 
41 #define THPFunction_assert(condition, ...) \
42  if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); }
43 
44 namespace torch { namespace autograd {
45 
46 VariableInfo::VariableInfo(const Variable& var)
47  : type(&var.type())
48  , device(var.device())
49  , size(var.sizes().vec())
50  , requires_grad(var.requires_grad()) {
51 }
52 
53 Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
54  // NB: This will NOT work if we ever get mixed device gradients
55  device_guard.reset_device(device);
56  return at::zeros(size, type->options());
57 }
58 
59 auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
60  AutoGIL gil;
61 
62  THPObjectPtr pyInputs(PyTuple_New(inputs.size()));
63  if (!pyInputs) throw python_error();
64 
65  for (size_t i = 0; i != inputs.size(); ++i) {
66  PyTuple_SET_ITEM(pyInputs.get(), i, THPVariable_Wrap(inputs[i]));
67  }
68 
69  THPObjectPtr r(PyObject_CallMethod(
70  obj, "_do_backward", "OO", pyInputs.get(), Py_True));
71  if (!r) throw python_error();
72 
73  auto num_outputs = PyTuple_GET_SIZE(r.get());
74  tensor_list tensor_results(num_outputs);
75  for (int i = 0; i != num_outputs; ++i) {
76  PyObject* obj = PyTuple_GET_ITEM(r.get(), i);
77  if (obj != Py_None) {
78  if (!THPVariable_Check(obj)) {
79  std::string msg("expected Variable (got '");
80  msg += THPUtils_typename(obj);
81  msg += "')'";
82  throw std::runtime_error(msg);
83  }
84  tensor_results[i] = ((THPVariable*)obj)->cdata.data();
85  }
86  }
87 
88  // XXX: this might get requires_grad wrong - there's no way to figure out
89  // if _do_backward didn't use ctx.saved_tensors and as a result some
90  // Variables might require grad, even if no args do. Unfortunately, this
91  // leads to unexpected error messages ("no nodes require computing gradients"),
92  // but I don't have a better idea. These functions would raise an error
93  // in backward anyway.
94  return wrap_outputs(
95  inputs,
96  std::move(tensor_results),
97  [this](edge_list&& next_edges) {
98  return std::make_shared<Error>(
99  name() + " is not differentiable twice", std::move(next_edges));
100  });
101 }
102 
103 // NOTE: this function is written in a way that assumes it's only called for backward;
104 // it's used by engine.cpp. This is responsible for forwarding a call from
105 // C++'s Function::apply to a Python method "apply".
106 auto PyFunction::apply(variable_list&& inputs) -> variable_list {
107  AutoGIL gil;
108  at::OptionalDeviceGuard _device_guard;
109  THPFunction* py_fn = (THPFunction*)obj;
110 
111  THPObjectPtr _legacy(PyObject_GetAttrString(obj, "_is_legacy"));
112  if (_legacy == Py_True) {
113  return legacy_apply(inputs);
114  }
115 
116  // Massage a C++ variable_list into a Python arguments tuple
117  auto num_inputs = inputs.size();
118  THPObjectPtr pyInputs(PyTuple_New(num_inputs));
119  if (!pyInputs) throw python_error();
120  auto& output_info = py_fn->output_info;
121  for (size_t i = 0; i < num_inputs; ++i) {
122  PyObject* input;
123  if (inputs[i].defined()) {
124  input = THPVariable_Wrap(inputs[i]);
125  } else {
126  input = THPVariable_Wrap(output_info[i].zeros(_device_guard));
127  }
128  if (!input) throw python_error();
129  PyTuple_SET_ITEM(pyInputs.get(), i, input);
130  }
131 
132  THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply"));
133  if (!apply_fn) throw python_error();
134  THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get()));
135  if (!r) throw python_error();
136  ensure_tuple(r);
137 
138  auto& is_variable_input = py_fn->is_variable_input;
139  int num_outputs = PyTuple_GET_SIZE(r.get());
140  int num_forward_inputs = is_variable_input.size();
141  // Returning too many results is ok, but only as long as they're all None.
142  // Truncate the result tuple in that case.
143  if (num_outputs > num_forward_inputs) {
144  bool all_none = true;
145  for (int i = num_forward_inputs; i < num_outputs; i++) {
146  all_none &= PyTuple_GET_ITEM(r.get(), i) == Py_None;
147  }
148  if (all_none) {
149  num_outputs = num_forward_inputs;
150  r = PyTuple_GetSlice(r.get(), 0, num_forward_inputs);
151  if (!r) throw python_error();
152  }
153  }
154 
155  // Now the number of gradients should match
156  if (num_outputs != num_forward_inputs) {
157  std::string msg("function ");
158  msg += name() + " returned an incorrect number of gradients (expected ";
159  msg += std::to_string(num_forward_inputs) + ", got " ;
160  msg += std::to_string(num_outputs) + ")";
161  throw std::runtime_error(msg);
162  }
163 
164  // Massage the Python results tuple back into a C++ variable_list
165  variable_list results;
166  results.reserve(num_outputs);
167  auto& input_info = py_fn->input_info;
168  for (int i = 0; i != num_outputs; ++i) {
169  PyObject* output = PyTuple_GET_ITEM(r.get(), i);
170  bool was_variable = is_variable_input[i];
171  if (!was_variable) {
172  if (output != Py_None) {
173  std::string msg("function ");
174  msg += name() + " returned a gradient different than None at position ";
175  msg += std::to_string(i + 1) + ", but the corresponding forward input was not a Variable";
176  throw std::runtime_error(msg);
177  }
178  continue;
179  }
180  if (output == Py_None) {
181  auto& info = input_info[results.size()];
182  if (info.requires_grad) {
183  results.emplace_back(info.zeros(_device_guard));
184  } else {
185  results.emplace_back();
186  }
187  } else {
188  if (!THPVariable_Check(output)) {
189  std::string msg("expected Variable or None (got ");
190  msg += THPUtils_typename(output);
191  msg += ")";
192  throw std::runtime_error(msg);
193  }
194  results.emplace_back(((THPVariable*)output)->cdata);
195  }
196  }
197 
198  return results;
199 }
200 
201 auto PyFunction::is_traceable() -> bool {
202  AutoGIL gil;
203  THPObjectPtr forward_class {PyObject_GetAttrString(obj, "_forward_cls")};
204  if (!forward_class) throw python_error();
205  THPObjectPtr traceable_py_bool {PyObject_GetAttrString(forward_class, "is_traceable")};
206  if (!traceable_py_bool) throw python_error();
207  return traceable_py_bool == Py_True;
208 }
209 
210 auto PyFunction::release_variables() -> void {
211  AutoGIL gil;
212  auto f = (THPFunction*) obj;
213  f->saved_variables.clear();
214  f->has_freed_buffers = 1;
215 }
216 
217 auto PyFunction::name() const -> std::string {
218  AutoGIL gil;
219  auto f = (THPFunction*) obj;
220  auto name = std::string(Py_TYPE(f)->tp_name);
221  THPObjectPtr _legacy(PyObject_GetAttrString(obj, "_is_legacy"));
222  if (_legacy == Py_True) {
223  name += "LegacyBackward";
224  }
225  return name;
226 }
227 
228 auto PyFunction::get_shared_ptr() -> std::shared_ptr<Function> {
229  return THPFunction_asFunction((THPFunction*)obj);
230 }
231 
232 }} // namespace torch::autograd
233 
234 // Traverse and clear are required for supporting Python's GC cycle handling.
235 static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg)
236 {
237  for (const auto& hook : self->cdata.pre_hooks()) {
238  if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
239  Py_VISIT(pyhook->dict);
240  }
241  }
242  for (const auto& hook : self->cdata.post_hooks()) {
243  if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
244  Py_VISIT(pyhook->dict);
245  }
246  }
247  Py_VISIT(self->to_save);
248  Py_VISIT(self->non_differentiable);
249  Py_VISIT(self->dirty_tensors);
250  return 0;
251 }
252 
253 static int THPFunction_clear(THPFunction *self)
254 {
255  self->cdata.clear_input_metadata();
256 
257  Py_CLEAR(self->needs_input_grad);
258 
259  Py_CLEAR(self->to_save);
260  Py_CLEAR(self->non_differentiable);
261  Py_CLEAR(self->dirty_tensors);
262 
263  self->output_info.clear();
264  self->input_info.clear();
265  self->saved_variables.clear();
266  self->is_variable_input.clear();
267 
268  // Moving the hooks out makes sure to first disassociate them from the
269  // function, but without destroying any of them. They will get deleted when
270  // exiting this scope. This is important, because deleting Python objects can
271  // trigger deletion of other objects, and they can reference this function,
272  // seeing it in a half-deleted state.
273  auto pre_hooks = std::move(self->cdata.pre_hooks());
274  auto post_hooks = std::move(self->cdata.post_hooks());
275 
276  return 0;
277 }
278 
279 static void THPFunction_dealloc(THPFunction* self)
280 {
281  PyObject_GC_UnTrack(self);
282  THPFunction_clear(self);
283  self->cdata.~PyFunction();
284  self->output_info.~vector();
285  self->input_info.~vector();
286  self->saved_variables.~vector();
287  self->is_variable_input.~vector();
288  Py_TYPE(self)->tp_free((PyObject*)self);
289 }
290 
291 PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
292 {
293  PyObject* obj = type->tp_alloc(type, 0);
294  if (!obj) return nullptr;
295  // Python zero-initializes the object memory, so there's no need to initialize
296  // most fields
297  THPFunction* self = (THPFunction*)obj;
298  new (&self->cdata) PyFunction(obj);
299  new (&self->output_info) std::vector<VariableInfo>();
300  new (&self->input_info) std::vector<VariableInfo>();
301  new (&self->saved_variables) std::vector<SavedVariable>();
302  new (&self->is_variable_input) std::vector<bool>();
303  return obj;
304 }
305 
307 // Forward
309 
310 using t2var_type = std::unordered_map<PyObject *, THPVariable *>;
311 
312 // Bump the counters of all recorded dirty input tensors, adding each of them
313 // into dirty_inputs. Also does some sanity checking.
314 static std::vector<PyObject*> _mark_dirty(THPFunction *self)
315 {
316  // Increase versions of modified tensors
317  std::vector<PyObject*> dirty_inputs;
318  if (!self->dirty_tensors) return dirty_inputs;
319 
320  THPFunction_assert(PyTuple_Check(self->dirty_tensors), "autograd "
321  "internal error: dirty_tensors attribute is expected to be a tuple "
322  "but is %s", THPUtils_typename(self->dirty_tensors));
323  Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors);
324  for (int i = 0; i < num_dirty; i++) {
325  PyObject *obj = PyTuple_GET_ITEM(self->dirty_tensors, i);
326  THPFunction_assert(THPVariable_Check(obj), "mark_dirty can "
327  "only accept variables, but argument %d is of type %s", i,
328  THPUtils_typename(obj));
329 
330  dirty_inputs.push_back(obj);
331  auto variable = (THPVariable*)obj;
332  variable->cdata.bump_version();
333  }
334  // We're not going to ever need this so let's remove references now
335  Py_CLEAR(self->dirty_tensors);
336  return dirty_inputs;
337 }
338 
339 static std::unordered_set<PyObject*> _parse_non_differentiable(THPFunction *self);
340 
341 // Given a Python tuple of raw output tensors (raw_output), set each of
342 // the corresponding entries in a different Python tuple (outputs) with
343 // these tensors wrapped with variables. We save the gradient function (self)
344 // to the variable if the output requires grad.
345 //
346 // There is a considerable amount of complexity to handle if the operation
347 // that produced these output tensors is inplace. A mapping of *input*
348 // tensors to variables (t2var) is used to test if this occurred, and
349 // the set of dirty tensors (dirty_inputs) is used to figure out what to
350 // do in this case. After this method is run, t2var is extended with
351 // mappings for output tensors as well.
352 static void _wrap_outputs(THPFunction *self,
353  PyObject* inputs_tuple, PyObject *raw_output, PyObject *outputs, bool is_executable)
354 {
355  auto cdata = is_executable ? THPFunction_asFunction(self) : nullptr;
356  Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
357  if (is_executable) {
358  self->output_info.clear();
359  self->output_info.reserve(num_outputs);
360  }
361 
362  std::unordered_set<PyObject*> inputs;
363  int num_inputs = PyTuple_GET_SIZE(inputs_tuple);
364  for (int i = 0; i < num_inputs; i++) {
365  inputs.emplace(PyTuple_GET_ITEM(inputs_tuple, i));
366  }
367 
368  auto non_differentiable = _parse_non_differentiable(self);
369  auto dirty_inputs = _mark_dirty(self);
370 
371  auto as_variable = [&](PyObject* obj, int i) -> Variable {
372  if (THPVariable_Check(obj)) {
373  return ((THPVariable*)obj)->cdata;
374  }
375  throw TypeError("%s.forward: expected Variable (got %s) for return value %d",
376  Py_TYPE(self)->tp_name, Py_TYPE(obj)->tp_name, i);
377  };
378 
379  // Sets the grad_fn and output_nr of an output Variable.
380  auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified,
381  bool is_differentiable) {
382  if (!is_differentiable) {
383  if (!var.requires_grad()) {
384  return;
385  }
386  // NB: we don't support returning non-differentiable views that could require grad
387  if (var.is_view()) {
388  throw std::runtime_error("Returning Variables sharing storage with other Variables "
389  "that require grad is not supported in Python functions. "
390  "Please submit a feature request if you hit this error.");
391  }
392  // Return detached aliases of inputs, instead of changing their requires_grad
393  // property.
394  if (is_input) {
395  var = var.detach();
396  } else {
397  var.detach_();
398  }
399  } else if (is_modified) {
400  if (var.is_leaf() && var.requires_grad()) {
401  throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
402  }
403  // If the input was modified, transplant the grad_fn in the graph:
404  // grad_fn <- variable <- self ==> grad_fn <- self <- variable
405  var.grad().reset();
406  var.clear_hooks();
407  if (auto grad_acc_fn = var.try_get_grad_accumulator()) {
408  auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get());
409  grad_acc->variable.reset();
410  }
411  if (cdata) {
412  var.rebase_history({cdata, output_nr});
413  }
414  } else if (is_input) {
415  // An input has been returned, but it wasn't modified. Return it as a view
416  // so that we can attach a new grad_fn to the Variable.
417  var = var.view_as(var);
418  var.set_gradient_edge({cdata, output_nr});
419  } else if (cdata) {
420  var.set_gradient_edge({cdata, output_nr});
421  }
422  };
423 
424  for (int i = 0; i < num_outputs; i++) {
425  PyObject* obj = PyTuple_GET_ITEM(raw_output, i);
426 
427  bool is_input = inputs.count(obj) > 0;
428  bool is_modified = std::find(dirty_inputs.begin(), dirty_inputs.end(), obj) != dirty_inputs.end();
429  bool is_differentiable = is_executable && non_differentiable.count(obj) == 0;
430 
431  // Note that output Variables may be repeated. In that case, the last call
432  // to set_history wins.
433  auto var = as_variable(obj, i);
434  if (cdata) {
435  auto output_nr = cdata->add_input_metadata(var);
436  AT_ASSERT(i == (int)output_nr);
437  }
438  set_history(var, i, is_input, is_modified, is_differentiable);
439 
440  if (is_executable) {
441  self->output_info.emplace_back(var);
442  }
443 
444  PyTuple_SET_ITEM(outputs, i, THPVariable_Wrap(var));
445  }
446 }
447 
448 // Save any variables that requested by to_save
449 static void _save_variables(THPFunction* self)
450 {
451  if (!self->to_save) return;
452 
453  THPFunction_assert(PyTuple_Check(self->to_save), "autograd internal "
454  "error: to_save attribute is expected to be a tuple but is %s",
455  THPUtils_typename(self->to_save));
456  Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
457  self->saved_variables.clear();
458  self->saved_variables.reserve(num_saved);
459  auto cdata_ptr = &self->cdata;
460  for (int i = 0; i < num_saved; i++) {
461  PyObject *obj = PyTuple_GET_ITEM(self->to_save, i);
462  if (obj == Py_None) {
463  self->saved_variables.emplace_back();
464  continue;
465  } else if (THPVariable_Check(obj)) {
466  auto variable = (THPVariable*)obj;
467  bool is_output = variable->cdata.grad_fn().get() == cdata_ptr;
468  self->saved_variables.emplace_back(variable->cdata, is_output);
469  } else {
470  throw TypeError(
471  "save_for_backward can only save variables, but argument %d is of "
472  "type %s", i, Py_TYPE(obj)->tp_name);
473  }
474  }
475  // Free .to_save
476  Py_CLEAR(self->to_save);
477 }
478 
479 // Mark requires_grad = 0 on non-differentiable variables (as per non_differentiable)
480 static std::unordered_set<PyObject*>
481 _parse_non_differentiable(THPFunction *self)
482 {
483  std::unordered_set<PyObject*> set;
484  if (!self->non_differentiable) return set;
485 
486  THPFunction_assert(PyTuple_Check(self->non_differentiable), "autograd "
487  "internal error: non_differentiable attribute is expected to be a "
488  "tuple but is %s", THPUtils_typename(self->non_differentiable));
489  Py_ssize_t num_nondiff = PyTuple_GET_SIZE(self->non_differentiable);
490  set.reserve(num_nondiff);
491  for (int i = 0; i < num_nondiff; i++) {
492  PyObject *t = PyTuple_GET_ITEM(self->non_differentiable, i);
493  THPFunction_assert(THPVariable_Check(t), "mark_non_differentiable "
494  "only accepts variable arguments, but got %s", THPUtils_typename(t));
495  set.insert(t);
496  }
497  Py_CLEAR(self->non_differentiable);
498  return set;
499 }
500 
502  THPObjectPtr input_tuple;
503  variable_list input_vars;
504 };
505 
506 struct InputFlags {
507  bool is_executable = false;
508  edge_list next_edges;
509  THPObjectPtr needs_input_grad;
510  std::vector<bool> is_variable_input;
511 };
512 
513 template<bool enforce_variables>
514 std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) {
515  UnpackedInput unpacked;
516  InputFlags flags;
517 
518  auto num_args = PyTuple_GET_SIZE(args);
519  unpacked.input_tuple = PyTuple_New(num_args);
520  flags.needs_input_grad = PyTuple_New(num_args);
521  for (int i = 0; i < num_args; i++) {
522  PyObject *arg = PyTuple_GET_ITEM(args, i);
523 
524  bool is_variable = THPVariable_Check(arg);
525  flags.is_variable_input.push_back(is_variable);
526  if (!is_variable) {
527  // TODO: remove this code path once Variable and Tensor are merged in Python
528  if (enforce_variables) {
529  THPUtils_setError("expected a Variable argument, but got %s",
530  THPUtils_typename(arg));
531  throw python_error();
532  }
533  Py_INCREF(Py_False);
534  PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False);
535  } else {
536  THPVariable* variable = (THPVariable*)arg;
537  unpacked.input_vars.push_back(variable->cdata);
538  PyObject* needs_grad = variable->cdata.requires_grad() ? Py_True : Py_False;
539  Py_INCREF(needs_grad);
540  PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad);
541  }
542  Py_INCREF(arg);
543  PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg);
544  }
545 
546  flags.is_executable = GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars);
547  flags.next_edges = collect_next_edges(unpacked.input_vars);
548  return std::make_pair(std::move(unpacked), std::move(flags));
549 }
550 
551 static void _assert_not_tracing(const char* name, const variable_list& input_vars) {
552  if (tracer::isTracing()) {
553  std::ostringstream oss;
554  oss << "Attempted to trace " << name;
555  oss << ", but tracing of legacy functions is not supported";
556  throw std::runtime_error(oss.str());
557  }
558 }
559 
560 static Node* _trace_pre_record(
561  PyObject* op_obj,
562  PyObject *input_objects,
563  const variable_list& input_vars) {
564  if (!jit::tracer::isTracing()) {
565  return nullptr;
566  }
567 
568  // Save scalar args and the calling convention
569  auto num_args = PyTuple_GET_SIZE(input_objects);
570  pyobj_list scalar_args;
571  std::string arg_types;
572  arg_types.reserve(num_args);
573  scalar_args.reserve(num_args);
574  for (int i = 0; i < num_args; i++) {
575  PyObject *arg_object = PyTuple_GET_ITEM(input_objects, i);
576  if (THPVariable_Check(arg_object)) {
577  arg_types.push_back('d');
578  } else {
579  arg_types.push_back('c');
580  Py_INCREF(arg_object);
581  scalar_args.emplace_back(arg_object);
582  }
583  }
584 
585  Py_INCREF(op_obj);
586  auto pyobj = THPObjectPtr(op_obj);
587  return jit::tracer::preRecordPythonTrace(
588  std::move(pyobj), arg_types, input_vars, std::move(scalar_args));
589 }
590 
591 static void _trace_post_record(
592  Node* node,
593  PyObject* op_obj,
594  const variable_list& input_vars,
595  PyObject *output_objects,
596  bool is_inplace,
597  bool unpack_output) {
598  if (!jit::tracer::isTracing()) {
599  return;
600  }
601 
602  node->i_(attr::inplace, is_inplace);
603 
604  // Isolate C variable ptrs in a vector
605  int num_outputs = PyTuple_GET_SIZE(output_objects);
606  variable_list output_vars(num_outputs);
607  auto graph = node->owningGraph();
608  node->addOutput();
609  if (!unpack_output) {
610  std::vector<TypePtr> tuple_values(num_outputs, TensorType::get());
611  TypePtr tuple_type = TupleType::create(std::move(tuple_values));
612  node->output()->setType(tuple_type);
613  auto unpacked = graph->createTupleUnpack(node->output())->insertAfter(node);
614  node = unpacked;
615  }
616  for (int i = 0; i < num_outputs; ++i) {
617  auto var = (THPVariable*)PyTuple_GET_ITEM(output_objects, i);
618  Value* value = node->outputs()[i];
619  if (var->cdata.defined()) {
620  value->inferTypeFrom(var->cdata);
621  jit::tracer::setValueTrace(autograd::as_variable_ref(var->cdata), value);
622  }
623  }
624 }
625 
626 PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const UnpackedInput& unpacked,
627  PyObject *inputs, THPObjectPtr&& raw_output, bool is_executable,
628  Node* node) {
629  bool unpack_output = ensure_tuple(raw_output);
630 
631  auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
632 
633  THPObjectPtr outputs(PyTuple_New(num_outputs));
634  if (!outputs) throw python_error();
635 
636  grad_fn->cdata.clear_input_metadata();
637 
638  // Record type, device, and size information about inputs
639  if (is_executable) {
640  grad_fn->input_info.clear();
641  grad_fn->input_info.reserve(unpacked.input_vars.size());
642  for (auto& var : unpacked.input_vars) {
643  grad_fn->input_info.emplace_back(var);
644  }
645  }
646 
647  bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
648  _wrap_outputs(grad_fn, inputs, raw_output, outputs, is_executable);
649  _trace_post_record(node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
650  if (is_executable) {
651  _save_variables(grad_fn);
652  } else {
653  // Remove unnecessary attributes
654  Py_XDECREF(grad_fn->to_save);
655  grad_fn->to_save = nullptr;
656  Py_XDECREF(grad_fn->non_differentiable);
657  grad_fn->non_differentiable = nullptr;
658  }
659 
660  // Unpack the output, unless .forward() returned a tuple
661  if (unpack_output) {
662  PyObject *output = PyTuple_GET_ITEM(outputs.get(), 0);
663  Py_INCREF(output);
664  return output;
665  }
666 
667  return outputs.release();
668 }
669 
670 // Legacy codepath
671 PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs)
672 {
673  HANDLE_TH_ERRORS
674  torch::autograd::profiler::RecordFunction record(Py_TYPE(self)->tp_name,
675  Function::peek_at_next_sequence_nr());
676 
677  auto info_pair = unpack_input<true>(_inputs);
678  auto& unpacked_input = info_pair.first;
679  auto& input_info = info_pair.second;
680  bool is_executable = input_info.is_executable;
681  self->cdata.set_next_edges(std::move(input_info.next_edges));
682  self->needs_input_grad = input_info.needs_input_grad.release();
683 
684  // We don't support tracing in the legacy code path
685  _assert_not_tracing(Py_TYPE(self)->tp_name, unpacked_input.input_vars);
686 
687  // Now we're ready to call a forward (implemented in Python)
688  THPObjectPtr raw_output;
689  {
690  AutoGradMode grad_mode(false);
691  THPObjectPtr forward_fn(PyObject_GetAttrString((PyObject*)self, "forward"));
692  if (!forward_fn) return nullptr;
693  raw_output = PyObject_CallObject(forward_fn, unpacked_input.input_tuple);
694  if (!raw_output) return nullptr;
695  }
696 
697  return process_outputs(nullptr, self, unpacked_input, _inputs, std::move(raw_output),
698  is_executable, nullptr);
699  END_HANDLE_TH_ERRORS
700 }
701 
702 PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
703 {
704  HANDLE_TH_ERRORS
705  torch::autograd::profiler::RecordFunction record(((PyTypeObject*)cls)->tp_name,
706  Function::peek_at_next_sequence_nr());
707 
708  THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
709  if (!backward_cls) return nullptr;
710  THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr));
711  if (!ctx_obj) return nullptr;
712  THPFunction* ctx = (THPFunction*)ctx_obj.get();
713 
714  // Prepare inputs and allocate context (grad fn)
715  auto info_pair = unpack_input<false>(inputs);
716  UnpackedInput& unpacked_input = info_pair.first;
717  InputFlags& input_info = info_pair.second;
718 
719  // Record input nodes if tracing
720  auto* node = _trace_pre_record(cls, inputs, unpacked_input.input_vars);
721 
722  // Initialize backward function (and ctx)
723  bool is_executable = input_info.is_executable;
724  ctx->cdata.set_next_edges(std::move(input_info.next_edges));
725  ctx->needs_input_grad = input_info.needs_input_grad.release();
726  ctx->is_variable_input = std::move(input_info.is_variable_input);
727 
728  // Prepend ctx to input_tuple, in preparation for static method call
729  auto num_args = PyTuple_GET_SIZE(inputs);
730  THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
731  PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, ctx_obj.release());
732  for (int i = 0; i < num_args; ++i) {
733  PyObject *arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i);
734  Py_INCREF(arg);
735  PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg);
736  }
737 
738  // Call forward
739  THPObjectPtr tensor_outputs;
740  {
741  AutoGradMode grad_mode(false);
742  THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward"));
743  if (!forward_fn) return nullptr;
744  tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple);
745  if (!tensor_outputs) return nullptr;
746  }
747 
748  return process_outputs(cls, ctx, unpacked_input, inputs, std::move(tensor_outputs),
749  is_executable, node);
750  END_HANDLE_TH_ERRORS
751 }
752 
753 
755 // Backward
757 
758 static void _prepare_grads(THPFunction *self, THPObjectPtr& raw_grads, bool is_grad_output)
759 {
760  at::OptionalDeviceGuard device_guard;
761  int num_grads = PyTuple_GET_SIZE(raw_grads.get());
762  // First, check if any of grads is None. If not, there's nothing to do
763  bool has_none = false;
764  for (int i = 0; i < num_grads; i++) {
765  has_none |= PyTuple_GET_ITEM(raw_grads.get(), i) == Py_None;
766  }
767  if (!has_none)
768  return;
769 
770  THPObjectPtr grads;
771  grads = PyTuple_New(num_grads);
772  if (!grads) throw python_error();
773 
774  // Look for Nones and replace them with new buffers
775  auto& grads_info = is_grad_output ? self->output_info : self->input_info;
776  AT_ASSERT(grads_info.size() == (size_t)num_grads);
777  for (int i = 0; i < num_grads; i++) {
778  PyObject *grad = PyTuple_GET_ITEM(raw_grads.get(), i);
779  if (grad == Py_None) {
780  grad = THPVariable_Wrap(grads_info[i].zeros(device_guard));
781  if (!grad) throw python_error();
782  } else {
783  Py_INCREF(grad);
784  }
785  PyTuple_SET_ITEM(grads.get(), i, grad);
786  }
787  raw_grads = grads.release();
788 }
789 
790 static void _trim_grad_input(THPFunction *self, THPObjectPtr& grad_input)
791 {
792  int num_grads = PyTuple_GET_SIZE(grad_input.get());
793  const int num_outputs = self->cdata.num_outputs();
794  if (num_grads > num_outputs) {
795  // Check that all extra grads are none
796  bool all_none = true;
797  for (int i = num_outputs; i < num_grads; i++) {
798  all_none = (PyTuple_GET_ITEM(grad_input.get(), i) == Py_None);
799  if (!all_none) break;
800  }
801  // If yes, slice the tuple
802  if (all_none) {
803  num_grads = num_outputs;
804  grad_input = PyTuple_GetSlice(grad_input.get(), 0, num_grads);
805  if (!grad_input) throw python_error();
806  }
807  }
808 }
809 
810 PyObject * THPFunction_do_backward(THPFunction *self, PyObject *args)
811 {
812  try {
813  Py_ssize_t num_args = args ? PyTuple_GET_SIZE(args) : 0;
814  THPUtils_assert(num_args == 2, "_do_backward expects exactly two arguments");
815  PyObject *raw_grad_output = PyTuple_GET_ITEM(args, 0);
816  PyObject *retain_variables = PyTuple_GET_ITEM(args, 1);
817  if (!PyTuple_Check(raw_grad_output) || !PyBool_Check(retain_variables)) {
818  THPUtils_invalidArguments(args, nullptr, "_do_backward", 1, "(tuple, bool)");
819  return nullptr;
820  }
821  THPUtils_assert(PyTuple_GET_SIZE(raw_grad_output) == self->cdata.num_inputs(),
822  "%s got an invalid number of gradients (expected %d got %d)",
823  THPUtils_typename(self), self->cdata.num_inputs(),
824  PyTuple_GET_SIZE(raw_grad_output));
825 
826  // Some of the output might have been unused, so we have to allocate
827  // zero-filled buffers instead
828  Py_INCREF(raw_grad_output);
829  THPObjectPtr grad_output(raw_grad_output);
830  _prepare_grads(self, grad_output, true);
831 
832  // self.backward(*grad_output)
833  THPObjectPtr backward_fn(PyObject_GetAttrString((PyObject*)self, "backward"));
834  THPUtils_assert(backward_fn.get(), "function %s doesn't implement a required "
835  "'backward' method", THPUtils_typename((PyObject*)self));
836  THPObjectPtr grad_input(PyObject_CallObject(backward_fn, grad_output.get()));
837  if (!grad_input) return nullptr;
838  ensure_tuple(grad_input);
839 
840  // We allow functions to return more gradients, than there were outputs,
841  // if and only if the additional ones are all None
842  _trim_grad_input(self, grad_input);
843  int num_grads = PyTuple_GET_SIZE(grad_input.get());
844  int num_outputs = self->cdata.num_outputs();
845  THPUtils_assert(num_grads == num_outputs, "%s returned an invalid number of "
846  "gradient tensors (expected %d, but got %d)", THPUtils_typename(self),
847  num_outputs, num_grads);
848 
849  // If any of the remaining grad_inputs are None, zero them.
850  _prepare_grads(self, grad_input, false);
851  return grad_input.release();
852 
853  } catch (python_error& e) {
854  return nullptr;
855  } catch (std::exception& e) {
856  THPUtils_setError(e.what());
857  return nullptr;
858  }
859 }
860 
862 // Other methods / attributes
864 
865 PyObject* THPFunction__register_hook_dict(THPFunction *self, PyObject *_var)
866 {
867  THPUtils_assert(THPVariable_Check(_var), "_register_hook_dict expected a variable");
868  THPVariable *var = (THPVariable*)_var;
869  std::unique_ptr<FunctionPreHook> hook(new PyFunctionPreHook(
870  var->backward_hooks, var->cdata.output_nr()));
871  self->cdata.add_pre_hook(std::move(hook));
872  Py_RETURN_NONE;
873 }
874 
875 PyObject* THPFunction_register_hook(THPFunction *self, PyObject *hook)
876 {
877  return torch::autograd::registerFunctionHook(self->cdata, hook);
878 }
879 
880 static PyObject *unpack_saved_variables(
881  THPFunction *self,
882  const std::function<PyObject*(const Variable&)>& unpack_fn)
883 {
884  THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE);
885  auto& saved_variables = self->saved_variables;
886  if (saved_variables.empty())
887  return PyTuple_New(0);
888 
889  int num_saved = saved_variables.size();
890  THPObjectPtr saved(PyTuple_New(num_saved));
891  if (!saved)
892  return nullptr;
893  auto saved_for = THPFunction_asFunction(self);
894  for (int i = 0; i < num_saved; i++) {
895  auto unpacked_var = saved_variables[i].unpack(saved_for);
896  THPObjectPtr value;
897  if (!unpacked_var.defined()) {
898  Py_INCREF(Py_None);
899  value = Py_None;
900  } else {
901  value = unpack_fn(unpacked_var);
902  }
903  PyTuple_SET_ITEM(saved.get(), i, value.release());
904  }
905  return saved.release();
906 }
907 
908 PyObject *THPFunction_saved_tensors(THPFunction *self, void *_unused)
909 {
910  HANDLE_TH_ERRORS
911  return unpack_saved_variables(self, [](const Variable& var) {
912  return THPVariable_Wrap(var);
913  });
914  END_HANDLE_TH_ERRORS
915 }
916 
917 PyObject *THPFunction_saved_variables(THPFunction *self, void *_unused)
918 {
919  HANDLE_TH_ERRORS
920  auto r = PyErr_WarnEx(PyExc_DeprecationWarning,
921  "'saved_variables' is deprecated; use 'saved_tensors'", 0);
922  if (r != 0) throw python_error();
923  return unpack_saved_variables(self, [](const Variable& var) {
924  return THPVariable_Wrap(var);
925  });
926  END_HANDLE_TH_ERRORS
927 }
928 
929 PyObject *THPFunction_next_functions(THPFunction *self, void *_unused)
930 {
931  const auto num_outputs = self->cdata.num_outputs();
932  THPObjectPtr result(PyTuple_New(num_outputs));
933  if (!result)
934  return nullptr;
935  for (uint32_t i = 0; i < num_outputs; i++) {
936  THPObjectPtr fn_tuple(PyTuple_New(2));
937  if (!fn_tuple) return nullptr;
938  const auto& edge = self->cdata.next_edge(i);
939  PyObject* fn = functionToPyObject(edge.function);
940  if (!fn) return nullptr;
941  PyTuple_SET_ITEM(fn_tuple.get(), 0, fn);
942  PyTuple_SET_ITEM(fn_tuple.get(), 1, THPUtils_packInt64(edge.input_nr));
943  PyTuple_SET_ITEM(result.get(), i, fn_tuple.release());
944  }
945  return result.release();
946 }
947 
948 PyObject *THPFunction_metadata(THPFunction *self, void *_unused)
949 {
950  auto metadata = static_cast<PyAnomalyMetadata*>(self->cdata.metadata())->dict();
951 
952  Py_INCREF(metadata);
953  return metadata;
954 }
955 
956 typedef PyObject *(*getter)(PyObject *, void *);
957 typedef int (*setter)(PyObject *, PyObject *, void *);
958 
959 namespace {
960 
961 template<PyObject* THPFunction::*ptr>
962 PyObject* getObject(PyObject* obj, void* _unused) {
963  auto self = (THPFunction*)obj;
964  PyObject* value = self->*ptr;
965  if (!value) {
966  Py_RETURN_NONE;
967  }
968  Py_INCREF(value);
969  return value;
970 }
971 
972 template<PyObject* THPFunction::*ptr>
973 int setObject(PyObject* obj, PyObject* value, void* _unused) {
974  auto self = (THPFunction*)obj;
975  if (value == Py_None) {
976  value = nullptr;
977  }
978  Py_XDECREF((self->*ptr));
979  Py_XINCREF(value);
980  self->*ptr = value;
981  return 0;
982 }
983 
984 template<typename M, M THPFunction::*ptr, PyObject* (*Convert)(long)>
985 PyObject* getMember(PyObject* obj, void* _unused) {
986  auto self = (THPFunction*)obj;
987  return Convert(self->*ptr);
988 }
989 
990 template<typename M, M Function::*ptr, PyObject* (*Convert)(long)>
991 PyObject* getImplMember(PyObject* obj, void* _unused) {
992  auto self = (THPFunction*)obj;
993  return Convert(self->cdata.*ptr);
994 }
995 
996 PyObject* getRequiresGrad(PyObject* obj, void* _unused) {
997  Py_RETURN_TRUE;
998 }
999 
1000 }
1001 
1002 static struct PyGetSetDef THPFunction_properties[] = {
1003  {"saved_tensors", (getter)THPFunction_saved_tensors, nullptr, nullptr, nullptr},
1004  {"saved_variables", (getter)THPFunction_saved_variables, nullptr, nullptr, nullptr},
1005  {"next_functions", (getter)THPFunction_next_functions, nullptr, nullptr, nullptr},
1006  {"to_save", &getObject<&THPFunction::to_save>, &setObject<&THPFunction::to_save>, nullptr, nullptr},
1007  {"non_differentiable", &getObject<&THPFunction::non_differentiable>, &setObject<&THPFunction::non_differentiable>, nullptr, nullptr},
1008  {"dirty_tensors", &getObject<&THPFunction::dirty_tensors>, &setObject<&THPFunction::dirty_tensors>, nullptr, nullptr},
1009  {"needs_input_grad", &getObject<&THPFunction::needs_input_grad>, nullptr, nullptr, nullptr},
1010  {"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr},
1011  {"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr},
1012  {nullptr}
1013 };
1014 
1015 static struct PyMethodDef THPFunction_methods[] = {
1016  {(char*)"apply", (PyCFunction)THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr},
1017  {(char*)"_do_forward", (PyCFunction)THPFunction_do_forward, METH_VARARGS, nullptr},
1018  {(char*)"_do_backward", (PyCFunction)THPFunction_do_backward, METH_VARARGS, nullptr},
1019  {(char*)"_register_hook_dict", (PyCFunction)THPFunction__register_hook_dict, METH_O, nullptr},
1020  {(char*)"register_hook", (PyCFunction)THPFunction_register_hook, METH_O, nullptr},
1021  {nullptr}
1022 };
1023 
1024 PyTypeObject THPFunctionType = {
1025  PyVarObject_HEAD_INIT(nullptr, 0)
1026  "torch._C._FunctionBase", /* tp_name */
1027  sizeof(THPFunction), /* tp_basicsize */
1028  0, /* tp_itemsize */
1029  (destructor)THPFunction_dealloc, /* tp_dealloc */
1030  nullptr, /* tp_print */
1031  nullptr, /* tp_getattr */
1032  nullptr, /* tp_setattr */
1033  nullptr, /* tp_reserved */
1034  nullptr, /* tp_repr */
1035  nullptr, /* tp_as_number */
1036  nullptr, /* tp_as_sequence */
1037  nullptr, /* tp_as_mapping */
1038  nullptr, /* tp_hash */
1039  nullptr, /* tp_call */
1040  nullptr, /* tp_str */
1041  nullptr, /* tp_getattro */
1042  nullptr, /* tp_setattro */
1043  nullptr, /* tp_as_buffer */
1044  Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */
1045  nullptr, /* tp_doc */
1046  (traverseproc)THPFunction_traverse, /* tp_traverse */
1047  (inquiry)THPFunction_clear, /* tp_clear */
1048  nullptr, /* tp_richcompare */
1049  0, /* tp_weaklistoffset */
1050  nullptr, /* tp_iter */
1051  nullptr, /* tp_iternext */
1052  THPFunction_methods, /* tp_methods */
1053  nullptr, /* tp_members */
1054  THPFunction_properties, /* tp_getset */
1055  nullptr, /* tp_base */
1056  nullptr, /* tp_dict */
1057  nullptr, /* tp_descr_get */
1058  nullptr, /* tp_descr_set */
1059  0, /* tp_dictoffset */
1060  nullptr, /* tp_init */
1061  nullptr, /* tp_alloc */
1062  THPFunction_new /* tp_new */
1063 };
1064 
1065 bool THPFunction_initModule(PyObject *module)
1066 {
1067  if (PyType_Ready(&THPFunctionType) < 0)
1068  return false;
1069  Py_INCREF(&THPFunctionType);
1070  PyModule_AddObject(module, "_FunctionBase", (PyObject *)&THPFunctionType);
1071  return true;
1072 }
1073 
1074 struct Decref {
1075  void operator()(PyFunction* p) const {
1076  AutoGIL gil;
1077  Py_DECREF(p->obj);
1078  }
1079 };
1080 
1081 // Similar to shared_from_this. There's a problem that the Python object
1082 // and its cdata depend on each other being alive, so we can't keep
1083 // shared_ptrs as members, but we'd like to be able to manage the lifetime of
1084 // the objects using shared_ptrs in the C++ graph. This returns a new
1085 // shared_ptr, which will decrement the Python reference count when it's
1086 // destructed. WARNING: it's generally not safe to create weak_ptrs from
1087 // these shared_ptrs since multiple shared_ptrs may control the same underlying
1088 // object.
1089 std::shared_ptr<PyFunction> THPFunction_asFunction(THPFunction* self)
1090 {
1091  if (!self) {
1092  return std::shared_ptr<PyFunction>();
1093  }
1094 
1095  Py_INCREF((PyObject*)self);
1096  return std::shared_ptr<PyFunction>(&self->cdata, Decref());
1097 }
Variable & grad() override
Accesses the gradient Variable of this Variable.
Definition: variable.h:373
void set_gradient_edge(Edge edge) noexcept
Set the gradient edge – i.e.
Definition: variable.h:682
bool is_view() const noexcept
Returns true if this Variable is a view of another Variable.
Definition: variable.h:734
A OptionalDeviceGuard is an RAII class that sets a device to some value on initialization, and resets the device to its original value on destruction.
Definition: DeviceGuard.h:119
Variable detach() const
Returns a copy of this Variable that is detached from its autograd graph and has a blank version...
Definition: variable.h:673
void rebase_history(Edge gradient_edge)
Update the grad_fn of an existing Variable.
Definition: variable.cpp:236
bool is_leaf() const noexcept
True if this Variable is a leaf and thus does not have a grad_fn.
Definition: variable.h:691
std::shared_ptr< Function > try_get_grad_accumulator() const
Attempts to get a pointer to the gradient accumulator of the Variable, if it still exists...
Definition: variable.h:669
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
void detach_()
Like detach(), but removes this Variable in-place.
Definition: variable.cpp:134
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
uint32_t output_nr() const noexcept
Returns the input index of the gradient Function to which this Variable is connected.
Definition: variable.h:687
void reset_device(at::Device device)
Sets the device to the given one.
Definition: DeviceGuard.h:147