Caffe2 - C++ API
A deep learning, cross platform ML framework
init.cpp
1 #include <torch/csrc/python_headers.h>
2 
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/utils/pybind.h>
5 #include <torch/csrc/autograd/grad_mode.h>
6 #include <torch/csrc/autograd/profiler.h>
7 #include <torch/csrc/autograd/python_function.h>
8 #include <torch/csrc/autograd/function.h>
9 
10 PyObject * THPAutograd_initExtension(PyObject *_unused)
11 {
12  auto tensor_module = THPObjectPtr(PyImport_ImportModule("torch.tensor"));
13  if (!tensor_module) throw python_error();
14 
15  // NOTE: "leaks" THPVariableClass
16  THPVariableClass = PyObject_GetAttrString(tensor_module, "Tensor");
17  if (!THPVariableClass) throw python_error();
18 
19  auto autograd_module = THPObjectPtr(PyImport_ImportModule("torch.autograd"));
20  if (!autograd_module) throw python_error();
21 
22  // NOTE: "leaks" Function
23  THPFunctionClass = PyObject_GetAttrString(autograd_module, "Function");
24  if (!THPFunctionClass) throw python_error();
25 
26  auto m = py::handle(autograd_module).cast<py::module>();
27 
28  py::class_<torch::autograd::profiler::Event>(m, "ProfilerEvent")
29  .def("kind", &torch::autograd::profiler::Event::kind)
30  .def(
31  "name",
32  [](const torch::autograd::profiler::Event& e) { return e.name(); })
33  .def("thread_id", &torch::autograd::profiler::Event::thread_id)
34  .def("device", &torch::autograd::profiler::Event::device)
35  .def("cpu_elapsed_us", &torch::autograd::profiler::Event::cpu_elapsed_us)
36  .def(
37  "cuda_elapsed_us", &torch::autograd::profiler::Event::cuda_elapsed_us)
38  .def("has_cuda", &torch::autograd::profiler::Event::has_cuda);
39  py::enum_<torch::autograd::profiler::ProfilerState>(m,"ProfilerState")
40  .value("Disabled", torch::autograd::profiler::ProfilerState::Disabled)
41  .value("CPU", torch::autograd::profiler::ProfilerState::CPU)
42  .value("CUDA", torch::autograd::profiler::ProfilerState::CUDA)
43  .value("NVTX", torch::autograd::profiler::ProfilerState::NVTX);
44 
45  m.def("_enable_profiler", torch::autograd::profiler::enableProfiler);
46  m.def("_disable_profiler", torch::autograd::profiler::disableProfiler);
47 
48  m.def("_push_range", [](std::string name) {
49  torch::autograd::profiler::pushRange(std::move(name));
50  });
51  m.def("_pop_range", []() { torch::autograd::profiler::popRange(); });
52 
53  Py_RETURN_TRUE;
54 }
55 
56 namespace torch { namespace autograd {
57 
58 static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) {
59  HANDLE_TH_ERRORS
60  if (!PyBool_Check(arg)) {
61  throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
62  }
63  GradMode::set_enabled(arg == Py_True);
64  Py_RETURN_NONE;
65  END_HANDLE_TH_ERRORS
66 }
67 
68 static PyObject * is_grad_enabled(PyObject* _unused, PyObject *arg) {
69  HANDLE_TH_ERRORS
70  if (GradMode::is_enabled()) {
71  Py_RETURN_TRUE;
72  } else {
73  Py_RETURN_FALSE;
74  }
75  END_HANDLE_TH_ERRORS
76 }
77 
78 static PyObject * set_anomaly_mode_enabled(PyObject* _unused, PyObject *arg) {
79  HANDLE_TH_ERRORS
80  if (!PyBool_Check(arg)) {
81  throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
82  }
83  AnomalyMode::set_enabled(arg == Py_True);
84  Py_RETURN_NONE;
85  END_HANDLE_TH_ERRORS
86 }
87 
88 static PyObject * is_anomaly_mode_enabled(PyObject* _unused, PyObject *arg) {
89  HANDLE_TH_ERRORS
90  if (AnomalyMode::is_enabled()) {
91  Py_RETURN_TRUE;
92  } else {
93  Py_RETURN_FALSE;
94  }
95  END_HANDLE_TH_ERRORS
96 }
97 
98 // autograd methods on torch._C
99 static PyMethodDef methods[] = {
100  {"set_grad_enabled", (PyCFunction)set_grad_enabled, METH_O, nullptr},
101  {"is_grad_enabled", (PyCFunction)is_grad_enabled, METH_NOARGS, nullptr},
102  {"set_anomaly_enabled", (PyCFunction)set_anomaly_mode_enabled, METH_O, nullptr},
103  {"is_anomaly_enabled", (PyCFunction)is_anomaly_mode_enabled, METH_NOARGS, nullptr},
104  {nullptr, nullptr, 0, nullptr}
105 };
106 
107 PyMethodDef* python_functions() {
108  return methods;
109 }
110 
111 }} // namespace torch::autograd
Definition: jit_type.h:17