Caffe2 - C++ API
A deep learning, cross platform ML framework
lars_op.h
1 #ifndef CAFFE2_OPERATORS_LARS_OP_H_
2 #define CAFFE2_OPERATORS_LARS_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class LarsOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  LarsOp(const OperatorDef& operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws),
17  offset_(this->template GetSingleArgument<float>("offset", 0.5)),
18  lr_min_(this->template GetSingleArgument<float>("lr_min", 0.02)) {}
19 
20  bool RunOnDevice() override {
21  auto& X = Input(0);
22  auto& dX = Input(1);
23  CAFFE_ENFORCE(
24  dX.numel() == X.numel(), "Gradient size doesn't match parameter size.");
25  CAFFE_ENFORCE_GE(offset_, 0);
26  CAFFE_ENFORCE_GE(lr_min_, 0);
27 
28  auto& wd = Input(2);
29  auto& trust = Input(3);
30  auto& lr_max = Input(4);
31 
32  auto* lr_rescaled = Output(0, vector<int64_t>{1}, at::dtype<T>());
33 
34  ReinitializeTensor(&X_norm_tensor_, {1}, at::dtype<T>().device(Context::GetDeviceType()));
35  T* X_norm_ = X_norm_tensor_.template mutable_data<T>();
36 
37  ReinitializeTensor(&dX_norm_tensor_, {1}, at::dtype<T>().device(Context::GetDeviceType()));
38  T* dX_norm_ = dX_norm_tensor_.template mutable_data<T>();
39 
40  ComputeNorms(
41  dX.numel(),
42  X.template data<T>(),
43  dX.template data<T>(),
44  X_norm_,
45  dX_norm_);
46 
47  ComputeLearningRate(
48  wd.template data<T>(),
49  trust.template data<T>(),
50  lr_max.template data<T>(),
51  offset_,
52  lr_min_,
53  X_norm_,
54  dX_norm_,
55  lr_rescaled->template mutable_data<T>());
56 
57  return true;
58  }
59 
60  private:
61  // Compute the l2 norm of X_data and dX_data
62  void ComputeNorms(
63  int64_t N,
64  const T* X_data,
65  const T* dX_data,
66  T* X_norm,
67  T* dX_norm) {
68  math::SumSqr(N, X_data, X_norm, &context_);
69  math::Sqrt(1, X_norm, X_norm, &context_);
70  math::SumSqr(N, dX_data, dX_norm, &context_);
71  math::Sqrt(1, dX_norm, dX_norm, &context_);
72  }
73  // Compute the learning rate and apply clipping
74  void ComputeLearningRate(
75  const T* wd,
76  const T* trust,
77  const T* lr_max,
78  T offset,
79  T lr_min,
80  T* X_norm,
81  T* dX_norm,
82  T* lr_rescaled);
83 
84  T offset_;
85  T lr_min_;
86 
87  Tensor X_norm_tensor_;
88  Tensor dX_norm_tensor_;
89 };
90 
91 } // namespace caffe2
92 
93 #endif // CAFFE2_OPERATORS_LARS_OP_H_
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13