1 #include <torch/csrc/autograd/python_cpp_function.h> 3 #include <torch/csrc/python_headers.h> 7 #include <unordered_map> 9 #include <torch/csrc/autograd/python_function.h> 10 #include <torch/csrc/autograd/python_variable.h> 11 #include <torch/csrc/autograd/python_hook.h> 12 #include <torch/csrc/autograd/python_anomaly_mode.h> 13 #include <torch/csrc/utils/auto_gil.h> 14 #include <torch/csrc/utils/python_strings.h> 15 #include <torch/csrc/DynamicTypes.h> 16 #include <torch/csrc/Exceptions.h> 20 namespace torch {
namespace autograd {
24 PyObject* THPCppFunction_call(PyObject*
self, PyObject* args, PyObject *kwargs)
26 if (kwargs && PyDict_Size(kwargs) != 0) {
27 return PyErr_Format(PyExc_TypeError,
"keyword arguments are not supported");
30 int num_inputs = PyTuple_GET_SIZE(args);
31 int num_inputs_required = ((
THPCppFunction*)
self)->cdata->num_inputs();
32 if (num_inputs != num_inputs_required) {
33 return PyErr_Format(PyExc_TypeError,
"expected %d arguments, got %d instead",
34 num_inputs_required, num_inputs);
36 variable_list vars(num_inputs);
37 for (
int i = 0; i != num_inputs; ++i) {
38 PyObject* arg = PyTuple_GET_ITEM(args, i);
42 if (!THPVariable_Check(arg)) {
43 return PyErr_Format(PyExc_TypeError,
"argument %d is not a Variable", i);
56 int num_outputs = output.size();
57 if (num_outputs == 1) {
59 return THPVariable_Wrap(output[0]);
63 for (
int i = 0; i != num_outputs; ++i) {
64 PyTuple_SET_ITEM(tuple.get(), i, THPVariable_Wrap(output[i]));
66 return tuple.release();
69 int THPCppFunction_traverse(PyObject*
self, visitproc visit,
void *arg)
72 for (
const auto& hook : fn.pre_hooks()) {
73 if (
auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
74 Py_VISIT(pyhook->dict);
77 for (
const auto& hook : fn.post_hooks()) {
78 if (
auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
79 Py_VISIT(pyhook->dict);
85 int THPCppFunction_clear(PyObject*
self)
90 f->cdata->set_pyobj(
nullptr);
96 void THPCppFunction_dealloc(PyObject*
self)
98 THPCppFunction_clear(
self);
100 Py_TYPE(
self)->tp_free(
self);
105 PyObject* THPCppFunction_next_functions(
THPCppFunction*
self, PyObject* hook)
107 const auto num_next =
self->cdata->num_outputs();
109 if (!py_functions)
return nullptr;
110 for (
size_t i = 0; i < num_next; ++i) {
111 auto& c_tuple =
self->cdata->next_edge(i);
113 if (!tuple)
return nullptr;
114 PyObject *py_fn = functionToPyObject(c_tuple.function);
115 if (!py_fn)
return nullptr;
116 PyTuple_SET_ITEM(tuple.get(), 0, py_fn);
117 PyObject *py_idx = PyLong_FromLong(c_tuple.input_nr);
118 if (!py_idx)
return nullptr;
119 PyTuple_SET_ITEM(tuple.get(), 1, py_idx);
120 PyTuple_SET_ITEM(py_functions.get(), i, tuple.release());
122 return py_functions.release();
125 PyObject* THPCppFunction_metadata(
THPCppFunction *
self,
void *_unused)
127 auto metadata =
static_cast<PyAnomalyMetadata*
>(
self->cdata->metadata())->dict();
137 PyObject* THPCppFunction_register_hook_dict(PyObject*
self, PyObject* _var)
139 if (!THPVariable_Check(_var)) {
140 return PyErr_Format(PyExc_TypeError,
"_register_hook_dict expected a variable");
144 std::unique_ptr<FunctionPreHook> hook(
146 fn.add_pre_hook(std::move(hook));
150 PyObject* THPCppFunction_register_hook(PyObject*
self, PyObject* hook)
153 return registerFunctionHook(fn, hook);
156 PyObject* THPCppFunction_name(PyObject*
self) {
158 return THPUtils_packString(fn.name());
161 static struct PyMethodDef default_methods[] = {
162 THP_FUNCTION_DEFAULT_METHODS,
166 static struct PyGetSetDef default_properties[] = {
167 THP_FUNCTION_DEFAULT_PROPERTIES,
171 PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type,
const char* name,
172 PyGetSetDef* function_properties, PyMethodDef* function_methods)
174 type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC;
177 type.tp_call = THPCppFunction_call;
178 type.tp_methods = function_methods ? function_methods : default_methods;
179 type.tp_getset = function_properties ? function_properties : default_properties;
180 type.tp_dealloc = THPCppFunction_dealloc;
181 type.tp_traverse = THPCppFunction_traverse;
182 type.tp_clear = THPCppFunction_clear;
183 if (PyType_Ready(&type) < 0) {
184 auto msg = std::string(
"Unable to instantiate PyTypeObject for ") + name;
185 throw std::runtime_error(msg);
190 static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types;
194 _initFunctionPyTypeObject(type,
"CppFunction",
nullptr,
nullptr);
201 PyObject* functionToPyObject(
const std::shared_ptr<Function>& cdata)
209 if (
auto pfw = dynamic_cast<PyFunction*>(cdata.get())) {
210 PyObject* obj = pfw->obj;
215 if (cdata->pyobj()) {
216 Py_INCREF(cdata->pyobj());
219 auto it = cpp_function_types.find(std::type_index(
typeid(fn)));
221 if (it == cpp_function_types.end()) {
222 type = &default_type.type;
224 type = (PyTypeObject*)it->second.get();
228 if (!obj)
return nullptr;
230 new (&f->cdata) std::shared_ptr<Function>(cdata);
233 cdata->set_pyobj(obj.release());
236 return cdata->pyobj();
239 void registerCppFunction(
const std::type_info& type, PyTypeObject* pytype)
241 Py_INCREF((PyObject*)pytype);
242 cpp_function_types[std::type_index(type)] =
THPObjectPtr((PyObject*)pytype);
245 PyObject* registerFunctionHook(Function& fn, PyObject* hook)
247 PyObject* dict = Py_None;
248 for (
const auto& hook : fn.post_hooks()) {
249 if (
auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
255 THPObjectPtr register_fn(PyObject_GetAttrString(THPFunctionClass,
"_register_hook"));
256 if (!register_fn)
return nullptr;
257 THPObjectPtr res(PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook,
nullptr));
258 if (!res)
return nullptr;
260 if (dict == Py_None) {
261 dict = PyTuple_GET_ITEM(res.get(), 0);
263 fn.add_post_hook(std::move(hook));
266 PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);