3 #include "caffe2/core/operator.h" 4 #include "caffe2/perfkernels/adagrad.h" 8 template <
typename Context>
20 return adagrad_update(N, w, g, h, nw, nh, epsilon, decay, lr[0]);
23 template <
typename Context>
24 void adagrad_update_output_effective_lr(
28 const float* momentIn,
31 float* effectiveLROut,
36 for (
auto i = 0; i < N; ++i) {
37 float grad = gradIn[i];
38 float moment = momentOut[i] = decay * momentIn[i] + grad * grad;
39 float effective_lr = effectiveLROut[i] =
40 lr[0] / (std::sqrt(moment) + epsilon);
41 paramOut[i] = paramIn[i] + effective_lr * grad;
45 template <
typename Context>
46 void adagrad_update_output_effective_lr_and_update(
50 const float* momentIn,
53 float* effectiveLROut,
59 for (
auto i = 0; i < N; ++i) {
60 float grad = gradIn[i];
61 float moment = momentOut[i] = decay * momentIn[i] + grad * grad;
62 float effective_lr = effectiveLROut[i] =
63 lr[0] / (std::sqrt(moment) + epsilon);
64 float update = updateOut[i] = effective_lr * grad;
65 paramOut[i] = paramIn[i] + update;
69 template <
typename T,
class Context>
72 USE_OPERATOR_CONTEXT_FUNCTIONS;
75 epsilon_(this->
template GetSingleArgument<T>(
"epsilon", 1e-5f)),
76 decay_(this->
template GetSingleArgument<T>(
"decay", 1.0f)) {}
78 bool RunOnDevice()
override {
81 Input(MOMENT_1).numel(),
87 Input(MOMENT_1).numel(),
91 CAFFE_ENFORCE_EQ(
Input(GRAD).numel(),
Input(PARAM).numel());
92 Output(OUTPUT_PARAM)->ResizeLike(
Input(PARAM));
93 Output(OUTPUT_MOMENT_1)->ResizeLike(
Input(MOMENT_1));
94 if (OutputSize() == 2) {
95 adagrad_update<Context>(
97 Input(PARAM).template data<T>(),
98 Input(GRAD).template data<T>(),
99 Input(MOMENT_1).template data<T>(),
100 Output(OUTPUT_PARAM)->template mutable_data<T>(),
101 Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
104 Input(LR).template data<T>(),
106 }
else if (OutputSize() == 3) {
107 Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(
Input(GRAD));
108 adagrad_update_output_effective_lr<Context>(
110 Input(PARAM).template data<T>(),
111 Input(GRAD).template data<T>(),
112 Input(MOMENT_1).template data<T>(),
113 Output(OUTPUT_PARAM)->template mutable_data<T>(),
114 Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
115 Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<T>(),
118 Input(LR).template data<T>(),
121 Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(
Input(GRAD));
122 Output(OUTPUT_UPDATE)->ResizeLike(
Input(GRAD));
123 adagrad_update_output_effective_lr_and_update<Context>(
125 Input(PARAM).template data<T>(),
126 Input(GRAD).template data<T>(),
127 Input(MOMENT_1).template data<T>(),
128 Output(OUTPUT_PARAM)->template mutable_data<T>(),
129 Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
130 Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<T>(),
131 Output(OUTPUT_UPDATE)->template mutable_data<T>(),
134 Input(LR).template data<T>(),
144 INPUT_TAGS(PARAM, MOMENT_1, GRAD, LR);
152 template <
typename T,
class Context>
155 USE_OPERATOR_CONTEXT_FUNCTIONS;
158 epsilon_(this->
template GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
160 bool RunOnDevice()
override {
162 CAFFE_ENFORCE_EQ(
Input(PARAM).numel(),
Input(MOMENT_1).numel());
163 CAFFE_ENFORCE_EQ(
Input(LR).numel(), 1);
165 Input(PARAM).size_from_dim(1),
166 Input(GRAD).size_from_dim(
Input(INDICES).dim()));
169 this,
Input(INDICES));
172 template <
typename SIndex>
173 bool DoRunWithType() {
174 const auto* lr =
Input(LR).template data<T>();
175 const auto* indices =
Input(INDICES).template data<SIndex>();
176 const auto* gradIn =
Input(GRAD).template data<T>();
177 const auto* paramIn =
Input(PARAM).template data<T>();
178 const auto* momentIn =
Input(MOMENT_1).template data<T>();
179 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
180 auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
182 auto n =
Input(INDICES).numel();
187 auto block_size =
Input(GRAD).numel() / n;
188 for (
auto i = 0; i < n; ++i) {
189 auto idx = indices[i];
190 if (block_size == 1) {
191 float gi = gradIn[i];
192 float hi = momentOut[idx] = momentIn[idx] + gi * gi;
193 paramOut[idx] = paramIn[idx] + lr[0] * gi / (std::sqrt(hi) + epsilon_);
195 auto offsetI = i * block_size;
196 auto offsetIdx = idx * block_size;
200 Input(PARAM).numel(),
201 block_size + offsetIdx,
202 this->debug_def().input(PARAM),
203 ", out of bound, idx:",
211 block_size + offsetI,
212 this->debug_def().input(GRAD),
213 ", out of bound idx, idx:",
222 momentIn + offsetIdx,
223 paramOut + offsetIdx,
224 momentOut + offsetIdx,
236 INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR);
237 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);
240 template <
typename T,
class Context>
243 USE_OPERATOR_CONTEXT_FUNCTIONS;
246 epsilon_(this->
template GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
248 bool RunOnDevice()
override {
250 CAFFE_ENFORCE_EQ(
Input(PARAM).sizes()[0],
Input(MOMENT_1).numel());
251 CAFFE_ENFORCE_EQ(
Input(LR).numel(), 1);
253 Input(PARAM).size_from_dim(1),
254 Input(GRAD).size_from_dim(
Input(INDICES).dim()));
257 this,
Input(INDICES));
260 template <
typename SIndex>
261 bool DoRunWithType() {
262 const auto* lr =
Input(LR).template data<T>();
263 const auto* indices =
Input(INDICES).template data<SIndex>();
264 const auto* gradIn =
Input(GRAD).template data<T>();
265 const auto* paramIn =
Input(PARAM).template data<T>();
266 const auto* momentIn =
Input(MOMENT_1).template data<T>();
267 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
268 auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
270 auto n =
Input(INDICES).numel();
275 auto block_size =
Input(GRAD).numel() / n;
277 for (
auto i = 0; i < n; ++i) {
278 auto idx = indices[i];
279 if (block_size == 1) {
280 float gi = gradIn[i];
281 float hi = momentOut[idx] = momentIn[idx] + gi * gi;
282 paramOut[idx] = paramIn[idx] + lr[0] * gi / (std::sqrt(hi) + epsilon_);
284 auto offsetI = i * block_size;
285 auto offsetIdx = idx * block_size;
289 Input(PARAM).numel(),
290 block_size + offsetIdx,
291 this->debug_def().input(PARAM),
292 ", out of bound, idx:",
300 block_size + offsetI,
301 this->debug_def().input(GRAD),
302 ", out of bound idx, idx:",
308 const float* w = paramIn + offsetIdx;
309 const float* g = gradIn + offsetI;
310 const float* h = momentIn + idx;
311 float* nw = paramOut + offsetIdx;
312 float* nh = momentOut + idx;
314 for (
auto j = 0; j < block_size; ++j) {
318 float hi = nh[0] = h[0] + hs / block_size;
319 float step = lr[0] / (std::sqrt(hi) + epsilon_);
320 for (
auto j = 0; j < block_size; ++j) {
321 nw[j] = w[j] + g[j] * step;
330 INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR);
331 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...