Caffe2 - C++ API
A deep learning, cross platform ML framework
gru_unit_op.h
1 
17 #ifndef CAFFE2_OPERATORS_GRU_UNIT_OP_H_
18 #define CAFFE2_OPERATORS_GRU_UNIT_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 namespace detail {
26 
27 template <typename T>
28 inline T sigmoid(T x) {
29  return 1.0f / (1.0f + exp(-x));
30 }
31 
32 template <typename T>
33 inline T host_tanh(T x) {
34  return 2.0f * sigmoid(2.0f * x) - 1.0f;
35 }
36 
37 template <typename T, typename Context>
38 void GRUUnit(
39  int N,
40  int D,
41  int t,
42  const T* H_prev,
43  const T* X,
44  const int32_t* seqLengths,
45  bool drop_states,
46  T* H,
47  Context* /*context*/) {
48  for (int n = 0; n < N; ++n) {
49  const bool valid = seqLengths == nullptr || t < seqLengths[n];
50 
51  for (int d = 0; d < D; ++d) {
52  if (!valid) {
53  if (drop_states) {
54  H[d] = 0;
55  } else {
56  H[d] = H_prev[d];
57  }
58  } else {
59  const T update = X[1 * D + d];
60  const T output = X[2 * D + d];
61  T sigmoid_update = sigmoid(update);
62  H[d] = H_prev[d] * sigmoid_update +
63  host_tanh(output) * (1.0f - sigmoid_update);
64  }
65  }
66 
67  H_prev += D;
68  X += 3 * D;
69  H += D;
70  }
71 }
72 
73 template <typename T, typename Context>
74 void GRUUnitGradient(
75  int N,
76  int D,
77  int t,
78  const T* H_prev,
79  const T* X,
80  const int32_t* seqLengths,
81  const T* H,
82  const T* H_diff,
83  bool drop_states,
84  T* H_prev_diff,
85  T* X_diff,
86  Context* /*context*/) {
87  for (int n = 0; n < N; ++n) {
88  const bool valid = seqLengths == nullptr || t < seqLengths[n];
89 
90  for (int d = 0; d < D; ++d) {
91  T* h_prev_diff = H_prev_diff + d;
92  T* reset_diff = X_diff + 0 * D + d;
93  T* update_diff = X_diff + 1 * D + d;
94  T* output_diff = X_diff + 2 * D + d;
95 
96  if (!valid) {
97  if (drop_states) {
98  *h_prev_diff = 0;
99  } else {
100  *h_prev_diff = H_diff[d];
101  }
102  *reset_diff = 0;
103  *update_diff = 0;
104  *output_diff = 0;
105  } else {
106  // Calculate Gate Outputs
107  const T u = sigmoid(X[1 * D + d]);
108  const T o = host_tanh(X[2 * D + d]);
109 
110  *h_prev_diff = H_diff[d] * u;
111  *reset_diff = 0; // 0 contribution to gradient from this operation
112  *update_diff = (H_diff[d] * H_prev[d] - H_diff[d] * o) * u * (1.0f - u);
113  *output_diff = H_diff[d] * (1.0f - u) * (1.0f - o * o);
114  }
115  }
116 
117  H_prev += D;
118  X += 3 * D;
119  H += D;
120  H_diff += D;
121  X_diff += 3 * D;
122  H_prev_diff += D;
123  }
124 }
125 
126 } // namespace detail
127 
128 template <typename T, typename Context>
129 class GRUUnitOp : public Operator<Context> {
130  public:
131  GRUUnitOp(const OperatorDef& operator_def, Workspace* ws)
132  : Operator<Context>(operator_def, ws),
133  drop_states_(OperatorBase::template GetSingleArgument<bool>(
134  "drop_states",
135  false)),
136  sequence_lengths_(OperatorBase::template GetSingleArgument<bool>(
137  "sequence_lengths",
138  true)) {}
139  USE_OPERATOR_CONTEXT_FUNCTIONS;
140 
141  bool RunOnDevice() override {
142  // handle potentially-missing sequence lengths input
143  const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
144 
145  // Extract N
146  const auto N = Input(HIDDEN_T_M_1).dim(1);
147 
148  // Gates: 1xNxG
149  const auto G = Input(GATES).dim(2);
150  const auto D = Input(HIDDEN_T_M_1).dim(2);
151 
152  CAFFE_ENFORCE_EQ(3 * D, G);
153  const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
154  const auto* X = Input(GATES).template data<T>();
155 
156  const int32_t* seqLengths = nullptr;
157  if (sequence_lengths_) {
158  CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
159  seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
160  }
161 
162  const auto t = static_cast<OperatorBase*>(this)->
163  Input<Tensor<CPUContext>>(TIMESTEP).template data<int32_t>()[0];
164  Output(HIDDEN_T)->ResizeLike(Input(HIDDEN_T_M_1));
165  auto* H = Output(HIDDEN_T)->template mutable_data<T>();
166 
167  detail::GRUUnit<T, Context>(
168  N, D, t, H_prev, X, seqLengths, drop_states_, H, &context_);
169  return true;
170  }
171 
172  protected:
173  INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS);
174  // additional input tags are determined dynamically based on whether
175  // sequence_lengths is present.
176  OUTPUT_TAGS(HIDDEN_T);
177 
178  private:
179  bool drop_states_;
180  bool sequence_lengths_;
181 };
182 
183 template <typename T, typename Context>
184 class GRUUnitGradientOp : public Operator<Context> {
185  public:
186  GRUUnitGradientOp(const OperatorDef& operator_def, Workspace* ws)
187  : Operator<Context>(operator_def, ws),
188  drop_states_(OperatorBase::template GetSingleArgument<bool>(
189  "drop_states",
190  false)),
191  sequence_lengths_(OperatorBase::template GetSingleArgument<bool>(
192  "sequence_lengths",
193  true)) {}
194  USE_OPERATOR_CONTEXT_FUNCTIONS;
195 
196  bool RunOnDevice() override {
197  // handle potentially-missing sequence lengths input
198  const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
199  const size_t TIMESTEP = inputOffset;
200  const size_t HIDDEN_T = inputOffset + 1;
201  const size_t HIDDEN_T_GRAD = inputOffset + 2;
202 
203  // Extract N
204  const auto N = Input(HIDDEN_T_M_1).dim(1);
205 
206  // Gates: 1xNxG
207  const auto G = Input(GATES).dim(2);
208  const auto D = Input(HIDDEN_T_M_1).dim(2);
209 
210  CAFFE_ENFORCE_EQ(3 * D, G);
211  const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
212  const auto* X = Input(GATES).template data<T>();
213  const auto t = static_cast<OperatorBase*>(this)->
214  Input<Tensor<CPUContext>>(TIMESTEP).template data<int32_t>()[0];
215  const auto* H = Input(HIDDEN_T).template data<T>();
216  const auto* H_diff = Input(HIDDEN_T_GRAD).template data<T>();
217 
218  const int32_t* seqLengths = nullptr;
219  if (sequence_lengths_) {
220  CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
221  seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
222  }
223 
224  Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1));
225  auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data<T>();
226  Output(GATES_GRAD)->ResizeLike(Input(GATES));
227  auto* X_diff = Output(GATES_GRAD)->template mutable_data<T>();
228 
229  detail::GRUUnitGradient<T, Context>(
230  N,
231  D,
232  t,
233  H_prev,
234  X,
235  seqLengths,
236  H,
237  H_diff,
238  drop_states_,
239  H_prev_diff,
240  X_diff,
241  &context_);
242  return true;
243  }
244 
245  protected:
246  INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS);
247  OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, GATES_GRAD);
248 
249  private:
250  bool drop_states_;
251  bool sequence_lengths_;
252 };
253 
254 } // namespace caffe2
255 
256 #endif // CAFFE2_OPERATORS_GRU_UNIT_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.