Caffe2 - C++ API
A deep learning, cross platform ML framework
wngrad_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 
5 namespace caffe2 {
6 
7 template <typename Context>
8 void wngrad_update(
9  int N,
10  const float* w,
11  const float* g,
12  const float* h,
13  float* nw,
14  float* nh,
15  float epsilon,
16  const float* lr,
17  Context* /*context*/) {
18  for (auto i = 0; i < N; ++i) {
19  float gi = g[i];
20  nw[i] = w[i] + lr[0] * gi / (h[0] + epsilon);
21  }
22  float nhTmp = 0.0;
23  for (auto i = 0; i < N; ++i) {
24  float gi = g[i];
25  nhTmp += gi * gi;
26  }
27  nhTmp /= (h[0] + epsilon);
28  nh[0] = h[0] + nhTmp;
29 }
30 
31 template <typename Context>
32 void wngrad_update_output_effective_lr(
33  int N,
34  const float* paramIn,
35  const float* gradIn,
36  const float* seqBIn,
37  float* paramOut,
38  float* seqBOut,
39  float* effectiveLROut,
40  float epsilon,
41  const float* lr,
42  Context* /*context*/) {
43  effectiveLROut[0] = lr[0] / (seqBIn[0] + epsilon);
44  float seqBTmp = 0.0;
45  for (auto i = 0; i < N; ++i) {
46  float gi = gradIn[i];
47  seqBTmp += gi * gi;
48  }
49  seqBTmp /= (seqBIn[0] + epsilon);
50  seqBOut[0] = seqBIn[0] + seqBTmp;
51  for (auto i = 0; i < N; ++i) {
52  float grad = gradIn[i];
53  paramOut[i] = paramIn[i] + effectiveLROut[0] * grad;
54  }
55 }
56 
57 template <typename Context>
58 void wngrad_update_output_effective_lr_and_update(
59  int N,
60  const float* paramIn,
61  const float* gradIn,
62  const float* seqBIn,
63  float* paramOut,
64  float* seqBOut,
65  float* effectiveLROut,
66  float* updateOut,
67  float epsilon,
68  const float* lr,
69  Context* /*context*/) {
70  effectiveLROut[0] = lr[0] / (seqBIn[0] + epsilon);
71  float seqBTmp = 0.0;
72  for (auto i = 0; i < N; ++i) {
73  float gi = gradIn[i];
74  seqBTmp += gi * gi;
75  }
76  seqBTmp /= (seqBIn[0] + epsilon);
77  seqBOut[0] = seqBIn[0] + seqBTmp;
78 
79  for (auto i = 0; i < N; ++i) {
80  float grad = gradIn[i];
81  float update = updateOut[i] = effectiveLROut[0] * grad;
82  paramOut[i] = paramIn[i] + update;
83  }
84 }
85 
86 template <typename T, class Context>
87 class WngradOp final : public Operator<Context> {
88  public:
89  USE_OPERATOR_CONTEXT_FUNCTIONS;
90  WngradOp(const OperatorDef& operator_def, Workspace* ws)
91  : Operator<Context>(operator_def, ws),
92  epsilon_(this->template GetSingleArgument<T>("epsilon", 1e-5f)) {}
93 
94  bool RunOnDevice() override {
95  CAFFE_ENFORCE_EQ(
96  Input(GRAD).numel(),
97  Input(PARAM).numel(),
98  "PARAM size: ",
99  Input(PARAM).numel(),
100  ", GRAD size: ",
101  Input(GRAD).numel(),
102  ", SEQ_B size: ",
103  Input(SEQ_B).numel(),
104  ", LR size: ",
105  Input(LR).numel());
106 
107  Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
108  Output(OUTPUT_SEQ_B)->ResizeLike(Input(SEQ_B));
109  if (OutputSize() == 2) {
110  wngrad_update<Context>(
111  Input(GRAD).numel(),
112  Input(PARAM).template data<T>(),
113  Input(GRAD).template data<T>(),
114  Input(SEQ_B).template data<T>(),
115  Output(OUTPUT_PARAM)->template mutable_data<T>(),
116  Output(OUTPUT_SEQ_B)->template mutable_data<T>(),
117  epsilon_,
118  Input(LR).template data<T>(),
119  &context_);
120  } else if (OutputSize() == 3) {
121  Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(Input(SEQ_B));
122  wngrad_update_output_effective_lr<Context>(
123  Input(GRAD).numel(),
124  Input(PARAM).template data<T>(),
125  Input(GRAD).template data<T>(),
126  Input(SEQ_B).template data<T>(),
127  Output(OUTPUT_PARAM)->template mutable_data<T>(),
128  Output(OUTPUT_SEQ_B)->template mutable_data<T>(),
129  Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<T>(),
130  epsilon_,
131  Input(LR).template data<T>(),
132  &context_);
133  } else {
134  Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(Input(SEQ_B));
135  Output(OUTPUT_UPDATE)->ResizeLike(Input(GRAD));
136  wngrad_update_output_effective_lr_and_update<Context>(
137  Input(GRAD).numel(),
138  Input(PARAM).template data<T>(),
139  Input(GRAD).template data<T>(),
140  Input(SEQ_B).template data<T>(),
141  Output(OUTPUT_PARAM)->template mutable_data<T>(),
142  Output(OUTPUT_SEQ_B)->template mutable_data<T>(),
143  Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<T>(),
144  Output(OUTPUT_UPDATE)->template mutable_data<T>(),
145  epsilon_,
146  Input(LR).template data<T>(),
147  &context_);
148  }
149 
150  return true;
151  }
152 
153  protected:
154  T epsilon_;
155  INPUT_TAGS(PARAM, SEQ_B, GRAD, LR);
156  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_SEQ_B, OUTPUT_EFFECTIVE_LR, OUTPUT_UPDATE);
157 };
158 
159 template <typename T, class Context>
160 class SparseWngradOp final : public Operator<Context> {
161  public:
162  USE_OPERATOR_CONTEXT_FUNCTIONS;
163  SparseWngradOp(const OperatorDef& operator_def, Workspace* ws)
164  : Operator<Context>(operator_def, ws),
165  epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {}
166 
167  bool RunOnDevice() override {
168  // Enforce shapes
169  CAFFE_ENFORCE_EQ(Input(SEQ_B).numel(), 1);
170  CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
171  CAFFE_ENFORCE_EQ(
172  Input(PARAM).size_from_dim(1),
173  Input(GRAD).size_from_dim(Input(INDICES).dim()));
174 
176  this, Input(INDICES));
177  }
178 
179  template <typename SIndex>
180  bool DoRunWithType() {
181  const auto* lr = Input(LR).template data<T>();
182  const auto* indices = Input(INDICES).template data<SIndex>();
183  const auto* gradIn = Input(GRAD).template data<T>();
184  const auto* paramIn = Input(PARAM).template data<T>();
185  const auto* seqBIn = Input(SEQ_B).template data<T>();
186  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
187  auto* seqBOut = Output(OUTPUT_SEQ_B)->template mutable_data<T>();
188 
189  auto n = Input(INDICES).numel();
190  if (n == 0) {
191  return true;
192  }
193 
194  auto block_size = Input(GRAD).numel() / n;
195 
196  for (auto i = 0; i < n; ++i) {
197  auto idx = indices[i];
198  if (block_size == 1) {
199  float gi = gradIn[i];
200  paramOut[idx] = paramIn[idx] + lr[0] * gi / (seqBIn[0] + epsilon_);
201  } else {
202  auto offsetI = i * block_size;
203  auto offsetIdx = idx * block_size;
204 
205 #ifndef NDEBUG
206  CAFFE_ENFORCE_GE(
207  Input(PARAM).numel(),
208  block_size + offsetIdx,
209  this->debug_def().input(PARAM),
210  ", out of bound, idx:",
211  idx,
212  " for input i:",
213  i,
214  " and block size:",
215  block_size);
216  CAFFE_ENFORCE_GE(
217  Input(GRAD).numel(),
218  block_size + offsetI,
219  this->debug_def().input(GRAD),
220  ", out of bound idx, idx:",
221  idx,
222  " for input i:",
223  i);
224 #endif
225  for (auto j = 0; j < block_size; ++j) {
226  float gi = gradIn[offsetI + j];
227  paramOut[offsetIdx + j] =
228  paramIn[offsetIdx + j] + lr[0] * gi / (seqBIn[0] + epsilon_);
229  }
230  }
231  }
232  float seqBTmp = 0.0;
233  for (auto i = 0; i < Input(GRAD).numel(); ++i) {
234  float gi = gradIn[i];
235  seqBTmp += gi * gi;
236  }
237  seqBTmp /= seqBIn[0];
238  seqBOut[0] = seqBTmp + seqBIn[0];
239  return true;
240  }
241 
242  protected:
243  T epsilon_;
244  INPUT_TAGS(PARAM, SEQ_B, INDICES, GRAD, LR);
245  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_SEQ_B);
246 };
247 
248 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13