5 #include "caffe2/core/context.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
typename Context>
19 bool normalized_lr_adaption,
23 const float kEps = 1e-12f;
24 for (
auto i = 0; i < n; i++) {
25 x += grad[i] * effgrad[i];
26 if (normalized_lr_adaption) {
27 y += grad[i] * grad[i];
28 z += effgrad[i] * effgrad[i];
31 if (normalized_lr_adaption) {
32 y = fmax(std::sqrt(y), kEps);
33 z = fmax(std::sqrt(z), kEps);
34 nlr[0] = lr[0] * (1 - lr_alpha * x / (y * z));
36 nlr[0] = lr[0] - lr_alpha * x;
40 template <
typename T,
class Context>
45 lr_alpha_(this->
template GetSingleArgument<float>(
"lr_alpha", 0.01f)),
46 normalized_lr_adaption_(this->
template GetSingleArgument<bool>(
47 "normalized_lr_adaption",
49 USE_OPERATOR_CONTEXT_FUNCTIONS;
51 bool RunOnDevice()
override {
52 CAFFE_ENFORCE(
Input(LR).numel() == 1);
53 CAFFE_ENFORCE(
Input(GRAD).numel() ==
Input(EFFGRAD).numel());
54 Output(OUTPUT_LR)->ResizeLike(
Input(LR));
57 Input(GRAD).template data<T>(),
58 Input(EFFGRAD).template data<T>(),
59 Input(LR).template data<T>(),
60 Output(OUTPUT_LR)->template mutable_data<T>(),
62 normalized_lr_adaption_,
69 bool normalized_lr_adaption_{
true};
70 INPUT_TAGS(LR, GRAD, EFFGRAD);
71 OUTPUT_TAGS(OUTPUT_LR);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...