Caffe2 - C++ API
A deep learning, cross platform ML framework
init.cpp
1 #include <Python.h>
2 #include <torch/csrc/autograd/functions/accumulate_grad.h>
3 #include <torch/csrc/autograd/functions/basic_ops.h>
4 #include <torch/csrc/autograd/functions/tensor.h>
5 #include <torch/csrc/autograd/functions/pybind.h>
6 #include <torch/csrc/autograd/python_cpp_function.h>
7 #include <torch/csrc/autograd/generated/python_functions.h>
8 #include <torch/csrc/jit/python_tracer.h>
9 #include <torch/csrc/utils/pybind.h>
10 #include <torch/csrc/utils/tuple_parser.h>
11 
12 using namespace torch::autograd;
13 using torch::TupleParser;
14 
16  DelayedError* operator()(PyObject* args) {
17  std::string msg;
18  int num_inputs;
19 
20  TupleParser parser(args, 2);
21  parser.parse(msg, "msg");
22  parser.parse(num_inputs, "num_inputs");
23 
24  return new DelayedError(msg, num_inputs);
25  }
26 };
27 
28 struct NoCtor {
29  Function* operator()(PyObject* args) {
30  throw std::runtime_error("Cannot construct");
31  }
32 };
33 
34 template<typename C, typename T>
35 static void addClass(PyObject* module, PyTypeObject& type, const char* name,
36  PyGetSetDef* function_properties=nullptr, PyMethodDef* function_methods=nullptr)
37 {
38  createForwardFunctionPyTypeObject<T>(type, name, function_properties, function_methods);
39  Py_INCREF(&type);
40  PyModule_AddObject(module, name, (PyObject*)&type);
41  registerCppFunction(typeid(C), &type);
42 }
43 
44 template<typename T, typename ValueT, typename ParamsT, ValueT ParamsT::*ptr,
45  typename ConvertArgT, PyObject* (*Convert)(ConvertArgT)>
46 PyObject* getTupleAttr(PyObject* obj, void* _unused)
47 {
48  HANDLE_TH_ERRORS
49  THPCppFunction* self = (THPCppFunction*)obj;
50  auto& arr = ((T*)(self->cdata.get()))->*ptr;
51  auto num_elems = arr.size();
52  THPObjectPtr py_tuple(PyTuple_New(num_elems));
53  if (!py_tuple) return nullptr;
54  for (size_t i = 0; i < num_elems; ++i) {
55  PyTuple_SET_ITEM(py_tuple.get(), i, Convert(arr[i]));
56  }
57  return py_tuple.release();
58  END_HANDLE_TH_ERRORS
59 }
60 
61 template<typename T, typename ValueT, typename ParamsT, ValueT ParamsT::*ptr,
62  typename ConvertArgT, PyObject* (*Convert)(ConvertArgT)>
63 PyObject* getValueAttr(PyObject* obj, void* _unused)
64 {
65  HANDLE_TH_ERRORS
66  THPCppFunction* self = (THPCppFunction*)obj;
67  auto& val = ((T*)(self->cdata.get()))->*ptr;
68  return Convert(val);
69  END_HANDLE_TH_ERRORS
70 }
71 
72 static PyObject* accumulateGradVar(PyObject *_self, void* _unused)
73 {
74  THPCppFunction* self = (THPCppFunction*)_self;
75  auto grad_acc = (AccumulateGrad*)self->cdata.get();
76  return THPVariable_Wrap(grad_acc->variable);
77 }
78 
79 static struct PyGetSetDef accumulate_grad_properties[] = {
80  THP_FUNCTION_DEFAULT_PROPERTIES,
81  {(char*)"variable", accumulateGradVar, nullptr, nullptr, nullptr},
82  {nullptr}
83 };
84 
85 void THPAutograd_initFunctions()
86 {
87  THPObjectPtr module(PyModule_New("torch._C._functions"));
88  if (!module) throw python_error();
89 
90  static PyTypeObject AccumulateGradClass;
91  addClass<AccumulateGrad, NoCtor>(module, AccumulateGradClass, "AccumulateGrad", accumulate_grad_properties);
92 
93  static PyTypeObject ErrorClass;
94  addClass<Error, NoCtor>(module, ErrorClass, "Error");
95 
96  static PyTypeObject NotImplementedClass;
97  addClass<NotImplemented, NoCtor>(module, NotImplementedClass, "NotImplemented");
98 
99  static PyTypeObject DelayedErrorClass;
100  addClass<DelayedError, DelayedErrorCtor>(module, DelayedErrorClass, "DelayedError");
101 
102  static PyTypeObject CopyBackwardsClass;
103  addClass<CopyBackwards, NoCtor>(module, CopyBackwardsClass, "CopyBackwards");
104 
105  static PyTypeObject CopySlicesClass;
106  addClass<CopySlices, NoCtor>(module, CopySlicesClass, "CopySlices");
107 
108  generated::initialize_autogenerated_functions();
109 
110  auto c_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
111  if (!c_module) throw python_error();
112 
113  Py_INCREF(module);
114  if (PyModule_AddObject(c_module, "_functions", module) < 0) {
115  throw python_error();
116  }
117 }
Definition: static.cpp:64
Definition: init.cpp:28