Caffe2 - C++ API
A deep learning, cross platform ML framework
python_nn_functions.cpp
1 #include "python_nn_functions.h"
2 
3 // ${generated_comment}
4 
5 #include "torch/csrc/Device.h"
6 #include "torch/csrc/DynamicTypes.h"
7 #include "torch/csrc/Exceptions.h"
8 #include "torch/csrc/autograd/python_variable.h"
9 #include "torch/csrc/autograd/utils/wrap_outputs.h"
10 #include "torch/csrc/autograd/utils/python_arg_parsing.h"
11 #include "torch/csrc/utils/python_arg_parser.h"
12 #include "torch/csrc/utils/structseq.h"
13 
14 #include "python_nn_functions_dispatch.h"
15 
16 using at::Tensor;
17 using at::Scalar;
18 using namespace torch::autograd::utils;
19 
20 namespace torch { namespace autograd {
21 
22 static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObject* kwargs)
23 {
24  HANDLE_TH_ERRORS
25  auto parsed = parse_to_conversion(args, kwargs, /*allow_copy*/ false); // we don't want copy for nn.Module.to
26  auto& device = std::get<0>(parsed);
27  auto& scalarType = std::get<1>(parsed);
28  auto non_blocking = std::get<2>(parsed);
29  auto tuple = THPObjectPtr{PyTuple_New(3)};
30  if (!tuple) throw python_error();
31  if (device) {
32  PyTuple_SET_ITEM(tuple.get(), 0, THPDevice_New(*device));
33  } else {
34  Py_INCREF(Py_None);
35  PyTuple_SET_ITEM(tuple.get(), 0, Py_None);
36  }
37  if (scalarType) {
38  PyTuple_SET_ITEM(tuple.get(), 1, torch::autograd::utils::wrap(torch::getDtype(*scalarType)));
39  } else {
40  Py_INCREF(Py_None);
41  PyTuple_SET_ITEM(tuple.get(), 1, Py_None);
42  }
43  PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking));
44  return tuple.release();
45  END_HANDLE_TH_ERRORS
46 }
47 
48 ${py_methods}
49 
50 static PyMethodDef nn_functions[] = {
51  {"_parse_to", (PyCFunction)THPVariable__parse_to, METH_VARARGS | METH_KEYWORDS, nullptr},
52  ${py_method_defs}
53  {NULL}
54 };
55 
56 void initNNFunctions(PyObject* module) {
57 #if PY_MAJOR_VERSION == 2
58  PyObject* nn = Py_InitModule("torch._C._nn", nn_functions);
59  Py_XINCREF(nn); // Py_InitModule returns "borrowed" reference
60 #else
61  static struct PyModuleDef def = {
62  PyModuleDef_HEAD_INIT,
63  "torch._C._nn",
64  NULL,
65  -1,
66  nn_functions
67  };
68  PyObject* nn = PyModule_Create(&def);
69 #endif
70  if (!nn) {
71  throw python_error();
72  }
73  // steals a reference to nn
74  if (PyModule_AddObject(module, "_nn", nn) != 0) {
75  throw python_error();
76  }
77 }
78 
79 }} // namespace torch::autograd
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Definition: jit_type.h:17