3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/modules/dropout.h> 5 #include <torch/nn/pimpl.h> 6 #include <torch/types.h> 9 #include <c10/util/Exception.h> 34 virtual ~RNNOptionsBase() =
default;
36 TORCH_ARG(int64_t, input_size);
38 TORCH_ARG(int64_t, hidden_size);
40 TORCH_ARG(int64_t, layers) = 1;
42 TORCH_ARG(
bool, with_bias) =
true;
45 TORCH_ARG(
double, dropout) = 0.0;
47 TORCH_ARG(
bool, bidirectional) =
false;
51 TORCH_ARG(
bool, batch_first) =
false;
55 template <
typename Derived>
60 enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
65 int64_t number_of_gates = 1);
68 void reset()
override;
72 void to(
torch::Device device, torch::Dtype dtype,
bool non_blocking =
false)
74 void to(torch::Dtype dtype,
bool non_blocking =
false)
override;
75 void to(
torch::Device device,
bool non_blocking =
false)
override;
78 void pretty_print(std::ostream& stream)
const override;
88 void flatten_parameters();
118 std::function<RNNFunctionSignature>
function,
124 std::vector<Tensor> flat_weights()
const;
127 bool any_parameters_alias()
const;
142 enum class RNNActivation : uint32_t {ReLU, Tanh};
146 RNNOptions(int64_t input_size, int64_t hidden_size);
154 TORCH_ARG(int64_t, input_size);
156 TORCH_ARG(int64_t, hidden_size);
158 TORCH_ARG(int64_t, layers) = 1;
160 TORCH_ARG(
bool, with_bias) =
true;
163 TORCH_ARG(
double, dropout) = 0.0;
165 TORCH_ARG(
bool, bidirectional) =
false;
169 TORCH_ARG(
bool, batch_first) =
false;
171 TORCH_ARG(RNNActivation, activation) = RNNActivation::ReLU;
179 RNNImpl(int64_t input_size, int64_t hidden_size)
184 void pretty_print(std::ostream& stream)
const override;
210 LSTMImpl(int64_t input_size, int64_t hidden_size)
236 GRUImpl(int64_t input_size, int64_t hidden_size)
optional< CuDNNMode > cudnn_mode_
The cuDNN RNN mode, if this RNN subclass has any.
int64_t number_of_gates_
The number of gate weights/biases required by the RNN subclass.
Tensor state
The new, updated state that can be fed into the RNN in the next forward step.
The output of a single invocation of an RNN module's forward() method.
std::vector< Tensor > w_hh
The weights for hidden x hidden gates.
Common options for LSTM and GRU modules.
CuDNNMode
These must line up with the CUDNN mode codes: https://docs.nvidia.com/deeplearning/sdk/cudnn-develope...
Represents a a compute device on which a tensor is located.
std::vector< Tensor > b_ih
The biases for input x hidden gates.
A multi-layer gated recurrent unit (GRU) module.
A multi-layer Elman RNN module with Tanh or ReLU activation.
std::vector< Tensor > w_ih
The weights for input x hidden gates.
A multi-layer long-short-term-memory (LSTM) module.
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Tensor output
The result of applying the specific RNN algorithm to the input tensor and input state.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
RNNOptionsBase options
The RNN's options.
std::tuple< Tensor, Tensor >(const Tensor &, const Tensor &, TensorList, bool, int64_t, double, bool, bool, bool) RNNFunctionSignature
The function signature of rnn_relu, rnn_tanh and gru.
Base class for all RNN implementations (intended for code sharing).
std::vector< Tensor > b_hh
The biases for hidden x hidden gates.
std::vector< Tensor > flat_weights_
The cached result of the latest flat_weights() call.