A multi-layer gated recurrent unit (GRU) module. More...
#include <rnn.h>
Public Member Functions | |
GRUImpl (int64_t input_size, int64_t hidden_size) | |
GRUImpl (const GRUOptions &options) | |
RNNOutput | forward (const Tensor &input, Tensor state={}) |
Applies the GRU module to an input sequence and input state. More... | |
Public Member Functions inherited from torch::nn::detail::RNNImplBase< GRUImpl > | |
RNNImplBase (const RNNOptionsBase &options_, optional< CuDNNMode > cudnn_mode=nullopt, int64_t number_of_gates=1) | |
void | reset () override |
Initializes the parameters of the RNN module. | |
void | to (torch::Device device, torch::Dtype dtype, bool non_blocking=false) override |
Overrides nn::Module::to() to call flatten_parameters() after the original operation. More... | |
void | to (torch::Dtype dtype, bool non_blocking=false) override |
Recursively casts all parameters to the given dtype. More... | |
void | to (torch::Device device, bool non_blocking=false) override |
Recursively moves all parameters to the given device. More... | |
void | pretty_print (std::ostream &stream) const override |
Pretty prints the RNN module into the given stream . | |
void | flatten_parameters () |
Modifies the internal storage of weights for optimization purposes. More... | |
Public Member Functions inherited from torch::nn::Cloneable< GRUImpl > | |
std::shared_ptr< Module > | clone (const optional< Device > &device=nullopt) const override |
Performs a recursive "deep copy" of the Module , such that all parameters and submodules in the cloned module are different from those in the original module. More... | |
Public Member Functions inherited from torch::nn::Module | |
Module (std::string name) | |
Tells the base Module about the name of the submodule. | |
Module () | |
Constructs the module without immediate knowledge of the submodule's name. More... | |
const std::string & | name () const noexcept |
Returns the name of the Module . More... | |
void | apply (const ModuleApplyFunction &function) |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const ConstModuleApplyFunction &function) const |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const NamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const ConstNamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) const |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const ModulePointerApplyFunction &function) const |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const NamedModulePointerApplyFunction &function, const std::string &name_prefix=std::string()) const |
Applies the function to the Module and recursively to every submodule. More... | |
std::vector< Tensor > | parameters (bool recurse=true) const |
Returns the parameters of this Module and if recurse is true, also recursively of every submodule. More... | |
OrderedDict< std::string, Tensor > | named_parameters (bool recurse=true) const |
Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule. More... | |
std::vector< Tensor > | buffers (bool recurse=true) const |
Returns the buffers of this Module and if recurse is true, also recursively of every submodule. More... | |
OrderedDict< std::string, Tensor > | named_buffers (bool recurse=true) const |
Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule. More... | |
std::vector< std::shared_ptr< Module > > | modules (bool include_self=true) const |
Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position. More... | |
OrderedDict< std::string, std::shared_ptr< Module > > | named_modules (const std::string &name_prefix=std::string(), bool include_self=true) const |
Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys, and if include_self is true, also inserts a shared_ptr to this module in the first position. More... | |
std::vector< std::shared_ptr< Module > > | children () const |
Returns the direct submodules of this Module . | |
OrderedDict< std::string, std::shared_ptr< Module > > | named_children () const |
Returns an OrderedDict of the direct submodules of this Module and their keys. More... | |
virtual void | train (bool on=true) |
Enables "training" mode. | |
void | eval () |
Calls train(false) to enable "eval" mode. More... | |
virtual bool | is_training () const noexcept |
True if the module is in training mode. More... | |
virtual void | zero_grad () |
Recursively zeros out the grad value of each registered parameter. | |
template<typename ModuleType > | |
ModuleType::ContainedType * | as () noexcept |
Attempts to cast this Module to the given ModuleType . More... | |
template<typename ModuleType > | |
const ModuleType::ContainedType * | as () const noexcept |
Attempts to cast this Module to the given ModuleType . More... | |
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>> | |
ModuleType * | as () noexcept |
Attempts to cast this Module to the given ModuleType . More... | |
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>> | |
const ModuleType * | as () const noexcept |
Attempts to cast this Module to the given ModuleType . More... | |
virtual void | save (serialize::OutputArchive &archive) const |
Serializes the Module into the given OutputArchive . | |
virtual void | load (serialize::InputArchive &archive) |
Deserializes the Module from the given InputArchive . | |
Additional Inherited Members | |
Public Types inherited from torch::nn::detail::RNNImplBase< GRUImpl > | |
enum | CuDNNMode |
These must line up with the CUDNN mode codes: https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t. | |
Public Types inherited from torch::nn::Module | |
using | ModuleApplyFunction = std::function< void(Module &)> |
using | ConstModuleApplyFunction = std::function< void(const Module &)> |
using | NamedModuleApplyFunction = std::function< void(const std::string &, Module &)> |
using | ConstNamedModuleApplyFunction = std::function< void(const std::string &, const Module &)> |
using | ModulePointerApplyFunction = std::function< void(const std::shared_ptr< Module > &)> |
using | NamedModulePointerApplyFunction = std::function< void(const std::string &, const std::shared_ptr< Module > &)> |
Data Fields inherited from torch::nn::detail::RNNImplBase< GRUImpl > | |
RNNOptionsBase | options |
The RNN's options. | |
std::vector< Tensor > | w_ih |
The weights for input x hidden gates. | |
std::vector< Tensor > | w_hh |
The weights for hidden x hidden gates. | |
std::vector< Tensor > | b_ih |
The biases for input x hidden gates. | |
std::vector< Tensor > | b_hh |
The biases for hidden x hidden gates. | |
Protected Types inherited from torch::nn::detail::RNNImplBase< GRUImpl > | |
using | RNNFunctionSignature = std::tuple< Tensor, Tensor >(const Tensor &, const Tensor &, TensorList, bool, int64_t, double, bool, bool, bool) |
The function signature of rnn_relu , rnn_tanh and gru . | |
Protected Member Functions inherited from torch::nn::detail::RNNImplBase< GRUImpl > | |
RNNOutput | generic_forward (std::function< RNNFunctionSignature > function, const Tensor &input, Tensor state) |
A generic forward() used for RNN and GRU (but not LSTM!). More... | |
std::vector< Tensor > | flat_weights () const |
Returns a flat vector of all weights, with layer weights following each other sequentially in (w_ih, w_hh, b_ih, b_hh) order. More... | |
bool | any_parameters_alias () const |
Very simple check if any of the parameters (weights, biases) are the same. | |
Protected Member Functions inherited from torch::nn::Module | |
Tensor & | register_parameter (std::string name, Tensor tensor, bool requires_grad=true) |
Registers a parameter with this Module . More... | |
Tensor & | register_buffer (std::string name, Tensor tensor) |
Registers a buffer with this Module . More... | |
template<typename ModuleType > | |
std::shared_ptr< ModuleType > | register_module (std::string name, std::shared_ptr< ModuleType > module) |
Registers a submodule with this Module . More... | |
template<typename ModuleType > | |
std::shared_ptr< ModuleType > | register_module (std::string name, ModuleHolder< ModuleType > module_holder) |
Registers a submodule with this Module . More... | |
Protected Attributes inherited from torch::nn::detail::RNNImplBase< GRUImpl > | |
int64_t | number_of_gates_ |
The number of gate weights/biases required by the RNN subclass. | |
optional< CuDNNMode > | cudnn_mode_ |
The cuDNN RNN mode, if this RNN subclass has any. | |
std::vector< Tensor > | flat_weights_ |
The cached result of the latest flat_weights() call. | |
A multi-layer gated recurrent unit (GRU) module.
See https://pytorch.org/docs/master/nn.html#torch.nn.GRU to learn about the exact behavior of this module.