1 #include "caffe2/core/operator.h" 7 template <
typename Context>
21 for (
int i = 0; i < N; ++i) {
24 float hi = nh[i] = decay * h[i] + (1.0f - decay) * gi * gi;
25 float ng = (std::sqrt(di + epsilon) / std::sqrt(hi + epsilon)) * gi;
26 nw[i] = w[i] + lr[0] * ng;
27 nd[i] = decay * di + (1.0f - decay) * ng * ng;
33 template <
class Context>
36 USE_OPERATOR_CONTEXT_FUNCTIONS;
39 OP_SINGLE_ARG(
float,
"epsilon", epsilon_, 1e-5f),
40 OP_SINGLE_ARG(
float,
"decay", decay_, 0.95f) {}
42 bool RunOnDevice()
override {
43 CAFFE_ENFORCE(
Input(GRAD).numel() ==
Input(MOMENT_GRAD).numel());
44 CAFFE_ENFORCE(
Input(GRAD).numel() ==
Input(MOMENT_DELTA).numel());
45 CAFFE_ENFORCE(
Input(GRAD).numel() ==
Input(PARAM).numel());
46 CAFFE_ENFORCE_GE(epsilon_, 0.0f);
47 CAFFE_ENFORCE_GT(decay_, 0.0f);
48 CAFFE_ENFORCE_LT(decay_, 1.0f);
50 Output(OUTPUT_PARAM)->ResizeLike(
Input(PARAM));
51 Output(OUTPUT_MOMENT_GRAD)->ResizeLike(
Input(MOMENT_GRAD));
52 Output(OUTPUT_MOMENT_DELTA)->ResizeLike(
Input(MOMENT_DELTA));
53 AdadeltaUpdate<Context>(
55 Input(PARAM).template data<float>(),
56 Input(GRAD).template data<float>(),
57 Input(MOMENT_GRAD).template data<float>(),
58 Input(MOMENT_DELTA).template data<float>(),
61 Input(LR).template data<float>(),
62 Output(OUTPUT_PARAM)->template mutable_data<float>(),
63 Output(OUTPUT_MOMENT_GRAD)->template mutable_data<float>(),
64 Output(OUTPUT_MOMENT_DELTA)->template mutable_data<float>(),
72 INPUT_TAGS(PARAM, MOMENT_GRAD, MOMENT_DELTA, GRAD, LR);
73 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_GRAD, OUTPUT_MOMENT_DELTA);
76 template <
class Context>
79 USE_OPERATOR_CONTEXT_FUNCTIONS;
82 OP_SINGLE_ARG(
float,
"epsilon", epsilon_, 1e-5f),
83 OP_SINGLE_ARG(
float,
"decay", decay_, 0.95f) {}
85 bool RunOnDevice()
override {
87 CAFFE_ENFORCE_EQ(
Input(PARAM).numel(),
Input(MOMENT_GRAD).numel());
88 CAFFE_ENFORCE_EQ(
Input(PARAM).numel(),
Input(MOMENT_DELTA).numel());
89 CAFFE_ENFORCE_EQ(
Input(LR).numel(), 1);
91 Input(PARAM).size_from_dim(1),
92 Input(GRAD).size_from_dim(
Input(INDICES).dim()));
95 CAFFE_ENFORCE_GE(epsilon_, 0.0f);
96 CAFFE_ENFORCE_GT(decay_, 0.0f);
97 CAFFE_ENFORCE_LT(decay_, 1.0f);
100 this,
Input(INDICES));
103 template <
typename SIndex>
104 bool DoRunWithType() {
105 const auto* lr =
Input(LR).template data<float>();
106 const auto* indices =
Input(INDICES).template data<SIndex>();
107 const auto* gradIn =
Input(GRAD).template data<float>();
108 const auto* paramIn =
Input(PARAM).template data<float>();
109 const auto* momentIn =
Input(MOMENT_GRAD).template data<float>();
110 const auto* momentDeltaIn =
Input(MOMENT_DELTA).template data<float>();
111 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<float>();
113 Output(OUTPUT_MOMENT_GRAD)->template mutable_data<float>();
114 auto* momentDeltaOut =
115 Output(OUTPUT_MOMENT_DELTA)->template mutable_data<float>();
117 auto n =
Input(INDICES).numel();
122 auto block_size =
Input(GRAD).numel() / n;
123 for (
int i = 0; i < n; ++i) {
124 auto idx = indices[i];
125 if (block_size == 1) {
126 float gi = gradIn[i];
127 float di = momentDeltaIn[idx];
128 float hi = momentOut[idx] =
129 decay_ * momentIn[idx] + (1.0f - decay_) * gi * gi;
130 float ng = (std::sqrt(di + epsilon_) / std::sqrt(hi + epsilon_)) * gi;
131 paramOut[idx] = paramIn[idx] + lr[0] * ng;
132 momentDeltaOut[idx] = decay_ * di + (1.0f - decay_) * ng * ng;
134 auto offsetI = i * block_size;
135 auto offsetIdx = idx * block_size;
139 Input(PARAM).numel(),
140 block_size + offsetIdx,
141 this->debug_def().input(PARAM),
142 ", out of bound, idx:",
150 block_size + offsetI,
151 this->debug_def().input(GRAD),
152 ", out of bound idx, idx:",
161 momentIn + offsetIdx,
162 momentDeltaIn + offsetIdx,
166 paramOut + offsetIdx,
167 momentOut + offsetIdx,
168 momentDeltaOut + offsetIdx,
176 const float epsilon_;
178 INPUT_TAGS(PARAM, MOMENT_GRAD, MOMENT_DELTA, INDICES, GRAD, LR);
179 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_GRAD, OUTPUT_MOMENT_DELTA);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...