1 #ifndef CAFFE2_OPERATORS_LSTM_UNIT_OP_H_ 2 #define CAFFE2_OPERATORS_LSTM_UNIT_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/conversions.h" 11 inline T sigmoid(
T x) {
12 return 1. / (1. + exp(-x));
16 inline T host_tanh(
T x) {
17 return 2. * sigmoid(2. * x) - 1.;
20 template <
typename T,
typename Context>
28 const int32_t* seqLengths,
32 const float forget_bias,
34 for (
int n = 0; n < N; ++n) {
35 const bool valid = seqLengths ==
nullptr || t < seqLengths[n];
37 for (
int d = 0; d < D; ++d) {
47 const T i = sigmoid(X[d]);
48 const T f = sigmoid(X[1 * D + d] + convert::To<float, T>(forget_bias));
49 const T o = sigmoid(X[2 * D + d]);
50 const T g = host_tanh(X[3 * D + d]);
51 const T c_prev = C_prev[d];
52 const T c = f * c_prev + i * g;
54 const T host_tanh_c = host_tanh(c);
55 H[d] = o * host_tanh_c;
66 template <
typename T,
typename Context>
67 void LSTMUnitGradient(
73 const int32_t* seqLengths,
82 const float forget_bias,
84 for (
int n = 0; n < N; ++n) {
85 const bool valid = seqLengths ==
nullptr || t < seqLengths[n];
87 for (
int d = 0; d < D; ++d) {
88 T* c_prev_diff = C_prev_diff + d;
89 T* h_prev_diff = H_prev_diff + d;
90 T* i_diff = X_diff + d;
91 T* f_diff = X_diff + 1 * D + d;
92 T* o_diff = X_diff + 2 * D + d;
93 T* g_diff = X_diff + 3 * D + d;
100 *h_prev_diff = H_diff[d];
101 *c_prev_diff = C_diff[d];
108 const T i = sigmoid(X[d]);
109 const T f = sigmoid(X[1 * D + d] + convert::To<float, T>(forget_bias));
110 const T o = sigmoid(X[2 * D + d]);
111 const T g = host_tanh(X[3 * D + d]);
112 const T c_prev = C_prev[d];
114 const T host_tanh_c = host_tanh(c);
115 const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - host_tanh_c * host_tanh_c);
116 *c_prev_diff = c_term_diff * f;
118 *i_diff = c_term_diff * g * i * (1 - i);
119 *f_diff = c_term_diff * c_prev * f * (1 - f);
120 *o_diff = H_diff[d] * host_tanh_c * o * (1 - o);
121 *g_diff = c_term_diff * i * (1 - g * g);
137 template <
typename Context>
142 forget_bias_(static_cast<float>(
143 this->
template GetSingleArgument<float>(
"forget_bias", 0.0))),
145 this->
template GetSingleArgument<bool>(
"sequence_lengths",
true)),
147 this->
template GetSingleArgument<bool>(
"drop_states",
false)) {}
148 USE_OPERATOR_CONTEXT_FUNCTIONS;
151 template <
typename T>
152 bool DoRunWithType() {
154 const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
157 const auto N = Input(CELL_T_M_1).size(1);
160 const auto G = Input(GATES).size(2);
161 const auto D = Input(CELL_T_M_1).size(2);
163 CAFFE_ENFORCE_EQ(4 * D, G);
164 const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
165 const auto* C_prev = Input(CELL_T_M_1).template data<T>();
166 const auto* X = Input(GATES).template data<T>();
168 const int32_t* seqLengths =
nullptr;
169 if (sequence_lengths_) {
170 CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N);
171 seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
175 ->Input<Tensor>(TIMESTEP, CPU)
176 .template data<int32_t>()[0];
177 Output(CELL_T)->ResizeLike(Input(CELL_T_M_1));
178 auto* C = Output(CELL_T)->template mutable_data<T>();
179 Output(HIDDEN_T)->ResizeLike(Input(CELL_T_M_1));
180 auto* H = Output(HIDDEN_T)->template mutable_data<T>();
181 detail::LSTMUnit<T, Context>(
197 bool RunOnDevice()
override {
198 return DoRunWithType<float>();
202 INPUT_TAGS(HIDDEN_T_M_1, CELL_T_M_1, GATES, SEQ_LENGTHS);
205 OUTPUT_TAGS(HIDDEN_T, CELL_T);
208 bool sequence_lengths_;
214 template <
typename Context>
217 template <
class... Args>
220 forget_bias_(static_cast<float>(
221 this->
template GetSingleArgument<float>(
"forget_bias", 0.0))),
223 this->
template GetSingleArgument<bool>(
"sequence_lengths",
true)),
225 this->
template GetSingleArgument<bool>(
"drop_states",
false)) {}
226 USE_OPERATOR_CONTEXT_FUNCTIONS;
228 template <
typename T>
229 bool DoRunWithType() {
231 const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
232 const size_t TIMESTEP = inputOffset;
233 const size_t HIDDEN_T = inputOffset + 1;
234 const size_t CELL_T = inputOffset + 2;
235 const size_t HIDDEN_T_GRAD = inputOffset + 3;
236 const size_t CELL_T_GRAD = inputOffset + 4;
239 const auto N = Input(CELL_T_M_1).size(1);
242 const auto G = Input(GATES).size(2);
243 const auto D = Input(CELL_T_M_1).size(2);
245 CAFFE_ENFORCE_EQ(4 * D, G);
246 const auto* C_prev = Input(CELL_T_M_1).template data<T>();
247 const auto* X = Input(GATES).template data<T>();
249 ->Input<Tensor>(TIMESTEP, CPU)
250 .template data<int32_t>()[0];
251 const auto* C = Input(CELL_T).template data<T>();
252 const auto* H = Input(HIDDEN_T).template data<T>();
253 const auto* C_diff = Input(CELL_T_GRAD).template data<T>();
254 const auto* H_diff = Input(HIDDEN_T_GRAD).template data<T>();
256 const int32_t* seqLengths =
nullptr;
257 if (sequence_lengths_) {
258 CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N);
259 seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
262 Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1));
263 auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data<T>();
264 Output(CELL_T_M_1_GRAD)->ResizeLike(Input(CELL_T_M_1));
265 auto* C_prev_diff = Output(CELL_T_M_1_GRAD)->template mutable_data<T>();
266 Output(GATES_GRAD)->ResizeLike(Input(GATES));
267 auto* X_diff = Output(GATES_GRAD)->template mutable_data<T>();
269 detail::LSTMUnitGradient<T, Context>(
289 bool RunOnDevice()
override {
290 return DoRunWithType<float>();
294 INPUT_TAGS(HIDDEN_T_M_1, CELL_T_M_1, GATES, SEQ_LENGTHS);
297 OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, CELL_T_M_1_GRAD, GATES_GRAD);
300 bool sequence_lengths_;
307 #endif // CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...