Caffe2 - C++ API
A deep learning, cross platform ML framework
python_cpp_function.cpp
1 #include <torch/csrc/autograd/python_cpp_function.h>
2 
3 #include <torch/csrc/python_headers.h>
4 #include <memory>
5 #include <cstdio>
6 #include <typeindex>
7 #include <unordered_map>
8 
9 #include <torch/csrc/autograd/python_function.h>
10 #include <torch/csrc/autograd/python_variable.h>
11 #include <torch/csrc/autograd/python_hook.h>
12 #include <torch/csrc/autograd/python_anomaly_mode.h>
13 #include <torch/csrc/utils/auto_gil.h>
14 #include <torch/csrc/utils/python_strings.h>
15 #include <torch/csrc/DynamicTypes.h>
16 #include <torch/csrc/Exceptions.h>
17 
18 using namespace torch::autograd;
19 
20 namespace torch { namespace autograd {
21 
22 namespace {
23 
24 PyObject* THPCppFunction_call(PyObject* self, PyObject* args, PyObject *kwargs)
25 {
26  if (kwargs && PyDict_Size(kwargs) != 0) {
27  return PyErr_Format(PyExc_TypeError, "keyword arguments are not supported");
28  }
29 
30  int num_inputs = PyTuple_GET_SIZE(args);
31  int num_inputs_required = ((THPCppFunction*)self)->cdata->num_inputs();
32  if (num_inputs != num_inputs_required) {
33  return PyErr_Format(PyExc_TypeError, "expected %d arguments, got %d instead",
34  num_inputs_required, num_inputs);
35  }
36  variable_list vars(num_inputs);
37  for (int i = 0; i != num_inputs; ++i) {
38  PyObject* arg = PyTuple_GET_ITEM(args, i);
39  if (arg == Py_None) {
40  continue;
41  }
42  if (!THPVariable_Check(arg)) {
43  return PyErr_Format(PyExc_TypeError, "argument %d is not a Variable", i);
44  }
45  vars[i] = ((THPVariable*)arg)->cdata;
46  }
47 
48  variable_list output;
49 
50  HANDLE_TH_ERRORS {
51  AutoNoGIL nogil;
52  output = (*((THPCppFunction*)self)->cdata)(std::move(vars));
53  }
54  END_HANDLE_TH_ERRORS
55 
56  int num_outputs = output.size();
57  if (num_outputs == 1) {
58  // assume we want to unpack one element tuples for now
59  return THPVariable_Wrap(output[0]);
60  }
61 
62  THPObjectPtr tuple(PyTuple_New(num_outputs));
63  for (int i = 0; i != num_outputs; ++i) {
64  PyTuple_SET_ITEM(tuple.get(), i, THPVariable_Wrap(output[i]));
65  }
66  return tuple.release();
67 }
68 
69 int THPCppFunction_traverse(PyObject* self, visitproc visit, void *arg)
70 {
71  auto& fn = *((THPCppFunction*)self)->cdata;
72  for (const auto& hook : fn.pre_hooks()) {
73  if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
74  Py_VISIT(pyhook->dict);
75  }
76  }
77  for (const auto& hook : fn.post_hooks()) {
78  if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
79  Py_VISIT(pyhook->dict);
80  }
81  }
82  return 0;
83 }
84 
85 int THPCppFunction_clear(PyObject* self)
86 {
87  auto f = (THPCppFunction*)self;
88  // Remove the weak ref of the c++ object if it exist
89  if (f->cdata) {
90  f->cdata->set_pyobj(nullptr);
91  }
92  f->cdata.reset();
93  return 0;
94 }
95 
96 void THPCppFunction_dealloc(PyObject* self)
97 {
98  THPCppFunction_clear(self);
99  ((THPCppFunction*)self)->cdata.~shared_ptr();
100  Py_TYPE(self)->tp_free(self);
101 }
102 
103 } // namespace
104 
105 PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook)
106 {
107  const auto num_next = self->cdata->num_outputs();
108  THPObjectPtr py_functions(PyTuple_New(num_next));
109  if (!py_functions) return nullptr;
110  for (size_t i = 0; i < num_next; ++i) {
111  auto& c_tuple = self->cdata->next_edge(i);
112  THPObjectPtr tuple(PyTuple_New(2));
113  if (!tuple) return nullptr;
114  PyObject *py_fn = functionToPyObject(c_tuple.function);
115  if (!py_fn) return nullptr;
116  PyTuple_SET_ITEM(tuple.get(), 0, py_fn);
117  PyObject *py_idx = PyLong_FromLong(c_tuple.input_nr);
118  if (!py_idx) return nullptr;
119  PyTuple_SET_ITEM(tuple.get(), 1, py_idx);
120  PyTuple_SET_ITEM(py_functions.get(), i, tuple.release());
121  }
122  return py_functions.release();
123 }
124 
125 PyObject* THPCppFunction_metadata(THPCppFunction *self, void *_unused)
126 {
127  auto metadata = static_cast<PyAnomalyMetadata*>(self->cdata->metadata())->dict();
128 
129  Py_INCREF(metadata);
130  return metadata;
131 }
132 
133 PyObject* THPCppFunction_requires_grad(THPCppFunction* self) {
134  Py_RETURN_TRUE;
135 }
136 
137 PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var)
138 {
139  if (!THPVariable_Check(_var)) {
140  return PyErr_Format(PyExc_TypeError, "_register_hook_dict expected a variable");
141  }
142  auto var = (THPVariable*)_var;
143  auto& fn = *((THPCppFunction*)self)->cdata;
144  std::unique_ptr<FunctionPreHook> hook(
145  new PyFunctionPreHook(var->backward_hooks, var->cdata.output_nr()));
146  fn.add_pre_hook(std::move(hook));
147  Py_RETURN_NONE;
148 }
149 
150 PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook)
151 {
152  auto& fn = *((THPCppFunction*)self)->cdata;
153  return registerFunctionHook(fn, hook);
154 }
155 
156 PyObject* THPCppFunction_name(PyObject* self) {
157  auto& fn = *((THPCppFunction*)self)->cdata;
158  return THPUtils_packString(fn.name());
159 }
160 
161 static struct PyMethodDef default_methods[] = {
162  THP_FUNCTION_DEFAULT_METHODS,
163  {nullptr}
164 };
165 
166 static struct PyGetSetDef default_properties[] = {
167  THP_FUNCTION_DEFAULT_PROPERTIES,
168  {nullptr}
169 };
170 
171 PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type, const char* name,
172  PyGetSetDef* function_properties, PyMethodDef* function_methods)
173 {
174  type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC;
175  type.tp_name = name;
176  type.tp_basicsize = sizeof(THPCppFunction);
177  type.tp_call = THPCppFunction_call;
178  type.tp_methods = function_methods ? function_methods : default_methods;
179  type.tp_getset = function_properties ? function_properties : default_properties;
180  type.tp_dealloc = THPCppFunction_dealloc;
181  type.tp_traverse = THPCppFunction_traverse;
182  type.tp_clear = THPCppFunction_clear;
183  if (PyType_Ready(&type) < 0) {
184  auto msg = std::string("Unable to instantiate PyTypeObject for ") + name;
185  throw std::runtime_error(msg);
186  }
187  return &type;
188 }
189 
190 static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types;
191 
193  DefaultFunctionType() : type() {
194  _initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr);
195  Py_INCREF(&type);
196  }
197 
198  PyTypeObject type;
199 };
200 
201 PyObject* functionToPyObject(const std::shared_ptr<Function>& cdata)
202 {
203  static DefaultFunctionType default_type;
204 
205  if (!cdata) {
206  Py_RETURN_NONE;
207  }
208 
209  if (auto pfw = dynamic_cast<PyFunction*>(cdata.get())) {
210  PyObject* obj = pfw->obj;
211  Py_INCREF(obj);
212  return obj;
213  }
214 
215  if (cdata->pyobj()) {
216  Py_INCREF(cdata->pyobj());
217  } else {
218  auto& fn = *cdata;
219  auto it = cpp_function_types.find(std::type_index(typeid(fn)));
220  PyTypeObject* type;
221  if (it == cpp_function_types.end()) {
222  type = &default_type.type;
223  } else {
224  type = (PyTypeObject*)it->second.get();
225  }
226 
227  THPObjectPtr obj(type->tp_alloc(type, 0));
228  if (!obj) return nullptr;
229  THPCppFunction* f = (THPCppFunction*)obj.get();
230  new (&f->cdata) std::shared_ptr<Function>(cdata);
231 
232  // No INCREF here as we only have a weak reference
233  cdata->set_pyobj(obj.release());
234  }
235 
236  return cdata->pyobj();
237 }
238 
239 void registerCppFunction(const std::type_info& type, PyTypeObject* pytype)
240 {
241  Py_INCREF((PyObject*)pytype);
242  cpp_function_types[std::type_index(type)] = THPObjectPtr((PyObject*)pytype);
243 }
244 
245 PyObject* registerFunctionHook(Function& fn, PyObject* hook)
246 {
247  PyObject* dict = Py_None;
248  for (const auto& hook : fn.post_hooks()) {
249  if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
250  dict = pyhook->dict;
251  break;
252  }
253  }
254 
255  THPObjectPtr register_fn(PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
256  if (!register_fn) return nullptr;
257  THPObjectPtr res(PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
258  if (!res) return nullptr;
259 
260  if (dict == Py_None) {
261  dict = PyTuple_GET_ITEM(res.get(), 0);
262  std::unique_ptr<FunctionPostHook> hook(new PyFunctionPostHook(dict));
263  fn.add_post_hook(std::move(hook));
264  }
265 
266  PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
267  Py_INCREF(handle);
268  return handle;
269 }
270 
271 }} // namespace torch::autograd
Definition: python_hook.h:17
Definition: jit_type.h:17