The base class for all modules in PyTorch. More...
#include <module.h>
Public Types | |
using | ModuleApplyFunction = std::function< void(Module &)> |
using | ConstModuleApplyFunction = std::function< void(const Module &)> |
using | NamedModuleApplyFunction = std::function< void(const std::string &, Module &)> |
using | ConstNamedModuleApplyFunction = std::function< void(const std::string &, const Module &)> |
using | ModulePointerApplyFunction = std::function< void(const std::shared_ptr< Module > &)> |
using | NamedModulePointerApplyFunction = std::function< void(const std::string &, const std::shared_ptr< Module > &)> |
Public Member Functions | |
Module (std::string name) | |
Tells the base Module about the name of the submodule. | |
Module () | |
Constructs the module without immediate knowledge of the submodule's name. More... | |
const std::string & | name () const noexcept |
Returns the name of the Module . More... | |
virtual std::shared_ptr< Module > | clone (const optional< Device > &device=nullopt) const |
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules. More... | |
void | apply (const ModuleApplyFunction &function) |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const ConstModuleApplyFunction &function) const |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const NamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const ConstNamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) const |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const ModulePointerApplyFunction &function) const |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const NamedModulePointerApplyFunction &function, const std::string &name_prefix=std::string()) const |
Applies the function to the Module and recursively to every submodule. More... | |
std::vector< Tensor > | parameters (bool recurse=true) const |
Returns the parameters of this Module and if recurse is true, also recursively of every submodule. More... | |
OrderedDict< std::string, Tensor > | named_parameters (bool recurse=true) const |
Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule. More... | |
std::vector< Tensor > | buffers (bool recurse=true) const |
Returns the buffers of this Module and if recurse is true, also recursively of every submodule. More... | |
OrderedDict< std::string, Tensor > | named_buffers (bool recurse=true) const |
Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule. More... | |
std::vector< std::shared_ptr< Module > > | modules (bool include_self=true) const |
Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position. More... | |
OrderedDict< std::string, std::shared_ptr< Module > > | named_modules (const std::string &name_prefix=std::string(), bool include_self=true) const |
Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys, and if include_self is true, also inserts a shared_ptr to this module in the first position. More... | |
std::vector< std::shared_ptr< Module > > | children () const |
Returns the direct submodules of this Module . | |
OrderedDict< std::string, std::shared_ptr< Module > > | named_children () const |
Returns an OrderedDict of the direct submodules of this Module and their keys. More... | |
virtual void | train (bool on=true) |
Enables "training" mode. | |
void | eval () |
Calls train(false) to enable "eval" mode. More... | |
virtual bool | is_training () const noexcept |
True if the module is in training mode. More... | |
virtual void | to (torch::Device device, torch::Dtype dtype, bool non_blocking=false) |
Recursively casts all parameters to the given dtype and device . More... | |
virtual void | to (torch::Dtype dtype, bool non_blocking=false) |
Recursively casts all parameters to the given dtype. More... | |
virtual void | to (torch::Device device, bool non_blocking=false) |
Recursively moves all parameters to the given device. More... | |
virtual void | zero_grad () |
Recursively zeros out the grad value of each registered parameter. | |
template<typename ModuleType > | |
ModuleType::ContainedType * | as () noexcept |
Attempts to cast this Module to the given ModuleType . More... | |
template<typename ModuleType > | |
const ModuleType::ContainedType * | as () const noexcept |
Attempts to cast this Module to the given ModuleType . More... | |
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>> | |
ModuleType * | as () noexcept |
Attempts to cast this Module to the given ModuleType . More... | |
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>> | |
const ModuleType * | as () const noexcept |
Attempts to cast this Module to the given ModuleType . More... | |
virtual void | save (serialize::OutputArchive &archive) const |
Serializes the Module into the given OutputArchive . | |
virtual void | load (serialize::InputArchive &archive) |
Deserializes the Module from the given InputArchive . | |
virtual void | pretty_print (std::ostream &stream) const |
Streams a pretty representation of the Module into the given stream . More... | |
Protected Member Functions | |
Tensor & | register_parameter (std::string name, Tensor tensor, bool requires_grad=true) |
Registers a parameter with this Module . More... | |
Tensor & | register_buffer (std::string name, Tensor tensor) |
Registers a buffer with this Module . More... | |
template<typename ModuleType > | |
std::shared_ptr< ModuleType > | register_module (std::string name, std::shared_ptr< ModuleType > module) |
Registers a submodule with this Module . More... | |
template<typename ModuleType > | |
std::shared_ptr< ModuleType > | register_module (std::string name, ModuleHolder< ModuleType > module_holder) |
Registers a submodule with this Module . More... | |
Friends | |
template<typename Derived > | |
class | Cloneable |
TORCH_API friend std::ostream & | operator<< (std::ostream &stream, const nn::Module &module) |
Pretty prints the given Module into the ostream . | |
The base class for all modules in PyTorch.
.. note:: The design and implementation of this class is largely based on the Python API. You may want to consult the python documentation for :py:class:pytorch:torch.nn.Module
for further clarification on certain methods or behavior.
A Module
is an abstraction over the implementation of some function or algorithm, possibly associated with some persistent data. A Module
may contain further Module
s ("submodules"), each with their own implementation, persistent data and further submodules. Module
s can thus be said to form a recursive tree structure. A Module
is registered as a submodule to another Module
by calling register_module()
, typically from within a parent module's constructor.
A distinction is made between three kinds of persistent data that may be associated with a Module
:
weight
of a Linear
module),mean
and variance
in the BatchNorm
module),Module
.The first two kinds of state are special in that they may be registered with the Module
system to allow convenient access and batch configuration. For example, registered parameters in any Module
may be iterated over via the parameters()
accessor. Further, changing the data type of a Module
's registered parameters can be done conveniently via Module::to()
, e.g. module->to(torch::kCUDA)
to move all parameters to GPU memory. Lastly, registered parameters and buffers are handled specially during a clone()
operation, which performs a deepcopy of a cloneable Module
hierarchy.
Parameters are registered with a Module
via register_parameter
. Buffers are registered separately via register_buffer
. These methods are part of the protected API of Module
and are typically invoked from within a concrete Module
s constructor.
torch::nn::Module::Module | ( | ) |
Constructs the module without immediate knowledge of the submodule's name.
The name of the submodule is inferred via RTTI (if possible) the first time .name()
is invoked.
Definition at line 46 of file module.cpp.
void torch::nn::Module::apply | ( | const ModuleApplyFunction & | function | ) |
Applies the function
to the Module
and recursively to every submodule.
The function must accept a Module&
.
.. code-block:: cpp MyModule module; module->apply([](nn::Module& module) { std::cout << module.name() << std::endl; });
Definition at line 87 of file module.cpp.
void torch::nn::Module::apply | ( | const ConstModuleApplyFunction & | function | ) | const |
Applies the function
to the Module
and recursively to every submodule.
The function must accept a const Module&
.
.. code-block:: cpp MyModule module; module->apply([](const nn::Module& module) { std::cout << module.name() << std::endl; });
Definition at line 95 of file module.cpp.
void torch::nn::Module::apply | ( | const NamedModuleApplyFunction & | function, |
const std::string & | name_prefix = std::string() |
||
) |
Applies the function
to the Module
and recursively to every submodule.
The function must accept a const std::string&
for the key of the module, and a Module&
. The key of the module itself is the empty string. If name_prefix
is given, it is prepended to every key as <name_prefix>.<key>
(and just name_prefix
for the module itself).
.. code-block:: cpp MyModule module; module->apply([](const std::string& key, nn::Module& module) { std::cout << key << ": " << module.name() << std::endl; });
Definition at line 103 of file module.cpp.
void torch::nn::Module::apply | ( | const ConstNamedModuleApplyFunction & | function, |
const std::string & | name_prefix = std::string() |
||
) | const |
Applies the function
to the Module
and recursively to every submodule.
The function must accept a const std::string&
for the key of the module, and a const Module&
. The key of the module itself is the empty string. If name_prefix
is given, it is prepended to every key as <name_prefix>.<key>
(and just name_prefix
for the module itself).
.. code-block:: cpp MyModule module; module->apply([](const std::string& key, const nn::Module& module) { std::cout << key << ": " << module.name() << std::endl; });
Definition at line 115 of file module.cpp.
void torch::nn::Module::apply | ( | const ModulePointerApplyFunction & | function | ) | const |
Applies the function
to the Module
and recursively to every submodule.
The function must accept a const std::shared_ptr<Module>&
.
.. code-block:: cpp MyModule module; module->apply([](const std::shared_ptr<nn::Module>& module) { std::cout << module->name() << std::endl; });
Definition at line 127 of file module.cpp.
void torch::nn::Module::apply | ( | const NamedModulePointerApplyFunction & | function, |
const std::string & | name_prefix = std::string() |
||
) | const |
Applies the function
to the Module
and recursively to every submodule.
The function must accept a const std::string&
for the key of the module, and a const std::shared_ptr<Module>&
. The key of the module itself is the empty string. If name_prefix
is given, it is prepended to every key as <name_prefix>.<key>
(and just name_prefix
for the module itself).
.. code-block:: cpp MyModule module; module->apply([](const std::string& key, const std::shared_ptr<nn::Module>& module) { std::cout << key << ": " << module->name() << std::endl; });
Definition at line 135 of file module.cpp.
|
noexcept |
Attempts to cast this Module
to the given ModuleType
.
This method is useful when calling apply()
. .. code-block:: cpp
void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } }
MyModule module; module->apply(initialize_weights);
|
noexcept |
Attempts to cast this Module
to the given ModuleType
.
This method is useful when calling apply()
. .. code-block:: cpp void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } }
MyModule module; module->apply(initialize_weights);
|
noexcept |
Attempts to cast this Module
to the given ModuleType
.
This method is useful when calling apply()
. .. code-block:: cpp
void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } }
MyModule module; module.apply(initialize_weights);
|
noexcept |
Attempts to cast this Module
to the given ModuleType
.
This method is useful when calling apply()
. .. code-block:: cpp
void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } }
MyModule module; module.apply(initialize_weights);
std::vector< Tensor > torch::nn::Module::buffers | ( | bool | recurse = true | ) | const |
Returns the buffers of this Module
and if recurse
is true, also recursively of every submodule.
Definition at line 166 of file module.cpp.
|
virtual |
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules.
Optionally, this method sets the current device to the one supplied before cloning. If no device is given, each parameter and buffer will be moved to the device of its source.
.. attention:: Attempting to call the clone()
method inherited from the base Module
class (the one documented here) will fail. To inherit an actual implementation of clone()
, you must subclass Cloneable
. Cloneable
is templatized on the concrete module type, and can thus properly copy a Module
. This method is provided on the base class' API solely for an easier-to-use polymorphic interface.
Reimplemented in torch::nn::SequentialImpl, torch::nn::Cloneable< Derived >, torch::nn::Cloneable< BatchNormImpl >, torch::nn::Cloneable< SimpleContainer >, torch::nn::Cloneable< SequentialImpl >, torch::nn::Cloneable< Conv2dImpl >, torch::nn::Cloneable< RNNImpl >, torch::nn::Cloneable< Conv3dImpl >, torch::nn::Cloneable< LSTMImpl >, torch::nn::Cloneable< Net >, torch::nn::Cloneable< EmbeddingImpl >, torch::nn::Cloneable< LinearImpl >, torch::nn::Cloneable< DropoutImpl >, torch::nn::Cloneable< FeatureDropoutImpl >, torch::nn::Cloneable< Conv1dImpl >, and torch::nn::Cloneable< GRUImpl >.
Definition at line 78 of file module.cpp.
void torch::nn::Module::eval | ( | ) |
Calls train(false) to enable "eval" mode.
Do not override this method, override train()
instead.
Definition at line 240 of file module.cpp.
|
virtualnoexcept |
True if the module is in training mode.
Every Module
has a boolean associated with it that determines whether the Module
is currently in training mode (set via .train()
) or in evaluation (inference) mode (set via .eval()
). This property is exposed via is_training()
, and may be used by the implementation of a concrete module to modify its runtime behavior. See the BatchNorm
or Dropout
modules for examples of Module
s that use different code paths depending on this property.
Definition at line 256 of file module.cpp.
std::vector< std::shared_ptr< Module > > torch::nn::Module::modules | ( | bool | include_self = true | ) | const |
Returns the submodules of this Module
(the entire submodule hierarchy) and if include_self
is true, also inserts a shared_ptr
to this module in the first position.
.. warning:: Only pass include_self
as true
if this Module
is stored in a shared_ptr
! Otherwise an exception will be thrown. You may still call this method with include_self
set to false if your Module
is not stored in a shared_ptr
.
Definition at line 187 of file module.cpp.
|
noexcept |
Returns the name of the Module
.
A Module
has an associated name
, which is a string representation of the kind of concrete Module
it represents, such as "Linear"
for the Linear
module. Under most circumstances, this name is automatically inferred via runtime type information (RTTI). In the unusual circumstance that you have this feature disabled, you may want to manually name your Module
s by passing the string name to the Module
base class' constructor.
Definition at line 53 of file module.cpp.
OrderedDict< std::string, Tensor > torch::nn::Module::named_buffers | ( | bool | recurse = true | ) | const |
Returns an OrderedDict
with the buffers of this Module
along with their keys, and if recurse
is true also recursively of every submodule.
Definition at line 174 of file module.cpp.
OrderedDict< std::string, std::shared_ptr< Module > > torch::nn::Module::named_children | ( | ) | const |
Returns an OrderedDict
of the direct submodules of this Module
and their keys.
Definition at line 228 of file module.cpp.
OrderedDict< std::string, std::shared_ptr< Module > > torch::nn::Module::named_modules | ( | const std::string & | name_prefix = std::string() , |
bool | include_self = true |
||
) | const |
Returns an OrderedDict
of he submodules of this Module
(the entire submodule hierarchy) and thei keys, and if include_self
is true, also inserts a shared_ptr
to this module in the first position.
If name_prefix
is given, it is prepended to every key as <name_prefix>.<key>
(and just name_prefix
for the module itself).
.. warning:: Only pass include_self
as true
if this Module
is stored in a shared_ptr
! Otherwise an exception will be thrown. You may still call this method with include_self
set to false if your Module
is not stored in a shared_ptr
.
Definition at line 202 of file module.cpp.
OrderedDict< std::string, Tensor > torch::nn::Module::named_parameters | ( | bool | recurse = true | ) | const |
Returns an OrderedDict
with the parameters of this Module
along with their keys, and if recurse
is true also recursively of every submodule.
Definition at line 153 of file module.cpp.
std::vector< Tensor > torch::nn::Module::parameters | ( | bool | recurse = true | ) | const |
Returns the parameters of this Module
and if recurse
is true, also recursively of every submodule.
Definition at line 143 of file module.cpp.
|
virtual |
Streams a pretty representation of the Module
into the given stream
.
By default, this representation will be the name of the module (taken from name()
), followed by a recursive pretty print of all of the Module
's submodules.
Override this method to change the pretty print. The input stream
should be returned from the method, to allow easy chaining.
Reimplemented in torch::nn::RNNImpl, torch::nn::SequentialImpl, torch::nn::ConvImpl< D, Derived >, torch::nn::ConvImpl< 1, Conv1dImpl >, torch::nn::ConvImpl< 3, Conv3dImpl >, torch::nn::ConvImpl< 2, Conv2dImpl >, torch::nn::detail::RNNImplBase< Derived >, torch::nn::detail::RNNImplBase< RNNImpl >, torch::nn::detail::RNNImplBase< LSTMImpl >, torch::nn::detail::RNNImplBase< GRUImpl >, torch::nn::FeatureDropoutImpl, torch::nn::BatchNormImpl, torch::nn::DropoutImpl, torch::nn::LinearImpl, and torch::nn::EmbeddingImpl.
Definition at line 325 of file module.cpp.
Registers a buffer with this Module
.
A buffer is intended to be state in your module that does not record gradients, such as running statistics. Registering it makes it available to methods such as buffers()
, clone()
or `to().
.. code-block:: cpp
MyModule::MyModule() { mean_ = register_buffer("mean", torch::empty({num_features_})); }
Definition at line 315 of file module.cpp.
|
protected |
|
protected |
Registers a submodule with this Module
.
This method deals with ModuleHolder
s.
Registering a module makes it available to methods such as modules()
, clone()
or to()
.
.. code-block:: cpp
MyModule::MyModule() { submodule_ = register_module("linear", torch::nn::Linear(3, 4)); }
|
protected |
Registers a parameter with this Module
.
A parameter should be any gradient-recording tensor used in the implementation of your Module
. Registering it makes it available to methods such as parameters()
, clone()
or to().
.. code-block:: cpp
MyModule::MyModule() { weight_ = register_parameter("weight", torch::randn({A, B})); }
Definition at line 301 of file module.cpp.
|
virtual |
Recursively casts all parameters to the given dtype
and device
.
If non_blocking
is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
Reimplemented in torch::nn::detail::RNNImplBase< Derived >, torch::nn::detail::RNNImplBase< RNNImpl >, torch::nn::detail::RNNImplBase< LSTMImpl >, and torch::nn::detail::RNNImplBase< GRUImpl >.
Definition at line 244 of file module.cpp.
|
virtual |
Recursively casts all parameters to the given dtype.
If non_blocking
is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
Reimplemented in torch::nn::detail::RNNImplBase< Derived >, torch::nn::detail::RNNImplBase< RNNImpl >, torch::nn::detail::RNNImplBase< LSTMImpl >, and torch::nn::detail::RNNImplBase< GRUImpl >.
Definition at line 248 of file module.cpp.
|
virtual |
Recursively moves all parameters to the given device.
If non_blocking
is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
Reimplemented in torch::nn::detail::RNNImplBase< Derived >, torch::nn::detail::RNNImplBase< RNNImpl >, torch::nn::detail::RNNImplBase< LSTMImpl >, and torch::nn::detail::RNNImplBase< GRUImpl >.
Definition at line 252 of file module.cpp.