1 #ifndef CAFFE2_OPERATORS_GRU_UNIT_OP_H_ 2 #define CAFFE2_OPERATORS_GRU_UNIT_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 12 inline T sigmoid(
T x) {
13 return 1.0f / (1.0f + exp(-x));
17 inline T host_tanh(
T x) {
18 return 2.0f * sigmoid(2.0f * x) - 1.0f;
21 template <
typename T,
typename Context>
28 const int32_t* seqLengths,
32 for (
int n = 0; n < N; ++n) {
33 const bool valid = seqLengths ==
nullptr || t < seqLengths[n];
35 for (
int d = 0; d < D; ++d) {
43 const T update = X[1 * D + d];
44 const T output = X[2 * D + d];
45 T sigmoid_update = sigmoid(update);
46 H[d] = H_prev[d] * sigmoid_update +
47 host_tanh(output) * (1.0f - sigmoid_update);
57 template <
typename T,
typename Context>
64 const int32_t* seqLengths,
71 for (
int n = 0; n < N; ++n) {
72 const bool valid = seqLengths ==
nullptr || t < seqLengths[n];
74 for (
int d = 0; d < D; ++d) {
75 T* h_prev_diff = H_prev_diff + d;
76 T* reset_diff = X_diff + 0 * D + d;
77 T* update_diff = X_diff + 1 * D + d;
78 T* output_diff = X_diff + 2 * D + d;
84 *h_prev_diff = H_diff[d];
91 const T u = sigmoid(X[1 * D + d]);
92 const T o = host_tanh(X[2 * D + d]);
94 *h_prev_diff = H_diff[d] * u;
96 *update_diff = (H_diff[d] * H_prev[d] - H_diff[d] * o) * u * (1.0f - u);
97 *output_diff = H_diff[d] * (1.0f - u) * (1.0f - o * o);
112 template <
typename T,
typename Context>
115 template <
class... Args>
119 this->
template GetSingleArgument<bool>(
"drop_states",
false)),
121 this->
template GetSingleArgument<bool>(
"sequence_lengths",
true)) {}
122 USE_OPERATOR_CONTEXT_FUNCTIONS;
124 bool RunOnDevice()
override {
126 const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
129 const auto N = Input(HIDDEN_T_M_1).size(1);
132 const auto G = Input(GATES).size(2);
133 const auto D = Input(HIDDEN_T_M_1).size(2);
135 CAFFE_ENFORCE_EQ(3 * D, G);
136 const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
137 const auto* X = Input(GATES).template data<T>();
139 const int32_t* seqLengths =
nullptr;
140 if (sequence_lengths_) {
141 CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N);
142 seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
146 ->Input<Tensor>(TIMESTEP, CPU)
147 .template data<int32_t>()[0];
148 Output(HIDDEN_T)->ResizeLike(Input(HIDDEN_T_M_1));
149 auto* H = Output(HIDDEN_T)->template mutable_data<T>();
151 detail::GRUUnit<T, Context>(
152 N, D, t, H_prev, X, seqLengths, drop_states_, H, &context_);
157 INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS);
160 OUTPUT_TAGS(HIDDEN_T);
164 bool sequence_lengths_;
167 template <
typename T,
typename Context>
170 template <
class... Args>
174 this->
template GetSingleArgument<bool>(
"drop_states",
false)),
176 this->
template GetSingleArgument<bool>(
"sequence_lengths",
true)) {}
177 USE_OPERATOR_CONTEXT_FUNCTIONS;
179 bool RunOnDevice()
override {
181 const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
182 const size_t TIMESTEP = inputOffset;
183 const size_t HIDDEN_T = inputOffset + 1;
184 const size_t HIDDEN_T_GRAD = inputOffset + 2;
187 const auto N = Input(HIDDEN_T_M_1).size(1);
190 const auto G = Input(GATES).size(2);
191 const auto D = Input(HIDDEN_T_M_1).size(2);
193 CAFFE_ENFORCE_EQ(3 * D, G);
194 const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
195 const auto* X = Input(GATES).template data<T>();
197 ->Input<Tensor>(TIMESTEP, CPU)
198 .template data<int32_t>()[0];
199 const auto* H = Input(HIDDEN_T).template data<T>();
200 const auto* H_diff = Input(HIDDEN_T_GRAD).template data<T>();
202 const int32_t* seqLengths =
nullptr;
203 if (sequence_lengths_) {
204 CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N);
205 seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
208 Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1));
209 auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data<T>();
210 Output(GATES_GRAD)->ResizeLike(Input(GATES));
211 auto* X_diff = Output(GATES_GRAD)->template mutable_data<T>();
213 detail::GRUUnitGradient<T, Context>(
230 INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS);
231 OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, GATES_GRAD);
235 bool sequence_lengths_;
240 #endif // CAFFE2_OPERATORS_GRU_UNIT_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...