Caffe2 - C++ API
A deep learning, cross platform ML framework
adagrad_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 #include "caffe2/perfkernels/adagrad.h"
5 
6 namespace caffe2 {
7 
8 template <typename Context>
9 void adagrad_update(
10  int N,
11  const float* w,
12  const float* g,
13  const float* h,
14  float* nw,
15  float* nh,
16  float epsilon,
17  float decay,
18  const float* lr,
19  Context* /*context*/) {
20  return adagrad_update(N, w, g, h, nw, nh, epsilon, decay, lr[0]);
21 }
22 
23 template <typename Context>
24 void adagrad_update_output_effective_lr(
25  int N,
26  const float* paramIn,
27  const float* gradIn,
28  const float* momentIn,
29  float* paramOut,
30  float* momentOut,
31  float* effectiveLROut,
32  float epsilon,
33  float decay,
34  const float* lr,
35  Context* /*context*/) {
36  for (auto i = 0; i < N; ++i) {
37  float grad = gradIn[i];
38  float moment = momentOut[i] = decay * momentIn[i] + grad * grad;
39  float effective_lr = effectiveLROut[i] =
40  lr[0] / (std::sqrt(moment) + epsilon);
41  paramOut[i] = paramIn[i] + effective_lr * grad;
42  }
43 }
44 
45 template <typename Context>
46 void adagrad_update_output_effective_lr_and_update(
47  int N,
48  const float* paramIn,
49  const float* gradIn,
50  const float* momentIn,
51  float* paramOut,
52  float* momentOut,
53  float* effectiveLROut,
54  float* updateOut,
55  float epsilon,
56  float decay,
57  const float* lr,
58  Context* /*context*/) {
59  for (auto i = 0; i < N; ++i) {
60  float grad = gradIn[i];
61  float moment = momentOut[i] = decay * momentIn[i] + grad * grad;
62  float effective_lr = effectiveLROut[i] =
63  lr[0] / (std::sqrt(moment) + epsilon);
64  float update = updateOut[i] = effective_lr * grad;
65  paramOut[i] = paramIn[i] + update;
66  }
67 }
68 
69 template <typename T, class Context>
70 class AdagradOp final : public Operator<Context> {
71  public:
72  USE_OPERATOR_CONTEXT_FUNCTIONS;
73  AdagradOp(const OperatorDef& operator_def, Workspace* ws)
74  : Operator<Context>(operator_def, ws),
75  epsilon_(this->template GetSingleArgument<T>("epsilon", 1e-5f)),
76  decay_(this->template GetSingleArgument<T>("decay", 1.0f)) {}
77 
78  bool RunOnDevice() override {
79  CAFFE_ENFORCE_EQ(
80  Input(GRAD).numel(),
81  Input(MOMENT_1).numel(),
82  "PARAM size: ",
83  Input(PARAM).numel(),
84  ", GRAD size: ",
85  Input(GRAD).numel(),
86  ", MOMENT_1 size: ",
87  Input(MOMENT_1).numel(),
88  ", LR size: ",
89  Input(LR).numel());
90 
91  CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(PARAM).numel());
92  Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
93  Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
94  if (OutputSize() == 2) {
95  adagrad_update<Context>(
96  Input(GRAD).numel(),
97  Input(PARAM).template data<T>(),
98  Input(GRAD).template data<T>(),
99  Input(MOMENT_1).template data<T>(),
100  Output(OUTPUT_PARAM)->template mutable_data<T>(),
101  Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
102  epsilon_,
103  decay_,
104  Input(LR).template data<T>(),
105  &context_);
106  } else if (OutputSize() == 3) {
107  Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(Input(GRAD));
108  adagrad_update_output_effective_lr<Context>(
109  Input(GRAD).numel(),
110  Input(PARAM).template data<T>(),
111  Input(GRAD).template data<T>(),
112  Input(MOMENT_1).template data<T>(),
113  Output(OUTPUT_PARAM)->template mutable_data<T>(),
114  Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
115  Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<T>(),
116  epsilon_,
117  decay_,
118  Input(LR).template data<T>(),
119  &context_);
120  } else {
121  Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(Input(GRAD));
122  Output(OUTPUT_UPDATE)->ResizeLike(Input(GRAD));
123  adagrad_update_output_effective_lr_and_update<Context>(
124  Input(GRAD).numel(),
125  Input(PARAM).template data<T>(),
126  Input(GRAD).template data<T>(),
127  Input(MOMENT_1).template data<T>(),
128  Output(OUTPUT_PARAM)->template mutable_data<T>(),
129  Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
130  Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<T>(),
131  Output(OUTPUT_UPDATE)->template mutable_data<T>(),
132  epsilon_,
133  decay_,
134  Input(LR).template data<T>(),
135  &context_);
136  }
137 
138  return true;
139  }
140 
141  protected:
142  T epsilon_;
143  T decay_;
144  INPUT_TAGS(PARAM, MOMENT_1, GRAD, LR);
145  OUTPUT_TAGS(
146  OUTPUT_PARAM,
147  OUTPUT_MOMENT_1,
148  OUTPUT_EFFECTIVE_LR,
149  OUTPUT_UPDATE);
150 };
151 
152 template <typename T, class Context>
153 class SparseAdagradOp final : public Operator<Context> {
154  public:
155  USE_OPERATOR_CONTEXT_FUNCTIONS;
156  SparseAdagradOp(const OperatorDef& operator_def, Workspace* ws)
157  : Operator<Context>(operator_def, ws),
158  epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {}
159 
160  bool RunOnDevice() override {
161  // Enforce shapes
162  CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_1).numel());
163  CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
164  CAFFE_ENFORCE_EQ(
165  Input(PARAM).size_from_dim(1),
166  Input(GRAD).size_from_dim(Input(INDICES).dim()));
167 
169  this, Input(INDICES));
170  }
171 
172  template <typename SIndex>
173  bool DoRunWithType() {
174  const auto* lr = Input(LR).template data<T>();
175  const auto* indices = Input(INDICES).template data<SIndex>();
176  const auto* gradIn = Input(GRAD).template data<T>();
177  const auto* paramIn = Input(PARAM).template data<T>();
178  const auto* momentIn = Input(MOMENT_1).template data<T>();
179  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
180  auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
181 
182  auto n = Input(INDICES).numel();
183  if (n == 0) {
184  return true;
185  }
186 
187  auto block_size = Input(GRAD).numel() / n;
188  for (auto i = 0; i < n; ++i) {
189  auto idx = indices[i];
190  if (block_size == 1) {
191  float gi = gradIn[i];
192  float hi = momentOut[idx] = momentIn[idx] + gi * gi;
193  paramOut[idx] = paramIn[idx] + lr[0] * gi / (std::sqrt(hi) + epsilon_);
194  } else {
195  auto offsetI = i * block_size;
196  auto offsetIdx = idx * block_size;
197 
198 #ifndef NDEBUG
199  CAFFE_ENFORCE_GE(
200  Input(PARAM).numel(),
201  block_size + offsetIdx,
202  this->debug_def().input(PARAM),
203  ", out of bound, idx:",
204  idx,
205  " for input i:",
206  i,
207  " and block size:",
208  block_size);
209  CAFFE_ENFORCE_GE(
210  Input(GRAD).numel(),
211  block_size + offsetI,
212  this->debug_def().input(GRAD),
213  ", out of bound idx, idx:",
214  idx,
215  " for input i:",
216  i);
217 #endif
218  adagrad_update(
219  block_size,
220  paramIn + offsetIdx,
221  gradIn + offsetI,
222  momentIn + offsetIdx,
223  paramOut + offsetIdx,
224  momentOut + offsetIdx,
225  epsilon_,
226  1.0f,
227  lr,
228  &context_);
229  }
230  }
231  return true;
232  }
233 
234  protected:
235  T epsilon_;
236  INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR);
237  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);
238 };
239 
240 template <typename T, class Context>
241 class RowWiseSparseAdagradOp final : public Operator<Context> {
242  public:
243  USE_OPERATOR_CONTEXT_FUNCTIONS;
244  RowWiseSparseAdagradOp(const OperatorDef& operator_def, Workspace* ws)
245  : Operator<Context>(operator_def, ws),
246  epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {}
247 
248  bool RunOnDevice() override {
249  // Enforce shapes
250  CAFFE_ENFORCE_EQ(Input(PARAM).sizes()[0], Input(MOMENT_1).numel());
251  CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
252  CAFFE_ENFORCE_EQ(
253  Input(PARAM).size_from_dim(1),
254  Input(GRAD).size_from_dim(Input(INDICES).dim()));
255 
257  this, Input(INDICES));
258  }
259 
260  template <typename SIndex>
261  bool DoRunWithType() {
262  const auto* lr = Input(LR).template data<T>();
263  const auto* indices = Input(INDICES).template data<SIndex>();
264  const auto* gradIn = Input(GRAD).template data<T>();
265  const auto* paramIn = Input(PARAM).template data<T>();
266  const auto* momentIn = Input(MOMENT_1).template data<T>();
267  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
268  auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
269 
270  auto n = Input(INDICES).numel();
271  if (n == 0) {
272  return true;
273  }
274 
275  auto block_size = Input(GRAD).numel() / n;
276 
277  for (auto i = 0; i < n; ++i) {
278  auto idx = indices[i];
279  if (block_size == 1) {
280  float gi = gradIn[i];
281  float hi = momentOut[idx] = momentIn[idx] + gi * gi;
282  paramOut[idx] = paramIn[idx] + lr[0] * gi / (std::sqrt(hi) + epsilon_);
283  } else {
284  auto offsetI = i * block_size;
285  auto offsetIdx = idx * block_size;
286 
287 #ifndef NDEBUG
288  CAFFE_ENFORCE_GE(
289  Input(PARAM).numel(),
290  block_size + offsetIdx,
291  this->debug_def().input(PARAM),
292  ", out of bound, idx:",
293  idx,
294  " for input i:",
295  i,
296  " and block size:",
297  block_size);
298  CAFFE_ENFORCE_GE(
299  Input(GRAD).numel(),
300  block_size + offsetI,
301  this->debug_def().input(GRAD),
302  ", out of bound idx, idx:",
303  idx,
304  " for input i:",
305  i);
306 #endif
307 
308  const float* w = paramIn + offsetIdx;
309  const float* g = gradIn + offsetI;
310  const float* h = momentIn + idx;
311  float* nw = paramOut + offsetIdx;
312  float* nh = momentOut + idx;
313  float hs = 0.;
314  for (auto j = 0; j < block_size; ++j) {
315  float gj = g[j];
316  hs += gj * gj;
317  }
318  float hi = nh[0] = h[0] + hs / block_size;
319  float step = lr[0] / (std::sqrt(hi) + epsilon_);
320  for (auto j = 0; j < block_size; ++j) {
321  nw[j] = w[j] + g[j] * step;
322  }
323  }
324  }
325  return true;
326  }
327 
328  protected:
329  T epsilon_;
330  INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR);
331  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);
332 };
333 }
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13