The base class for all modules in PyTorch. More...
#include <module.h>
Public Types | |
| using | ModuleApplyFunction = std::function< void(Module &)> |
| using | ConstModuleApplyFunction = std::function< void(const Module &)> |
| using | NamedModuleApplyFunction = std::function< void(const std::string &, Module &)> |
| using | ConstNamedModuleApplyFunction = std::function< void(const std::string &, const Module &)> |
| using | ModulePointerApplyFunction = std::function< void(const std::shared_ptr< Module > &)> |
| using | NamedModulePointerApplyFunction = std::function< void(const std::string &, const std::shared_ptr< Module > &)> |
Public Member Functions | |
| Module (std::string name) | |
Tells the base Module about the name of the submodule. | |
| Module () | |
| Constructs the module without immediate knowledge of the submodule's name. More... | |
| const std::string & | name () const noexcept |
Returns the name of the Module. More... | |
| virtual std::shared_ptr< Module > | clone (const optional< Device > &device=nullopt) const |
| Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules. More... | |
| void | apply (const ModuleApplyFunction &function) |
Applies the function to the Module and recursively to every submodule. More... | |
| void | apply (const ConstModuleApplyFunction &function) const |
Applies the function to the Module and recursively to every submodule. More... | |
| void | apply (const NamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) |
Applies the function to the Module and recursively to every submodule. More... | |
| void | apply (const ConstNamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) const |
Applies the function to the Module and recursively to every submodule. More... | |
| void | apply (const ModulePointerApplyFunction &function) const |
Applies the function to the Module and recursively to every submodule. More... | |
| void | apply (const NamedModulePointerApplyFunction &function, const std::string &name_prefix=std::string()) const |
Applies the function to the Module and recursively to every submodule. More... | |
| std::vector< Tensor > | parameters (bool recurse=true) const |
Returns the parameters of this Module and if recurse is true, also recursively of every submodule. More... | |
| OrderedDict< std::string, Tensor > | named_parameters (bool recurse=true) const |
Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule. More... | |
| std::vector< Tensor > | buffers (bool recurse=true) const |
Returns the buffers of this Module and if recurse is true, also recursively of every submodule. More... | |
| OrderedDict< std::string, Tensor > | named_buffers (bool recurse=true) const |
Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule. More... | |
| std::vector< std::shared_ptr< Module > > | modules (bool include_self=true) const |
Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position. More... | |
| OrderedDict< std::string, std::shared_ptr< Module > > | named_modules (const std::string &name_prefix=std::string(), bool include_self=true) const |
Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys, and if include_self is true, also inserts a shared_ptr to this module in the first position. More... | |
| std::vector< std::shared_ptr< Module > > | children () const |
Returns the direct submodules of this Module. | |
| OrderedDict< std::string, std::shared_ptr< Module > > | named_children () const |
Returns an OrderedDict of the direct submodules of this Module and their keys. More... | |
| virtual void | train (bool on=true) |
| Enables "training" mode. | |
| void | eval () |
| Calls train(false) to enable "eval" mode. More... | |
| virtual bool | is_training () const noexcept |
| True if the module is in training mode. More... | |
| virtual void | to (torch::Device device, torch::Dtype dtype, bool non_blocking=false) |
Recursively casts all parameters to the given dtype and device. More... | |
| virtual void | to (torch::Dtype dtype, bool non_blocking=false) |
| Recursively casts all parameters to the given dtype. More... | |
| virtual void | to (torch::Device device, bool non_blocking=false) |
| Recursively moves all parameters to the given device. More... | |
| virtual void | zero_grad () |
Recursively zeros out the grad value of each registered parameter. | |
| template<typename ModuleType > | |
| ModuleType::ContainedType * | as () noexcept |
Attempts to cast this Module to the given ModuleType. More... | |
| template<typename ModuleType > | |
| const ModuleType::ContainedType * | as () const noexcept |
Attempts to cast this Module to the given ModuleType. More... | |
| template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>> | |
| ModuleType * | as () noexcept |
Attempts to cast this Module to the given ModuleType. More... | |
| template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>> | |
| const ModuleType * | as () const noexcept |
Attempts to cast this Module to the given ModuleType. More... | |
| virtual void | save (serialize::OutputArchive &archive) const |
Serializes the Module into the given OutputArchive. | |
| virtual void | load (serialize::InputArchive &archive) |
Deserializes the Module from the given InputArchive. | |
| virtual void | pretty_print (std::ostream &stream) const |
Streams a pretty representation of the Module into the given stream. More... | |
Protected Member Functions | |
| Tensor & | register_parameter (std::string name, Tensor tensor, bool requires_grad=true) |
Registers a parameter with this Module. More... | |
| Tensor & | register_buffer (std::string name, Tensor tensor) |
Registers a buffer with this Module. More... | |
| template<typename ModuleType > | |
| std::shared_ptr< ModuleType > | register_module (std::string name, std::shared_ptr< ModuleType > module) |
Registers a submodule with this Module. More... | |
| template<typename ModuleType > | |
| std::shared_ptr< ModuleType > | register_module (std::string name, ModuleHolder< ModuleType > module_holder) |
Registers a submodule with this Module. More... | |
Friends | |
| template<typename Derived > | |
| class | Cloneable |
| TORCH_API friend std::ostream & | operator<< (std::ostream &stream, const nn::Module &module) |
Pretty prints the given Module into the ostream. | |
The base class for all modules in PyTorch.
.. note:: The design and implementation of this class is largely based on the Python API. You may want to consult the python documentation for :py:class:pytorch:torch.nn.Module for further clarification on certain methods or behavior.
A Module is an abstraction over the implementation of some function or algorithm, possibly associated with some persistent data. A Module may contain further Modules ("submodules"), each with their own implementation, persistent data and further submodules. Modules can thus be said to form a recursive tree structure. A Module is registered as a submodule to another Module by calling register_module(), typically from within a parent module's constructor.
A distinction is made between three kinds of persistent data that may be associated with a Module:
weight of a Linear module),mean and variance in the BatchNorm module),Module.The first two kinds of state are special in that they may be registered with the Module system to allow convenient access and batch configuration. For example, registered parameters in any Module may be iterated over via the parameters() accessor. Further, changing the data type of a Module's registered parameters can be done conveniently via Module::to(), e.g. module->to(torch::kCUDA) to move all parameters to GPU memory. Lastly, registered parameters and buffers are handled specially during a clone() operation, which performs a deepcopy of a cloneable Module hierarchy.
Parameters are registered with a Module via register_parameter. Buffers are registered separately via register_buffer. These methods are part of the protected API of Module and are typically invoked from within a concrete Modules constructor.
| torch::nn::Module::Module | ( | ) |
Constructs the module without immediate knowledge of the submodule's name.
The name of the submodule is inferred via RTTI (if possible) the first time .name() is invoked.
Definition at line 46 of file module.cpp.
| void torch::nn::Module::apply | ( | const ModuleApplyFunction & | function | ) |
Applies the function to the Module and recursively to every submodule.
The function must accept a Module&.
.. code-block:: cpp MyModule module; module->apply([](nn::Module& module) { std::cout << module.name() << std::endl; });
Definition at line 87 of file module.cpp.
| void torch::nn::Module::apply | ( | const ConstModuleApplyFunction & | function | ) | const |
Applies the function to the Module and recursively to every submodule.
The function must accept a const Module&.
.. code-block:: cpp MyModule module; module->apply([](const nn::Module& module) { std::cout << module.name() << std::endl; });
Definition at line 95 of file module.cpp.
| void torch::nn::Module::apply | ( | const NamedModuleApplyFunction & | function, |
| const std::string & | name_prefix = std::string() |
||
| ) |
Applies the function to the Module and recursively to every submodule.
The function must accept a const std::string& for the key of the module, and a Module&. The key of the module itself is the empty string. If name_prefix is given, it is prepended to every key as <name_prefix>.<key> (and just name_prefix for the module itself).
.. code-block:: cpp MyModule module; module->apply([](const std::string& key, nn::Module& module) { std::cout << key << ": " << module.name() << std::endl; });
Definition at line 103 of file module.cpp.
| void torch::nn::Module::apply | ( | const ConstNamedModuleApplyFunction & | function, |
| const std::string & | name_prefix = std::string() |
||
| ) | const |
Applies the function to the Module and recursively to every submodule.
The function must accept a const std::string& for the key of the module, and a const Module&. The key of the module itself is the empty string. If name_prefix is given, it is prepended to every key as <name_prefix>.<key> (and just name_prefix for the module itself).
.. code-block:: cpp MyModule module; module->apply([](const std::string& key, const nn::Module& module) { std::cout << key << ": " << module.name() << std::endl; });
Definition at line 115 of file module.cpp.
| void torch::nn::Module::apply | ( | const ModulePointerApplyFunction & | function | ) | const |
Applies the function to the Module and recursively to every submodule.
The function must accept a const std::shared_ptr<Module>&.
.. code-block:: cpp MyModule module; module->apply([](const std::shared_ptr<nn::Module>& module) { std::cout << module->name() << std::endl; });
Definition at line 127 of file module.cpp.
| void torch::nn::Module::apply | ( | const NamedModulePointerApplyFunction & | function, |
| const std::string & | name_prefix = std::string() |
||
| ) | const |
Applies the function to the Module and recursively to every submodule.
The function must accept a const std::string& for the key of the module, and a const std::shared_ptr<Module>&. The key of the module itself is the empty string. If name_prefix is given, it is prepended to every key as <name_prefix>.<key> (and just name_prefix for the module itself).
.. code-block:: cpp MyModule module; module->apply([](const std::string& key, const std::shared_ptr<nn::Module>& module) { std::cout << key << ": " << module->name() << std::endl; });
Definition at line 135 of file module.cpp.
|
noexcept |
Attempts to cast this Module to the given ModuleType.
This method is useful when calling apply(). .. code-block:: cpp
void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } }
MyModule module; module->apply(initialize_weights);
|
noexcept |
Attempts to cast this Module to the given ModuleType.
This method is useful when calling apply(). .. code-block:: cpp void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } }
MyModule module; module->apply(initialize_weights);
|
noexcept |
Attempts to cast this Module to the given ModuleType.
This method is useful when calling apply(). .. code-block:: cpp
void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } }
MyModule module; module.apply(initialize_weights);
|
noexcept |
Attempts to cast this Module to the given ModuleType.
This method is useful when calling apply(). .. code-block:: cpp
void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } }
MyModule module; module.apply(initialize_weights);
| std::vector< Tensor > torch::nn::Module::buffers | ( | bool | recurse = true | ) | const |
Returns the buffers of this Module and if recurse is true, also recursively of every submodule.
Definition at line 166 of file module.cpp.
|
virtual |
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules.
Optionally, this method sets the current device to the one supplied before cloning. If no device is given, each parameter and buffer will be moved to the device of its source.
.. attention:: Attempting to call the clone() method inherited from the base Module class (the one documented here) will fail. To inherit an actual implementation of clone(), you must subclass Cloneable. Cloneable is templatized on the concrete module type, and can thus properly copy a Module. This method is provided on the base class' API solely for an easier-to-use polymorphic interface.
Reimplemented in torch::nn::SequentialImpl, torch::nn::Cloneable< Derived >, torch::nn::Cloneable< BatchNormImpl >, torch::nn::Cloneable< SimpleContainer >, torch::nn::Cloneable< SequentialImpl >, torch::nn::Cloneable< Conv2dImpl >, torch::nn::Cloneable< RNNImpl >, torch::nn::Cloneable< Conv3dImpl >, torch::nn::Cloneable< LSTMImpl >, torch::nn::Cloneable< Net >, torch::nn::Cloneable< EmbeddingImpl >, torch::nn::Cloneable< LinearImpl >, torch::nn::Cloneable< DropoutImpl >, torch::nn::Cloneable< FeatureDropoutImpl >, torch::nn::Cloneable< Conv1dImpl >, and torch::nn::Cloneable< GRUImpl >.
Definition at line 78 of file module.cpp.
| void torch::nn::Module::eval | ( | ) |
Calls train(false) to enable "eval" mode.
Do not override this method, override train() instead.
Definition at line 240 of file module.cpp.
|
virtualnoexcept |
True if the module is in training mode.
Every Module has a boolean associated with it that determines whether the Module is currently in training mode (set via .train()) or in evaluation (inference) mode (set via .eval()). This property is exposed via is_training(), and may be used by the implementation of a concrete module to modify its runtime behavior. See the BatchNorm or Dropout modules for examples of Modules that use different code paths depending on this property.
Definition at line 256 of file module.cpp.
| std::vector< std::shared_ptr< Module > > torch::nn::Module::modules | ( | bool | include_self = true | ) | const |
Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position.
.. warning:: Only pass include_self as true if this Module is stored in a shared_ptr! Otherwise an exception will be thrown. You may still call this method with include_self set to false if your Module is not stored in a shared_ptr.
Definition at line 187 of file module.cpp.
|
noexcept |
Returns the name of the Module.
A Module has an associated name, which is a string representation of the kind of concrete Module it represents, such as "Linear" for the Linear module. Under most circumstances, this name is automatically inferred via runtime type information (RTTI). In the unusual circumstance that you have this feature disabled, you may want to manually name your Modules by passing the string name to the Module base class' constructor.
Definition at line 53 of file module.cpp.
| OrderedDict< std::string, Tensor > torch::nn::Module::named_buffers | ( | bool | recurse = true | ) | const |
Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule.
Definition at line 174 of file module.cpp.
| OrderedDict< std::string, std::shared_ptr< Module > > torch::nn::Module::named_children | ( | ) | const |
Returns an OrderedDict of the direct submodules of this Module and their keys.
Definition at line 228 of file module.cpp.
| OrderedDict< std::string, std::shared_ptr< Module > > torch::nn::Module::named_modules | ( | const std::string & | name_prefix = std::string(), |
| bool | include_self = true |
||
| ) | const |
Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys, and if include_self is true, also inserts a shared_ptr to this module in the first position.
If name_prefix is given, it is prepended to every key as <name_prefix>.<key> (and just name_prefix for the module itself).
.. warning:: Only pass include_self as true if this Module is stored in a shared_ptr! Otherwise an exception will be thrown. You may still call this method with include_self set to false if your Module is not stored in a shared_ptr.
Definition at line 202 of file module.cpp.
| OrderedDict< std::string, Tensor > torch::nn::Module::named_parameters | ( | bool | recurse = true | ) | const |
Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule.
Definition at line 153 of file module.cpp.
| std::vector< Tensor > torch::nn::Module::parameters | ( | bool | recurse = true | ) | const |
Returns the parameters of this Module and if recurse is true, also recursively of every submodule.
Definition at line 143 of file module.cpp.
|
virtual |
Streams a pretty representation of the Module into the given stream.
By default, this representation will be the name of the module (taken from name()), followed by a recursive pretty print of all of the Module's submodules.
Override this method to change the pretty print. The input stream should be returned from the method, to allow easy chaining.
Reimplemented in torch::nn::RNNImpl, torch::nn::SequentialImpl, torch::nn::ConvImpl< D, Derived >, torch::nn::ConvImpl< 1, Conv1dImpl >, torch::nn::ConvImpl< 3, Conv3dImpl >, torch::nn::ConvImpl< 2, Conv2dImpl >, torch::nn::detail::RNNImplBase< Derived >, torch::nn::detail::RNNImplBase< RNNImpl >, torch::nn::detail::RNNImplBase< LSTMImpl >, torch::nn::detail::RNNImplBase< GRUImpl >, torch::nn::FeatureDropoutImpl, torch::nn::BatchNormImpl, torch::nn::DropoutImpl, torch::nn::LinearImpl, and torch::nn::EmbeddingImpl.
Definition at line 325 of file module.cpp.
Registers a buffer with this Module.
A buffer is intended to be state in your module that does not record gradients, such as running statistics. Registering it makes it available to methods such as buffers(), clone() or `to().
.. code-block:: cpp
MyModule::MyModule() { mean_ = register_buffer("mean", torch::empty({num_features_})); }
Definition at line 315 of file module.cpp.
|
protected |
|
protected |
Registers a submodule with this Module.
This method deals with ModuleHolders.
Registering a module makes it available to methods such as modules(), clone() or to().
.. code-block:: cpp
MyModule::MyModule() { submodule_ = register_module("linear", torch::nn::Linear(3, 4)); }
|
protected |
Registers a parameter with this Module.
A parameter should be any gradient-recording tensor used in the implementation of your Module. Registering it makes it available to methods such as parameters(), clone() or to().
.. code-block:: cpp
MyModule::MyModule() { weight_ = register_parameter("weight", torch::randn({A, B})); }
Definition at line 301 of file module.cpp.
|
virtual |
Recursively casts all parameters to the given dtype and device.
If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
Reimplemented in torch::nn::detail::RNNImplBase< Derived >, torch::nn::detail::RNNImplBase< RNNImpl >, torch::nn::detail::RNNImplBase< LSTMImpl >, and torch::nn::detail::RNNImplBase< GRUImpl >.
Definition at line 244 of file module.cpp.
|
virtual |
Recursively casts all parameters to the given dtype.
If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
Reimplemented in torch::nn::detail::RNNImplBase< Derived >, torch::nn::detail::RNNImplBase< RNNImpl >, torch::nn::detail::RNNImplBase< LSTMImpl >, and torch::nn::detail::RNNImplBase< GRUImpl >.
Definition at line 248 of file module.cpp.
|
virtual |
Recursively moves all parameters to the given device.
If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
Reimplemented in torch::nn::detail::RNNImplBase< Derived >, torch::nn::detail::RNNImplBase< RNNImpl >, torch::nn::detail::RNNImplBase< LSTMImpl >, and torch::nn::detail::RNNImplBase< GRUImpl >.
Definition at line 252 of file module.cpp.
1.8.11