3 #include "caffe2/core/operator.h" 7 template <
typename Context>
22 for (
auto i = 0; i < N; ++i) {
24 float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
25 float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
26 ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
30 template <
typename Context>
46 for (
auto i = 0; i < N; ++i) {
48 float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
49 float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
50 nw[i] = w[i] + lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
54 template <
typename Context>
55 void adam_compute_output_grad(
71 for (
auto i = 0; i < N; ++i) {
73 float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
74 float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
75 float ngi = ng[i] = correction * mi / (std::sqrt(vi) + eps_hat);
76 nw[i] = w[i] + lr[0] * ngi;
80 template <
typename T,
class Context>
83 USE_OPERATOR_CONTEXT_FUNCTIONS;
86 beta1_(this->
template GetSingleArgument<float>(
"beta1", 0.9f)),
87 beta2_(this->
template GetSingleArgument<float>(
"beta2", 0.999f)),
88 epsilon_(this->
template GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
89 bool RunOnDevice()
override {
91 CAFFE_ENFORCE(OperatorBase::InputIsTensorType(ITER, CPU));
92 CAFFE_ENFORCE(
Input(LR).numel() == 1);
93 CAFFE_ENFORCE(
Input(GRAD).numel() ==
Input(PARAM).numel());
94 CAFFE_ENFORCE(
Input(GRAD).numel() ==
Input(MOMENT_1).numel());
95 CAFFE_ENFORCE(
Input(GRAD).numel() ==
Input(MOMENT_2).numel());
96 Output(OUTPUT_PARAM)->ResizeLike(
Input(PARAM));
97 Output(OUTPUT_MOMENT_1)->ResizeLike(
Input(MOMENT_1));
98 Output(OUTPUT_MOMENT_2)->ResizeLike(
Input(MOMENT_2));
101 OperatorBase::Input<Tensor>(ITER, CPU).
template data<int64_t>()[0];
103 const auto t = iter + 1;
104 const auto correction =
105 std::sqrt(
T(1.) - std::pow(beta2_, t)) / (
T(1.) - std::pow(beta1_, t));
106 if (OutputSize() == 3) {
107 adam_compute<Context>(
109 Input(PARAM).template data<T>(),
110 Input(GRAD).template data<T>(),
111 Input(MOMENT_1).template data<T>(),
112 Input(MOMENT_2).template data<T>(),
113 Output(OUTPUT_PARAM)->template mutable_data<T>(),
114 Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
115 Output(OUTPUT_MOMENT_2)->template mutable_data<T>(),
120 Input(LR).template data<T>(),
123 Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
124 adam_compute_output_grad<Context>(
126 Input(PARAM).template data<T>(),
127 Input(GRAD).template data<T>(),
128 Input(MOMENT_1).template data<T>(),
129 Input(MOMENT_2).template data<T>(),
130 Output(OUTPUT_PARAM)->template mutable_data<T>(),
131 Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
132 Output(OUTPUT_MOMENT_2)->template mutable_data<T>(),
133 Output(OUTPUT_GRAD)->template mutable_data<T>(),
138 Input(LR).template data<T>(),
149 INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, GRAD, LR, ITER);
150 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
153 template <
typename T,
class Context>
156 USE_OPERATOR_CONTEXT_FUNCTIONS;
159 beta1_(this->
template GetSingleArgument<float>(
"beta1", 0.9f)),
160 beta2_(this->
template GetSingleArgument<float>(
"beta2", 0.999f)),
161 epsilon_(this->
template GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
163 bool RunOnDevice()
override {
165 CAFFE_ENFORCE_EQ(
Input(PARAM).numel(),
Input(MOMENT_1).numel());
166 CAFFE_ENFORCE_EQ(
Input(PARAM).numel(),
Input(MOMENT_2).numel());
168 Input(PARAM).size_from_dim(1),
169 Input(GRAD).size_from_dim(
Input(INDICES).dim()));
170 CAFFE_ENFORCE_EQ(
Input(LR).numel(), 1);
173 this,
Input(INDICES));
176 template <
typename SIndex>
177 bool DoRunWithType() {
178 const auto* lr =
Input(LR).template data<T>();
180 OperatorBase::Input<Tensor>(ITER, CPU).
template data<int64_t>()[0];
182 const auto t = iter + 1;
183 const auto correction =
184 std::sqrt(
T(1.) - std::pow(beta2_, t)) / (
T(1.) - std::pow(beta1_, t));
186 auto block_size =
Input(PARAM).numel() /
Input(PARAM).size(0);
187 auto n =
Input(GRAD).numel() / block_size;
189 const auto* paramIn =
Input(PARAM).template data<T>();
190 const auto* indices =
Input(INDICES).template data<SIndex>();
191 const auto* gradIn =
Input(GRAD).template data<T>();
192 const auto* moment1In =
Input(MOMENT_1).template data<T>();
193 const auto* moment2In =
Input(MOMENT_2).template data<T>();
194 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
195 auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
196 auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
198 if (OutputSize() == 3) {
199 for (
auto i = 0; i < n; ++i) {
200 auto idx = indices[i];
202 if (block_size == 1) {
203 float gi = gradIn[i];
204 float mi = moment1Out[idx] =
205 moment1In[idx] * beta1_ + gi * (1 - beta1_);
206 float vi = moment2Out[idx] =
207 moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
208 paramOut[idx] = paramIn[idx] +
209 lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
212 auto offsetI = i * block_size;
213 auto offsetIdx = idx * block_size;
217 Input(PARAM).numel(),
218 block_size + offsetIdx,
219 this->debug_def().input(PARAM),
220 ", out of bound, idx:",
228 block_size + offsetI,
229 this->debug_def().input(GRAD),
230 ", out of bound idx, idx:",
240 moment1In + offsetIdx,
241 moment2In + offsetIdx,
242 paramOut + offsetIdx,
243 moment1Out + offsetIdx,
244 moment2Out + offsetIdx,
254 Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
255 auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
256 for (
auto i = 0; i < n; ++i) {
257 auto idx = indices[i];
259 if (block_size == 1) {
260 float gi = gradIn[i];
261 float mi = moment1Out[idx] =
262 moment1In[idx] * beta1_ + gi * (1 - beta1_);
263 float vi = moment2Out[idx] =
264 moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
265 float ngi = gradOut[i] = correction * mi / (std::sqrt(vi) + epsilon_);
266 paramOut[idx] = paramIn[idx] + lr[0] * ngi;
269 auto offsetI = i * block_size;
270 auto offsetIdx = idx * block_size;
274 Input(PARAM).numel(),
275 block_size + offsetIdx,
276 this->debug_def().input(PARAM),
277 ", out of bound, idx:",
285 block_size + offsetI,
286 this->debug_def().input(GRAD),
287 ", out of bound idx, idx:",
293 adam_compute_output_grad(
297 moment1In + offsetIdx,
298 moment2In + offsetIdx,
299 paramOut + offsetIdx,
300 moment1Out + offsetIdx,
301 moment2Out + offsetIdx,
319 INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
320 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
323 template <
typename T,
class Context>
326 USE_OPERATOR_CONTEXT_FUNCTIONS;
329 beta1_(this->
template GetSingleArgument<float>(
"beta1", 0.9f)),
330 beta2_(this->
template GetSingleArgument<float>(
"beta2", 0.999f)),
331 epsilon_(this->
template GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
333 bool RunOnDevice()
override {
335 CAFFE_ENFORCE_EQ(
Input(PARAM).numel(),
Input(MOMENT_1).numel());
336 CAFFE_ENFORCE_EQ(
Input(PARAM).sizes()[0],
Input(MOMENT_2).numel());
338 Input(PARAM).size_from_dim(1),
339 Input(GRAD).size_from_dim(
Input(INDICES).dim()));
340 CAFFE_ENFORCE_EQ(
Input(LR).numel(), 1);
343 this,
Input(INDICES));
346 template <
typename SIndex>
347 bool DoRunWithType() {
348 const auto* lr =
Input(LR).template data<T>();
350 OperatorBase::Input<Tensor>(ITER, CPU).
template data<int64_t>()[0];
352 const auto t = iter + 1;
353 const auto correction =
354 std::sqrt(
T(1.) - std::pow(beta2_, t)) / (
T(1.) - std::pow(beta1_, t));
356 auto block_size =
Input(PARAM).numel() /
Input(PARAM).size(0);
357 auto n =
Input(GRAD).numel() / block_size;
359 const auto* paramIn =
Input(PARAM).template data<T>();
360 const auto* indices =
Input(INDICES).template data<SIndex>();
361 const auto* gradIn =
Input(GRAD).template data<T>();
362 const auto* moment1In =
Input(MOMENT_1).template data<T>();
363 const auto* moment2In =
Input(MOMENT_2).template data<T>();
364 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
365 auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
366 auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
368 if (OutputSize() == 3) {
369 for (
auto i = 0; i < n; ++i) {
370 auto idx = indices[i];
372 if (block_size == 1) {
373 float gi = gradIn[i];
374 float mi = moment1Out[idx] =
375 moment1In[idx] * beta1_ + gi * (1 - beta1_);
376 float vi = moment2Out[idx] =
377 moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
378 paramOut[idx] = paramIn[idx] +
379 lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
382 auto offsetI = i * block_size;
383 auto offsetIdx = idx * block_size;
387 Input(PARAM).numel(),
388 block_size + offsetIdx,
389 this->debug_def().input(PARAM),
390 ", out of bound, idx:",
398 block_size + offsetI,
399 this->debug_def().input(GRAD),
400 ", out of bound idx, idx:",
406 const float* w = paramIn + offsetIdx;
407 const float* g = gradIn + offsetI;
408 const float* m1 = moment1In + offsetIdx;
409 const float* m2 = moment2In + idx;
410 float* nw = paramOut + offsetIdx;
411 float* nm1 = moment1Out + offsetIdx;
412 float* nm2 = moment2Out + idx;
415 for (
auto j = 0; j < block_size; ++j) {
420 m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
421 for (
auto j = 0; j < block_size; ++j) {
422 float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_);
423 nw[j] = w[j] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
428 Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
429 auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
430 for (
auto i = 0; i < n; ++i) {
431 auto idx = indices[i];
433 if (block_size == 1) {
434 float gi = gradIn[i];
435 float mi = moment1Out[idx] =
436 moment1In[idx] * beta1_ + gi * (1 - beta1_);
437 float vi = moment2Out[idx] =
438 moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
439 float ngi = gradOut[i] = correction * mi / (std::sqrt(vi) + epsilon_);
440 paramOut[idx] = paramIn[idx] + lr[0] * ngi;
443 auto offsetI = i * block_size;
444 auto offsetIdx = idx * block_size;
448 Input(PARAM).numel(),
449 block_size + offsetIdx,
450 this->debug_def().input(PARAM),
451 ", out of bound, idx:",
459 block_size + offsetI,
460 this->debug_def().input(GRAD),
461 ", out of bound idx, idx:",
467 const float* w = paramIn + offsetIdx;
468 const float* g = gradIn + offsetI;
469 const float* m1 = moment1In + offsetIdx;
470 const float* m2 = moment2In + idx;
471 float* nw = paramOut + offsetIdx;
472 float* nm1 = moment1Out + offsetIdx;
473 float* nm2 = moment2Out + idx;
474 float* ng = gradOut + offsetI;
477 for (
auto j = 0; j < block_size; ++j) {
482 m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
483 for (
auto j = 0; j < block_size; ++j) {
484 float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_);
485 float ngi = ng[j] = correction * mi / (std::sqrt(vi) + epsilon_);
486 nw[j] = w[j] + lr[0] * ngi;
498 INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
499 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...