Caffe2 - C++ API
A deep learning, cross platform ML framework
learning_rate_adaption_op.h
1 #pragma once
2 
3 #include <cfloat>
4 #include <cmath>
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename Context>
12 void lr_update(
13  int n,
14  const float* grad,
15  const float* effgrad,
16  const float* lr,
17  float* nlr,
18  float lr_alpha,
19  bool normalized_lr_adaption,
20  Context* /*context*/) {
21  float x = 0;
22  float y = 0, z = 0;
23  const float kEps = 1e-12f;
24  for (auto i = 0; i < n; i++) {
25  x += grad[i] * effgrad[i];
26  if (normalized_lr_adaption) {
27  y += grad[i] * grad[i];
28  z += effgrad[i] * effgrad[i];
29  }
30  }
31  if (normalized_lr_adaption) {
32  y = fmax(std::sqrt(y), kEps);
33  z = fmax(std::sqrt(z), kEps);
34  nlr[0] = lr[0] * (1 - lr_alpha * x / (y * z));
35  } else {
36  nlr[0] = lr[0] - lr_alpha * x;
37  }
38 }
39 
40 template <typename T, class Context>
41 class LearningRateAdaptionOp final : public Operator<Context> {
42  public:
43  LearningRateAdaptionOp(const OperatorDef& operator_def, Workspace* ws)
44  : Operator<Context>(operator_def, ws),
45  lr_alpha_(this->template GetSingleArgument<float>("lr_alpha", 0.01f)),
46  normalized_lr_adaption_(this->template GetSingleArgument<bool>(
47  "normalized_lr_adaption",
48  true)) {}
49  USE_OPERATOR_CONTEXT_FUNCTIONS;
50 
51  bool RunOnDevice() override {
52  CAFFE_ENFORCE(Input(LR).numel() == 1);
53  CAFFE_ENFORCE(Input(GRAD).numel() == Input(EFFGRAD).numel());
54  Output(OUTPUT_LR)->ResizeLike(Input(LR));
55  lr_update<Context>(
56  Input(GRAD).numel(),
57  Input(GRAD).template data<T>(),
58  Input(EFFGRAD).template data<T>(),
59  Input(LR).template data<T>(),
60  Output(OUTPUT_LR)->template mutable_data<T>(),
61  lr_alpha_,
62  normalized_lr_adaption_,
63  &context_);
64  return true;
65  }
66 
67  protected:
68  T lr_alpha_{1e-2};
69  bool normalized_lr_adaption_{true};
70  INPUT_TAGS(LR, GRAD, EFFGRAD);
71  OUTPUT_TAGS(OUTPUT_LR);
72 };
73 
74 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13