Caffe2 - C++ API
A deep learning, cross platform ML framework
gru_unit_op.h
1 #ifndef CAFFE2_OPERATORS_GRU_UNIT_OP_H_
2 #define CAFFE2_OPERATORS_GRU_UNIT_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 namespace detail {
10 
11 template <typename T>
12 inline T sigmoid(T x) {
13  return 1.0f / (1.0f + exp(-x));
14 }
15 
16 template <typename T>
17 inline T host_tanh(T x) {
18  return 2.0f * sigmoid(2.0f * x) - 1.0f;
19 }
20 
21 template <typename T, typename Context>
22 void GRUUnit(
23  int N,
24  int D,
25  int t,
26  const T* H_prev,
27  const T* X,
28  const int32_t* seqLengths,
29  bool drop_states,
30  T* H,
31  Context* /*context*/) {
32  for (int n = 0; n < N; ++n) {
33  const bool valid = seqLengths == nullptr || t < seqLengths[n];
34 
35  for (int d = 0; d < D; ++d) {
36  if (!valid) {
37  if (drop_states) {
38  H[d] = 0;
39  } else {
40  H[d] = H_prev[d];
41  }
42  } else {
43  const T update = X[1 * D + d];
44  const T output = X[2 * D + d];
45  T sigmoid_update = sigmoid(update);
46  H[d] = H_prev[d] * sigmoid_update +
47  host_tanh(output) * (1.0f - sigmoid_update);
48  }
49  }
50 
51  H_prev += D;
52  X += 3 * D;
53  H += D;
54  }
55 }
56 
57 template <typename T, typename Context>
58 void GRUUnitGradient(
59  int N,
60  int D,
61  int t,
62  const T* H_prev,
63  const T* X,
64  const int32_t* seqLengths,
65  const T* H,
66  const T* H_diff,
67  bool drop_states,
68  T* H_prev_diff,
69  T* X_diff,
70  Context* /*context*/) {
71  for (int n = 0; n < N; ++n) {
72  const bool valid = seqLengths == nullptr || t < seqLengths[n];
73 
74  for (int d = 0; d < D; ++d) {
75  T* h_prev_diff = H_prev_diff + d;
76  T* reset_diff = X_diff + 0 * D + d;
77  T* update_diff = X_diff + 1 * D + d;
78  T* output_diff = X_diff + 2 * D + d;
79 
80  if (!valid) {
81  if (drop_states) {
82  *h_prev_diff = 0;
83  } else {
84  *h_prev_diff = H_diff[d];
85  }
86  *reset_diff = 0;
87  *update_diff = 0;
88  *output_diff = 0;
89  } else {
90  // Calculate Gate Outputs
91  const T u = sigmoid(X[1 * D + d]);
92  const T o = host_tanh(X[2 * D + d]);
93 
94  *h_prev_diff = H_diff[d] * u;
95  *reset_diff = 0; // 0 contribution to gradient from this operation
96  *update_diff = (H_diff[d] * H_prev[d] - H_diff[d] * o) * u * (1.0f - u);
97  *output_diff = H_diff[d] * (1.0f - u) * (1.0f - o * o);
98  }
99  }
100 
101  H_prev += D;
102  X += 3 * D;
103  H += D;
104  H_diff += D;
105  X_diff += 3 * D;
106  H_prev_diff += D;
107  }
108 }
109 
110 } // namespace detail
111 
112 template <typename T, typename Context>
113 class GRUUnitOp : public Operator<Context> {
114  public:
115  template <class... Args>
116  explicit GRUUnitOp(Args&&... args)
117  : Operator<Context>(std::forward<Args>(args)...),
118  drop_states_(
119  this->template GetSingleArgument<bool>("drop_states", false)),
120  sequence_lengths_(
121  this->template GetSingleArgument<bool>("sequence_lengths", true)) {}
122  USE_OPERATOR_CONTEXT_FUNCTIONS;
123 
124  bool RunOnDevice() override {
125  // handle potentially-missing sequence lengths input
126  const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
127 
128  // Extract N
129  const auto N = Input(HIDDEN_T_M_1).size(1);
130 
131  // Gates: 1xNxG
132  const auto G = Input(GATES).size(2);
133  const auto D = Input(HIDDEN_T_M_1).size(2);
134 
135  CAFFE_ENFORCE_EQ(3 * D, G);
136  const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
137  const auto* X = Input(GATES).template data<T>();
138 
139  const int32_t* seqLengths = nullptr;
140  if (sequence_lengths_) {
141  CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N);
142  seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
143  }
144 
145  const auto t = static_cast<OperatorBase*>(this)
146  ->Input<Tensor>(TIMESTEP, CPU)
147  .template data<int32_t>()[0];
148  Output(HIDDEN_T)->ResizeLike(Input(HIDDEN_T_M_1));
149  auto* H = Output(HIDDEN_T)->template mutable_data<T>();
150 
151  detail::GRUUnit<T, Context>(
152  N, D, t, H_prev, X, seqLengths, drop_states_, H, &context_);
153  return true;
154  }
155 
156  protected:
157  INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS);
158  // additional input tags are determined dynamically based on whether
159  // sequence_lengths is present.
160  OUTPUT_TAGS(HIDDEN_T);
161 
162  private:
163  bool drop_states_;
164  bool sequence_lengths_;
165 };
166 
167 template <typename T, typename Context>
168 class GRUUnitGradientOp : public Operator<Context> {
169  public:
170  template <class... Args>
171  explicit GRUUnitGradientOp(Args&&... args)
172  : Operator<Context>(std::forward<Args>(args)...),
173  drop_states_(
174  this->template GetSingleArgument<bool>("drop_states", false)),
175  sequence_lengths_(
176  this->template GetSingleArgument<bool>("sequence_lengths", true)) {}
177  USE_OPERATOR_CONTEXT_FUNCTIONS;
178 
179  bool RunOnDevice() override {
180  // handle potentially-missing sequence lengths input
181  const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
182  const size_t TIMESTEP = inputOffset;
183  const size_t HIDDEN_T = inputOffset + 1;
184  const size_t HIDDEN_T_GRAD = inputOffset + 2;
185 
186  // Extract N
187  const auto N = Input(HIDDEN_T_M_1).size(1);
188 
189  // Gates: 1xNxG
190  const auto G = Input(GATES).size(2);
191  const auto D = Input(HIDDEN_T_M_1).size(2);
192 
193  CAFFE_ENFORCE_EQ(3 * D, G);
194  const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
195  const auto* X = Input(GATES).template data<T>();
196  const auto t = static_cast<OperatorBase*>(this)
197  ->Input<Tensor>(TIMESTEP, CPU)
198  .template data<int32_t>()[0];
199  const auto* H = Input(HIDDEN_T).template data<T>();
200  const auto* H_diff = Input(HIDDEN_T_GRAD).template data<T>();
201 
202  const int32_t* seqLengths = nullptr;
203  if (sequence_lengths_) {
204  CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N);
205  seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
206  }
207 
208  Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1));
209  auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data<T>();
210  Output(GATES_GRAD)->ResizeLike(Input(GATES));
211  auto* X_diff = Output(GATES_GRAD)->template mutable_data<T>();
212 
213  detail::GRUUnitGradient<T, Context>(
214  N,
215  D,
216  t,
217  H_prev,
218  X,
219  seqLengths,
220  H,
221  H_diff,
222  drop_states_,
223  H_prev_diff,
224  X_diff,
225  &context_);
226  return true;
227  }
228 
229  protected:
230  INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS);
231  OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, GATES_GRAD);
232 
233  private:
234  bool drop_states_;
235  bool sequence_lengths_;
236 };
237 
238 } // namespace caffe2
239 
240 #endif // CAFFE2_OPERATORS_GRU_UNIT_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:70