Caffe2 - C++ API
A deep learning, cross platform ML framework
adam_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 
5 namespace caffe2 {
6 
7 template <typename Context>
8 void adam_update(
9  int N,
10  const float* g,
11  const float* m,
12  const float* v,
13  float* ng,
14  float* nm,
15  float* nv,
16  float beta1,
17  float beta2,
18  float eps_hat,
19  float correction,
20  const float* lr,
21  Context* /*context*/) {
22  for (auto i = 0; i < N; ++i) {
23  float gi = g[i];
24  float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
25  float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
26  ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
27  }
28 }
29 
30 template <typename Context>
31 void adam_compute(
32  int N,
33  const float* w,
34  const float* g,
35  const float* m,
36  const float* v,
37  float* nw,
38  float* nm,
39  float* nv,
40  float beta1,
41  float beta2,
42  float eps_hat,
43  float correction,
44  const float* lr,
45  Context* /*context*/) {
46  for (auto i = 0; i < N; ++i) {
47  float gi = g[i];
48  float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
49  float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
50  nw[i] = w[i] + lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
51  }
52 }
53 
54 template <typename Context>
55 void adam_compute_output_grad(
56  int N,
57  const float* w,
58  const float* g,
59  const float* m,
60  const float* v,
61  float* nw,
62  float* nm,
63  float* nv,
64  float* ng,
65  float beta1,
66  float beta2,
67  float eps_hat,
68  float correction,
69  const float* lr,
70  Context* /*context*/) {
71  for (auto i = 0; i < N; ++i) {
72  float gi = g[i];
73  float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
74  float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
75  float ngi = ng[i] = correction * mi / (std::sqrt(vi) + eps_hat);
76  nw[i] = w[i] + lr[0] * ngi;
77  }
78 }
79 
80 template <typename T, class Context>
81 class AdamOp final : public Operator<Context> {
82  public:
83  USE_OPERATOR_CONTEXT_FUNCTIONS;
84  AdamOp(const OperatorDef& operator_def, Workspace* ws)
85  : Operator<Context>(operator_def, ws),
86  beta1_(this->template GetSingleArgument<float>("beta1", 0.9f)),
87  beta2_(this->template GetSingleArgument<float>("beta2", 0.999f)),
88  epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {}
89  bool RunOnDevice() override {
90  // Iter live on the CPU
91  CAFFE_ENFORCE(OperatorBase::InputIsTensorType(ITER, CPU));
92  CAFFE_ENFORCE(Input(LR).numel() == 1);
93  CAFFE_ENFORCE(Input(GRAD).numel() == Input(PARAM).numel());
94  CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_1).numel());
95  CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_2).numel());
96  Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
97  Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
98  Output(OUTPUT_MOMENT_2)->ResizeLike(Input(MOMENT_2));
99 
100  const auto iter =
101  OperatorBase::Input<Tensor>(ITER, CPU).template data<int64_t>()[0];
102 
103  const auto t = iter + 1;
104  const auto correction =
105  std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
106  if (OutputSize() == 3) {
107  adam_compute<Context>(
108  Input(GRAD).numel(),
109  Input(PARAM).template data<T>(),
110  Input(GRAD).template data<T>(),
111  Input(MOMENT_1).template data<T>(),
112  Input(MOMENT_2).template data<T>(),
113  Output(OUTPUT_PARAM)->template mutable_data<T>(),
114  Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
115  Output(OUTPUT_MOMENT_2)->template mutable_data<T>(),
116  beta1_,
117  beta2_,
118  epsilon_,
119  correction,
120  Input(LR).template data<T>(),
121  &context_);
122  } else {
123  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
124  adam_compute_output_grad<Context>(
125  Input(GRAD).numel(),
126  Input(PARAM).template data<T>(),
127  Input(GRAD).template data<T>(),
128  Input(MOMENT_1).template data<T>(),
129  Input(MOMENT_2).template data<T>(),
130  Output(OUTPUT_PARAM)->template mutable_data<T>(),
131  Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
132  Output(OUTPUT_MOMENT_2)->template mutable_data<T>(),
133  Output(OUTPUT_GRAD)->template mutable_data<T>(),
134  beta1_,
135  beta2_,
136  epsilon_,
137  correction,
138  Input(LR).template data<T>(),
139  &context_);
140  }
141 
142  return true;
143  }
144 
145  protected:
146  T beta1_{0.9};
147  T beta2_{0.999};
148  T epsilon_{1e-8};
149  INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, GRAD, LR, ITER);
150  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
151 };
152 
153 template <typename T, class Context>
154 class SparseAdamOp final : public Operator<Context> {
155  public:
156  USE_OPERATOR_CONTEXT_FUNCTIONS;
157  SparseAdamOp(const OperatorDef& operator_def, Workspace* ws)
158  : Operator<Context>(operator_def, ws),
159  beta1_(this->template GetSingleArgument<float>("beta1", 0.9f)),
160  beta2_(this->template GetSingleArgument<float>("beta2", 0.999f)),
161  epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {}
162 
163  bool RunOnDevice() override {
164  // Enforce shapes
165  CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_1).numel());
166  CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_2).numel());
167  CAFFE_ENFORCE_EQ(
168  Input(PARAM).size_from_dim(1),
169  Input(GRAD).size_from_dim(Input(INDICES).dim()));
170  CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
171 
173  this, Input(INDICES));
174  }
175 
176  template <typename SIndex>
177  bool DoRunWithType() {
178  const auto* lr = Input(LR).template data<T>();
179  const auto iter =
180  OperatorBase::Input<Tensor>(ITER, CPU).template data<int64_t>()[0];
181 
182  const auto t = iter + 1;
183  const auto correction =
184  std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
185 
186  auto block_size = Input(PARAM).numel() / Input(PARAM).size(0);
187  auto n = Input(GRAD).numel() / block_size;
188 
189  const auto* paramIn = Input(PARAM).template data<T>();
190  const auto* indices = Input(INDICES).template data<SIndex>();
191  const auto* gradIn = Input(GRAD).template data<T>();
192  const auto* moment1In = Input(MOMENT_1).template data<T>();
193  const auto* moment2In = Input(MOMENT_2).template data<T>();
194  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
195  auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
196  auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
197 
198  if (OutputSize() == 3) {
199  for (auto i = 0; i < n; ++i) {
200  auto idx = indices[i];
201 
202  if (block_size == 1) {
203  float gi = gradIn[i];
204  float mi = moment1Out[idx] =
205  moment1In[idx] * beta1_ + gi * (1 - beta1_);
206  float vi = moment2Out[idx] =
207  moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
208  paramOut[idx] = paramIn[idx] +
209  lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
210 
211  } else {
212  auto offsetI = i * block_size;
213  auto offsetIdx = idx * block_size;
214 
215 #ifndef NDEBUG
216  CAFFE_ENFORCE_GE(
217  Input(PARAM).numel(),
218  block_size + offsetIdx,
219  this->debug_def().input(PARAM),
220  ", out of bound, idx:",
221  idx,
222  " for input i:",
223  i,
224  " and block size:",
225  block_size);
226  CAFFE_ENFORCE_GE(
227  Input(GRAD).numel(),
228  block_size + offsetI,
229  this->debug_def().input(GRAD),
230  ", out of bound idx, idx:",
231  idx,
232  " for input i:",
233  i);
234 #endif
235 
236  adam_compute(
237  block_size,
238  paramIn + offsetIdx,
239  gradIn + offsetI,
240  moment1In + offsetIdx,
241  moment2In + offsetIdx,
242  paramOut + offsetIdx,
243  moment1Out + offsetIdx,
244  moment2Out + offsetIdx,
245  beta1_,
246  beta2_,
247  epsilon_,
248  correction,
249  lr,
250  &context_);
251  }
252  }
253  } else {
254  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
255  auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
256  for (auto i = 0; i < n; ++i) {
257  auto idx = indices[i];
258 
259  if (block_size == 1) {
260  float gi = gradIn[i];
261  float mi = moment1Out[idx] =
262  moment1In[idx] * beta1_ + gi * (1 - beta1_);
263  float vi = moment2Out[idx] =
264  moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
265  float ngi = gradOut[i] = correction * mi / (std::sqrt(vi) + epsilon_);
266  paramOut[idx] = paramIn[idx] + lr[0] * ngi;
267 
268  } else {
269  auto offsetI = i * block_size;
270  auto offsetIdx = idx * block_size;
271 
272 #ifndef NDEBUG
273  CAFFE_ENFORCE_GE(
274  Input(PARAM).numel(),
275  block_size + offsetIdx,
276  this->debug_def().input(PARAM),
277  ", out of bound, idx:",
278  idx,
279  " for input i:",
280  i,
281  " and block size:",
282  block_size);
283  CAFFE_ENFORCE_GE(
284  Input(GRAD).numel(),
285  block_size + offsetI,
286  this->debug_def().input(GRAD),
287  ", out of bound idx, idx:",
288  idx,
289  " for input i:",
290  i);
291 #endif
292 
293  adam_compute_output_grad(
294  block_size,
295  paramIn + offsetIdx,
296  gradIn + offsetI,
297  moment1In + offsetIdx,
298  moment2In + offsetIdx,
299  paramOut + offsetIdx,
300  moment1Out + offsetIdx,
301  moment2Out + offsetIdx,
302  gradOut + offsetI,
303  beta1_,
304  beta2_,
305  epsilon_,
306  correction,
307  lr,
308  &context_);
309  }
310  }
311  }
312  return true;
313  }
314 
315  protected:
316  T beta1_;
317  T beta2_;
318  T epsilon_;
319  INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
320  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
321 };
322 
323 template <typename T, class Context>
324 class RowWiseSparseAdamOp final : public Operator<Context> {
325  public:
326  USE_OPERATOR_CONTEXT_FUNCTIONS;
327  RowWiseSparseAdamOp(const OperatorDef& operator_def, Workspace* ws)
328  : Operator<Context>(operator_def, ws),
329  beta1_(this->template GetSingleArgument<float>("beta1", 0.9f)),
330  beta2_(this->template GetSingleArgument<float>("beta2", 0.999f)),
331  epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {}
332 
333  bool RunOnDevice() override {
334  // Enforce shapes
335  CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_1).numel());
336  CAFFE_ENFORCE_EQ(Input(PARAM).sizes()[0], Input(MOMENT_2).numel());
337  CAFFE_ENFORCE_EQ(
338  Input(PARAM).size_from_dim(1),
339  Input(GRAD).size_from_dim(Input(INDICES).dim()));
340  CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
341 
343  this, Input(INDICES));
344  }
345 
346  template <typename SIndex>
347  bool DoRunWithType() {
348  const auto* lr = Input(LR).template data<T>();
349  const auto iter =
350  OperatorBase::Input<Tensor>(ITER, CPU).template data<int64_t>()[0];
351 
352  const auto t = iter + 1;
353  const auto correction =
354  std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
355 
356  auto block_size = Input(PARAM).numel() / Input(PARAM).size(0);
357  auto n = Input(GRAD).numel() / block_size;
358 
359  const auto* paramIn = Input(PARAM).template data<T>();
360  const auto* indices = Input(INDICES).template data<SIndex>();
361  const auto* gradIn = Input(GRAD).template data<T>();
362  const auto* moment1In = Input(MOMENT_1).template data<T>();
363  const auto* moment2In = Input(MOMENT_2).template data<T>();
364  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
365  auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
366  auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
367 
368  if (OutputSize() == 3) {
369  for (auto i = 0; i < n; ++i) {
370  auto idx = indices[i];
371 
372  if (block_size == 1) {
373  float gi = gradIn[i];
374  float mi = moment1Out[idx] =
375  moment1In[idx] * beta1_ + gi * (1 - beta1_);
376  float vi = moment2Out[idx] =
377  moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
378  paramOut[idx] = paramIn[idx] +
379  lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
380 
381  } else {
382  auto offsetI = i * block_size;
383  auto offsetIdx = idx * block_size;
384 
385 #ifndef NDEBUG
386  CAFFE_ENFORCE_GE(
387  Input(PARAM).numel(),
388  block_size + offsetIdx,
389  this->debug_def().input(PARAM),
390  ", out of bound, idx:",
391  idx,
392  " for input i:",
393  i,
394  " and block size:",
395  block_size);
396  CAFFE_ENFORCE_GE(
397  Input(GRAD).numel(),
398  block_size + offsetI,
399  this->debug_def().input(GRAD),
400  ", out of bound idx, idx:",
401  idx,
402  " for input i:",
403  i);
404 #endif
405 
406  const float* w = paramIn + offsetIdx;
407  const float* g = gradIn + offsetI;
408  const float* m1 = moment1In + offsetIdx;
409  const float* m2 = moment2In + idx;
410  float* nw = paramOut + offsetIdx;
411  float* nm1 = moment1Out + offsetIdx;
412  float* nm2 = moment2Out + idx;
413 
414  float m2_sum = 0.;
415  for (auto j = 0; j < block_size; ++j) {
416  float gj = g[j];
417  m2_sum += gj * gj;
418  }
419  float vi = nm2[0] =
420  m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
421  for (auto j = 0; j < block_size; ++j) {
422  float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_);
423  nw[j] = w[j] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
424  }
425  }
426  }
427  } else {
428  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
429  auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
430  for (auto i = 0; i < n; ++i) {
431  auto idx = indices[i];
432 
433  if (block_size == 1) {
434  float gi = gradIn[i];
435  float mi = moment1Out[idx] =
436  moment1In[idx] * beta1_ + gi * (1 - beta1_);
437  float vi = moment2Out[idx] =
438  moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
439  float ngi = gradOut[i] = correction * mi / (std::sqrt(vi) + epsilon_);
440  paramOut[idx] = paramIn[idx] + lr[0] * ngi;
441 
442  } else {
443  auto offsetI = i * block_size;
444  auto offsetIdx = idx * block_size;
445 
446 #ifndef NDEBUG
447  CAFFE_ENFORCE_GE(
448  Input(PARAM).numel(),
449  block_size + offsetIdx,
450  this->debug_def().input(PARAM),
451  ", out of bound, idx:",
452  idx,
453  " for input i:",
454  i,
455  " and block size:",
456  block_size);
457  CAFFE_ENFORCE_GE(
458  Input(GRAD).numel(),
459  block_size + offsetI,
460  this->debug_def().input(GRAD),
461  ", out of bound idx, idx:",
462  idx,
463  " for input i:",
464  i);
465 #endif
466 
467  const float* w = paramIn + offsetIdx;
468  const float* g = gradIn + offsetI;
469  const float* m1 = moment1In + offsetIdx;
470  const float* m2 = moment2In + idx;
471  float* nw = paramOut + offsetIdx;
472  float* nm1 = moment1Out + offsetIdx;
473  float* nm2 = moment2Out + idx;
474  float* ng = gradOut + offsetI;
475 
476  float m2_sum = 0.;
477  for (auto j = 0; j < block_size; ++j) {
478  float gj = g[j];
479  m2_sum += gj * gj;
480  }
481  float vi = nm2[0] =
482  m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
483  for (auto j = 0; j < block_size; ++j) {
484  float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_);
485  float ngi = ng[j] = correction * mi / (std::sqrt(vi) + epsilon_);
486  nw[j] = w[j] + lr[0] * ngi;
487  }
488  }
489  }
490  }
491  return true;
492  }
493 
494  protected:
495  T beta1_;
496  T beta2_;
497  T epsilon_;
498  INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
499  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
500 };
501 
502 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13