Caffe2 - C++ API
A deep learning, cross platform ML framework
rnn.cpp
1 #include <torch/nn/modules/rnn.h>
2 
3 #include <torch/nn/modules/dropout.h>
4 #include <torch/types.h>
5 #include <torch/utils.h>
6 
7 #include <c10/util/Exception.h>
8 
9 #include <array>
10 #include <cmath>
11 #include <cstdint>
12 #include <functional>
13 #include <memory>
14 #include <string>
15 #include <tuple>
16 #include <unordered_set>
17 #include <utility>
18 #include <vector>
19 
20 namespace torch {
21 namespace nn {
22 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNOptionsBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
23 
24 namespace detail {
25 RNNOptionsBase::RNNOptionsBase(int64_t input_size, int64_t hidden_size)
26  : input_size_(input_size), hidden_size_(hidden_size) {}
27 
28 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
29 
30 template <typename Derived>
31 RNNImplBase<Derived>::RNNImplBase(
32  const RNNOptionsBase& options_,
33  optional<CuDNNMode> cudnn_mode,
34  int64_t number_of_gates)
35  : options(options_),
36  number_of_gates_(number_of_gates),
37  cudnn_mode_(std::move(cudnn_mode)) {
38  reset();
39 }
40 
41 template <typename Derived>
43  w_ih.resize(options.layers_);
44  w_hh.resize(options.layers_);
45  b_ih.resize(options.layers_);
46  b_hh.resize(options.layers_);
47 
48  const int64_t gate_size = options.hidden_size_ * number_of_gates_;
49 
50  for (int64_t layer = 0; layer < options.layers_; ++layer) {
51  const int64_t input_size =
52  (layer == 0) ? options.input_size_ : options.hidden_size_;
53  w_ih[layer] = this->register_parameter(
54  "weight_ih_l" + std::to_string(layer),
55  torch::empty({gate_size, input_size}));
56  w_hh[layer] = this->register_parameter(
57  "weight_hh_l" + std::to_string(layer),
58  torch::empty({gate_size, options.hidden_size_}));
59 
60  if (options.with_bias_) {
61  b_ih[layer] = this->register_parameter(
62  "bias_ih_l" + std::to_string(layer), torch::empty({gate_size}));
63  b_hh[layer] = this->register_parameter(
64  "bias_hh_l" + std::to_string(layer), torch::empty({gate_size}));
65  }
66  }
67 
68  {
69  NoGradGuard no_grad;
70  const auto stdv = 1.0 / std::sqrt(options.hidden_size_);
71  for (auto& p : this->parameters()) {
72  p.uniform_(-stdv, stdv);
73  }
74  }
75 
76  flatten_parameters();
77 }
78 
79 template <typename Derived>
81  torch::Device device,
82  torch::Dtype dtype,
83  bool non_blocking) {
84  nn::Module::to(device, dtype, non_blocking);
85  flatten_parameters();
86 }
87 
88 template <typename Derived>
89 void RNNImplBase<Derived>::to(torch::Dtype dtype, bool non_blocking) {
90  nn::Module::to(dtype, non_blocking);
91  flatten_parameters();
92 }
93 
94 template <typename Derived>
95 void RNNImplBase<Derived>::to(torch::Device device, bool non_blocking) {
96  nn::Module::to(device, non_blocking);
97  flatten_parameters();
98 }
99 
100 template <typename Derived>
101 void RNNImplBase<Derived>::pretty_print(std::ostream& stream) const {
102  const std::string name = this->name();
103  const std::string name_without_impl = name.substr(0, name.size() - 4);
104  stream << name_without_impl << "(input_size=" << options.input_size_
105  << ", hidden_size=" << options.hidden_size_
106  << ", layers=" << options.layers_ << ", dropout=" << options.dropout_
107  << ")";
108 }
109 
110 template <typename Derived>
112  // Cache the flattened weight and bias vector.
113  flat_weights_ = flat_weights();
114 
115  if (!cudnn_mode_ || !torch::cudnn_is_acceptable(w_ih.at(0))) {
116  return;
117  }
118 
119  NoGradGuard no_grad;
120  torch::_cudnn_rnn_flatten_weight(
121  flat_weights_,
122  /*weight_stride0=*/options.with_bias_ ? 4 : 2,
123  options.input_size_,
124  static_cast<int64_t>(*cudnn_mode_),
125  options.hidden_size_,
126  options.layers_,
127  /*batch_first=*/options.batch_first_,
128  /*bidirectional=*/options.bidirectional_);
129 }
130 
131 template <typename Derived>
133  std::function<RNNFunctionSignature> function,
134  const Tensor& input,
135  Tensor state) {
136  if (!state.defined()) {
137  // #layers, batch size, state size
138  const auto batch_size = input.size(options.batch_first_ ? 0 : 1);
139  state = torch::zeros(
140  {options.layers_, batch_size, options.hidden_size_}, input.options());
141  }
142  Tensor output, new_state;
143  std::tie(output, new_state) = function(
144  input,
145  std::move(state),
146  flat_weights_,
147  options.with_bias_,
148  options.layers_,
149  options.dropout_,
150  this->is_training(),
151  options.bidirectional_,
152  options.batch_first_);
153  return {output, new_state};
154 }
155 
156 template <typename Derived>
157 std::vector<Tensor> RNNImplBase<Derived>::flat_weights() const {
158  // Organize all weights in a flat vector in the order
159  // (w_ih, w_hh, b_ih, b_hh), repeated for each layer (next to each other).
160  std::vector<Tensor> flat;
161  for (int64_t layer = 0; layer < options.layers_; layer++) {
162  flat.push_back(w_ih[layer]);
163  flat.push_back(w_hh[layer]);
164  if (options.with_bias_) {
165  flat.push_back(b_ih[layer]);
166  flat.push_back(b_hh[layer]);
167  }
168  }
169  return flat;
170 }
171 
172 template <typename Derived>
174  // If any parameters alias, we fall back to the slower, copying code path.
175  // This is a sufficient check, because overlapping parameter buffers that
176  // don't completely alias would break the assumptions of the uniqueness check
177  // in Module.named_parameters().
178  std::unordered_set<void*> unique_data_ptrs;
179  auto params = this->parameters();
180  unique_data_ptrs.reserve(params.size());
181  for (const auto& p : params) {
182  unique_data_ptrs.emplace(p.data_ptr());
183  }
184  return unique_data_ptrs.size() != params.size();
185 }
186 
187 template class RNNImplBase<LSTMImpl>;
188 template class RNNImplBase<GRUImpl>;
189 template class RNNImplBase<RNNImpl>;
190 } // namespace detail
191 
192 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
193 
194 RNNOptions::RNNOptions(int64_t input_size, int64_t hidden_size)
195  : input_size_(input_size), hidden_size_(hidden_size) {}
196 
197 RNNOptions& RNNOptions::tanh() {
198  return activation(RNNActivation::Tanh);
199 }
200 
201 RNNOptions& RNNOptions::relu() {
202  return activation(RNNActivation::ReLU);
203 }
204 
205 RNNImpl::RNNImpl(const RNNOptions& options)
207  detail::RNNOptionsBase(options.input_size_, options.hidden_size_)
208  .layers(options.layers_)
209  .with_bias(options.with_bias_)
210  .dropout(options.dropout_)
211  .bidirectional(options.bidirectional_)
212  .batch_first(options.batch_first_),
213  static_cast<CuDNNMode>(options.activation_)),
214  options(options) {}
215 
216 void RNNImpl::pretty_print(std::ostream& stream) const {
217  stream << "torch::nn::RNN(input_size=" << options.input_size_
218  << ", hidden_size=" << options.hidden_size_
219  << ", layers=" << options.layers_ << ", dropout=" << options.dropout_
220  << ", activation="
221  << (options.activation_ == RNNActivation::Tanh ? "tanh" : "relu")
222  << ")";
223 }
224 
225 RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) {
226  switch (options.activation_) {
227  case RNNActivation::ReLU:
228  return generic_forward(
229  static_cast<RNNFunctionSignature*>(&torch::rnn_relu),
230  input,
231  std::move(state));
232  case RNNActivation::Tanh:
233  return generic_forward(
234  static_cast<RNNFunctionSignature*>(&torch::rnn_tanh),
235  input,
236  std::move(state));
237  default:
238  AT_ERROR("Unhandled RNN activation function!");
239  }
240 }
241 
242 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
243 
244 LSTMImpl::LSTMImpl(const LSTMOptions& options)
246  options,
247  CuDNNMode::LSTM,
248  /*number_of_gates=*/4) {}
249 
250 RNNOutput LSTMImpl::forward(const Tensor& input, Tensor state) {
251  // It would be trickier to adapt the `generic_forward` for the LSTM because
252  // its output has a different dimensionality (3-tuple vs. 2-tuple), while we
253  // always return one state variable (stacking the hidden/cell state into one),
254  // which also makes the state variables going into the `generic_forward`, and
255  // the way we default-initialize the state when it is not passed, slightly
256  // different. So we just re-implement it specifically for the LSTM here.
257  if (!state.defined()) {
258  // 2 for hidden state and cell state, then #layers, batch size, state size
259  const auto batch_size = input.size(options.batch_first_ ? 0 : 1);
260  state = torch::zeros(
261  {2, options.layers_, batch_size, options.hidden_size_},
262  input.options());
263  }
264  Tensor output, hidden_state, cell_state;
265  std::tie(output, hidden_state, cell_state) = torch::lstm(
266  input,
267  {state[0], state[1]},
268  flat_weights_,
269  options.with_bias_,
270  options.layers_,
271  options.dropout_,
272  this->is_training(),
273  options.bidirectional_,
274  options.batch_first_);
275  return {output, torch::stack({hidden_state, cell_state})};
276 }
277 
278 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
279 
280 GRUImpl::GRUImpl(const GRUOptions& options)
282  options,
283  CuDNNMode::GRU,
284  /*number_of_gates=*/3) {}
285 
286 RNNOutput GRUImpl::forward(const Tensor& input, Tensor state) {
287  return generic_forward(
288  static_cast<RNNFunctionSignature*>(&torch::gru), input, std::move(state));
289 }
290 } // namespace nn
291 } // namespace torch
The output of a single invocation of an RNN module&#39;s forward() method.
Definition: rnn.h:20
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
Definition: TensorMethods.h:42
Common options for LSTM and GRU modules.
Definition: rnn.h:32
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Options for RNN modules.
Definition: rnn.h:145
Definition: jit_type.h:17
Base class for all RNN implementations (intended for code sharing).
Definition: rnn.h:56