Caffe2 - C++ API
A deep learning, cross platform ML framework
lstm_unit_op.h
1 
17 #ifndef CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
18 #define CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/conversions.h"
23 
24 namespace caffe2 {
25 namespace detail {
26 template <typename T>
27 inline T sigmoid(T x) {
28  return 1. / (1. + exp(-x));
29 }
30 
31 template <typename T>
32 inline T host_tanh(T x) {
33  return 2. * sigmoid(2. * x) - 1.;
34 }
35 
36 template <typename T, typename Context>
37 void LSTMUnit(
38  int N,
39  int D,
40  int t,
41  const T* H_prev,
42  const T* C_prev,
43  const T* X,
44  const int32_t* seqLengths,
45  bool drop_states,
46  T* C,
47  T* H,
48  const float forget_bias,
49  Context* /*context*/) {
50  for (int n = 0; n < N; ++n) {
51  const bool valid = seqLengths == nullptr || t < seqLengths[n];
52 
53  for (int d = 0; d < D; ++d) {
54  if (!valid) {
55  if (drop_states) {
56  H[d] = 0;
57  C[d] = 0;
58  } else {
59  H[d] = H_prev[d];
60  C[d] = C_prev[d];
61  }
62  } else {
63  const T i = sigmoid(X[d]);
64  const T f = sigmoid(X[1 * D + d] + convert::To<float, T>(forget_bias));
65  const T o = sigmoid(X[2 * D + d]);
66  const T g = host_tanh(X[3 * D + d]);
67  const T c_prev = C_prev[d];
68  const T c = f * c_prev + i * g;
69  C[d] = c;
70  const T host_tanh_c = host_tanh(c);
71  H[d] = o * host_tanh_c;
72  }
73  }
74  H_prev += D;
75  C_prev += D;
76  X += 4 * D;
77  C += D;
78  H += D;
79  }
80 }
81 
82 template <typename T, typename Context>
83 void LSTMUnitGradient(
84  int N,
85  int D,
86  int t,
87  const T* C_prev,
88  const T* X,
89  const int32_t* seqLengths,
90  const T* C,
91  const T* H,
92  const T* C_diff,
93  const T* H_diff,
94  bool drop_states,
95  T* H_prev_diff,
96  T* C_prev_diff,
97  T* X_diff,
98  const float forget_bias,
99  Context* /*context*/) {
100  for (int n = 0; n < N; ++n) {
101  const bool valid = seqLengths == nullptr || t < seqLengths[n];
102 
103  for (int d = 0; d < D; ++d) {
104  T* c_prev_diff = C_prev_diff + d;
105  T* h_prev_diff = H_prev_diff + d;
106  T* i_diff = X_diff + d;
107  T* f_diff = X_diff + 1 * D + d;
108  T* o_diff = X_diff + 2 * D + d;
109  T* g_diff = X_diff + 3 * D + d;
110 
111  if (!valid) {
112  if (drop_states) {
113  *h_prev_diff = 0;
114  *c_prev_diff = 0;
115  } else {
116  *h_prev_diff = H_diff[d];
117  *c_prev_diff = C_diff[d];
118  }
119  *i_diff = 0;
120  *f_diff = 0;
121  *o_diff = 0;
122  *g_diff = 0;
123  } else {
124  const T i = sigmoid(X[d]);
125  const T f = sigmoid(X[1 * D + d] + convert::To<float, T>(forget_bias));
126  const T o = sigmoid(X[2 * D + d]);
127  const T g = host_tanh(X[3 * D + d]);
128  const T c_prev = C_prev[d];
129  const T c = C[d];
130  const T host_tanh_c = host_tanh(c);
131  const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - host_tanh_c * host_tanh_c);
132  *c_prev_diff = c_term_diff * f;
133  *h_prev_diff = 0; // not used in 'valid' case
134  *i_diff = c_term_diff * g * i * (1 - i);
135  *f_diff = c_term_diff * c_prev * f * (1 - f);
136  *o_diff = H_diff[d] * host_tanh_c * o * (1 - o);
137  *g_diff = c_term_diff * i * (1 - g * g);
138  }
139  }
140  C_prev += D;
141  X += 4 * D;
142  C += D;
143  H += D;
144  C_diff += D;
145  H_diff += D;
146  X_diff += 4 * D;
147  H_prev_diff += D;
148  C_prev_diff += D;
149  }
150 }
151 } // namespace detail
152 
153 template <typename Context>
154 class LSTMUnitOp : public Operator<Context> {
155  public:
156  LSTMUnitOp(const OperatorDef& operator_def, Workspace* ws)
157  : Operator<Context>(operator_def, ws),
158  forget_bias_(
159  static_cast<float>(OperatorBase::template GetSingleArgument<float>(
160  "forget_bias",
161  0.0))),
162  sequence_lengths_(OperatorBase::template GetSingleArgument<bool>(
163  "sequence_lengths",
164  true)),
165  drop_states_(OperatorBase::template GetSingleArgument<bool>(
166  "drop_states",
167  false)) {}
168  USE_OPERATOR_CONTEXT_FUNCTIONS;
170 
171  template <typename T>
172  bool DoRunWithType() {
173  // handle potentially-missing sequence lengths input
174  const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
175 
176  // Extract N
177  const auto N = Input(CELL_T_M_1).dim(1);
178 
179  // Gates: 1xNxG
180  const auto G = Input(GATES).dim(2);
181  const auto D = Input(CELL_T_M_1).dim(2);
182 
183  CAFFE_ENFORCE_EQ(4 * D, G);
184  const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
185  const auto* C_prev = Input(CELL_T_M_1).template data<T>();
186  const auto* X = Input(GATES).template data<T>();
187 
188  const int32_t* seqLengths = nullptr;
189  if (sequence_lengths_) {
190  CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
191  seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
192  }
193 
194  const auto t = static_cast<OperatorBase*>(this)
195  ->Input<Tensor<CPUContext>>(TIMESTEP)
196  .template data<int32_t>()[0];
197  Output(CELL_T)->ResizeLike(Input(CELL_T_M_1));
198  auto* C = Output(CELL_T)->template mutable_data<T>();
199  Output(HIDDEN_T)->ResizeLike(Input(CELL_T_M_1));
200  auto* H = Output(HIDDEN_T)->template mutable_data<T>();
201  detail::LSTMUnit<T, Context>(
202  N,
203  D,
204  t,
205  H_prev,
206  C_prev,
207  X,
208  seqLengths,
209  drop_states_,
210  C,
211  H,
212  forget_bias_,
213  &context_);
214  return true;
215  }
216 
217  bool RunOnDevice() override {
218  return DoRunWithType<float>();
219  }
220 
221  protected:
222  INPUT_TAGS(HIDDEN_T_M_1, CELL_T_M_1, GATES, SEQ_LENGTHS);
223  // additional input tags are determined dynamically based on whether
224  // sequence_lengths is present.
225  OUTPUT_TAGS(HIDDEN_T, CELL_T);
226 
227  float forget_bias_;
228  bool sequence_lengths_;
229 
230  private:
231  bool drop_states_;
232 };
233 
234 template <typename Context>
235 class LSTMUnitGradientOp : public Operator<Context> {
236  public:
237  LSTMUnitGradientOp(const OperatorDef& operator_def, Workspace* ws)
238  : Operator<Context>(operator_def, ws),
239  forget_bias_(
240  static_cast<float>(OperatorBase::template GetSingleArgument<float>(
241  "forget_bias",
242  0.0))),
243  sequence_lengths_(OperatorBase::template GetSingleArgument<bool>(
244  "sequence_lengths",
245  true)),
246  drop_states_(OperatorBase::template GetSingleArgument<bool>(
247  "drop_states",
248  false)) {}
249  USE_OPERATOR_CONTEXT_FUNCTIONS;
250 
251  template <typename T>
252  bool DoRunWithType() {
253  // handle potentially-missing sequence lengths input
254  const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
255  const size_t TIMESTEP = inputOffset;
256  const size_t HIDDEN_T = inputOffset + 1;
257  const size_t CELL_T = inputOffset + 2;
258  const size_t HIDDEN_T_GRAD = inputOffset + 3;
259  const size_t CELL_T_GRAD = inputOffset + 4;
260 
261  // Extract N
262  const auto N = Input(CELL_T_M_1).dim(1);
263 
264  // Gates: 1xNxG
265  const auto G = Input(GATES).dim(2);
266  const auto D = Input(CELL_T_M_1).dim(2);
267 
268  CAFFE_ENFORCE_EQ(4 * D, G);
269  const auto* C_prev = Input(CELL_T_M_1).template data<T>();
270  const auto* X = Input(GATES).template data<T>();
271  const auto t = static_cast<OperatorBase*>(this)
272  ->Input<Tensor<CPUContext>>(TIMESTEP)
273  .template data<int32_t>()[0];
274  const auto* C = Input(CELL_T).template data<T>();
275  const auto* H = Input(HIDDEN_T).template data<T>();
276  const auto* C_diff = Input(CELL_T_GRAD).template data<T>();
277  const auto* H_diff = Input(HIDDEN_T_GRAD).template data<T>();
278 
279  const int32_t* seqLengths = nullptr;
280  if (sequence_lengths_) {
281  CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
282  seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
283  }
284 
285  Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1));
286  auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data<T>();
287  Output(CELL_T_M_1_GRAD)->ResizeLike(Input(CELL_T_M_1));
288  auto* C_prev_diff = Output(CELL_T_M_1_GRAD)->template mutable_data<T>();
289  Output(GATES_GRAD)->ResizeLike(Input(GATES));
290  auto* X_diff = Output(GATES_GRAD)->template mutable_data<T>();
291 
292  detail::LSTMUnitGradient<T, Context>(
293  N,
294  D,
295  t,
296  C_prev,
297  X,
298  seqLengths,
299  C,
300  H,
301  C_diff,
302  H_diff,
303  drop_states_,
304  H_prev_diff,
305  C_prev_diff,
306  X_diff,
307  forget_bias_,
308  &context_);
309  return true;
310  }
311 
312  bool RunOnDevice() override {
313  return DoRunWithType<float>();
314  }
315 
316  protected:
317  INPUT_TAGS(HIDDEN_T_M_1, CELL_T_M_1, GATES, SEQ_LENGTHS);
318  // additional input tags are determined dynamically based on whether
319  // sequence_lengths is present.
320  OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, CELL_T_M_1_GRAD, GATES_GRAD);
321 
322  float forget_bias_;
323  bool sequence_lengths_;
324 
325  private:
326  bool drop_states_;
327 };
328 } // namespace caffe2
329 
330 #endif // CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.