Caffe2 - C++ API
A deep learning, cross platform ML framework
momentum_sgd_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 
5 namespace caffe2 {
6 
7 template <typename Context>
8 void momentum_sgd_update(
9  const int N,
10  const float* g,
11  const float* m,
12  float* ng,
13  float* nm,
14  const float* lr,
15  const float momentum,
16  const bool nesterov,
17  float* param,
18  Context* /*context*/) {
19  const float LR = lr[0];
20  for (auto i = 0; i < N; ++i) {
21  if (!nesterov) {
22  const float adjusted_gradient = LR * g[i] + momentum * m[i];
23  nm[i] = adjusted_gradient;
24  ng[i] = adjusted_gradient;
25  } else {
26  const float mi = m[i];
27  const float mi_new = momentum * mi + LR * g[i];
28  nm[i] = mi_new;
29  ng[i] = (1 + momentum) * mi_new - momentum * mi;
30  }
31 
32  if (param) {
33  param[i] -= ng[i];
34  }
35  }
36 }
37 
38 template <typename T, class Context>
39 class MomentumSGDOp final : public Operator<Context> {
40  public:
41  USE_OPERATOR_CONTEXT_FUNCTIONS;
42  MomentumSGDOp(const OperatorDef& operator_def, Workspace* ws)
43  : Operator<Context>(operator_def, ws),
44  momentum_(this->template GetSingleArgument<T>("momentum", 0.0)),
45  nesterov_(this->template GetSingleArgument<int>("nesterov", 0)) {}
46 
47  bool RunOnDevice() override {
48  auto device_type = Context::GetDeviceType();
49  // Iter live on the CPU
50  CAFFE_ENFORCE(OperatorBase::InputIsTensorType(GRAD, device_type));
51  CAFFE_ENFORCE(OperatorBase::InputIsTensorType(MOMENTUM, device_type));
52  CAFFE_ENFORCE(Input(LR).numel() == 1);
53  CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENTUM).numel());
54  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
55  Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
56 
57  momentum_sgd_update<Context>(
58  Input(GRAD).numel(),
59  Input(GRAD).template data<T>(),
60  Input(MOMENTUM).template data<T>(),
61  Output(OUTPUT_GRAD)->template mutable_data<T>(),
62  Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
63  Input(LR).template data<T>(),
64  momentum_,
65  nesterov_,
66  NULL,
67  &context_);
68  return true;
69  }
70 
71  protected:
72  T momentum_{0.9};
73  bool nesterov_;
74  INPUT_TAGS(GRAD, MOMENTUM, LR);
75  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM);
76 };
77 
78 template <typename T, class Context>
79 class MomentumSGDUpdateOp final : public Operator<Context> {
80  public:
81  USE_OPERATOR_CONTEXT_FUNCTIONS;
82  MomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws)
83  : Operator<Context>(operator_def, ws),
84  momentum_(this->template GetSingleArgument<T>("momentum", 0.0)),
85  nesterov_(this->template GetSingleArgument<int>("nesterov", 0)) {}
86 
87  bool RunOnDevice() override {
88  auto device_type = Context::GetDeviceType();
89  // Iter live on the CPU
90  CAFFE_ENFORCE(OperatorBase::InputIsTensorType(GRAD, device_type));
91  CAFFE_ENFORCE(OperatorBase::InputIsTensorType(MOMENTUM, device_type));
92  CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
93  CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(MOMENTUM).numel());
94  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
95  Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
96 
97  momentum_sgd_update<Context>(
98  Input(GRAD).numel(),
99  Input(GRAD).template data<T>(),
100  Input(MOMENTUM).template data<T>(),
101  Output(OUTPUT_GRAD)->template mutable_data<T>(),
102  Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
103  Input(LR).template data<T>(),
104  momentum_,
105  nesterov_,
106  Output(OUTPUT_PARAM)->template mutable_data<T>(),
107  &context_);
108  return true;
109  }
110 
111  protected:
112  T momentum_{0.9};
113  bool nesterov_;
114  INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
115  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
116 };
117 
118 template <typename T, class Context>
119 class SparseMomentumSGDUpdateOp final : public Operator<Context> {
120  public:
121  USE_OPERATOR_CONTEXT_FUNCTIONS;
122  SparseMomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws)
123  : Operator<Context>(operator_def, ws),
124  momentum_(this->template GetSingleArgument<T>("momentum", 0.0)),
125  nesterov_(this->template GetSingleArgument<int>("nesterov", 0)) {}
126 
127  bool RunOnDevice() override {
128  // Resize [potentially] out-of-place blobs
129  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
130 
131  // Enforce shapes
132  CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
133  CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENTUM).numel());
134  CAFFE_ENFORCE_EQ(
135  Input(PARAM).size_from_dim(1),
136  Input(GRAD).size_from_dim(Input(INDICES).dim()));
137 
139  this, Input(INDICES));
140  }
141 
142  template <typename SIndex>
143  bool DoRunWithType() {
144  auto block_size = Input(PARAM).numel() / Input(PARAM).size(0);
145  auto n = Input(GRAD).numel() / block_size;
146 
147  const auto* gradIn = Input(GRAD).template data<T>();
148  const auto* momentumIn = Input(MOMENTUM).template data<T>();
149  const auto* lr = Input(LR).template data<T>();
150  const auto* paramIn = Input(PARAM).template data<T>();
151  const auto* indices = Input(INDICES).template data<SIndex>();
152 
153  auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
154  auto* momentumOut = Output(OUTPUT_MOMENTUM)->template mutable_data<T>();
155  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
156 
157  for (auto i = 0; i < n; ++i) {
158  auto idx = indices[i];
159  auto offsetI = i * block_size;
160  auto offsetIdx = idx * block_size;
161 
162  CAFFE_ENFORCE(offsetIdx + block_size <= Input(PARAM).numel());
163  CAFFE_ENFORCE(offsetI + block_size <= Input(GRAD).numel());
164 
165  momentum_sgd_update<Context>(
166  block_size,
167  gradIn + offsetI,
168  momentumIn + offsetIdx,
169  gradOut + offsetI,
170  momentumOut + offsetIdx,
171  lr,
172  momentum_,
173  nesterov_,
174  paramOut + offsetIdx,
175  &context_);
176  }
177  return true;
178  }
179 
180  protected:
181  T momentum_;
182  bool nesterov_;
183  INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM, INDICES);
184  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
185 };
186 }
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13