3 #include <torch/detail/static.h> 4 #include <torch/nn/module.h> 5 #include <torch/ordered_dict.h> 6 #include <torch/types.h> 8 #include <torch/csrc/Device.h> 9 #include <torch/csrc/Dtype.h> 10 #include <torch/csrc/DynamicTypes.h> 11 #include <torch/csrc/python_headers.h> 12 #include <torch/csrc/utils/pybind.h> 16 #include <unordered_map> 23 inline Device py_object_to_device(py::object
object) {
24 PyObject* obj =
object.ptr();
25 if (THPDevice_Check(obj)) {
26 return reinterpret_cast<THPDevice*
>(obj)->device;
28 throw TypeError(
"Expected device");
31 inline Dtype py_object_to_dtype(py::object
object) {
32 PyObject* obj =
object.ptr();
33 if (THPDtype_Check(obj)) {
34 return reinterpret_cast<THPDtype*
>(obj)->scalar_type;
36 throw TypeError(
"Expected dtype");
39 template <
typename ModuleType>
41 py::class_<ModuleType, torch::nn::Module, std::shared_ptr<ModuleType>>;
46 template <
typename ModuleType>
47 void bind_cpp_module_wrapper(
49 PyModuleClass<ModuleType> cpp_class,
53 py::object cpp_module =
54 py::module::import(
"torch.nn.cpp").attr(
"ModuleWrapper");
58 py::object type_metaclass =
59 py::reinterpret_borrow<py::object>((PyObject*)&PyType_Type);
70 #if PY_MAJOR_VERSION < 3 72 py::reinterpret_steal<py::object>(PyString_FromString(name));
74 py::object name_str = py::str(name);
80 py::object wrapper_class =
81 type_metaclass(name_str, py::make_tuple(cpp_module), attributes);
85 wrapper_class.attr(
"__init__") = py::cpp_function(
86 [cpp_module, cpp_class](
87 py::object
self, py::args args, py::kwargs kwargs) {
88 cpp_module.attr(
"__init__")(
self, cpp_class(*args, **kwargs));
90 py::is_method(wrapper_class));
94 module.attr(name) = wrapper_class;
107 template <
typename ModuleType,
typename... Extra>
108 py::class_<ModuleType, Extra...> add_module_bindings(
109 py::class_<ModuleType, Extra...> module) {
113 [](ModuleType& module,
bool mode) { module.train(mode); },
114 py::arg(
"mode") =
true)
115 .def(
"eval", [](ModuleType& module) { module.eval(); })
116 .def(
"clone", [](ModuleType& module) {
return module.clone(); })
117 .def_property_readonly(
118 "training", [](ModuleType& module) {
return module.is_training(); })
119 .def(
"zero_grad", [](ModuleType& module) { module.zero_grad(); })
120 .def_property_readonly(
"_parameters", [](ModuleType& module) {
121 return module.named_parameters(
false);
123 .def(
"parameters", [](ModuleType& module,
bool recurse) {
124 return module.parameters(recurse);
126 py::arg(
"recurse") =
true)
127 .def(
"named_parameters", [](ModuleType& module,
bool recurse) {
128 return module.named_parameters(recurse);
130 py::arg(
"recurse") =
true)
131 .def_property_readonly(
"_buffers", [](ModuleType& module) {
132 return module.named_buffers(
false);
134 .def(
"buffers", [](ModuleType& module,
bool recurse) {
135 return module.buffers(recurse); },
136 py::arg(
"recurse") =
true)
137 .def(
"named_buffers", [](ModuleType& module,
bool recurse) {
138 return module.named_buffers(recurse);
140 py::arg(
"recurse") =
true)
141 .def_property_readonly(
142 "_modules", [](ModuleType& module) {
return module.named_children(); })
143 .def(
"modules", [](ModuleType& module) {
return module.modules(); })
144 .def(
"named_modules",
145 [](ModuleType& module, py::object , std::string prefix) {
146 return module.named_modules(std::move(prefix));
148 py::arg(
"memo") = py::none(),
149 py::arg(
"prefix") = std::string())
150 .def(
"children", [](ModuleType& module) {
return module.children(); })
151 .def(
"named_children",
152 [](ModuleType& module) {
return module.named_children(); })
153 .def(
"to", [](ModuleType& module, py::object
object,
bool non_blocking) {
154 if (THPDevice_Check(
object.ptr())) {
156 reinterpret_cast<THPDevice*>(
object.ptr())->device,
159 module.to(detail::py_object_to_dtype(
object), non_blocking);
162 py::arg(
"dtype_or_device"),
163 py::arg(
"non_blocking") =
false)
165 [](ModuleType& module,
169 if (device.is_none()) {
170 module.to(detail::py_object_to_dtype(dtype), non_blocking);
171 }
else if (dtype.is_none()) {
172 module.to(detail::py_object_to_device(device), non_blocking);
175 detail::py_object_to_device(device),
176 detail::py_object_to_dtype(dtype),
182 py::arg(
"non_blocking") =
false)
183 .def(
"cuda", [](ModuleType& module) { module.to(kCUDA); })
184 .def(
"cpu", [](ModuleType& module) { module.to(kCPU); })
185 .def(
"float", [](ModuleType& module) { module.to(kFloat32); })
186 .def(
"double", [](ModuleType& module) { module.to(kFloat64); })
187 .def(
"half", [](ModuleType& module) { module.to(kFloat16); })
188 .def(
"__str__", [](ModuleType& module) {
return module.name(); })
189 .def(
"__repr__", [](ModuleType& module) {
return module.name(); });
214 template <
typename ModuleType,
bool force_enable = false>
217 detail::PyModuleClass<ModuleType>>
218 bind_module(py::module module,
const char* name) {
219 py::module cpp = module.def_submodule(
"cpp");
221 add_module_bindings(detail::PyModuleClass<ModuleType>(cpp, name));
222 detail::bind_cpp_module_wrapper(module, cpp_class, name);
253 torch::enable_if_t<torch::detail::has_forward<ModuleType>::value>>
254 detail::PyModuleClass<ModuleType> bind_module(
257 return bind_module<ModuleType,
true>(module, name)
258 .def(
"forward", &ModuleType::forward)
259 .def(
"__call__", &ModuleType::forward);
Detects if a type T has a forward() method.