Caffe2 - C++ API
A deep learning, cross platform ML framework
Public Types | Public Member Functions | Protected Member Functions | Friends
torch::nn::Module Class Reference

The base class for all modules in PyTorch. More...

#include <module.h>

Inheritance diagram for torch::nn::Module:
A AGIUnit AImpl B BufferTestModule C torch::nn::Cloneable< BatchNormImpl > torch::nn::Cloneable< Conv1dImpl > torch::nn::Cloneable< Conv2dImpl > torch::nn::Cloneable< Conv3dImpl > torch::nn::Cloneable< DropoutImpl > torch::nn::Cloneable< EmbeddingImpl > torch::nn::Cloneable< FeatureDropoutImpl > torch::nn::Cloneable< GRUImpl > torch::nn::Cloneable< LinearImpl > torch::nn::Cloneable< LSTMImpl > torch::nn::Cloneable< Net > torch::nn::Cloneable< RNNImpl > torch::nn::Cloneable< SequentialImpl > torch::nn::Cloneable< SimpleContainer > D E EmptyModule M ModuleWithNonTensorForwardImpl NestedModel ParameterTestModule test::AGIUnit test::AGIUnit2 TestContainer TestModel TestModule torch::nn::Cloneable< Derived >

Public Types

using ModuleApplyFunction = std::function< void(Module &)>
 
using ConstModuleApplyFunction = std::function< void(const Module &)>
 
using NamedModuleApplyFunction = std::function< void(const std::string &, Module &)>
 
using ConstNamedModuleApplyFunction = std::function< void(const std::string &, const Module &)>
 
using ModulePointerApplyFunction = std::function< void(const std::shared_ptr< Module > &)>
 
using NamedModulePointerApplyFunction = std::function< void(const std::string &, const std::shared_ptr< Module > &)>
 

Public Member Functions

 Module (std::string name)
 Tells the base Module about the name of the submodule.
 
 Module ()
 Constructs the module without immediate knowledge of the submodule's name. More...
 
const std::string & name () const noexcept
 Returns the name of the Module. More...
 
virtual std::shared_ptr< Moduleclone (const optional< Device > &device=nullopt) const
 Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules. More...
 
void apply (const ModuleApplyFunction &function)
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ConstModuleApplyFunction &function) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const NamedModuleApplyFunction &function, const std::string &name_prefix=std::string())
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ConstNamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ModulePointerApplyFunction &function) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const NamedModulePointerApplyFunction &function, const std::string &name_prefix=std::string()) const
 Applies the function to the Module and recursively to every submodule. More...
 
std::vector< Tensorparameters (bool recurse=true) const
 Returns the parameters of this Module and if recurse is true, also recursively of every submodule. More...
 
OrderedDict< std::string, Tensornamed_parameters (bool recurse=true) const
 Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule. More...
 
std::vector< Tensorbuffers (bool recurse=true) const
 Returns the buffers of this Module and if recurse is true, also recursively of every submodule. More...
 
OrderedDict< std::string, Tensornamed_buffers (bool recurse=true) const
 Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule. More...
 
std::vector< std::shared_ptr< Module > > modules (bool include_self=true) const
 Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position. More...
 
OrderedDict< std::string, std::shared_ptr< Module > > named_modules (const std::string &name_prefix=std::string(), bool include_self=true) const
 Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys, and if include_self is true, also inserts a shared_ptr to this module in the first position. More...
 
std::vector< std::shared_ptr< Module > > children () const
 Returns the direct submodules of this Module.
 
OrderedDict< std::string, std::shared_ptr< Module > > named_children () const
 Returns an OrderedDict of the direct submodules of this Module and their keys. More...
 
virtual void train (bool on=true)
 Enables "training" mode.
 
void eval ()
 Calls train(false) to enable "eval" mode. More...
 
virtual bool is_training () const noexcept
 True if the module is in training mode. More...
 
virtual void to (torch::Device device, torch::Dtype dtype, bool non_blocking=false)
 Recursively casts all parameters to the given dtype and device. More...
 
virtual void to (torch::Dtype dtype, bool non_blocking=false)
 Recursively casts all parameters to the given dtype. More...
 
virtual void to (torch::Device device, bool non_blocking=false)
 Recursively moves all parameters to the given device. More...
 
virtual void zero_grad ()
 Recursively zeros out the grad value of each registered parameter.
 
template<typename ModuleType >
ModuleType::ContainedType * as () noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType >
const ModuleType::ContainedType * as () const noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType * as () noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>>
const ModuleType * as () const noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
virtual void save (serialize::OutputArchive &archive) const
 Serializes the Module into the given OutputArchive.
 
virtual void load (serialize::InputArchive &archive)
 Deserializes the Module from the given InputArchive.
 
virtual void pretty_print (std::ostream &stream) const
 Streams a pretty representation of the Module into the given stream. More...
 

Protected Member Functions

Tensorregister_parameter (std::string name, Tensor tensor, bool requires_grad=true)
 Registers a parameter with this Module. More...
 
Tensorregister_buffer (std::string name, Tensor tensor)
 Registers a buffer with this Module. More...
 
template<typename ModuleType >
std::shared_ptr< ModuleType > register_module (std::string name, std::shared_ptr< ModuleType > module)
 Registers a submodule with this Module. More...
 
template<typename ModuleType >
std::shared_ptr< ModuleType > register_module (std::string name, ModuleHolder< ModuleType > module_holder)
 Registers a submodule with this Module. More...
 

Friends

template<typename Derived >
class Cloneable
 
TORCH_API friend std::ostream & operator<< (std::ostream &stream, const nn::Module &module)
 Pretty prints the given Module into the ostream.
 

Detailed Description

The base class for all modules in PyTorch.

.. note:: The design and implementation of this class is largely based on the Python API. You may want to consult the python documentation for :py:class:pytorch:torch.nn.Module for further clarification on certain methods or behavior.

A Module is an abstraction over the implementation of some function or algorithm, possibly associated with some persistent data. A Module may contain further Modules ("submodules"), each with their own implementation, persistent data and further submodules. Modules can thus be said to form a recursive tree structure. A Module is registered as a submodule to another Module by calling register_module(), typically from within a parent module's constructor.

A distinction is made between three kinds of persistent data that may be associated with a Module:

  1. Parameters: tensors that record gradients, typically weights updated during the backward step (e.g. the weight of a Linear module),
  2. Buffers: tensors that do not record gradients, typically updated during the forward step, such as running statistics (e.g. mean and variance in the BatchNorm module),
  3. Any additional state, not necessarily tensors, required for the implementation or configuration of a Module.

The first two kinds of state are special in that they may be registered with the Module system to allow convenient access and batch configuration. For example, registered parameters in any Module may be iterated over via the parameters() accessor. Further, changing the data type of a Module's registered parameters can be done conveniently via Module::to(), e.g. module->to(torch::kCUDA) to move all parameters to GPU memory. Lastly, registered parameters and buffers are handled specially during a clone() operation, which performs a deepcopy of a cloneable Module hierarchy.

Parameters are registered with a Module via register_parameter. Buffers are registered separately via register_buffer. These methods are part of the protected API of Module and are typically invoked from within a concrete Modules constructor.

Definition at line 62 of file module.h.

Constructor & Destructor Documentation

torch::nn::Module::Module ( )

Constructs the module without immediate knowledge of the submodule's name.

The name of the submodule is inferred via RTTI (if possible) the first time .name() is invoked.

Definition at line 46 of file module.cpp.

Member Function Documentation

void torch::nn::Module::apply ( const ModuleApplyFunction &  function)

Applies the function to the Module and recursively to every submodule.

The function must accept a Module&.

.. code-block:: cpp MyModule module; module->apply([](nn::Module& module) { std::cout << module.name() << std::endl; });

Definition at line 87 of file module.cpp.

void torch::nn::Module::apply ( const ConstModuleApplyFunction &  function) const

Applies the function to the Module and recursively to every submodule.

The function must accept a const Module&.

.. code-block:: cpp MyModule module; module->apply([](const nn::Module& module) { std::cout << module.name() << std::endl; });

Definition at line 95 of file module.cpp.

void torch::nn::Module::apply ( const NamedModuleApplyFunction &  function,
const std::string &  name_prefix = std::string() 
)

Applies the function to the Module and recursively to every submodule.

The function must accept a const std::string& for the key of the module, and a Module&. The key of the module itself is the empty string. If name_prefix is given, it is prepended to every key as <name_prefix>.<key> (and just name_prefix for the module itself).

.. code-block:: cpp MyModule module; module->apply([](const std::string& key, nn::Module& module) { std::cout << key << ": " << module.name() << std::endl; });

Definition at line 103 of file module.cpp.

void torch::nn::Module::apply ( const ConstNamedModuleApplyFunction &  function,
const std::string &  name_prefix = std::string() 
) const

Applies the function to the Module and recursively to every submodule.

The function must accept a const std::string& for the key of the module, and a const Module&. The key of the module itself is the empty string. If name_prefix is given, it is prepended to every key as <name_prefix>.<key> (and just name_prefix for the module itself).

.. code-block:: cpp MyModule module; module->apply([](const std::string& key, const nn::Module& module) { std::cout << key << ": " << module.name() << std::endl; });

Definition at line 115 of file module.cpp.

void torch::nn::Module::apply ( const ModulePointerApplyFunction &  function) const

Applies the function to the Module and recursively to every submodule.

The function must accept a const std::shared_ptr<Module>&.

.. code-block:: cpp MyModule module; module->apply([](const std::shared_ptr<nn::Module>& module) { std::cout << module->name() << std::endl; });

Definition at line 127 of file module.cpp.

void torch::nn::Module::apply ( const NamedModulePointerApplyFunction &  function,
const std::string &  name_prefix = std::string() 
) const

Applies the function to the Module and recursively to every submodule.

The function must accept a const std::string& for the key of the module, and a const std::shared_ptr<Module>&. The key of the module itself is the empty string. If name_prefix is given, it is prepended to every key as <name_prefix>.<key> (and just name_prefix for the module itself).

.. code-block:: cpp MyModule module; module->apply([](const std::string& key, const std::shared_ptr<nn::Module>& module) { std::cout << key << ": " << module->name() << std::endl; });

Definition at line 135 of file module.cpp.

template<typename ModuleType >
ModuleType::ContainedType * torch::nn::Module::as ( )
noexcept

Attempts to cast this Module to the given ModuleType.

This method is useful when calling apply(). .. code-block:: cpp

void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } }

MyModule module; module->apply(initialize_weights);

Definition at line 532 of file module.h.

template<typename ModuleType >
const ModuleType::ContainedType * torch::nn::Module::as ( ) const
noexcept

Attempts to cast this Module to the given ModuleType.

This method is useful when calling apply(). .. code-block:: cpp void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } }

MyModule module; module->apply(initialize_weights);

Definition at line 539 of file module.h.

template<typename ModuleType , typename >
ModuleType * torch::nn::Module::as ( )
noexcept

Attempts to cast this Module to the given ModuleType.

This method is useful when calling apply(). .. code-block:: cpp

void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } }

MyModule module; module.apply(initialize_weights);

Definition at line 546 of file module.h.

template<typename ModuleType , typename >
const ModuleType * torch::nn::Module::as ( ) const
noexcept

Attempts to cast this Module to the given ModuleType.

This method is useful when calling apply(). .. code-block:: cpp

void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } }

MyModule module; module.apply(initialize_weights);

Definition at line 551 of file module.h.

std::vector< Tensor > torch::nn::Module::buffers ( bool  recurse = true) const

Returns the buffers of this Module and if recurse is true, also recursively of every submodule.

Definition at line 166 of file module.cpp.

std::shared_ptr< Module > torch::nn::Module::clone ( const optional< Device > &  device = nullopt) const
virtual

Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules.

Optionally, this method sets the current device to the one supplied before cloning. If no device is given, each parameter and buffer will be moved to the device of its source.

.. attention:: Attempting to call the clone() method inherited from the base Module class (the one documented here) will fail. To inherit an actual implementation of clone(), you must subclass Cloneable. Cloneable is templatized on the concrete module type, and can thus properly copy a Module. This method is provided on the base class' API solely for an easier-to-use polymorphic interface.

Reimplemented in torch::nn::SequentialImpl, torch::nn::Cloneable< Derived >, torch::nn::Cloneable< BatchNormImpl >, torch::nn::Cloneable< SimpleContainer >, torch::nn::Cloneable< SequentialImpl >, torch::nn::Cloneable< Conv2dImpl >, torch::nn::Cloneable< RNNImpl >, torch::nn::Cloneable< Conv3dImpl >, torch::nn::Cloneable< LSTMImpl >, torch::nn::Cloneable< Net >, torch::nn::Cloneable< EmbeddingImpl >, torch::nn::Cloneable< LinearImpl >, torch::nn::Cloneable< DropoutImpl >, torch::nn::Cloneable< FeatureDropoutImpl >, torch::nn::Cloneable< Conv1dImpl >, and torch::nn::Cloneable< GRUImpl >.

Definition at line 78 of file module.cpp.

void torch::nn::Module::eval ( )

Calls train(false) to enable "eval" mode.

Do not override this method, override train() instead.

Definition at line 240 of file module.cpp.

bool torch::nn::Module::is_training ( ) const
virtualnoexcept

True if the module is in training mode.

Every Module has a boolean associated with it that determines whether the Module is currently in training mode (set via .train()) or in evaluation (inference) mode (set via .eval()). This property is exposed via is_training(), and may be used by the implementation of a concrete module to modify its runtime behavior. See the BatchNorm or Dropout modules for examples of Modules that use different code paths depending on this property.

Definition at line 256 of file module.cpp.

std::vector< std::shared_ptr< Module > > torch::nn::Module::modules ( bool  include_self = true) const

Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position.

.. warning:: Only pass include_self as true if this Module is stored in a shared_ptr! Otherwise an exception will be thrown. You may still call this method with include_self set to false if your Module is not stored in a shared_ptr.

Definition at line 187 of file module.cpp.

const std::string & torch::nn::Module::name ( ) const
noexcept

Returns the name of the Module.

A Module has an associated name, which is a string representation of the kind of concrete Module it represents, such as "Linear" for the Linear module. Under most circumstances, this name is automatically inferred via runtime type information (RTTI). In the unusual circumstance that you have this feature disabled, you may want to manually name your Modules by passing the string name to the Module base class' constructor.

Definition at line 53 of file module.cpp.

OrderedDict< std::string, Tensor > torch::nn::Module::named_buffers ( bool  recurse = true) const

Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule.

Definition at line 174 of file module.cpp.

OrderedDict< std::string, std::shared_ptr< Module > > torch::nn::Module::named_children ( ) const

Returns an OrderedDict of the direct submodules of this Module and their keys.

Definition at line 228 of file module.cpp.

OrderedDict< std::string, std::shared_ptr< Module > > torch::nn::Module::named_modules ( const std::string &  name_prefix = std::string(),
bool  include_self = true 
) const

Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys, and if include_self is true, also inserts a shared_ptr to this module in the first position.

If name_prefix is given, it is prepended to every key as <name_prefix>.<key> (and just name_prefix for the module itself).

.. warning:: Only pass include_self as true if this Module is stored in a shared_ptr! Otherwise an exception will be thrown. You may still call this method with include_self set to false if your Module is not stored in a shared_ptr.

Definition at line 202 of file module.cpp.

OrderedDict< std::string, Tensor > torch::nn::Module::named_parameters ( bool  recurse = true) const

Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule.

Definition at line 153 of file module.cpp.

std::vector< Tensor > torch::nn::Module::parameters ( bool  recurse = true) const

Returns the parameters of this Module and if recurse is true, also recursively of every submodule.

Definition at line 143 of file module.cpp.

void torch::nn::Module::pretty_print ( std::ostream &  stream) const
virtual

Streams a pretty representation of the Module into the given stream.

By default, this representation will be the name of the module (taken from name()), followed by a recursive pretty print of all of the Module's submodules.

Override this method to change the pretty print. The input stream should be returned from the method, to allow easy chaining.

Reimplemented in torch::nn::RNNImpl, torch::nn::SequentialImpl, torch::nn::ConvImpl< D, Derived >, torch::nn::ConvImpl< 1, Conv1dImpl >, torch::nn::ConvImpl< 3, Conv3dImpl >, torch::nn::ConvImpl< 2, Conv2dImpl >, torch::nn::detail::RNNImplBase< Derived >, torch::nn::detail::RNNImplBase< RNNImpl >, torch::nn::detail::RNNImplBase< LSTMImpl >, torch::nn::detail::RNNImplBase< GRUImpl >, torch::nn::FeatureDropoutImpl, torch::nn::BatchNormImpl, torch::nn::DropoutImpl, torch::nn::LinearImpl, and torch::nn::EmbeddingImpl.

Definition at line 325 of file module.cpp.

Tensor & torch::nn::Module::register_buffer ( std::string  name,
Tensor  tensor 
)
protected

Registers a buffer with this Module.

A buffer is intended to be state in your module that does not record gradients, such as running statistics. Registering it makes it available to methods such as buffers(), clone() or `to().

.. code-block:: cpp

MyModule::MyModule() { mean_ = register_buffer("mean", torch::empty({num_features_})); }

Definition at line 315 of file module.cpp.

template<typename ModuleType >
std::shared_ptr< ModuleType > torch::nn::Module::register_module ( std::string  name,
std::shared_ptr< ModuleType >  module 
)
protected

Registers a submodule with this Module.

Registering a module makes it available to methods such as modules(), clone() or to().

.. code-block:: cpp

MyModule::MyModule() { submodule_ = register_module("linear", torch::nn::Linear(3, 4)); }

Definition at line 556 of file module.h.

template<typename ModuleType >
std::shared_ptr< ModuleType > torch::nn::Module::register_module ( std::string  name,
ModuleHolder< ModuleType >  module_holder 
)
protected

Registers a submodule with this Module.

This method deals with ModuleHolders.

Registering a module makes it available to methods such as modules(), clone() or to().

.. code-block:: cpp

MyModule::MyModule() { submodule_ = register_module("linear", torch::nn::Linear(3, 4)); }

Definition at line 570 of file module.h.

Tensor & torch::nn::Module::register_parameter ( std::string  name,
Tensor  tensor,
bool  requires_grad = true 
)
protected

Registers a parameter with this Module.

A parameter should be any gradient-recording tensor used in the implementation of your Module. Registering it makes it available to methods such as parameters(), clone() or to().

.. code-block:: cpp

MyModule::MyModule() { weight_ = register_parameter("weight", torch::randn({A, B})); }

Definition at line 301 of file module.cpp.

void torch::nn::Module::to ( torch::Device  device,
torch::Dtype  dtype,
bool  non_blocking = false 
)
virtual

Recursively casts all parameters to the given dtype and device.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

Reimplemented in torch::nn::detail::RNNImplBase< Derived >, torch::nn::detail::RNNImplBase< RNNImpl >, torch::nn::detail::RNNImplBase< LSTMImpl >, and torch::nn::detail::RNNImplBase< GRUImpl >.

Definition at line 244 of file module.cpp.

void torch::nn::Module::to ( torch::Dtype  dtype,
bool  non_blocking = false 
)
virtual

Recursively casts all parameters to the given dtype.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

Reimplemented in torch::nn::detail::RNNImplBase< Derived >, torch::nn::detail::RNNImplBase< RNNImpl >, torch::nn::detail::RNNImplBase< LSTMImpl >, and torch::nn::detail::RNNImplBase< GRUImpl >.

Definition at line 248 of file module.cpp.

void torch::nn::Module::to ( torch::Device  device,
bool  non_blocking = false 
)
virtual

Recursively moves all parameters to the given device.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

Reimplemented in torch::nn::detail::RNNImplBase< Derived >, torch::nn::detail::RNNImplBase< RNNImpl >, torch::nn::detail::RNNImplBase< LSTMImpl >, and torch::nn::detail::RNNImplBase< GRUImpl >.

Definition at line 252 of file module.cpp.


The documentation for this class was generated from the following files: