Caffe2 - C++ API
A deep learning, cross platform ML framework
rnn.h
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/modules/dropout.h>
5 #include <torch/nn/pimpl.h>
6 #include <torch/types.h>
7 
8 #include <ATen/ATen.h>
9 #include <c10/util/Exception.h>
10 
11 #include <cstddef>
12 #include <functional>
13 #include <memory>
14 #include <vector>
15 
16 namespace torch {
17 namespace nn {
18 
20 struct TORCH_API RNNOutput {
27 };
28 
29 namespace detail {
30 
32 struct TORCH_API RNNOptionsBase {
33  RNNOptionsBase(int64_t input_size, int64_t hidden_size);
34  virtual ~RNNOptionsBase() = default;
36  TORCH_ARG(int64_t, input_size);
38  TORCH_ARG(int64_t, hidden_size);
40  TORCH_ARG(int64_t, layers) = 1;
42  TORCH_ARG(bool, with_bias) = true;
45  TORCH_ARG(double, dropout) = 0.0;
47  TORCH_ARG(bool, bidirectional) = false;
51  TORCH_ARG(bool, batch_first) = false;
52 };
53 
55 template <typename Derived>
56 class RNNImplBase : public torch::nn::Cloneable<Derived> {
57  public:
60  enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
61 
62  explicit RNNImplBase(
63  const RNNOptionsBase& options_,
64  optional<CuDNNMode> cudnn_mode = nullopt,
65  int64_t number_of_gates = 1);
66 
68  void reset() override;
69 
72  void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)
73  override;
74  void to(torch::Dtype dtype, bool non_blocking = false) override;
75  void to(torch::Device device, bool non_blocking = false) override;
76 
78  void pretty_print(std::ostream& stream) const override;
79 
88  void flatten_parameters();
89 
92 
94  std::vector<Tensor> w_ih;
96  std::vector<Tensor> w_hh;
98  std::vector<Tensor> b_ih;
100  std::vector<Tensor> b_hh;
101 
102  protected:
104  using RNNFunctionSignature = std::tuple<Tensor, Tensor>(
105  /*input=*/const Tensor&,
106  /*state=*/const Tensor&,
107  /*params=*/TensorList,
108  /*has_biases=*/bool,
109  /*layers=*/int64_t,
110  /*dropout=*/double,
111  /*train=*/bool,
112  /*bidirectional=*/bool,
113  /*batch_first=*/bool);
114 
117  RNNOutput generic_forward(
118  std::function<RNNFunctionSignature> function,
119  const Tensor& input,
120  Tensor state);
121 
124  std::vector<Tensor> flat_weights() const;
125 
127  bool any_parameters_alias() const;
128 
131 
134 
136  std::vector<Tensor> flat_weights_;
137 };
138 } // namespace detail
139 
140 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
141 
142 enum class RNNActivation : uint32_t {ReLU, Tanh};
143 
145 struct TORCH_API RNNOptions {
146  RNNOptions(int64_t input_size, int64_t hidden_size);
147 
149  RNNOptions& tanh();
151  RNNOptions& relu();
152 
154  TORCH_ARG(int64_t, input_size);
156  TORCH_ARG(int64_t, hidden_size);
158  TORCH_ARG(int64_t, layers) = 1;
160  TORCH_ARG(bool, with_bias) = true;
163  TORCH_ARG(double, dropout) = 0.0;
165  TORCH_ARG(bool, bidirectional) = false;
169  TORCH_ARG(bool, batch_first) = false;
171  TORCH_ARG(RNNActivation, activation) = RNNActivation::ReLU;
172 };
173 
177 class TORCH_API RNNImpl : public detail::RNNImplBase<RNNImpl> {
178  public:
179  RNNImpl(int64_t input_size, int64_t hidden_size)
180  : RNNImpl(RNNOptions(input_size, hidden_size)) {}
181  explicit RNNImpl(const RNNOptions& options);
182 
184  void pretty_print(std::ostream& stream) const override;
185 
190  RNNOutput forward(const Tensor& input, Tensor state = {});
191 
192  RNNOptions options;
193 };
194 
199 TORCH_MODULE(RNN);
200 
201 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
202 
204 
208 class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
209  public:
210  LSTMImpl(int64_t input_size, int64_t hidden_size)
211  : LSTMImpl(LSTMOptions(input_size, hidden_size)) {}
212  explicit LSTMImpl(const LSTMOptions& options);
213 
218  RNNOutput forward(const Tensor& input, Tensor state = {});
219 };
220 
225 TORCH_MODULE(LSTM);
226 
227 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
228 
230 
234 class TORCH_API GRUImpl : public detail::RNNImplBase<GRUImpl> {
235  public:
236  GRUImpl(int64_t input_size, int64_t hidden_size)
237  : GRUImpl(GRUOptions(input_size, hidden_size)) {}
238  explicit GRUImpl(const GRUOptions& options);
239 
244  RNNOutput forward(const Tensor& input, Tensor state = {});
245 };
246 
251 TORCH_MODULE(GRU);
252 
253 } // namespace nn
254 } // namespace torch
optional< CuDNNMode > cudnn_mode_
The cuDNN RNN mode, if this RNN subclass has any.
Definition: rnn.h:133
int64_t number_of_gates_
The number of gate weights/biases required by the RNN subclass.
Definition: rnn.h:130
Tensor state
The new, updated state that can be fed into the RNN in the next forward step.
Definition: rnn.h:26
The output of a single invocation of an RNN module&#39;s forward() method.
Definition: rnn.h:20
std::vector< Tensor > w_hh
The weights for hidden x hidden gates.
Definition: rnn.h:96
Common options for LSTM and GRU modules.
Definition: rnn.h:32
CuDNNMode
These must line up with the CUDNN mode codes: https://docs.nvidia.com/deeplearning/sdk/cudnn-develope...
Definition: rnn.h:60
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
std::vector< Tensor > b_ih
The biases for input x hidden gates.
Definition: rnn.h:98
A multi-layer gated recurrent unit (GRU) module.
Definition: rnn.h:234
Options for RNN modules.
Definition: rnn.h:145
A multi-layer Elman RNN module with Tanh or ReLU activation.
Definition: rnn.h:177
std::vector< Tensor > w_ih
The weights for input x hidden gates.
Definition: rnn.h:94
A multi-layer long-short-term-memory (LSTM) module.
Definition: rnn.h:208
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Definition: cloneable.h:23
Definition: jit_type.h:17
Tensor output
The result of applying the specific RNN algorithm to the input tensor and input state.
Definition: rnn.h:23
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
RNNOptionsBase options
The RNN&#39;s options.
Definition: rnn.h:91
std::tuple< Tensor, Tensor >(const Tensor &, const Tensor &, TensorList, bool, int64_t, double, bool, bool, bool) RNNFunctionSignature
The function signature of rnn_relu, rnn_tanh and gru.
Definition: rnn.h:113
Base class for all RNN implementations (intended for code sharing).
Definition: rnn.h:56
std::vector< Tensor > b_hh
The biases for hidden x hidden gates.
Definition: rnn.h:100
std::vector< Tensor > flat_weights_
The cached result of the latest flat_weights() call.
Definition: rnn.h:136