1 #include <torch/nn/modules/rnn.h> 3 #include <torch/nn/modules/dropout.h> 4 #include <torch/types.h> 5 #include <torch/utils.h> 7 #include <c10/util/Exception.h> 16 #include <unordered_set> 25 RNNOptionsBase::RNNOptionsBase(int64_t input_size, int64_t hidden_size)
26 : input_size_(input_size), hidden_size_(hidden_size) {}
30 template <
typename Derived>
31 RNNImplBase<Derived>::RNNImplBase(
32 const RNNOptionsBase& options_,
33 optional<CuDNNMode> cudnn_mode,
34 int64_t number_of_gates)
36 number_of_gates_(number_of_gates),
37 cudnn_mode_(
std::move(cudnn_mode)) {
41 template <
typename Derived>
43 w_ih.resize(options.layers_);
44 w_hh.resize(options.layers_);
45 b_ih.resize(options.layers_);
46 b_hh.resize(options.layers_);
48 const int64_t gate_size = options.hidden_size_ * number_of_gates_;
50 for (int64_t layer = 0; layer < options.layers_; ++layer) {
51 const int64_t input_size =
52 (layer == 0) ? options.input_size_ : options.hidden_size_;
53 w_ih[layer] = this->register_parameter(
54 "weight_ih_l" + std::to_string(layer),
55 torch::empty({gate_size, input_size}));
56 w_hh[layer] = this->register_parameter(
57 "weight_hh_l" + std::to_string(layer),
58 torch::empty({gate_size, options.hidden_size_}));
60 if (options.with_bias_) {
61 b_ih[layer] = this->register_parameter(
62 "bias_ih_l" + std::to_string(layer), torch::empty({gate_size}));
63 b_hh[layer] = this->register_parameter(
64 "bias_hh_l" + std::to_string(layer), torch::empty({gate_size}));
70 const auto stdv = 1.0 / std::sqrt(options.hidden_size_);
71 for (
auto& p : this->parameters()) {
72 p.uniform_(-stdv, stdv);
79 template <
typename Derived>
84 nn::Module::to(device, dtype, non_blocking);
88 template <
typename Derived>
90 nn::Module::to(dtype, non_blocking);
94 template <
typename Derived>
96 nn::Module::to(device, non_blocking);
100 template <
typename Derived>
102 const std::string name = this->name();
103 const std::string name_without_impl = name.substr(0, name.size() - 4);
104 stream << name_without_impl <<
"(input_size=" << options.input_size_
105 <<
", hidden_size=" << options.hidden_size_
106 <<
", layers=" << options.layers_ <<
", dropout=" << options.dropout_
110 template <
typename Derived>
113 flat_weights_ = flat_weights();
115 if (!cudnn_mode_ || !torch::cudnn_is_acceptable(w_ih.at(0))) {
120 torch::_cudnn_rnn_flatten_weight(
122 options.with_bias_ ? 4 : 2,
124 static_cast<int64_t>(*cudnn_mode_),
125 options.hidden_size_,
127 options.batch_first_,
128 options.bidirectional_);
131 template <
typename Derived>
133 std::function<RNNFunctionSignature>
function,
136 if (!state.defined()) {
138 const auto batch_size = input.size(options.batch_first_ ? 0 : 1);
139 state = torch::zeros(
140 {options.layers_, batch_size, options.hidden_size_}, input.
options());
143 std::tie(output, new_state) =
function(
151 options.bidirectional_,
152 options.batch_first_);
153 return {output, new_state};
156 template <
typename Derived>
160 std::vector<Tensor> flat;
161 for (int64_t layer = 0; layer < options.layers_; layer++) {
162 flat.push_back(w_ih[layer]);
163 flat.push_back(w_hh[layer]);
164 if (options.with_bias_) {
165 flat.push_back(b_ih[layer]);
166 flat.push_back(b_hh[layer]);
172 template <
typename Derived>
178 std::unordered_set<void*> unique_data_ptrs;
179 auto params = this->parameters();
180 unique_data_ptrs.reserve(params.size());
181 for (
const auto& p : params) {
182 unique_data_ptrs.emplace(p.data_ptr());
184 return unique_data_ptrs.size() != params.size();
194 RNNOptions::RNNOptions(int64_t input_size, int64_t hidden_size)
195 : input_size_(input_size), hidden_size_(hidden_size) {}
198 return activation(RNNActivation::Tanh);
202 return activation(RNNActivation::ReLU);
208 .layers(options.layers_)
209 .with_bias(options.with_bias_)
210 .dropout(options.dropout_)
211 .bidirectional(options.bidirectional_)
212 .batch_first(options.batch_first_),
213 static_cast<CuDNNMode
>(options.activation_)),
216 void RNNImpl::pretty_print(std::ostream& stream)
const {
217 stream <<
"torch::nn::RNN(input_size=" << options.input_size_
218 <<
", hidden_size=" << options.hidden_size_
219 <<
", layers=" << options.layers_ <<
", dropout=" << options.dropout_
221 << (options.activation_ == RNNActivation::Tanh ?
"tanh" :
"relu")
226 switch (options.activation_) {
227 case RNNActivation::ReLU:
228 return generic_forward(
229 static_cast<RNNFunctionSignature*>(&torch::rnn_relu),
232 case RNNActivation::Tanh:
233 return generic_forward(
234 static_cast<RNNFunctionSignature*>(&torch::rnn_tanh),
238 AT_ERROR(
"Unhandled RNN activation function!");
257 if (!state.defined()) {
259 const auto batch_size = input.size(options.batch_first_ ? 0 : 1);
260 state = torch::zeros(
261 {2, options.layers_, batch_size, options.hidden_size_},
264 Tensor output, hidden_state, cell_state;
265 std::tie(output, hidden_state, cell_state) = torch::lstm(
267 {state[0], state[1]},
273 options.bidirectional_,
274 options.batch_first_);
275 return {output, torch::stack({hidden_state, cell_state})};
287 return generic_forward(
288 static_cast<RNNFunctionSignature*>(&torch::gru), input, std::move(state));
The output of a single invocation of an RNN module's forward() method.
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
Common options for LSTM and GRU modules.
Represents a a compute device on which a tensor is located.
Base class for all RNN implementations (intended for code sharing).