Caffe2 - C++ API
A deep learning, cross platform ML framework
adam_op.h
1 
17 #pragma once
18 
19 #include "caffe2/core/operator.h"
20 
21 namespace caffe2 {
22 
23 template <typename Context>
24 void adam_update(
25  int N,
26  const float* g,
27  const float* m,
28  const float* v,
29  float* ng,
30  float* nm,
31  float* nv,
32  float beta1,
33  float beta2,
34  float eps_hat,
35  float correction,
36  const float* lr,
37  Context* /*context*/) {
38  for (auto i = 0; i < N; ++i) {
39  float gi = g[i];
40  float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
41  float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
42  ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
43  }
44 }
45 
46 template <typename Context>
47 void adam_compute(
48  int N,
49  const float* w,
50  const float* g,
51  const float* m,
52  const float* v,
53  float* nw,
54  float* nm,
55  float* nv,
56  float beta1,
57  float beta2,
58  float eps_hat,
59  float correction,
60  const float* lr,
61  Context* /*context*/) {
62  for (auto i = 0; i < N; ++i) {
63  float gi = g[i];
64  float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
65  float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
66  float ng = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
67  nw[i] = w[i] + ng;
68  }
69 }
70 
71 template <typename T, class Context>
72 class AdamOp final : public Operator<Context> {
73  public:
74  USE_OPERATOR_CONTEXT_FUNCTIONS;
75  AdamOp(const OperatorDef& operator_def, Workspace* ws)
76  : Operator<Context>(operator_def, ws),
77  beta1_(OperatorBase::GetSingleArgument<float>("beta1", 0.9f)),
78  beta2_(OperatorBase::GetSingleArgument<float>("beta2", 0.999f)),
79  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)) {}
80  bool RunOnDevice() override {
81  // Iter live on the CPU
82  CAFFE_ENFORCE(OperatorBase::InputIsType<TensorCPU>(ITER));
83  CAFFE_ENFORCE(Input(LR).size() == 1);
84  CAFFE_ENFORCE(Input(GRAD).size() == Input(PARAM).size());
85  CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENT_1).size());
86  CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENT_2).size());
87  Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
88  Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
89  Output(OUTPUT_MOMENT_2)->ResizeLike(Input(MOMENT_2));
90 
91  const auto iter =
92  OperatorBase::Input<TensorCPU>(ITER).template data<int64_t>()[0];
93 
94  const auto t = iter + 1;
95  const auto correction =
96  std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
97  adam_compute<Context>(
98  Input(GRAD).size(),
99  Input(PARAM).template data<T>(),
100  Input(GRAD).template data<T>(),
101  Input(MOMENT_1).template data<T>(),
102  Input(MOMENT_2).template data<T>(),
103  Output(OUTPUT_PARAM)->template mutable_data<T>(),
104  Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
105  Output(OUTPUT_MOMENT_2)->template mutable_data<T>(),
106  beta1_,
107  beta2_,
108  epsilon_,
109  correction,
110  Input(LR).template data<T>(),
111  &context_);
112  return true;
113  }
114 
115  protected:
116  T beta1_{0.9};
117  T beta2_{0.999};
118  T epsilon_{1e-8};
119  INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, GRAD, LR, ITER);
120  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
121 };
122 
123 template <typename T, class Context>
124 class SparseAdamOp final : public Operator<Context> {
125  public:
126  USE_OPERATOR_CONTEXT_FUNCTIONS;
127  SparseAdamOp(const OperatorDef& operator_def, Workspace* ws)
128  : Operator<Context>(operator_def, ws),
129  beta1_(OperatorBase::GetSingleArgument<float>("beta1", 0.9f)),
130  beta2_(OperatorBase::GetSingleArgument<float>("beta2", 0.999f)),
131  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)) {}
132 
133  bool RunOnDevice() override {
134  // Enforce shapes
135  CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENT_1).size());
136  CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENT_2).size());
137  CAFFE_ENFORCE_EQ(
138  Input(PARAM).size_from_dim(1),
139  Input(GRAD).size_from_dim(Input(INDICES).ndim()));
140  CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
141 
143  this, Input(INDICES));
144  }
145 
146  template <typename SIndex>
147  bool DoRunWithType() {
148  const auto* lr = Input(LR).template data<T>();
149  const auto iter =
150  OperatorBase::Input<TensorCPU>(ITER).template data<int64_t>()[0];
151 
152  const auto t = iter + 1;
153  const auto correction =
154  std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
155 
156  auto block_size = Input(PARAM).size() / Input(PARAM).dim(0);
157  auto n = Input(GRAD).size() / block_size;
158 
159  const auto* paramIn = Input(PARAM).template data<T>();
160  const auto* indices = Input(INDICES).template data<SIndex>();
161  const auto* gradIn = Input(GRAD).template data<T>();
162  const auto* moment1In = Input(MOMENT_1).template data<T>();
163  const auto* moment2In = Input(MOMENT_2).template data<T>();
164  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
165  auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
166  auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
167 
168  for (auto i = 0; i < n; ++i) {
169  auto idx = indices[i];
170 
171  if (block_size == 1) {
172  float gi = gradIn[i];
173  float mi = moment1Out[idx] =
174  moment1In[idx] * beta1_ + gi * (1 - beta1_);
175  float vi = moment2Out[idx] =
176  moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
177  paramOut[idx] =
178  paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
179 
180  } else {
181  auto offsetI = i * block_size;
182  auto offsetIdx = idx * block_size;
183 
184 #ifndef NDEBUG
185  CAFFE_ENFORCE_GE(
186  Input(PARAM).size(),
187  block_size + offsetIdx,
188  this->debug_def().input(PARAM),
189  ", out of bound, idx:",
190  idx,
191  " for input i:",
192  i,
193  " and block size:",
194  block_size);
195  CAFFE_ENFORCE_GE(
196  Input(GRAD).size(),
197  block_size + offsetI,
198  this->debug_def().input(GRAD),
199  ", out of bound idx, idx:",
200  idx,
201  " for input i:",
202  i);
203 #endif
204 
205  adam_compute(
206  block_size,
207  paramIn + offsetIdx,
208  gradIn + offsetI,
209  moment1In + offsetIdx,
210  moment2In + offsetIdx,
211  paramOut + offsetIdx,
212  moment1Out + offsetIdx,
213  moment2Out + offsetIdx,
214  beta1_,
215  beta2_,
216  epsilon_,
217  correction,
218  lr,
219  &context_);
220  }
221  }
222  return true;
223  }
224 
225  protected:
226  T beta1_;
227  T beta2_;
228  T epsilon_;
229  INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
230  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
231 };
232 
233 template <typename T, class Context>
234 class RowWiseSparseAdamOp final : public Operator<Context> {
235  public:
236  USE_OPERATOR_CONTEXT_FUNCTIONS;
237  RowWiseSparseAdamOp(const OperatorDef& operator_def, Workspace* ws)
238  : Operator<Context>(operator_def, ws),
239  beta1_(OperatorBase::GetSingleArgument<float>("beta1", 0.9f)),
240  beta2_(OperatorBase::GetSingleArgument<float>("beta2", 0.999f)),
241  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)) {}
242 
243  bool RunOnDevice() override {
244  // Enforce shapes
245  CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENT_1).size());
246  CAFFE_ENFORCE_EQ(Input(PARAM).dims()[0], Input(MOMENT_2).size());
247  CAFFE_ENFORCE_EQ(
248  Input(PARAM).size_from_dim(1),
249  Input(GRAD).size_from_dim(Input(INDICES).ndim()));
250  CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
251 
253  this, Input(INDICES));
254  }
255 
256  template <typename SIndex>
257  bool DoRunWithType() {
258  const auto* lr = Input(LR).template data<T>();
259  const auto iter =
260  OperatorBase::Input<TensorCPU>(ITER).template data<int64_t>()[0];
261 
262  const auto t = iter + 1;
263  const auto correction =
264  std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
265 
266  auto block_size = Input(PARAM).size() / Input(PARAM).dim(0);
267  auto n = Input(GRAD).size() / block_size;
268 
269  const auto* paramIn = Input(PARAM).template data<T>();
270  const auto* indices = Input(INDICES).template data<SIndex>();
271  const auto* gradIn = Input(GRAD).template data<T>();
272  const auto* moment1In = Input(MOMENT_1).template data<T>();
273  const auto* moment2In = Input(MOMENT_2).template data<T>();
274  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
275  auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
276  auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
277 
278  for (auto i = 0; i < n; ++i) {
279  auto idx = indices[i];
280 
281  if (block_size == 1) {
282  float gi = gradIn[i];
283  float mi = moment1Out[idx] =
284  moment1In[idx] * beta1_ + gi * (1 - beta1_);
285  float vi = moment2Out[idx] =
286  moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
287  paramOut[idx] =
288  paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
289 
290  } else {
291  auto offsetI = i * block_size;
292  auto offsetIdx = idx * block_size;
293 
294 #ifndef NDEBUG
295  CAFFE_ENFORCE_GE(
296  Input(PARAM).size(),
297  block_size + offsetIdx,
298  this->debug_def().input(PARAM),
299  ", out of bound, idx:",
300  idx,
301  " for input i:",
302  i,
303  " and block size:",
304  block_size);
305  CAFFE_ENFORCE_GE(
306  Input(GRAD).size(),
307  block_size + offsetI,
308  this->debug_def().input(GRAD),
309  ", out of bound idx, idx:",
310  idx,
311  " for input i:",
312  i);
313 #endif
314 
315  const float* w = paramIn + offsetIdx;
316  const float* g = gradIn + offsetI;
317  const float* m1 = moment1In + offsetIdx;
318  const float* m2 = moment2In + idx;
319  float* nw = paramOut + offsetIdx;
320  float* nm1 = moment1Out + offsetIdx;
321  float* nm2 = moment2Out + idx;
322 
323  float m2_sum = 0.;
324  for (auto j = 0; j < block_size; ++j) {
325  float gj = g[j];
326  m2_sum += gj * gj;
327  }
328  float vi = nm2[0] =
329  m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
330  for (auto j = 0; j < block_size; ++j) {
331  float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_);
332  nw[j] = w[j] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
333  }
334  }
335  }
336  return true;
337  }
338 
339  protected:
340  T beta1_;
341  T beta2_;
342  T epsilon_;
343  INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
344  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
345 };
346 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.