Caffe2 - C++ API
A deep learning, cross platform ML framework
python_engine.cpp
1 #include <torch/csrc/autograd/python_engine.h>
2 
3 #include <torch/csrc/DynamicTypes.h>
4 #include <torch/csrc/PtrWrapper.h>
5 #include <torch/csrc/THP.h>
6 #include <torch/csrc/autograd/engine.h>
7 #include <torch/csrc/autograd/function.h>
8 #include <torch/csrc/autograd/edge.h>
9 #include <torch/csrc/autograd/python_function.h>
10 #include <torch/csrc/utils/auto_gil.h>
11 
12 #ifndef _WIN32
13 #include <pthread.h>
14 #endif
15 
16 #include <unordered_set>
17 #include <memory> // for unique_ptr
18 
19 using namespace torch::autograd;
20 
21 struct THPEngine {
22  PyObject_HEAD
23 };
24 
26 
27 static Engine& get_python_engine() {
28  return engine;
29 }
30 
31 namespace torch { namespace autograd { namespace python {
32 
33 void PythonEngine::thread_init(int device) {
34  // Create a PyThreadState, but release the GIL. This lets AutoGIL calls
35  // inside thread_main acquire the GIL without having to create a new
36  // PyThreadState each time.
37  AutoGIL gil;
38  AutoNoGIL no_gil;
39  Engine::thread_init(device);
40 }
41 
42 void PythonEngine::thread_on_exception(FunctionTask& task, std::exception& e) {
43  auto python_err = dynamic_cast<python_error*>(&e);
44  if (python_err) {
45  python_err->persist();
46  }
47  Engine::thread_on_exception(task, e);
48 }
49 
50 std::unique_ptr<AnomalyMetadata> PythonEngine::make_anomaly_metadata() {
51  return std::unique_ptr<AnomalyMetadata>(new PyAnomalyMetadata());
52 }
53 
54 variable_list PythonEngine::execute(
55  const edge_list& roots,
56  const variable_list& inputs,
57  bool keep_graph,
58  bool create_graph,
59  const edge_list& outputs) {
60  try {
61  return Engine::execute(roots, inputs, keep_graph, create_graph, outputs);
62  } catch (python_error& e) {
63  e.restore();
64  throw;
65  }
66 }
67 
68 }}} // namespace torch::autograd::python
69 
70 PyObject *THPEngineClass = nullptr;
71 
72 static bool _reinitialize_engine = false;
73 
74 static void _maybe_reinitialize_engine_after_fork() {
75  // This is "probably" thread-safe because the flag is set in a fork handler
76  // before any threads are created, and this function is only called with the
77  // GIL held. However, using fork + threads is playing with fire so this is
78  // more of a "best effort" thing. For example, if the fork occurs while the
79  // backwards threads hold a lock, we'll probably deadlock in the engine
80  // destructor.
81  if (_reinitialize_engine) {
82  engine.~PythonEngine();
84  _reinitialize_engine = false;
85  }
86 }
87 
88 // Implementation of torch._C._EngineBase.run_backward
89 PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
90 {
91  HANDLE_TH_ERRORS
92  _maybe_reinitialize_engine_after_fork();
93  PyObject *tensors = nullptr;
94  PyObject *grad_tensors = nullptr;
95  unsigned char keep_graph = 0;
96  unsigned char create_graph = 0;
97  PyObject *inputs = nullptr;
98  unsigned char allow_unreachable = 0;
99  const char *accepted_kwargs[] = {
100  "tensors", "grad_tensors", "keep_graph", "create_graph", "inputs",
101  "allow_unreachable", nullptr
102  };
103  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Ob", (char**)accepted_kwargs,
104  &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable))
105  return nullptr;
106 
107  THPUtils_assert(PyTuple_Check(tensors), "tensors argument is expected to "
108  "be a tuple, but got %s", THPUtils_typename(tensors));
109  THPUtils_assert(PyTuple_Check(grad_tensors), "grad_tensors argument is "
110  "expected to be a tuple, but got %s", THPUtils_typename(grad_tensors));
111 
112  Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
113  Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
114  THPUtils_assert(num_tensors == num_gradients, "got %ld tensors and %ld "
115  "gradients", num_tensors, num_gradients);
116 
117  edge_list roots;
118  roots.reserve(num_tensors);
119  variable_list grads;
120  grads.reserve(num_tensors);
121  for (int i = 0; i < num_tensors; i++) {
122  PyObject *_tensor = PyTuple_GET_ITEM(tensors, i);
123  THPUtils_assert(THPVariable_Check(_tensor), "element %d of tensors "
124  "tuple is not a Tensor", i);
125  auto& variable = ((THPVariable*)_tensor)->cdata;
126  auto gradient_edge = variable.gradient_edge();
127  THPUtils_assert(gradient_edge.function,
128  "element %d of tensors does not require grad and does not have a grad_fn", i);
129  roots.push_back(std::move(gradient_edge));
130 
131  PyObject *grad = PyTuple_GET_ITEM(grad_tensors, i);
132  if (THPVariable_Check(grad)) {
133  grads.push_back(((THPVariable*)grad)->cdata);
134  } else {
135  THPUtils_assert(grad == Py_None,
136  "element %d of gradients tuple is not a Tensor or None", i);
137  THPUtils_assert(!variable.requires_grad(),
138  "element %d of gradients tuple is None, but the corresponding Tensor requires grad");
139  }
140  }
141 
142  std::vector<Edge> output_edges;
143  if (inputs != nullptr) {
144  int num_inputs = PyTuple_GET_SIZE(inputs);
145  output_edges.reserve(num_inputs);
146  for (int i = 0; i < num_inputs; ++i) {
147  PyObject *input = PyTuple_GET_ITEM(inputs, i);
148  THPUtils_assert(THPVariable_Check(input),
149  "all inputs have to be Tensors, but got %s", THPUtils_typename(input));
150  THPVariable *input_var = (THPVariable*)input;
151  const auto output_nr = input_var->cdata.output_nr();
152  auto grad_fn = input_var->cdata.grad_fn();
153  if (!grad_fn) {
154  grad_fn = input_var->cdata.try_get_grad_accumulator();
155  }
156  THPUtils_assert(input_var->cdata.requires_grad(),
157  "One of the differentiated Tensors does not require grad");
158  if (!grad_fn) {
159  output_edges.emplace_back();
160  } else {
161  output_edges.emplace_back(grad_fn, output_nr);
162  }
163  }
164  }
165 
166  variable_list outputs;
167  {
168  AutoNoGIL no_gil;
169  outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges);
170  }
171 
172  if (inputs != nullptr) {
173  int num_inputs = PyTuple_GET_SIZE(inputs);
174  THPObjectPtr py_outputs {PyTuple_New(num_inputs)};
175  if (!py_outputs) return nullptr;
176  for (int i = 0; i < num_inputs; i++) {
177  THPUtils_assert(allow_unreachable || outputs[i].defined(), "One of the "
178  "differentiated Tensors appears to not have been used "
179  "in the graph. Set allow_unused=True if this is the "
180  "desired behavior.");
181  PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
182  }
183  return py_outputs.release();
184  } else {
185  Py_RETURN_NONE;
186  }
187  END_HANDLE_TH_ERRORS
188 }
189 
190 PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) {
191  HANDLE_TH_ERRORS
192  _maybe_reinitialize_engine_after_fork();
193  std::shared_ptr<PyObject> callback(_callback, [](PyObject *obj) { AutoGIL gil; Py_DECREF(obj); });
194  Py_INCREF(_callback);
195  engine.queue_callback([callback]() {
196  AutoGIL gil;
197  THPObjectPtr result {PyObject_CallFunctionObjArgs(callback.get(), nullptr)};
198  if (!result) throw python_error();
199  });
200  Py_RETURN_NONE;
201  END_HANDLE_TH_ERRORS
202 }
203 
204 PyObject* THPEngine_is_checkpoint_valid(PyObject *self) {
205  HANDLE_TH_ERRORS
206  if(engine.is_checkpoint_valid()) {
207  Py_RETURN_TRUE;
208  } else {
209  Py_RETURN_FALSE;
210  }
211  END_HANDLE_TH_ERRORS
212 }
213 
214 PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
215 {
216  return type->tp_alloc(type, 0);
217 }
218 
219 static struct PyMethodDef THPEngine_methods[] = {
220  {(char*)"run_backward", (PyCFunction)THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, nullptr},
221  {(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, nullptr},
222  {(char*)"is_checkpoint_valid", (PyCFunction)THPEngine_is_checkpoint_valid, METH_NOARGS, nullptr},
223  {nullptr}
224 };
225 
226 
227 PyTypeObject THPEngineType = {
228  PyVarObject_HEAD_INIT(nullptr, 0)
229  "torch._C._EngineBase", /* tp_name */
230  sizeof(THPEngine), /* tp_basicsize */
231  0, /* tp_itemsize */
232  nullptr, /* tp_dealloc */
233  nullptr, /* tp_print */
234  nullptr, /* tp_getattr */
235  nullptr, /* tp_setattr */
236  nullptr, /* tp_reserved */
237  nullptr, /* tp_repr */
238  nullptr, /* tp_as_number */
239  nullptr, /* tp_as_sequence */
240  nullptr, /* tp_as_mapping */
241  nullptr, /* tp_hash */
242  nullptr, /* tp_call */
243  nullptr, /* tp_str */
244  nullptr, /* tp_getattro */
245  nullptr, /* tp_setattro */
246  nullptr, /* tp_as_buffer */
247  Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
248  nullptr, /* tp_doc */
249  nullptr, /* tp_traverse */
250  nullptr, /* tp_clear */
251  nullptr, /* tp_richcompare */
252  0, /* tp_weaklistoffset */
253  nullptr, /* tp_iter */
254  nullptr, /* tp_iternext */
255  THPEngine_methods, /* tp_methods */
256  nullptr, /* tp_members */
257  nullptr, /* tp_getset */
258  nullptr, /* tp_base */
259  nullptr, /* tp_dict */
260  nullptr, /* tp_descr_get */
261  nullptr, /* tp_descr_set */
262  0, /* tp_dictoffset */
263  nullptr, /* tp_init */
264  nullptr, /* tp_alloc */
265  THPEngine_new /* tp_new */
266 };
267 
268 static void child_atfork() {
269  _reinitialize_engine = true;
270 }
271 
272 bool THPEngine_initModule(PyObject *module)
273 {
274 #ifndef _WIN32
275  if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
276  throw std::runtime_error("unable to set pthread_atfork handler");
277  }
278 #endif
279  if (PyType_Ready(&THPEngineType) < 0)
280  return false;
281  Py_INCREF(&THPEngineType);
282  PyModule_AddObject(module, "_ImperativeEngine", (PyObject *)&THPEngineType);
283  set_default_engine_stub(get_python_engine);
284  return true;
285 }
void restore()
Sets the current Python error from this exception.
Definition: Exceptions.h:87
void persist()
Saves the exception so that it can be re-thrown on a different thread.
Definition: Exceptions.h:76
std::shared_ptr< Function > try_get_grad_accumulator() const
Attempts to get a pointer to the gradient accumulator of the Variable, if it still exists...
Definition: variable.h:669
Definition: jit_type.h:17
const std::shared_ptr< Function > & grad_fn() const
Gets the gradient function of the Variable.
Definition: variable.cpp:201
uint32_t output_nr() const noexcept
Returns the input index of the gradient Function to which this Variable is connected.
Definition: variable.h:687