1 #include <torch/csrc/autograd/python_engine.h> 3 #include <torch/csrc/DynamicTypes.h> 4 #include <torch/csrc/PtrWrapper.h> 5 #include <torch/csrc/THP.h> 6 #include <torch/csrc/autograd/engine.h> 7 #include <torch/csrc/autograd/function.h> 8 #include <torch/csrc/autograd/edge.h> 9 #include <torch/csrc/autograd/python_function.h> 10 #include <torch/csrc/utils/auto_gil.h> 16 #include <unordered_set> 27 static Engine& get_python_engine() {
31 namespace torch {
namespace autograd {
namespace python {
33 void PythonEngine::thread_init(
int device) {
39 Engine::thread_init(device);
42 void PythonEngine::thread_on_exception(
FunctionTask& task, std::exception& e) {
47 Engine::thread_on_exception(task, e);
50 std::unique_ptr<AnomalyMetadata> PythonEngine::make_anomaly_metadata() {
54 variable_list PythonEngine::execute(
55 const edge_list& roots,
56 const variable_list& inputs,
59 const edge_list& outputs) {
61 return Engine::execute(roots, inputs, keep_graph, create_graph, outputs);
70 PyObject *THPEngineClass =
nullptr;
72 static bool _reinitialize_engine =
false;
74 static void _maybe_reinitialize_engine_after_fork() {
81 if (_reinitialize_engine) {
82 engine.~PythonEngine();
84 _reinitialize_engine =
false;
89 PyObject *THPEngine_run_backward(
THPEngine *
self, PyObject *args, PyObject *kwargs)
92 _maybe_reinitialize_engine_after_fork();
93 PyObject *tensors =
nullptr;
94 PyObject *grad_tensors =
nullptr;
95 unsigned char keep_graph = 0;
96 unsigned char create_graph = 0;
97 PyObject *inputs =
nullptr;
98 unsigned char allow_unreachable = 0;
99 const char *accepted_kwargs[] = {
100 "tensors",
"grad_tensors",
"keep_graph",
"create_graph",
"inputs",
101 "allow_unreachable",
nullptr 103 if (!PyArg_ParseTupleAndKeywords(args, kwargs,
"OObb|Ob", (
char**)accepted_kwargs,
104 &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable))
107 THPUtils_assert(PyTuple_Check(tensors),
"tensors argument is expected to " 108 "be a tuple, but got %s", THPUtils_typename(tensors));
109 THPUtils_assert(PyTuple_Check(grad_tensors),
"grad_tensors argument is " 110 "expected to be a tuple, but got %s", THPUtils_typename(grad_tensors));
112 Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
113 Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
114 THPUtils_assert(num_tensors == num_gradients,
"got %ld tensors and %ld " 115 "gradients", num_tensors, num_gradients);
118 roots.reserve(num_tensors);
120 grads.reserve(num_tensors);
121 for (
int i = 0; i < num_tensors; i++) {
122 PyObject *_tensor = PyTuple_GET_ITEM(tensors, i);
123 THPUtils_assert(THPVariable_Check(_tensor),
"element %d of tensors " 124 "tuple is not a Tensor", i);
126 auto gradient_edge = variable.gradient_edge();
127 THPUtils_assert(gradient_edge.function,
128 "element %d of tensors does not require grad and does not have a grad_fn", i);
129 roots.push_back(std::move(gradient_edge));
131 PyObject *grad = PyTuple_GET_ITEM(grad_tensors, i);
132 if (THPVariable_Check(grad)) {
135 THPUtils_assert(grad == Py_None,
136 "element %d of gradients tuple is not a Tensor or None", i);
137 THPUtils_assert(!variable.requires_grad(),
138 "element %d of gradients tuple is None, but the corresponding Tensor requires grad");
142 std::vector<Edge> output_edges;
143 if (inputs !=
nullptr) {
144 int num_inputs = PyTuple_GET_SIZE(inputs);
145 output_edges.reserve(num_inputs);
146 for (
int i = 0; i < num_inputs; ++i) {
147 PyObject *input = PyTuple_GET_ITEM(inputs, i);
148 THPUtils_assert(THPVariable_Check(input),
149 "all inputs have to be Tensors, but got %s", THPUtils_typename(input));
151 const auto output_nr = input_var->cdata.
output_nr();
152 auto grad_fn = input_var->cdata.
grad_fn();
156 THPUtils_assert(input_var->cdata.requires_grad(),
157 "One of the differentiated Tensors does not require grad");
159 output_edges.emplace_back();
161 output_edges.emplace_back(grad_fn, output_nr);
166 variable_list outputs;
169 outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges);
172 if (inputs !=
nullptr) {
173 int num_inputs = PyTuple_GET_SIZE(inputs);
175 if (!py_outputs)
return nullptr;
176 for (
int i = 0; i < num_inputs; i++) {
177 THPUtils_assert(allow_unreachable || outputs[i].defined(),
"One of the " 178 "differentiated Tensors appears to not have been used " 179 "in the graph. Set allow_unused=True if this is the " 180 "desired behavior.");
181 PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
183 return py_outputs.release();
190 PyObject* THPEngine_queue_callback(PyObject *
self, PyObject *_callback) {
192 _maybe_reinitialize_engine_after_fork();
193 std::shared_ptr<PyObject> callback(_callback, [](PyObject *obj) {
AutoGIL gil; Py_DECREF(obj); });
194 Py_INCREF(_callback);
195 engine.queue_callback([callback]() {
197 THPObjectPtr result {PyObject_CallFunctionObjArgs(callback.get(),
nullptr)};
204 PyObject* THPEngine_is_checkpoint_valid(PyObject *
self) {
206 if(engine.is_checkpoint_valid()) {
214 PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
216 return type->tp_alloc(type, 0);
219 static struct PyMethodDef THPEngine_methods[] = {
220 {(
char*)
"run_backward", (PyCFunction)THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS,
nullptr},
221 {(
char*)
"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O,
nullptr},
222 {(
char*)
"is_checkpoint_valid", (PyCFunction)THPEngine_is_checkpoint_valid, METH_NOARGS,
nullptr},
227 PyTypeObject THPEngineType = {
228 PyVarObject_HEAD_INIT(
nullptr, 0)
229 "torch._C._EngineBase",
247 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
268 static void child_atfork() {
269 _reinitialize_engine =
true;
272 bool THPEngine_initModule(PyObject *module)
275 if (pthread_atfork(
nullptr,
nullptr, child_atfork) != 0) {
276 throw std::runtime_error(
"unable to set pthread_atfork handler");
279 if (PyType_Ready(&THPEngineType) < 0)
281 Py_INCREF(&THPEngineType);
282 PyModule_AddObject(module,
"_ImperativeEngine", (PyObject *)&THPEngineType);
283 set_default_engine_stub(get_python_engine);
void restore()
Sets the current Python error from this exception.
void persist()
Saves the exception so that it can be re-thrown on a different thread.
std::shared_ptr< Function > try_get_grad_accumulator() const
Attempts to get a pointer to the gradient accumulator of the Variable, if it still exists...
const std::shared_ptr< Function > & grad_fn() const
Gets the gradient function of the Variable.
uint32_t output_nr() const noexcept
Returns the input index of the gradient Function to which this Variable is connected.