Caffe2 - C++ API
A deep learning, cross platform ML framework
Public Types | Public Member Functions | Data Fields | Protected Types | Protected Member Functions | Protected Attributes
torch::nn::detail::RNNImplBase< Derived > Class Template Reference

Base class for all RNN implementations (intended for code sharing). More...

#include <rnn.h>

Inheritance diagram for torch::nn::detail::RNNImplBase< Derived >:
torch::nn::Cloneable< Derived > torch::nn::Module

Public Types

enum  CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 }
 These must line up with the CUDNN mode codes: https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t.
 
- Public Types inherited from torch::nn::Module
using ModuleApplyFunction = std::function< void(Module &)>
 
using ConstModuleApplyFunction = std::function< void(const Module &)>
 
using NamedModuleApplyFunction = std::function< void(const std::string &, Module &)>
 
using ConstNamedModuleApplyFunction = std::function< void(const std::string &, const Module &)>
 
using ModulePointerApplyFunction = std::function< void(const std::shared_ptr< Module > &)>
 
using NamedModulePointerApplyFunction = std::function< void(const std::string &, const std::shared_ptr< Module > &)>
 

Public Member Functions

 RNNImplBase (const RNNOptionsBase &options_, optional< CuDNNMode > cudnn_mode=nullopt, int64_t number_of_gates=1)
 
void reset () override
 Initializes the parameters of the RNN module.
 
void to (torch::Device device, torch::Dtype dtype, bool non_blocking=false) override
 Overrides nn::Module::to() to call flatten_parameters() after the original operation. More...
 
void to (torch::Dtype dtype, bool non_blocking=false) override
 Recursively casts all parameters to the given dtype. More...
 
void to (torch::Device device, bool non_blocking=false) override
 Recursively moves all parameters to the given device. More...
 
void pretty_print (std::ostream &stream) const override
 Pretty prints the RNN module into the given stream.
 
void flatten_parameters ()
 Modifies the internal storage of weights for optimization purposes. More...
 
- Public Member Functions inherited from torch::nn::Cloneable< Derived >
std::shared_ptr< Moduleclone (const optional< Device > &device=nullopt) const override
 Performs a recursive "deep copy" of the Module, such that all parameters and submodules in the cloned module are different from those in the original module. More...
 
- Public Member Functions inherited from torch::nn::Module
 Module (std::string name)
 Tells the base Module about the name of the submodule.
 
 Module ()
 Constructs the module without immediate knowledge of the submodule's name. More...
 
const std::string & name () const noexcept
 Returns the name of the Module. More...
 
void apply (const ModuleApplyFunction &function)
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ConstModuleApplyFunction &function) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const NamedModuleApplyFunction &function, const std::string &name_prefix=std::string())
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ConstNamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ModulePointerApplyFunction &function) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const NamedModulePointerApplyFunction &function, const std::string &name_prefix=std::string()) const
 Applies the function to the Module and recursively to every submodule. More...
 
std::vector< Tensorparameters (bool recurse=true) const
 Returns the parameters of this Module and if recurse is true, also recursively of every submodule. More...
 
OrderedDict< std::string, Tensornamed_parameters (bool recurse=true) const
 Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule. More...
 
std::vector< Tensorbuffers (bool recurse=true) const
 Returns the buffers of this Module and if recurse is true, also recursively of every submodule. More...
 
OrderedDict< std::string, Tensornamed_buffers (bool recurse=true) const
 Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule. More...
 
std::vector< std::shared_ptr< Module > > modules (bool include_self=true) const
 Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position. More...
 
OrderedDict< std::string, std::shared_ptr< Module > > named_modules (const std::string &name_prefix=std::string(), bool include_self=true) const
 Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys, and if include_self is true, also inserts a shared_ptr to this module in the first position. More...
 
std::vector< std::shared_ptr< Module > > children () const
 Returns the direct submodules of this Module.
 
OrderedDict< std::string, std::shared_ptr< Module > > named_children () const
 Returns an OrderedDict of the direct submodules of this Module and their keys. More...
 
virtual void train (bool on=true)
 Enables "training" mode.
 
void eval ()
 Calls train(false) to enable "eval" mode. More...
 
virtual bool is_training () const noexcept
 True if the module is in training mode. More...
 
virtual void zero_grad ()
 Recursively zeros out the grad value of each registered parameter.
 
template<typename ModuleType >
ModuleType::ContainedType * as () noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType >
const ModuleType::ContainedType * as () const noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType * as () noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>>
const ModuleType * as () const noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
virtual void save (serialize::OutputArchive &archive) const
 Serializes the Module into the given OutputArchive.
 
virtual void load (serialize::InputArchive &archive)
 Deserializes the Module from the given InputArchive.
 

Data Fields

RNNOptionsBase options
 The RNN's options.
 
std::vector< Tensorw_ih
 The weights for input x hidden gates.
 
std::vector< Tensorw_hh
 The weights for hidden x hidden gates.
 
std::vector< Tensorb_ih
 The biases for input x hidden gates.
 
std::vector< Tensorb_hh
 The biases for hidden x hidden gates.
 

Protected Types

using RNNFunctionSignature = std::tuple< Tensor, Tensor >(const Tensor &, const Tensor &, TensorList, bool, int64_t, double, bool, bool, bool)
 The function signature of rnn_relu, rnn_tanh and gru.
 

Protected Member Functions

RNNOutput generic_forward (std::function< RNNFunctionSignature > function, const Tensor &input, Tensor state)
 A generic forward() used for RNN and GRU (but not LSTM!). More...
 
std::vector< Tensorflat_weights () const
 Returns a flat vector of all weights, with layer weights following each other sequentially in (w_ih, w_hh, b_ih, b_hh) order. More...
 
bool any_parameters_alias () const
 Very simple check if any of the parameters (weights, biases) are the same.
 
- Protected Member Functions inherited from torch::nn::Module
Tensorregister_parameter (std::string name, Tensor tensor, bool requires_grad=true)
 Registers a parameter with this Module. More...
 
Tensorregister_buffer (std::string name, Tensor tensor)
 Registers a buffer with this Module. More...
 
template<typename ModuleType >
std::shared_ptr< ModuleType > register_module (std::string name, std::shared_ptr< ModuleType > module)
 Registers a submodule with this Module. More...
 
template<typename ModuleType >
std::shared_ptr< ModuleType > register_module (std::string name, ModuleHolder< ModuleType > module_holder)
 Registers a submodule with this Module. More...
 

Protected Attributes

int64_t number_of_gates_
 The number of gate weights/biases required by the RNN subclass.
 
optional< CuDNNModecudnn_mode_
 The cuDNN RNN mode, if this RNN subclass has any.
 
std::vector< Tensorflat_weights_
 The cached result of the latest flat_weights() call.
 

Detailed Description

template<typename Derived>
class torch::nn::detail::RNNImplBase< Derived >

Base class for all RNN implementations (intended for code sharing).

Definition at line 56 of file rnn.h.

Member Function Documentation

template<typename Derived >
std::vector< Tensor > torch::nn::detail::RNNImplBase< Derived >::flat_weights ( ) const
protected

Returns a flat vector of all weights, with layer weights following each other sequentially in (w_ih, w_hh, b_ih, b_hh) order.

Definition at line 157 of file rnn.cpp.

template<typename Derived >
void torch::nn::detail::RNNImplBase< Derived >::flatten_parameters ( )

Modifies the internal storage of weights for optimization purposes.

On CPU, this method should be called if any of the weight or bias vectors are changed (i.e. weights are added or removed). On GPU, it should be called any time the storage of any parameter is modified, e.g. any time a parameter is assigned a new value. This allows using the fast path in cuDNN implementations of respective RNN forward() methods. It is called once upon construction, inside reset().

Definition at line 111 of file rnn.cpp.

template<typename Derived >
RNNOutput torch::nn::detail::RNNImplBase< Derived >::generic_forward ( std::function< RNNFunctionSignature function,
const Tensor input,
Tensor  state 
)
protected

A generic forward() used for RNN and GRU (but not LSTM!).

Takes the ATen RNN function as first argument.

Definition at line 132 of file rnn.cpp.

template<typename Derived >
void torch::nn::detail::RNNImplBase< Derived >::to ( torch::Device  device,
torch::Dtype  dtype,
bool  non_blocking = false 
)
overridevirtual

Overrides nn::Module::to() to call flatten_parameters() after the original operation.

Reimplemented from torch::nn::Module.

Definition at line 80 of file rnn.cpp.

template<typename Derived >
void torch::nn::detail::RNNImplBase< Derived >::to ( torch::Dtype  dtype,
bool  non_blocking = false 
)
overridevirtual

Recursively casts all parameters to the given dtype.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

Reimplemented from torch::nn::Module.

Definition at line 89 of file rnn.cpp.

template<typename Derived >
void torch::nn::detail::RNNImplBase< Derived >::to ( torch::Device  device,
bool  non_blocking = false 
)
overridevirtual

Recursively moves all parameters to the given device.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

Reimplemented from torch::nn::Module.

Definition at line 95 of file rnn.cpp.


The documentation for this class was generated from the following files: