Caffe2 - C++ API
A deep learning, cross platform ML framework
Public Member Functions
torch::nn::GRUImpl Class Reference

A multi-layer gated recurrent unit (GRU) module. More...

#include <rnn.h>

Inheritance diagram for torch::nn::GRUImpl:
torch::nn::detail::RNNImplBase< GRUImpl > torch::nn::Cloneable< GRUImpl > torch::nn::Module

Public Member Functions

 GRUImpl (int64_t input_size, int64_t hidden_size)
 
 GRUImpl (const GRUOptions &options)
 
RNNOutput forward (const Tensor &input, Tensor state={})
 Applies the GRU module to an input sequence and input state. More...
 
- Public Member Functions inherited from torch::nn::detail::RNNImplBase< GRUImpl >
 RNNImplBase (const RNNOptionsBase &options_, optional< CuDNNMode > cudnn_mode=nullopt, int64_t number_of_gates=1)
 
void reset () override
 Initializes the parameters of the RNN module.
 
void to (torch::Device device, torch::Dtype dtype, bool non_blocking=false) override
 Overrides nn::Module::to() to call flatten_parameters() after the original operation. More...
 
void to (torch::Dtype dtype, bool non_blocking=false) override
 Recursively casts all parameters to the given dtype. More...
 
void to (torch::Device device, bool non_blocking=false) override
 Recursively moves all parameters to the given device. More...
 
void pretty_print (std::ostream &stream) const override
 Pretty prints the RNN module into the given stream.
 
void flatten_parameters ()
 Modifies the internal storage of weights for optimization purposes. More...
 
- Public Member Functions inherited from torch::nn::Cloneable< GRUImpl >
std::shared_ptr< Moduleclone (const optional< Device > &device=nullopt) const override
 Performs a recursive "deep copy" of the Module, such that all parameters and submodules in the cloned module are different from those in the original module. More...
 
- Public Member Functions inherited from torch::nn::Module
 Module (std::string name)
 Tells the base Module about the name of the submodule.
 
 Module ()
 Constructs the module without immediate knowledge of the submodule's name. More...
 
const std::string & name () const noexcept
 Returns the name of the Module. More...
 
void apply (const ModuleApplyFunction &function)
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ConstModuleApplyFunction &function) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const NamedModuleApplyFunction &function, const std::string &name_prefix=std::string())
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ConstNamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ModulePointerApplyFunction &function) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const NamedModulePointerApplyFunction &function, const std::string &name_prefix=std::string()) const
 Applies the function to the Module and recursively to every submodule. More...
 
std::vector< Tensorparameters (bool recurse=true) const
 Returns the parameters of this Module and if recurse is true, also recursively of every submodule. More...
 
OrderedDict< std::string, Tensornamed_parameters (bool recurse=true) const
 Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule. More...
 
std::vector< Tensorbuffers (bool recurse=true) const
 Returns the buffers of this Module and if recurse is true, also recursively of every submodule. More...
 
OrderedDict< std::string, Tensornamed_buffers (bool recurse=true) const
 Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule. More...
 
std::vector< std::shared_ptr< Module > > modules (bool include_self=true) const
 Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position. More...
 
OrderedDict< std::string, std::shared_ptr< Module > > named_modules (const std::string &name_prefix=std::string(), bool include_self=true) const
 Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys, and if include_self is true, also inserts a shared_ptr to this module in the first position. More...
 
std::vector< std::shared_ptr< Module > > children () const
 Returns the direct submodules of this Module.
 
OrderedDict< std::string, std::shared_ptr< Module > > named_children () const
 Returns an OrderedDict of the direct submodules of this Module and their keys. More...
 
virtual void train (bool on=true)
 Enables "training" mode.
 
void eval ()
 Calls train(false) to enable "eval" mode. More...
 
virtual bool is_training () const noexcept
 True if the module is in training mode. More...
 
virtual void zero_grad ()
 Recursively zeros out the grad value of each registered parameter.
 
template<typename ModuleType >
ModuleType::ContainedType * as () noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType >
const ModuleType::ContainedType * as () const noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType * as () noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>>
const ModuleType * as () const noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
virtual void save (serialize::OutputArchive &archive) const
 Serializes the Module into the given OutputArchive.
 
virtual void load (serialize::InputArchive &archive)
 Deserializes the Module from the given InputArchive.
 

Additional Inherited Members

- Public Types inherited from torch::nn::detail::RNNImplBase< GRUImpl >
enum  CuDNNMode
 These must line up with the CUDNN mode codes: https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t.
 
- Public Types inherited from torch::nn::Module
using ModuleApplyFunction = std::function< void(Module &)>
 
using ConstModuleApplyFunction = std::function< void(const Module &)>
 
using NamedModuleApplyFunction = std::function< void(const std::string &, Module &)>
 
using ConstNamedModuleApplyFunction = std::function< void(const std::string &, const Module &)>
 
using ModulePointerApplyFunction = std::function< void(const std::shared_ptr< Module > &)>
 
using NamedModulePointerApplyFunction = std::function< void(const std::string &, const std::shared_ptr< Module > &)>
 
- Data Fields inherited from torch::nn::detail::RNNImplBase< GRUImpl >
RNNOptionsBase options
 The RNN's options.
 
std::vector< Tensorw_ih
 The weights for input x hidden gates.
 
std::vector< Tensorw_hh
 The weights for hidden x hidden gates.
 
std::vector< Tensorb_ih
 The biases for input x hidden gates.
 
std::vector< Tensorb_hh
 The biases for hidden x hidden gates.
 
- Protected Types inherited from torch::nn::detail::RNNImplBase< GRUImpl >
using RNNFunctionSignature = std::tuple< Tensor, Tensor >(const Tensor &, const Tensor &, TensorList, bool, int64_t, double, bool, bool, bool)
 The function signature of rnn_relu, rnn_tanh and gru.
 
- Protected Member Functions inherited from torch::nn::detail::RNNImplBase< GRUImpl >
RNNOutput generic_forward (std::function< RNNFunctionSignature > function, const Tensor &input, Tensor state)
 A generic forward() used for RNN and GRU (but not LSTM!). More...
 
std::vector< Tensorflat_weights () const
 Returns a flat vector of all weights, with layer weights following each other sequentially in (w_ih, w_hh, b_ih, b_hh) order. More...
 
bool any_parameters_alias () const
 Very simple check if any of the parameters (weights, biases) are the same.
 
- Protected Member Functions inherited from torch::nn::Module
Tensorregister_parameter (std::string name, Tensor tensor, bool requires_grad=true)
 Registers a parameter with this Module. More...
 
Tensorregister_buffer (std::string name, Tensor tensor)
 Registers a buffer with this Module. More...
 
template<typename ModuleType >
std::shared_ptr< ModuleType > register_module (std::string name, std::shared_ptr< ModuleType > module)
 Registers a submodule with this Module. More...
 
template<typename ModuleType >
std::shared_ptr< ModuleType > register_module (std::string name, ModuleHolder< ModuleType > module_holder)
 Registers a submodule with this Module. More...
 
- Protected Attributes inherited from torch::nn::detail::RNNImplBase< GRUImpl >
int64_t number_of_gates_
 The number of gate weights/biases required by the RNN subclass.
 
optional< CuDNNModecudnn_mode_
 The cuDNN RNN mode, if this RNN subclass has any.
 
std::vector< Tensorflat_weights_
 The cached result of the latest flat_weights() call.
 

Detailed Description

A multi-layer gated recurrent unit (GRU) module.

See https://pytorch.org/docs/master/nn.html#torch.nn.GRU to learn about the exact behavior of this module.

Definition at line 234 of file rnn.h.

Member Function Documentation

RNNOutput torch::nn::GRUImpl::forward ( const Tensor input,
Tensor  state = {} 
)

Applies the GRU module to an input sequence and input state.

The input should follow a (sequence, batch, features) layout unless batch_first is true, in which case the layout should be (batch, sequence, features).

Definition at line 286 of file rnn.cpp.


The documentation for this class was generated from the following files: