3 #include "caffe2/core/operator.h" 7 template <
typename Context>
18 for (
auto i = 0; i < N; ++i) {
20 nw[i] = w[i] + lr[0] * gi / (h[0] + epsilon);
23 for (
auto i = 0; i < N; ++i) {
27 nhTmp /= (h[0] + epsilon);
31 template <
typename Context>
32 void wngrad_update_output_effective_lr(
39 float* effectiveLROut,
43 effectiveLROut[0] = lr[0] / (seqBIn[0] + epsilon);
45 for (
auto i = 0; i < N; ++i) {
49 seqBTmp /= (seqBIn[0] + epsilon);
50 seqBOut[0] = seqBIn[0] + seqBTmp;
51 for (
auto i = 0; i < N; ++i) {
52 float grad = gradIn[i];
53 paramOut[i] = paramIn[i] + effectiveLROut[0] * grad;
57 template <
typename Context>
58 void wngrad_update_output_effective_lr_and_update(
65 float* effectiveLROut,
70 effectiveLROut[0] = lr[0] / (seqBIn[0] + epsilon);
72 for (
auto i = 0; i < N; ++i) {
76 seqBTmp /= (seqBIn[0] + epsilon);
77 seqBOut[0] = seqBIn[0] + seqBTmp;
79 for (
auto i = 0; i < N; ++i) {
80 float grad = gradIn[i];
81 float update = updateOut[i] = effectiveLROut[0] * grad;
82 paramOut[i] = paramIn[i] + update;
86 template <
typename T,
class Context>
89 USE_OPERATOR_CONTEXT_FUNCTIONS;
92 epsilon_(this->
template GetSingleArgument<T>(
"epsilon", 1e-5f)) {}
94 bool RunOnDevice()
override {
103 Input(SEQ_B).numel(),
107 Output(OUTPUT_PARAM)->ResizeLike(
Input(PARAM));
108 Output(OUTPUT_SEQ_B)->ResizeLike(
Input(SEQ_B));
109 if (OutputSize() == 2) {
110 wngrad_update<Context>(
112 Input(PARAM).template data<T>(),
113 Input(GRAD).template data<T>(),
114 Input(SEQ_B).template data<T>(),
115 Output(OUTPUT_PARAM)->template mutable_data<T>(),
116 Output(OUTPUT_SEQ_B)->template mutable_data<T>(),
118 Input(LR).template data<T>(),
120 }
else if (OutputSize() == 3) {
121 Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(
Input(SEQ_B));
122 wngrad_update_output_effective_lr<Context>(
124 Input(PARAM).template data<T>(),
125 Input(GRAD).template data<T>(),
126 Input(SEQ_B).template data<T>(),
127 Output(OUTPUT_PARAM)->template mutable_data<T>(),
128 Output(OUTPUT_SEQ_B)->template mutable_data<T>(),
129 Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<T>(),
131 Input(LR).template data<T>(),
134 Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(
Input(SEQ_B));
135 Output(OUTPUT_UPDATE)->ResizeLike(
Input(GRAD));
136 wngrad_update_output_effective_lr_and_update<Context>(
138 Input(PARAM).template data<T>(),
139 Input(GRAD).template data<T>(),
140 Input(SEQ_B).template data<T>(),
141 Output(OUTPUT_PARAM)->template mutable_data<T>(),
142 Output(OUTPUT_SEQ_B)->template mutable_data<T>(),
143 Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<T>(),
144 Output(OUTPUT_UPDATE)->template mutable_data<T>(),
146 Input(LR).template data<T>(),
155 INPUT_TAGS(PARAM, SEQ_B, GRAD, LR);
156 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_SEQ_B, OUTPUT_EFFECTIVE_LR, OUTPUT_UPDATE);
159 template <
typename T,
class Context>
162 USE_OPERATOR_CONTEXT_FUNCTIONS;
165 epsilon_(this->
template GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
167 bool RunOnDevice()
override {
169 CAFFE_ENFORCE_EQ(
Input(SEQ_B).numel(), 1);
170 CAFFE_ENFORCE_EQ(
Input(LR).numel(), 1);
172 Input(PARAM).size_from_dim(1),
173 Input(GRAD).size_from_dim(
Input(INDICES).dim()));
176 this,
Input(INDICES));
179 template <
typename SIndex>
180 bool DoRunWithType() {
181 const auto* lr =
Input(LR).template data<T>();
182 const auto* indices =
Input(INDICES).template data<SIndex>();
183 const auto* gradIn =
Input(GRAD).template data<T>();
184 const auto* paramIn =
Input(PARAM).template data<T>();
185 const auto* seqBIn =
Input(SEQ_B).template data<T>();
186 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
187 auto* seqBOut = Output(OUTPUT_SEQ_B)->template mutable_data<T>();
189 auto n =
Input(INDICES).numel();
194 auto block_size =
Input(GRAD).numel() / n;
196 for (
auto i = 0; i < n; ++i) {
197 auto idx = indices[i];
198 if (block_size == 1) {
199 float gi = gradIn[i];
200 paramOut[idx] = paramIn[idx] + lr[0] * gi / (seqBIn[0] + epsilon_);
202 auto offsetI = i * block_size;
203 auto offsetIdx = idx * block_size;
207 Input(PARAM).numel(),
208 block_size + offsetIdx,
209 this->debug_def().input(PARAM),
210 ", out of bound, idx:",
218 block_size + offsetI,
219 this->debug_def().input(GRAD),
220 ", out of bound idx, idx:",
225 for (
auto j = 0; j < block_size; ++j) {
226 float gi = gradIn[offsetI + j];
227 paramOut[offsetIdx + j] =
228 paramIn[offsetIdx + j] + lr[0] * gi / (seqBIn[0] + epsilon_);
233 for (
auto i = 0; i <
Input(GRAD).numel(); ++i) {
234 float gi = gradIn[i];
237 seqBTmp /= seqBIn[0];
238 seqBOut[0] = seqBTmp + seqBIn[0];
244 INPUT_TAGS(PARAM, SEQ_B, INDICES, GRAD, LR);
245 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_SEQ_B);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...