Caffe2 - C++ API
A deep learning, cross platform ML framework
inference_lstm_op.h
1 #ifndef LSTM_OP_H_
2 #define LSTM_OP_H_
3 
4 #include <c10/core/Tensor.h>
5 #include <algorithm>
6 #include <sstream>
7 #include <unordered_map>
8 #include <vector>
9 #include "caffe2/core/blob_serialization.h"
10 #include "caffe2/core/operator.h"
11 #include "caffe2/core/tensor.h"
12 #include "caffe2/utils/eigen_utils.h"
13 #include "caffe2/utils/math.h"
14 #include "lstm_utils.h"
15 
16 C10_DECLARE_CAFFE2_OPERATOR(LSTMOp);
17 
18 namespace caffe2 {
19 namespace {
20 
21 using t_tuple = std::tuple<Tensor, Tensor>;
22 
23 struct CellParams {
24  CellParams(
25  const Tensor& _w_ih,
26  const Tensor& _w_hh,
27  const Tensor& _b_ih,
28  const Tensor& _b_hh,
29  CPUContext* _context) {
30  initParams(_w_ih, _w_hh, _b_ih, _b_hh, _context);
31  }
32 
33  CellParams(const CellParams& rhs) {
34  initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context);
35  }
36 
37  CellParams& operator=(const CellParams& rhs) {
38  initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context);
39  return *this;
40  }
41 
42  void initParams(
43  const Tensor& _w_ih,
44  const Tensor& _w_hh,
45  const Tensor& _b_ih,
46  const Tensor& _b_hh,
47  CPUContext* _context) {
48  w_ih = copy_ctor(_w_ih);
49  w_hh = copy_ctor(_w_hh);
50  b_ih = copy_ctor(_b_ih);
51  b_hh = copy_ctor(_b_hh);
52  context = _context;
53  }
54 
55  Tensor w_ih;
56  Tensor w_hh;
57  Tensor b_ih; /* optional */
58  Tensor b_hh; /* optional */
59  CPUContext* context;
60 
61  Tensor linear_ih(const Tensor& input) const {
62  return linear(input, w_ih, b_ih, context);
63  }
64  Tensor linear_hh(const Tensor& h) const {
65  return linear(h, w_hh, b_hh, context);
66  }
67 };
68 
69 struct LSTMCell {
70  explicit LSTMCell(CPUContext* context) : context_(context) {}
71  t_tuple operator()(
72  const Tensor& input,
73  const t_tuple& hidden,
74  const CellParams& params) const {
75  const auto& hx = std::get<0>(hidden);
76  const auto& cx = std::get<1>(hidden);
77  auto linear_ih = params.linear_ih(input);
78  auto linear_hh = params.linear_hh(hx);
79  auto gates = add(linear_ih, linear_hh, context_);
80  auto chunked_gates = chunk(gates, 4, 1, context_);
81  auto ingate = sigmoid(chunked_gates[0]);
82  auto forgetgate = sigmoid(chunked_gates[1]);
83  auto cellgate = tanh(chunked_gates[2], context_);
84  auto outgate = sigmoid(chunked_gates[3]);
85 
86  auto cy =
87  add(mul(forgetgate, cx, context_),
88  mul(ingate, cellgate, context_),
89  context_);
90  auto hy = mul(outgate, tanh(cy, context_), context_);
91  return std::make_tuple(std::move(hy), std::move(cy));
92  }
93  CPUContext* context_;
94 };
95 
96 template <typename output_type, typename hidden_type>
97 struct LayerOutput {
98  output_type outputs;
99  hidden_type final_hidden;
100 
101  LayerOutput(const output_type& _outputs, const hidden_type& _hidden) {
102  outputs = copy_ctor(_outputs);
103  final_hidden = copy_ctor(_hidden);
104  }
105 };
106 
107 template <typename hidden_type, typename param_type>
108 struct Layer {
109  using output_type = LayerOutput<Tensor, hidden_type>;
110  virtual ~Layer() {}
111  virtual output_type operator()(
112  const Tensor& input,
113  const hidden_type& input_hidden,
114  const param_type& params) const = 0;
115 };
116 
117 struct FullLSTMLayer : Layer<t_tuple, CellParams> {
118  FullLSTMLayer(LSTMCell& cell, CPUContext* context)
119  : cell_(cell), context_(context) {}
120 
121  LayerOutput<std::vector<Tensor>, t_tuple> operator()(
122  const std::vector<Tensor>& step_inputs,
123  const std::tuple<Tensor, Tensor>& input_hidden,
124  const CellParams& params) const {
125  std::vector<Tensor> step_outputs;
126  auto hidden = copy_ctor(input_hidden);
127 
128  for (size_t i = 0; i < step_inputs.size(); i++) {
129  hidden = cell_(step_inputs[i], hidden, params);
130  step_outputs.push_back(copy_ctor(std::get<0>(hidden)));
131  }
132 
133  return {step_outputs, hidden};
134  }
135 
136  LayerOutput<Tensor, t_tuple> operator()(
137  const Tensor& inputs,
138  const std::tuple<Tensor, Tensor>& input_hidden,
139  const CellParams& params) const override {
140  auto unstacked_output =
141  (*this)(unbind(inputs, 0, context_), input_hidden, params);
142  return {stack(unstacked_output.outputs, 0, context_),
143  unstacked_output.final_hidden};
144  }
145  LSTMCell cell_;
146  CPUContext* context_;
147 };
148 
149 struct FullBidirectionalLSTMLayer
150  : Layer<std::pair<t_tuple, t_tuple>, std::pair<CellParams, CellParams>> {
151  using bidir_hidden_type = std::pair<t_tuple, t_tuple>;
152  using param_type = std::pair<CellParams, CellParams>;
153  using output_type = LayerOutput<Tensor, bidir_hidden_type>;
154 
155  FullBidirectionalLSTMLayer(LSTMCell& cell, CPUContext* context)
156  : layer_(cell, context), context_(context) {}
157 
158  output_type operator()(
159  const Tensor& input,
160  const bidir_hidden_type& input_hidden,
161  const param_type& params) const override {
162  std::vector<Tensor> outputs;
163  auto step_inputs = unbind(input, 0, context_);
164  auto fw_result = layer_(step_inputs, input_hidden.first, params.first);
165  auto fw_output = stack(fw_result.outputs, 0, context_);
166  outputs.push_back(copy_ctor(fw_output));
167  auto rev_step_inputs = reverse(std::move(step_inputs));
168  auto rev_result =
169  layer_(rev_step_inputs, input_hidden.second, params.second);
170  std::reverse(rev_result.outputs.begin(), rev_result.outputs.end());
171  auto rev_output = stack(rev_result.outputs, 0, context_);
172  outputs.push_back(copy_ctor(rev_output));
173  return {cat(outputs, fw_output.dim() - 1, context_),
174  std::make_pair(
175  std::move(fw_result.final_hidden),
176  std::move(rev_result.final_hidden))};
177  }
178 
179  inline std::vector<Tensor> reverse(std::vector<Tensor>&& x) const {
180  std::reverse(x.begin(), x.end());
181  return std::move(x);
182  }
183 
184  private:
185  FullLSTMLayer layer_;
186  CPUContext* context_;
187 };
188 
189 template <typename hidden_type, typename weight_type>
190 LayerOutput<Tensor, std::vector<hidden_type>> apply_layer_stack(
191  const Layer<hidden_type, weight_type>& layer,
192  const Tensor& input,
193  const std::vector<hidden_type>& hiddens,
194  const std::vector<weight_type>& weights,
195  int64_t num_layers) {
196  CAFFE_ENFORCE(
197  num_layers == hiddens.size(),
198  "Expected more hidden states in stacked_rnn");
199  CAFFE_ENFORCE(
200  num_layers == weights.size(), "Expected more weights in stacked_rnn");
201 
202  auto layer_input = input.UnsafeSharedInstance();
203  auto hidden_it = hiddens.begin();
204  auto weight_it = weights.begin();
205  std::vector<hidden_type> final_hiddens(num_layers);
206  for (int64_t l = 0; l < num_layers; ++l) {
207  auto layer_output = layer(layer_input, *(hidden_it++), *(weight_it++));
208  final_hiddens.at(l) = std::move(layer_output.final_hidden);
209  layer_input = std::move(layer_output.outputs);
210  }
211  return {layer_input, final_hiddens};
212 }
213 
214 std::tuple<Tensor, Tensor, Tensor> _lstm_impl(
215  const Tensor& input,
216  const std::vector<CellParams>& params,
217  const Tensor& hx,
218  const Tensor& cx,
219  int64_t num_layers,
220  bool bidirectional,
221  CPUContext* context) {
222  using stack_output = LayerOutput<Tensor, std::vector<t_tuple>>;
223  auto layer_hx = unbind(hx, 0, context);
224  auto layer_cx = unbind(cx, 0, context);
225  int64_t total_layers = layer_hx.size();
226  std::vector<std::tuple<Tensor, Tensor>> hiddens;
227  hiddens.reserve(total_layers);
228  for (int64_t i = 0; i < total_layers; ++i) {
229  hiddens.emplace_back(std::move(layer_hx[i]), std::move(layer_cx[i]));
230  }
231  LSTMCell cell(context);
232  std::shared_ptr<stack_output> stack_output_ptr;
233  if (bidirectional) {
234  auto bidir_result = apply_layer_stack(
235  FullBidirectionalLSTMLayer{cell, context},
236  input,
237  pair_vec(hiddens),
238  pair_vec(params),
239  num_layers);
240  stack_output_ptr.reset(new stack_output(
241  bidir_result.outputs,
242  unpair_vec(std::move(bidir_result.final_hidden))));
243  } else {
244  auto result = apply_layer_stack(
245  FullLSTMLayer{cell, context}, input, hiddens, params, num_layers);
246  stack_output_ptr = std::make_shared<stack_output>(std::move(result));
247  }
248 
249  std::vector<Tensor> hy, cy;
250  hy.reserve(total_layers);
251  cy.reserve(total_layers);
252  for (auto& hidden : stack_output_ptr->final_hidden) {
253  hy.push_back(std::move(std::get<0>(hidden)));
254  cy.push_back(std::move(std::get<1>(hidden)));
255  }
256  return std::make_tuple(
257  std::move(stack_output_ptr->outputs),
258  stack(hy, 0, context),
259  stack(cy, 0, context));
260 }
261 
262 // Parses a flat list of parameter tensors into a list of CellParams
263 std::vector<CellParams> gather_params(
264  const std::vector<Tensor>& params,
265  bool has_biases,
266  CPUContext* context) {
267  Tensor undefined;
268  std::vector<CellParams> result;
269  if (has_biases) {
270  CAFFE_ENFORCE_EQ(
271  params.size() % 4, 0, "got an incorrect number of LSTM parameters");
272  for (size_t i = 0; i < params.size(); i += 4) {
273  result.emplace_back(
274  params[i], params[i + 1], params[i + 2], params[i + 3], context);
275  }
276  } else {
277  CAFFE_ENFORCE_EQ(
278  params.size() % 2, 0, "got an incorrect number of LSTM parameters");
279  for (size_t i = 0; i < params.size(); i += 2) {
280  result.emplace_back(
281  params[i], params[i + 1], undefined, undefined, context);
282  }
283  }
284  return result;
285 }
286 
287 class InferenceLSTMOp : public Operator<CPUContext> {
288  public:
289  template <class... Args>
290  explicit InferenceLSTMOp(Args&&... args)
291  : Operator(std::forward<Args>(args)...),
292  num_layers_(this->template GetSingleArgument<int64_t>("num_layers", 1)),
293  bidirectional_(
294  this->template GetSingleArgument<bool>("bidirectional", false)),
295  has_biases_(this->template GetSingleArgument<bool>("has_biases", true)),
296  batch_first_(
297  this->template GetSingleArgument<bool>("batch_first", false)) {}
298 
299  bool RunOnDevice() override;
300 
301  protected:
302  int64_t num_layers_;
303  bool bidirectional_;
304  bool has_biases_;
305  bool batch_first_;
306 };
307 
308 } // namespace
309 } // namespace caffe2
310 #endif // LSTM_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13