1 #include <torch/csrc/python_headers.h> 3 #include <torch/csrc/Exceptions.h> 4 #include <torch/csrc/utils/pybind.h> 5 #include <torch/csrc/autograd/grad_mode.h> 6 #include <torch/csrc/autograd/profiler.h> 7 #include <torch/csrc/autograd/python_function.h> 8 #include <torch/csrc/autograd/function.h> 10 PyObject * THPAutograd_initExtension(PyObject *_unused)
12 auto tensor_module =
THPObjectPtr(PyImport_ImportModule(
"torch.tensor"));
16 THPVariableClass = PyObject_GetAttrString(tensor_module,
"Tensor");
19 auto autograd_module =
THPObjectPtr(PyImport_ImportModule(
"torch.autograd"));
23 THPFunctionClass = PyObject_GetAttrString(autograd_module,
"Function");
26 auto m = py::handle(autograd_module).cast<py::module>();
28 py::class_<torch::autograd::profiler::Event>(m,
"ProfilerEvent")
29 .def(
"kind", &torch::autograd::profiler::Event::kind)
33 .def(
"thread_id", &torch::autograd::profiler::Event::thread_id)
34 .def(
"device", &torch::autograd::profiler::Event::device)
35 .def(
"cpu_elapsed_us", &torch::autograd::profiler::Event::cpu_elapsed_us)
37 "cuda_elapsed_us", &torch::autograd::profiler::Event::cuda_elapsed_us)
38 .def(
"has_cuda", &torch::autograd::profiler::Event::has_cuda);
39 py::enum_<torch::autograd::profiler::ProfilerState>(m,
"ProfilerState")
40 .value(
"Disabled", torch::autograd::profiler::ProfilerState::Disabled)
41 .value(
"CPU", torch::autograd::profiler::ProfilerState::CPU)
42 .value(
"CUDA", torch::autograd::profiler::ProfilerState::CUDA)
43 .value(
"NVTX", torch::autograd::profiler::ProfilerState::NVTX);
45 m.def(
"_enable_profiler", torch::autograd::profiler::enableProfiler);
46 m.def(
"_disable_profiler", torch::autograd::profiler::disableProfiler);
48 m.def(
"_push_range", [](std::string name) {
49 torch::autograd::profiler::pushRange(std::move(name));
51 m.def(
"_pop_range", []() { torch::autograd::profiler::popRange(); });
56 namespace torch {
namespace autograd {
58 static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) {
60 if (!PyBool_Check(arg)) {
61 throw TypeError(
"enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
63 GradMode::set_enabled(arg == Py_True);
68 static PyObject * is_grad_enabled(PyObject* _unused, PyObject *arg) {
70 if (GradMode::is_enabled()) {
78 static PyObject * set_anomaly_mode_enabled(PyObject* _unused, PyObject *arg) {
80 if (!PyBool_Check(arg)) {
81 throw TypeError(
"enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
83 AnomalyMode::set_enabled(arg == Py_True);
88 static PyObject * is_anomaly_mode_enabled(PyObject* _unused, PyObject *arg) {
90 if (AnomalyMode::is_enabled()) {
99 static PyMethodDef methods[] = {
100 {
"set_grad_enabled", (PyCFunction)set_grad_enabled, METH_O,
nullptr},
101 {
"is_grad_enabled", (PyCFunction)is_grad_enabled, METH_NOARGS,
nullptr},
102 {
"set_anomaly_enabled", (PyCFunction)set_anomaly_mode_enabled, METH_O,
nullptr},
103 {
"is_anomaly_enabled", (PyCFunction)is_anomaly_mode_enabled, METH_NOARGS,
nullptr},
104 {
nullptr,
nullptr, 0,
nullptr}
107 PyMethodDef* python_functions() {