Caffe2 - C++ API
A deep learning, cross platform ML framework
python.h
1 #pragma once
2 
3 #include <torch/detail/static.h>
4 #include <torch/nn/module.h>
5 #include <torch/ordered_dict.h>
6 #include <torch/types.h>
7 
8 #include <torch/csrc/Device.h>
9 #include <torch/csrc/Dtype.h>
10 #include <torch/csrc/DynamicTypes.h>
11 #include <torch/csrc/python_headers.h>
12 #include <torch/csrc/utils/pybind.h>
13 
14 #include <iterator>
15 #include <string>
16 #include <unordered_map>
17 #include <utility>
18 #include <vector>
19 
20 namespace torch {
21 namespace python {
22 namespace detail {
23 inline Device py_object_to_device(py::object object) {
24  PyObject* obj = object.ptr();
25  if (THPDevice_Check(obj)) {
26  return reinterpret_cast<THPDevice*>(obj)->device;
27  }
28  throw TypeError("Expected device");
29 }
30 
31 inline Dtype py_object_to_dtype(py::object object) {
32  PyObject* obj = object.ptr();
33  if (THPDtype_Check(obj)) {
34  return reinterpret_cast<THPDtype*>(obj)->scalar_type;
35  }
36  throw TypeError("Expected dtype");
37 }
38 
39 template <typename ModuleType>
40 using PyModuleClass =
41  py::class_<ModuleType, torch::nn::Module, std::shared_ptr<ModuleType>>;
42 
46 template <typename ModuleType>
47 void bind_cpp_module_wrapper(
48  py::module module,
49  PyModuleClass<ModuleType> cpp_class,
50  const char* name) {
51  // Grab the `torch.nn.cpp.ModuleWrapper` class, which we'll subclass
52  // with a dynamically created class below.
53  py::object cpp_module =
54  py::module::import("torch.nn.cpp").attr("ModuleWrapper");
55 
56  // Grab the `type` class which we'll use as a metaclass to create a new class
57  // dynamically.
58  py::object type_metaclass =
59  py::reinterpret_borrow<py::object>((PyObject*)&PyType_Type);
60 
61  // The `ModuleWrapper` constructor copies all functions to its own `__dict__`
62  // in its constructor, but we do need to give our dynamic class a constructor.
63  // Inside, we construct an instance of the original C++ module we're binding
64  // (the `torch::nn::Module` subclass), and then forward it to the
65  // `ModuleWrapper` constructor.
66  py::dict attributes;
67 
68  // `type()` always needs a `str`, but pybind11's `str()` method always creates
69  // a `unicode` object.
70 #if PY_MAJOR_VERSION < 3
71  py::object name_str =
72  py::reinterpret_steal<py::object>(PyString_FromString(name));
73 #else
74  py::object name_str = py::str(name);
75 #endif
76 
77  // Dynamically create the subclass of `ModuleWrapper`, which is a subclass of
78  // `torch.nn.Module`, and will delegate all calls to the C++ module we're
79  // binding.
80  py::object wrapper_class =
81  type_metaclass(name_str, py::make_tuple(cpp_module), attributes);
82 
83  // The constructor of the dynamic class calls `ModuleWrapper.__init__()`,
84  // which replaces its methods with those of the C++ module.
85  wrapper_class.attr("__init__") = py::cpp_function(
86  [cpp_module, cpp_class](
87  py::object self, py::args args, py::kwargs kwargs) {
88  cpp_module.attr("__init__")(self, cpp_class(*args, **kwargs));
89  },
90  py::is_method(wrapper_class));
91 
92  // Calling `my_module.my_class` now means that `my_class` is a subclass of
93  // `ModuleWrapper`, and whose methods call into the C++ module we're binding.
94  module.attr(name) = wrapper_class;
95 }
96 } // namespace detail
97 
107 template <typename ModuleType, typename... Extra>
108 py::class_<ModuleType, Extra...> add_module_bindings(
109  py::class_<ModuleType, Extra...> module) {
110  // clang-format off
111  return module
112  .def("train",
113  [](ModuleType& module, bool mode) { module.train(mode); },
114  py::arg("mode") = true)
115  .def("eval", [](ModuleType& module) { module.eval(); })
116  .def("clone", [](ModuleType& module) { return module.clone(); })
117  .def_property_readonly(
118  "training", [](ModuleType& module) { return module.is_training(); })
119  .def("zero_grad", [](ModuleType& module) { module.zero_grad(); })
120  .def_property_readonly( "_parameters", [](ModuleType& module) {
121  return module.named_parameters(/*recurse=*/false);
122  })
123  .def("parameters", [](ModuleType& module, bool recurse) {
124  return module.parameters(recurse);
125  },
126  py::arg("recurse") = true)
127  .def("named_parameters", [](ModuleType& module, bool recurse) {
128  return module.named_parameters(recurse);
129  },
130  py::arg("recurse") = true)
131  .def_property_readonly("_buffers", [](ModuleType& module) {
132  return module.named_buffers(/*recurse=*/false);
133  })
134  .def("buffers", [](ModuleType& module, bool recurse) {
135  return module.buffers(recurse); },
136  py::arg("recurse") = true)
137  .def("named_buffers", [](ModuleType& module, bool recurse) {
138  return module.named_buffers(recurse);
139  },
140  py::arg("recurse") = true)
141  .def_property_readonly(
142  "_modules", [](ModuleType& module) { return module.named_children(); })
143  .def("modules", [](ModuleType& module) { return module.modules(); })
144  .def("named_modules",
145  [](ModuleType& module, py::object /* unused */, std::string prefix) {
146  return module.named_modules(std::move(prefix));
147  },
148  py::arg("memo") = py::none(),
149  py::arg("prefix") = std::string())
150  .def("children", [](ModuleType& module) { return module.children(); })
151  .def("named_children",
152  [](ModuleType& module) { return module.named_children(); })
153  .def("to", [](ModuleType& module, py::object object, bool non_blocking) {
154  if (THPDevice_Check(object.ptr())) {
155  module.to(
156  reinterpret_cast<THPDevice*>(object.ptr())->device,
157  non_blocking);
158  } else {
159  module.to(detail::py_object_to_dtype(object), non_blocking);
160  }
161  },
162  py::arg("dtype_or_device"),
163  py::arg("non_blocking") = false)
164  .def("to",
165  [](ModuleType& module,
166  py::object device,
167  py::object dtype,
168  bool non_blocking) {
169  if (device.is_none()) {
170  module.to(detail::py_object_to_dtype(dtype), non_blocking);
171  } else if (dtype.is_none()) {
172  module.to(detail::py_object_to_device(device), non_blocking);
173  } else {
174  module.to(
175  detail::py_object_to_device(device),
176  detail::py_object_to_dtype(dtype),
177  non_blocking);
178  }
179  },
180  py::arg("device"),
181  py::arg("dtype"),
182  py::arg("non_blocking") = false)
183  .def("cuda", [](ModuleType& module) { module.to(kCUDA); })
184  .def("cpu", [](ModuleType& module) { module.to(kCPU); })
185  .def("float", [](ModuleType& module) { module.to(kFloat32); })
186  .def("double", [](ModuleType& module) { module.to(kFloat64); })
187  .def("half", [](ModuleType& module) { module.to(kFloat16); })
188  .def("__str__", [](ModuleType& module) { return module.name(); })
189  .def("__repr__", [](ModuleType& module) { return module.name(); });
190  // clang-format on
191 }
192 
214 template <typename ModuleType, bool force_enable = false>
215 torch::disable_if_t<
217  detail::PyModuleClass<ModuleType>>
218 bind_module(py::module module, const char* name) {
219  py::module cpp = module.def_submodule("cpp");
220  auto cpp_class =
221  add_module_bindings(detail::PyModuleClass<ModuleType>(cpp, name));
222  detail::bind_cpp_module_wrapper(module, cpp_class, name);
223  return cpp_class;
224 }
225 
250 template <
251  typename ModuleType,
252  typename =
253  torch::enable_if_t<torch::detail::has_forward<ModuleType>::value>>
254 detail::PyModuleClass<ModuleType> bind_module(
255  py::module module,
256  const char* name) {
257  return bind_module<ModuleType, /*force_enable=*/true>(module, name)
258  .def("forward", &ModuleType::forward)
259  .def("__call__", &ModuleType::forward);
260 }
261 } // namespace python
262 } // namespace torch
Detects if a type T has a forward() method.
Definition: static.h:19
Definition: Dtype.h:9
Definition: jit_type.h:17