1 #include <torch/csrc/autograd/python_function.h> 3 #include <torch/csrc/python_headers.h> 4 #include <structmember.h> 5 #include <unordered_map> 6 #include <unordered_set> 10 #include <torch/csrc/THP.h> 11 #include <torch/csrc/autograd/grad_mode.h> 12 #include <torch/csrc/autograd/functions/accumulate_grad.h> 13 #include <torch/csrc/autograd/functions/basic_ops.h> 14 #include <torch/csrc/autograd/functions/utils.h> 15 #include <torch/csrc/autograd/python_cpp_function.h> 16 #include <torch/csrc/autograd/python_hook.h> 17 #include <torch/csrc/autograd/saved_variable.h> 18 #include <torch/csrc/autograd/python_anomaly_mode.h> 19 #include <torch/csrc/jit/tracer.h> 20 #include <torch/csrc/jit/python_tracer.h> 21 #include <torch/csrc/DynamicTypes.h> 22 #include <torch/csrc/utils/auto_gil.h> 23 #include <torch/csrc/Exceptions.h> 34 using namespace torch;
39 PyObject *THPFunctionClass =
nullptr;
41 #define THPFunction_assert(condition, ...) \ 42 if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); } 44 namespace torch {
namespace autograd {
46 VariableInfo::VariableInfo(
const Variable& var)
48 , device(var.device())
49 , size(var.sizes().vec())
56 return at::zeros(size, type->options());
59 auto PyFunction::legacy_apply(
const variable_list& inputs) -> variable_list {
65 for (
size_t i = 0; i != inputs.size(); ++i) {
66 PyTuple_SET_ITEM(pyInputs.get(), i, THPVariable_Wrap(inputs[i]));
70 obj,
"_do_backward",
"OO", pyInputs.get(), Py_True));
73 auto num_outputs = PyTuple_GET_SIZE(r.get());
74 tensor_list tensor_results(num_outputs);
75 for (
int i = 0; i != num_outputs; ++i) {
76 PyObject* obj = PyTuple_GET_ITEM(r.get(), i);
78 if (!THPVariable_Check(obj)) {
79 std::string msg(
"expected Variable (got '");
80 msg += THPUtils_typename(obj);
82 throw std::runtime_error(msg);
84 tensor_results[i] = ((
THPVariable*)obj)->cdata.data();
96 std::move(tensor_results),
97 [
this](edge_list&& next_edges) {
98 return std::make_shared<Error>(
99 name() +
" is not differentiable twice", std::move(next_edges));
106 auto PyFunction::apply(variable_list&& inputs) -> variable_list {
111 THPObjectPtr _legacy(PyObject_GetAttrString(obj,
"_is_legacy"));
112 if (_legacy == Py_True) {
113 return legacy_apply(inputs);
117 auto num_inputs = inputs.size();
120 auto& output_info = py_fn->output_info;
121 for (
size_t i = 0; i < num_inputs; ++i) {
123 if (inputs[i].defined()) {
124 input = THPVariable_Wrap(inputs[i]);
126 input = THPVariable_Wrap(output_info[i].zeros(_device_guard));
129 PyTuple_SET_ITEM(pyInputs.get(), i, input);
132 THPObjectPtr apply_fn(PyObject_GetAttrString(obj,
"apply"));
134 THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get()));
138 auto& is_variable_input = py_fn->is_variable_input;
139 int num_outputs = PyTuple_GET_SIZE(r.get());
140 int num_forward_inputs = is_variable_input.size();
143 if (num_outputs > num_forward_inputs) {
144 bool all_none =
true;
145 for (
int i = num_forward_inputs; i < num_outputs; i++) {
146 all_none &= PyTuple_GET_ITEM(r.get(), i) == Py_None;
149 num_outputs = num_forward_inputs;
150 r = PyTuple_GetSlice(r.get(), 0, num_forward_inputs);
156 if (num_outputs != num_forward_inputs) {
157 std::string msg(
"function ");
158 msg += name() +
" returned an incorrect number of gradients (expected ";
159 msg += std::to_string(num_forward_inputs) +
", got " ;
160 msg += std::to_string(num_outputs) +
")";
161 throw std::runtime_error(msg);
165 variable_list results;
166 results.reserve(num_outputs);
167 auto& input_info = py_fn->input_info;
168 for (
int i = 0; i != num_outputs; ++i) {
169 PyObject* output = PyTuple_GET_ITEM(r.get(), i);
170 bool was_variable = is_variable_input[i];
172 if (output != Py_None) {
173 std::string msg(
"function ");
174 msg += name() +
" returned a gradient different than None at position ";
175 msg += std::to_string(i + 1) +
", but the corresponding forward input was not a Variable";
176 throw std::runtime_error(msg);
180 if (output == Py_None) {
181 auto& info = input_info[results.size()];
182 if (info.requires_grad) {
183 results.emplace_back(info.zeros(_device_guard));
185 results.emplace_back();
188 if (!THPVariable_Check(output)) {
189 std::string msg(
"expected Variable or None (got ");
190 msg += THPUtils_typename(output);
192 throw std::runtime_error(msg);
194 results.emplace_back(((
THPVariable*)output)->cdata);
201 auto PyFunction::is_traceable() ->
bool {
203 THPObjectPtr forward_class {PyObject_GetAttrString(obj,
"_forward_cls")};
205 THPObjectPtr traceable_py_bool {PyObject_GetAttrString(forward_class,
"is_traceable")};
207 return traceable_py_bool == Py_True;
210 auto PyFunction::release_variables() ->
void {
213 f->saved_variables.clear();
214 f->has_freed_buffers = 1;
217 auto PyFunction::name() const ->
std::
string {
220 auto name = std::string(Py_TYPE(f)->tp_name);
221 THPObjectPtr _legacy(PyObject_GetAttrString(obj,
"_is_legacy"));
222 if (_legacy == Py_True) {
223 name +=
"LegacyBackward";
228 auto PyFunction::get_shared_ptr() -> std::shared_ptr<Function> {
235 static int THPFunction_traverse(
THPFunction *
self, visitproc visit,
void *arg)
237 for (
const auto& hook : self->cdata.pre_hooks()) {
238 if (
auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
239 Py_VISIT(pyhook->dict);
242 for (
const auto& hook : self->cdata.post_hooks()) {
243 if (
auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
244 Py_VISIT(pyhook->dict);
247 Py_VISIT(self->to_save);
248 Py_VISIT(self->non_differentiable);
249 Py_VISIT(self->dirty_tensors);
255 self->cdata.clear_input_metadata();
257 Py_CLEAR(self->needs_input_grad);
259 Py_CLEAR(self->to_save);
260 Py_CLEAR(self->non_differentiable);
261 Py_CLEAR(self->dirty_tensors);
263 self->output_info.clear();
264 self->input_info.clear();
265 self->saved_variables.clear();
266 self->is_variable_input.clear();
273 auto pre_hooks = std::move(self->cdata.pre_hooks());
274 auto post_hooks = std::move(self->cdata.post_hooks());
281 PyObject_GC_UnTrack(
self);
282 THPFunction_clear(
self);
283 self->cdata.~PyFunction();
284 self->output_info.~vector();
285 self->input_info.~vector();
286 self->saved_variables.~vector();
287 self->is_variable_input.~vector();
288 Py_TYPE(
self)->tp_free((PyObject*)
self);
291 PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
293 PyObject* obj = type->tp_alloc(type, 0);
294 if (!obj)
return nullptr;
299 new (&
self->output_info) std::vector<VariableInfo>();
300 new (&
self->input_info) std::vector<VariableInfo>();
301 new (&
self->saved_variables) std::vector<SavedVariable>();
302 new (&
self->is_variable_input) std::vector<bool>();
310 using t2var_type = std::unordered_map<PyObject *, THPVariable *>;
314 static std::vector<PyObject*> _mark_dirty(
THPFunction *
self)
317 std::vector<PyObject*> dirty_inputs;
318 if (!self->dirty_tensors)
return dirty_inputs;
320 THPFunction_assert(PyTuple_Check(self->dirty_tensors),
"autograd " 321 "internal error: dirty_tensors attribute is expected to be a tuple " 322 "but is %s", THPUtils_typename(self->dirty_tensors));
323 Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors);
324 for (
int i = 0; i < num_dirty; i++) {
325 PyObject *obj = PyTuple_GET_ITEM(self->dirty_tensors, i);
326 THPFunction_assert(THPVariable_Check(obj),
"mark_dirty can " 327 "only accept variables, but argument %d is of type %s", i,
328 THPUtils_typename(obj));
330 dirty_inputs.push_back(obj);
332 variable->cdata.bump_version();
335 Py_CLEAR(self->dirty_tensors);
339 static std::unordered_set<PyObject*> _parse_non_differentiable(
THPFunction *
self);
353 PyObject* inputs_tuple, PyObject *raw_output, PyObject *outputs,
bool is_executable)
355 auto cdata = is_executable ? THPFunction_asFunction(
self) : nullptr;
356 Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
358 self->output_info.clear();
359 self->output_info.reserve(num_outputs);
362 std::unordered_set<PyObject*> inputs;
363 int num_inputs = PyTuple_GET_SIZE(inputs_tuple);
364 for (
int i = 0; i < num_inputs; i++) {
365 inputs.emplace(PyTuple_GET_ITEM(inputs_tuple, i));
368 auto non_differentiable = _parse_non_differentiable(
self);
369 auto dirty_inputs = _mark_dirty(
self);
371 auto as_variable = [&](PyObject* obj,
int i) ->
Variable {
372 if (THPVariable_Check(obj)) {
375 throw TypeError(
"%s.forward: expected Variable (got %s) for return value %d",
376 Py_TYPE(
self)->tp_name, Py_TYPE(obj)->tp_name, i);
380 auto set_history = [&](
Variable& var, uint32_t output_nr,
bool is_input,
bool is_modified,
381 bool is_differentiable) {
382 if (!is_differentiable) {
383 if (!var.requires_grad()) {
388 throw std::runtime_error(
"Returning Variables sharing storage with other Variables " 389 "that require grad is not supported in Python functions. " 390 "Please submit a feature request if you hit this error.");
399 }
else if (is_modified) {
400 if (var.
is_leaf() && var.requires_grad()) {
401 throw std::runtime_error(
"a leaf Variable that requires grad has been used in an in-place operation.");
409 grad_acc->variable.reset();
414 }
else if (is_input) {
417 var = var.view_as(var);
424 for (
int i = 0; i < num_outputs; i++) {
425 PyObject* obj = PyTuple_GET_ITEM(raw_output, i);
427 bool is_input = inputs.count(obj) > 0;
428 bool is_modified = std::find(dirty_inputs.begin(), dirty_inputs.end(), obj) != dirty_inputs.end();
429 bool is_differentiable = is_executable && non_differentiable.count(obj) == 0;
433 auto var = as_variable(obj, i);
435 auto output_nr = cdata->add_input_metadata(var);
436 AT_ASSERT(i == (
int)output_nr);
438 set_history(var, i, is_input, is_modified, is_differentiable);
441 self->output_info.emplace_back(var);
444 PyTuple_SET_ITEM(outputs, i, THPVariable_Wrap(var));
451 if (!self->to_save)
return;
453 THPFunction_assert(PyTuple_Check(self->to_save),
"autograd internal " 454 "error: to_save attribute is expected to be a tuple but is %s",
455 THPUtils_typename(self->to_save));
456 Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
457 self->saved_variables.clear();
458 self->saved_variables.reserve(num_saved);
459 auto cdata_ptr = &
self->cdata;
460 for (
int i = 0; i < num_saved; i++) {
461 PyObject *obj = PyTuple_GET_ITEM(self->to_save, i);
462 if (obj == Py_None) {
463 self->saved_variables.emplace_back();
465 }
else if (THPVariable_Check(obj)) {
467 bool is_output = variable->cdata.grad_fn().get() == cdata_ptr;
468 self->saved_variables.emplace_back(variable->cdata, is_output);
471 "save_for_backward can only save variables, but argument %d is of " 472 "type %s", i, Py_TYPE(obj)->tp_name);
476 Py_CLEAR(self->to_save);
480 static std::unordered_set<PyObject*>
483 std::unordered_set<PyObject*>
set;
484 if (!self->non_differentiable)
return set;
486 THPFunction_assert(PyTuple_Check(self->non_differentiable),
"autograd " 487 "internal error: non_differentiable attribute is expected to be a " 488 "tuple but is %s", THPUtils_typename(self->non_differentiable));
489 Py_ssize_t num_nondiff = PyTuple_GET_SIZE(self->non_differentiable);
490 set.reserve(num_nondiff);
491 for (
int i = 0; i < num_nondiff; i++) {
492 PyObject *t = PyTuple_GET_ITEM(self->non_differentiable, i);
493 THPFunction_assert(THPVariable_Check(t),
"mark_non_differentiable " 494 "only accepts variable arguments, but got %s", THPUtils_typename(t));
497 Py_CLEAR(self->non_differentiable);
503 variable_list input_vars;
507 bool is_executable =
false;
508 edge_list next_edges;
510 std::vector<bool> is_variable_input;
513 template<
bool enforce_variables>
514 std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) {
518 auto num_args = PyTuple_GET_SIZE(args);
519 unpacked.input_tuple = PyTuple_New(num_args);
520 flags.needs_input_grad = PyTuple_New(num_args);
521 for (
int i = 0; i < num_args; i++) {
522 PyObject *arg = PyTuple_GET_ITEM(args, i);
524 bool is_variable = THPVariable_Check(arg);
525 flags.is_variable_input.push_back(is_variable);
528 if (enforce_variables) {
529 THPUtils_setError(
"expected a Variable argument, but got %s",
530 THPUtils_typename(arg));
534 PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False);
537 unpacked.input_vars.push_back(variable->cdata);
538 PyObject* needs_grad = variable->cdata.requires_grad() ? Py_True : Py_False;
539 Py_INCREF(needs_grad);
540 PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad);
543 PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg);
546 flags.is_executable = GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars);
547 flags.next_edges = collect_next_edges(unpacked.input_vars);
548 return std::make_pair(std::move(unpacked), std::move(flags));
551 static void _assert_not_tracing(
const char* name,
const variable_list& input_vars) {
552 if (tracer::isTracing()) {
553 std::ostringstream oss;
554 oss <<
"Attempted to trace " << name;
555 oss <<
", but tracing of legacy functions is not supported";
556 throw std::runtime_error(oss.str());
560 static Node* _trace_pre_record(
562 PyObject *input_objects,
563 const variable_list& input_vars) {
564 if (!jit::tracer::isTracing()) {
569 auto num_args = PyTuple_GET_SIZE(input_objects);
570 pyobj_list scalar_args;
571 std::string arg_types;
572 arg_types.reserve(num_args);
573 scalar_args.reserve(num_args);
574 for (
int i = 0; i < num_args; i++) {
575 PyObject *arg_object = PyTuple_GET_ITEM(input_objects, i);
576 if (THPVariable_Check(arg_object)) {
577 arg_types.push_back(
'd');
579 arg_types.push_back(
'c');
580 Py_INCREF(arg_object);
581 scalar_args.emplace_back(arg_object);
587 return jit::tracer::preRecordPythonTrace(
588 std::move(pyobj), arg_types, input_vars, std::move(scalar_args));
591 static void _trace_post_record(
594 const variable_list& input_vars,
595 PyObject *output_objects,
597 bool unpack_output) {
598 if (!jit::tracer::isTracing()) {
602 node->i_(attr::inplace, is_inplace);
605 int num_outputs = PyTuple_GET_SIZE(output_objects);
606 variable_list output_vars(num_outputs);
607 auto graph = node->owningGraph();
609 if (!unpack_output) {
610 std::vector<TypePtr> tuple_values(num_outputs, TensorType::get());
611 TypePtr tuple_type = TupleType::create(std::move(tuple_values));
612 node->output()->setType(tuple_type);
613 auto unpacked = graph->createTupleUnpack(node->output())->insertAfter(node);
616 for (
int i = 0; i < num_outputs; ++i) {
617 auto var = (
THPVariable*)PyTuple_GET_ITEM(output_objects, i);
618 Value* value = node->outputs()[i];
619 if (var->cdata.defined()) {
620 value->inferTypeFrom(var->cdata);
621 jit::tracer::setValueTrace(autograd::as_variable_ref(var->cdata), value);
627 PyObject *inputs,
THPObjectPtr&& raw_output,
bool is_executable,
629 bool unpack_output = ensure_tuple(raw_output);
631 auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
636 grad_fn->cdata.clear_input_metadata();
640 grad_fn->input_info.clear();
641 grad_fn->input_info.reserve(unpacked.input_vars.size());
642 for (
auto& var : unpacked.input_vars) {
643 grad_fn->input_info.emplace_back(var);
647 bool is_inplace =
static_cast<bool>(grad_fn->dirty_tensors);
648 _wrap_outputs(grad_fn, inputs, raw_output, outputs, is_executable);
649 _trace_post_record(node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
651 _save_variables(grad_fn);
654 Py_XDECREF(grad_fn->to_save);
655 grad_fn->to_save =
nullptr;
656 Py_XDECREF(grad_fn->non_differentiable);
657 grad_fn->non_differentiable =
nullptr;
662 PyObject *output = PyTuple_GET_ITEM(outputs.get(), 0);
667 return outputs.release();
671 PyObject *THPFunction_do_forward(
THPFunction *
self, PyObject *_inputs)
675 Function::peek_at_next_sequence_nr());
677 auto info_pair = unpack_input<true>(_inputs);
678 auto& unpacked_input = info_pair.first;
679 auto& input_info = info_pair.second;
680 bool is_executable = input_info.is_executable;
681 self->cdata.set_next_edges(std::move(input_info.next_edges));
682 self->needs_input_grad = input_info.needs_input_grad.release();
685 _assert_not_tracing(Py_TYPE(
self)->tp_name, unpacked_input.input_vars);
691 THPObjectPtr forward_fn(PyObject_GetAttrString((PyObject*)
self,
"forward"));
692 if (!forward_fn)
return nullptr;
693 raw_output = PyObject_CallObject(forward_fn, unpacked_input.input_tuple);
694 if (!raw_output)
return nullptr;
697 return process_outputs(
nullptr,
self, unpacked_input, _inputs, std::move(raw_output),
698 is_executable,
nullptr);
702 PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
706 Function::peek_at_next_sequence_nr());
708 THPObjectPtr backward_cls(PyObject_GetAttrString(cls,
"_backward_cls"));
709 if (!backward_cls)
return nullptr;
710 THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls,
nullptr));
711 if (!ctx_obj)
return nullptr;
715 auto info_pair = unpack_input<false>(inputs);
720 auto* node = _trace_pre_record(cls, inputs, unpacked_input.input_vars);
723 bool is_executable = input_info.is_executable;
724 ctx->cdata.set_next_edges(std::move(input_info.next_edges));
725 ctx->needs_input_grad = input_info.needs_input_grad.release();
726 ctx->is_variable_input = std::move(input_info.is_variable_input);
729 auto num_args = PyTuple_GET_SIZE(inputs);
730 THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
731 PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, ctx_obj.release());
732 for (
int i = 0; i < num_args; ++i) {
733 PyObject *arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i);
735 PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg);
742 THPObjectPtr forward_fn(PyObject_GetAttrString(cls,
"forward"));
743 if (!forward_fn)
return nullptr;
744 tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple);
745 if (!tensor_outputs)
return nullptr;
748 return process_outputs(cls, ctx, unpacked_input, inputs, std::move(tensor_outputs),
749 is_executable, node);
761 int num_grads = PyTuple_GET_SIZE(raw_grads.get());
763 bool has_none =
false;
764 for (
int i = 0; i < num_grads; i++) {
765 has_none |= PyTuple_GET_ITEM(raw_grads.get(), i) == Py_None;
771 grads = PyTuple_New(num_grads);
775 auto& grads_info = is_grad_output ?
self->output_info :
self->input_info;
776 AT_ASSERT(grads_info.size() == (size_t)num_grads);
777 for (
int i = 0; i < num_grads; i++) {
778 PyObject *grad = PyTuple_GET_ITEM(raw_grads.get(), i);
779 if (grad == Py_None) {
780 grad = THPVariable_Wrap(grads_info[i].zeros(device_guard));
785 PyTuple_SET_ITEM(grads.get(), i, grad);
787 raw_grads = grads.release();
792 int num_grads = PyTuple_GET_SIZE(grad_input.get());
793 const int num_outputs =
self->cdata.num_outputs();
794 if (num_grads > num_outputs) {
796 bool all_none =
true;
797 for (
int i = num_outputs; i < num_grads; i++) {
798 all_none = (PyTuple_GET_ITEM(grad_input.get(), i) == Py_None);
799 if (!all_none)
break;
803 num_grads = num_outputs;
804 grad_input = PyTuple_GetSlice(grad_input.get(), 0, num_grads);
810 PyObject * THPFunction_do_backward(
THPFunction *
self, PyObject *args)
813 Py_ssize_t num_args = args ? PyTuple_GET_SIZE(args) : 0;
814 THPUtils_assert(num_args == 2,
"_do_backward expects exactly two arguments");
815 PyObject *raw_grad_output = PyTuple_GET_ITEM(args, 0);
816 PyObject *retain_variables = PyTuple_GET_ITEM(args, 1);
817 if (!PyTuple_Check(raw_grad_output) || !PyBool_Check(retain_variables)) {
818 THPUtils_invalidArguments(args,
nullptr,
"_do_backward", 1,
"(tuple, bool)");
821 THPUtils_assert(PyTuple_GET_SIZE(raw_grad_output) == self->cdata.num_inputs(),
822 "%s got an invalid number of gradients (expected %d got %d)",
823 THPUtils_typename(
self),
self->cdata.num_inputs(),
824 PyTuple_GET_SIZE(raw_grad_output));
828 Py_INCREF(raw_grad_output);
830 _prepare_grads(
self, grad_output,
true);
833 THPObjectPtr backward_fn(PyObject_GetAttrString((PyObject*)
self,
"backward"));
834 THPUtils_assert(backward_fn.get(),
"function %s doesn't implement a required " 835 "'backward' method", THPUtils_typename((PyObject*)
self));
836 THPObjectPtr grad_input(PyObject_CallObject(backward_fn, grad_output.get()));
837 if (!grad_input)
return nullptr;
838 ensure_tuple(grad_input);
842 _trim_grad_input(
self, grad_input);
843 int num_grads = PyTuple_GET_SIZE(grad_input.get());
844 int num_outputs =
self->cdata.num_outputs();
845 THPUtils_assert(num_grads == num_outputs,
"%s returned an invalid number of " 846 "gradient tensors (expected %d, but got %d)", THPUtils_typename(
self),
847 num_outputs, num_grads);
850 _prepare_grads(
self, grad_input,
false);
851 return grad_input.release();
855 }
catch (std::exception& e) {
856 THPUtils_setError(e.what());
865 PyObject* THPFunction__register_hook_dict(
THPFunction *
self, PyObject *_var)
867 THPUtils_assert(THPVariable_Check(_var),
"_register_hook_dict expected a variable");
870 var->backward_hooks, var->cdata.
output_nr()));
871 self->cdata.add_pre_hook(std::move(hook));
875 PyObject* THPFunction_register_hook(
THPFunction *
self, PyObject *hook)
877 return torch::autograd::registerFunctionHook(self->cdata, hook);
880 static PyObject *unpack_saved_variables(
882 const std::function<PyObject*(
const Variable&)>& unpack_fn)
884 THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE);
885 auto& saved_variables =
self->saved_variables;
886 if (saved_variables.empty())
887 return PyTuple_New(0);
889 int num_saved = saved_variables.size();
893 auto saved_for = THPFunction_asFunction(
self);
894 for (
int i = 0; i < num_saved; i++) {
895 auto unpacked_var = saved_variables[i].unpack(saved_for);
897 if (!unpacked_var.defined()) {
901 value = unpack_fn(unpacked_var);
903 PyTuple_SET_ITEM(saved.get(), i, value.release());
905 return saved.release();
908 PyObject *THPFunction_saved_tensors(
THPFunction *
self,
void *_unused)
911 return unpack_saved_variables(
self, [](
const Variable& var) {
912 return THPVariable_Wrap(var);
917 PyObject *THPFunction_saved_variables(
THPFunction *
self,
void *_unused)
920 auto r = PyErr_WarnEx(PyExc_DeprecationWarning,
921 "'saved_variables' is deprecated; use 'saved_tensors'", 0);
923 return unpack_saved_variables(
self, [](
const Variable& var) {
924 return THPVariable_Wrap(var);
929 PyObject *THPFunction_next_functions(
THPFunction *
self,
void *_unused)
931 const auto num_outputs =
self->cdata.num_outputs();
935 for (uint32_t i = 0; i < num_outputs; i++) {
937 if (!fn_tuple)
return nullptr;
938 const auto& edge =
self->cdata.next_edge(i);
939 PyObject* fn = functionToPyObject(edge.function);
940 if (!fn)
return nullptr;
941 PyTuple_SET_ITEM(fn_tuple.get(), 0, fn);
942 PyTuple_SET_ITEM(fn_tuple.get(), 1, THPUtils_packInt64(edge.input_nr));
943 PyTuple_SET_ITEM(result.get(), i, fn_tuple.release());
945 return result.release();
948 PyObject *THPFunction_metadata(
THPFunction *
self,
void *_unused)
956 typedef PyObject *(*getter)(PyObject *,
void *);
957 typedef int (*setter)(PyObject *, PyObject *,
void *);
961 template<PyObject* THPFunction::*ptr>
962 PyObject* getObject(PyObject* obj,
void* _unused) {
964 PyObject* value =
self->*ptr;
972 template<PyObject* THPFunction::*ptr>
973 int setObject(PyObject* obj, PyObject* value,
void* _unused) {
975 if (value == Py_None) {
978 Py_XDECREF((self->*ptr));
984 template<
typename M, M THPFunction::*ptr, PyObject* (*Convert)(
long)>
985 PyObject* getMember(PyObject* obj,
void* _unused) {
987 return Convert(self->*ptr);
990 template<
typename M, M Function::*ptr, PyObject* (*Convert)(
long)>
991 PyObject* getImplMember(PyObject* obj,
void* _unused) {
993 return Convert(self->cdata.*ptr);
996 PyObject* getRequiresGrad(PyObject* obj,
void* _unused) {
1002 static struct PyGetSetDef THPFunction_properties[] = {
1003 {
"saved_tensors", (getter)THPFunction_saved_tensors,
nullptr,
nullptr,
nullptr},
1004 {
"saved_variables", (getter)THPFunction_saved_variables,
nullptr,
nullptr,
nullptr},
1005 {
"next_functions", (getter)THPFunction_next_functions,
nullptr,
nullptr,
nullptr},
1006 {
"to_save", &getObject<&THPFunction::to_save>, &setObject<&THPFunction::to_save>,
nullptr,
nullptr},
1007 {
"non_differentiable", &getObject<&THPFunction::non_differentiable>, &setObject<&THPFunction::non_differentiable>,
nullptr,
nullptr},
1008 {
"dirty_tensors", &getObject<&THPFunction::dirty_tensors>, &setObject<&THPFunction::dirty_tensors>,
nullptr,
nullptr},
1009 {
"needs_input_grad", &getObject<&THPFunction::needs_input_grad>,
nullptr,
nullptr,
nullptr},
1010 {
"requires_grad", getRequiresGrad,
nullptr,
nullptr,
nullptr},
1011 {
"metadata", (getter)THPFunction_metadata,
nullptr,
nullptr,
nullptr},
1015 static struct PyMethodDef THPFunction_methods[] = {
1016 {(
char*)
"apply", (PyCFunction)THPFunction_apply, METH_CLASS | METH_VARARGS,
nullptr},
1017 {(
char*)
"_do_forward", (PyCFunction)THPFunction_do_forward, METH_VARARGS,
nullptr},
1018 {(
char*)
"_do_backward", (PyCFunction)THPFunction_do_backward, METH_VARARGS,
nullptr},
1019 {(
char*)
"_register_hook_dict", (PyCFunction)THPFunction__register_hook_dict, METH_O,
nullptr},
1020 {(
char*)
"register_hook", (PyCFunction)THPFunction_register_hook, METH_O,
nullptr},
1024 PyTypeObject THPFunctionType = {
1025 PyVarObject_HEAD_INIT(
nullptr, 0)
1026 "torch._C._FunctionBase",
1029 (destructor)THPFunction_dealloc,
1044 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,
1046 (traverseproc)THPFunction_traverse,
1047 (inquiry)THPFunction_clear,
1052 THPFunction_methods,
1054 THPFunction_properties,
1065 bool THPFunction_initModule(PyObject *module)
1067 if (PyType_Ready(&THPFunctionType) < 0)
1069 Py_INCREF(&THPFunctionType);
1070 PyModule_AddObject(module,
"_FunctionBase", (PyObject *)&THPFunctionType);
1089 std::shared_ptr<PyFunction> THPFunction_asFunction(
THPFunction*
self)
1092 return std::shared_ptr<PyFunction>();
1095 Py_INCREF((PyObject*)
self);
1096 return std::shared_ptr<PyFunction>(&
self->cdata,
Decref());
Variable & grad() override
Accesses the gradient Variable of this Variable.
void set_gradient_edge(Edge edge) noexcept
Set the gradient edge – i.e.
bool is_view() const noexcept
Returns true if this Variable is a view of another Variable.
A OptionalDeviceGuard is an RAII class that sets a device to some value on initialization, and resets the device to its original value on destruction.
Variable detach() const
Returns a copy of this Variable that is detached from its autograd graph and has a blank version...
void rebase_history(Edge gradient_edge)
Update the grad_fn of an existing Variable.
bool is_leaf() const noexcept
True if this Variable is a leaf and thus does not have a grad_fn.
std::shared_ptr< Function > try_get_grad_accumulator() const
Attempts to get a pointer to the gradient accumulator of the Variable, if it still exists...
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
void detach_()
Like detach(), but removes this Variable in-place.
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
uint32_t output_nr() const noexcept
Returns the input index of the gradient Function to which this Variable is connected.
void reset_device(at::Device device)
Sets the device to the given one.