3 #include <torch/detail/static.h>     4 #include <torch/nn/module.h>     5 #include <torch/ordered_dict.h>     6 #include <torch/types.h>     8 #include <torch/csrc/Device.h>     9 #include <torch/csrc/Dtype.h>    10 #include <torch/csrc/DynamicTypes.h>    11 #include <torch/csrc/python_headers.h>    12 #include <torch/csrc/utils/pybind.h>    16 #include <unordered_map>    23 inline Device py_object_to_device(py::object 
object) {
    24   PyObject* obj = 
object.ptr();
    25   if (THPDevice_Check(obj)) {
    26     return reinterpret_cast<THPDevice*
>(obj)->device;
    28   throw TypeError(
"Expected device");
    31 inline Dtype py_object_to_dtype(py::object 
object) {
    32   PyObject* obj = 
object.ptr();
    33   if (THPDtype_Check(obj)) {
    34     return reinterpret_cast<THPDtype*
>(obj)->scalar_type;
    36   throw TypeError(
"Expected dtype");
    39 template <
typename ModuleType>
    41     py::class_<ModuleType, torch::nn::Module, std::shared_ptr<ModuleType>>;
    46 template <
typename ModuleType>
    47 void bind_cpp_module_wrapper(
    49     PyModuleClass<ModuleType> cpp_class,
    53   py::object cpp_module =
    54       py::module::import(
"torch.nn.cpp").attr(
"ModuleWrapper");
    58   py::object type_metaclass =
    59       py::reinterpret_borrow<py::object>((PyObject*)&PyType_Type);
    70 #if PY_MAJOR_VERSION < 3    72       py::reinterpret_steal<py::object>(PyString_FromString(name));
    74   py::object name_str = py::str(name);
    80   py::object wrapper_class =
    81       type_metaclass(name_str, py::make_tuple(cpp_module), attributes);
    85   wrapper_class.attr(
"__init__") = py::cpp_function(
    86       [cpp_module, cpp_class](
    87           py::object 
self, py::args args, py::kwargs kwargs) {
    88         cpp_module.attr(
"__init__")(
self, cpp_class(*args, **kwargs));
    90       py::is_method(wrapper_class));
    94   module.attr(name) = wrapper_class;
   107 template <
typename ModuleType, 
typename... Extra>
   108 py::class_<ModuleType, Extra...> add_module_bindings(
   109     py::class_<ModuleType, Extra...> module) {
   113           [](ModuleType& module, 
bool mode) { module.train(mode); },
   114           py::arg(
"mode") = 
true)
   115       .def(
"eval", [](ModuleType& module) { module.eval(); })
   116       .def(
"clone", [](ModuleType& module) { 
return module.clone(); })
   117       .def_property_readonly(
   118           "training", [](ModuleType& module) { 
return module.is_training(); })
   119       .def(
"zero_grad", [](ModuleType& module) { module.zero_grad(); })
   120       .def_property_readonly( 
"_parameters", [](ModuleType& module) {
   121             return module.named_parameters(
false);
   123       .def(
"parameters", [](ModuleType& module, 
bool recurse) {
   124             return module.parameters(recurse);
   126           py::arg(
"recurse") = 
true)
   127       .def(
"named_parameters", [](ModuleType& module, 
bool recurse) {
   128             return module.named_parameters(recurse);
   130           py::arg(
"recurse") = 
true)
   131       .def_property_readonly(
"_buffers", [](ModuleType& module) {
   132             return module.named_buffers(
false);
   134       .def(
"buffers", [](ModuleType& module, 
bool recurse) {
   135             return module.buffers(recurse); },
   136           py::arg(
"recurse") = 
true)
   137       .def(
"named_buffers", [](ModuleType& module, 
bool recurse) {
   138             return module.named_buffers(recurse);
   140           py::arg(
"recurse") = 
true)
   141       .def_property_readonly(
   142         "_modules", [](ModuleType& module) { 
return module.named_children(); })
   143       .def(
"modules", [](ModuleType& module) { 
return module.modules(); })
   144       .def(
"named_modules",
   145           [](ModuleType& module, py::object , std::string prefix) {
   146             return module.named_modules(std::move(prefix));
   148           py::arg(
"memo") = py::none(),
   149           py::arg(
"prefix") = std::string())
   150       .def(
"children", [](ModuleType& module) { 
return module.children(); })
   151       .def(
"named_children",
   152           [](ModuleType& module) { 
return module.named_children(); })
   153       .def(
"to", [](ModuleType& module, py::object 
object, 
bool non_blocking) {
   154             if (THPDevice_Check(
object.ptr())) {
   156                   reinterpret_cast<THPDevice*>(
object.ptr())->device,
   159               module.to(detail::py_object_to_dtype(
object), non_blocking);
   162           py::arg(
"dtype_or_device"),
   163           py::arg(
"non_blocking") = 
false)
   165           [](ModuleType& module,
   169               if (device.is_none()) {
   170                 module.to(detail::py_object_to_dtype(dtype), non_blocking);
   171               } 
else if (dtype.is_none()) {
   172                 module.to(detail::py_object_to_device(device), non_blocking);
   175                     detail::py_object_to_device(device),
   176                     detail::py_object_to_dtype(dtype),
   182           py::arg(
"non_blocking") = 
false)
   183       .def(
"cuda", [](ModuleType& module) { module.to(kCUDA); })
   184       .def(
"cpu", [](ModuleType& module) { module.to(kCPU); })
   185       .def(
"float", [](ModuleType& module) { module.to(kFloat32); })
   186       .def(
"double", [](ModuleType& module) { module.to(kFloat64); })
   187       .def(
"half", [](ModuleType& module) { module.to(kFloat16); })
   188       .def(
"__str__", [](ModuleType& module) { 
return module.name(); })
   189       .def(
"__repr__", [](ModuleType& module) { 
return module.name(); });
   214 template <
typename ModuleType, 
bool force_enable = false>
   217     detail::PyModuleClass<ModuleType>>
   218 bind_module(py::module module, 
const char* name) {
   219   py::module cpp = module.def_submodule(
"cpp");
   221       add_module_bindings(detail::PyModuleClass<ModuleType>(cpp, name));
   222   detail::bind_cpp_module_wrapper(module, cpp_class, name);
   253         torch::enable_if_t<torch::detail::has_forward<ModuleType>::value>>
   254 detail::PyModuleClass<ModuleType> bind_module(
   257   return bind_module<ModuleType, 
true>(module, name)
   258       .def(
"forward", &ModuleType::forward)
   259       .def(
"__call__", &ModuleType::forward);
 
Detects if a type T has a forward() method.