1 #include <torch/csrc/autograd/python_cpp_function.h>     3 #include <torch/csrc/python_headers.h>     7 #include <unordered_map>     9 #include <torch/csrc/autograd/python_function.h>    10 #include <torch/csrc/autograd/python_variable.h>    11 #include <torch/csrc/autograd/python_hook.h>    12 #include <torch/csrc/autograd/python_anomaly_mode.h>    13 #include <torch/csrc/utils/auto_gil.h>    14 #include <torch/csrc/utils/python_strings.h>    15 #include <torch/csrc/DynamicTypes.h>    16 #include <torch/csrc/Exceptions.h>    20 namespace torch { 
namespace autograd {
    24 PyObject* THPCppFunction_call(PyObject* 
self, PyObject* args, PyObject *kwargs)
    26   if (kwargs && PyDict_Size(kwargs) != 0) {
    27     return PyErr_Format(PyExc_TypeError, 
"keyword arguments are not supported");
    30   int num_inputs = PyTuple_GET_SIZE(args);
    31   int num_inputs_required = ((
THPCppFunction*)
self)->cdata->num_inputs();
    32   if (num_inputs != num_inputs_required) {
    33     return PyErr_Format(PyExc_TypeError, 
"expected %d arguments, got %d instead",
    34                         num_inputs_required, num_inputs);
    36   variable_list vars(num_inputs);
    37   for (
int i = 0; i != num_inputs; ++i) {
    38     PyObject* arg = PyTuple_GET_ITEM(args, i);
    42     if (!THPVariable_Check(arg)) {
    43       return PyErr_Format(PyExc_TypeError, 
"argument %d is not a Variable", i);
    56   int num_outputs = output.size();
    57   if (num_outputs == 1) {
    59     return THPVariable_Wrap(output[0]);
    63   for (
int i = 0; i != num_outputs; ++i) {
    64     PyTuple_SET_ITEM(tuple.get(), i, THPVariable_Wrap(output[i]));
    66   return tuple.release();
    69 int THPCppFunction_traverse(PyObject* 
self, visitproc visit, 
void *arg)
    72   for (
const auto& hook : fn.pre_hooks()) {
    73     if (
auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
    74       Py_VISIT(pyhook->dict);
    77   for (
const auto& hook : fn.post_hooks()) {
    78     if (
auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
    79       Py_VISIT(pyhook->dict);
    85 int THPCppFunction_clear(PyObject* 
self)
    90     f->cdata->set_pyobj(
nullptr);
    96 void THPCppFunction_dealloc(PyObject* 
self)
    98   THPCppFunction_clear(
self);
   100   Py_TYPE(
self)->tp_free(
self);
   105 PyObject* THPCppFunction_next_functions(
THPCppFunction* 
self, PyObject* hook)
   107   const auto num_next = 
self->cdata->num_outputs();
   109   if (!py_functions) 
return nullptr;
   110   for (
size_t i = 0; i < num_next; ++i) {
   111     auto& c_tuple = 
self->cdata->next_edge(i);
   113     if (!tuple) 
return nullptr;
   114     PyObject *py_fn = functionToPyObject(c_tuple.function);
   115     if (!py_fn) 
return nullptr;
   116     PyTuple_SET_ITEM(tuple.get(), 0, py_fn);
   117     PyObject *py_idx = PyLong_FromLong(c_tuple.input_nr);
   118     if (!py_idx) 
return nullptr;
   119     PyTuple_SET_ITEM(tuple.get(), 1, py_idx);
   120     PyTuple_SET_ITEM(py_functions.get(), i, tuple.release());
   122   return py_functions.release();
   125 PyObject* THPCppFunction_metadata(
THPCppFunction *
self, 
void *_unused)
   127   auto metadata = 
static_cast<PyAnomalyMetadata*
>(
self->cdata->metadata())->dict();
   137 PyObject* THPCppFunction_register_hook_dict(PyObject* 
self, PyObject* _var)
   139   if (!THPVariable_Check(_var)) {
   140     return PyErr_Format(PyExc_TypeError, 
"_register_hook_dict expected a variable");
   144   std::unique_ptr<FunctionPreHook> hook(
   146   fn.add_pre_hook(std::move(hook));
   150 PyObject* THPCppFunction_register_hook(PyObject* 
self, PyObject* hook)
   153   return registerFunctionHook(fn, hook);
   156 PyObject* THPCppFunction_name(PyObject* 
self) {
   158   return THPUtils_packString(fn.name());
   161 static struct PyMethodDef default_methods[] = {
   162   THP_FUNCTION_DEFAULT_METHODS,
   166 static struct PyGetSetDef default_properties[] = {
   167   THP_FUNCTION_DEFAULT_PROPERTIES,
   171 PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type, 
const char* name,
   172   PyGetSetDef* function_properties, PyMethodDef* function_methods)
   174   type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC;
   177   type.tp_call = THPCppFunction_call;
   178   type.tp_methods = function_methods ? function_methods : default_methods;
   179   type.tp_getset = function_properties ? function_properties : default_properties;
   180   type.tp_dealloc = THPCppFunction_dealloc;
   181   type.tp_traverse = THPCppFunction_traverse;
   182   type.tp_clear = THPCppFunction_clear;
   183   if (PyType_Ready(&type) < 0) {
   184     auto msg = std::string(
"Unable to instantiate PyTypeObject for ") + name;
   185     throw std::runtime_error(msg);
   190 static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types;
   194     _initFunctionPyTypeObject(type, 
"CppFunction", 
nullptr, 
nullptr);
   201 PyObject* functionToPyObject(
const std::shared_ptr<Function>& cdata)
   209   if (
auto pfw = dynamic_cast<PyFunction*>(cdata.get())) {
   210     PyObject* obj = pfw->obj;
   215   if (cdata->pyobj()) {
   216     Py_INCREF(cdata->pyobj());
   219     auto it = cpp_function_types.find(std::type_index(
typeid(fn)));
   221     if (it == cpp_function_types.end()) {
   222       type = &default_type.type;
   224       type = (PyTypeObject*)it->second.get();
   228     if (!obj) 
return nullptr;
   230     new (&f->cdata) std::shared_ptr<Function>(cdata);
   233     cdata->set_pyobj(obj.release());
   236   return cdata->pyobj();
   239 void registerCppFunction(
const std::type_info& type, PyTypeObject* pytype)
   241   Py_INCREF((PyObject*)pytype);
   242   cpp_function_types[std::type_index(type)] = 
THPObjectPtr((PyObject*)pytype);
   245 PyObject* registerFunctionHook(Function& fn, PyObject* hook)
   247   PyObject* dict = Py_None;
   248   for (
const auto& hook : fn.post_hooks()) {
   249     if (
auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
   255   THPObjectPtr register_fn(PyObject_GetAttrString(THPFunctionClass, 
"_register_hook"));
   256   if (!register_fn) 
return nullptr;
   257   THPObjectPtr res(PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, 
nullptr));
   258   if (!res) 
return nullptr;
   260   if (dict == Py_None) {
   261     dict = PyTuple_GET_ITEM(res.get(), 0);
   263     fn.add_post_hook(std::move(hook));
   266   PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);