Caffe2 - C++ API
A deep learning, cross platform ML framework
momentum_sgd_op.h
1 
17 #pragma once
18 
19 #include "caffe2/core/operator.h"
20 
21 namespace caffe2 {
22 
23 template <typename Context>
24 void momentum_sgd_update(
25  const int N,
26  const float* g,
27  const float* m,
28  float* ng,
29  float* nm,
30  const float* lr,
31  const float momentum,
32  const bool nesterov,
33  float* param,
34  Context* /*context*/) {
35  const float LR = lr[0];
36  for (auto i = 0; i < N; ++i) {
37  if (!nesterov) {
38  const float adjusted_gradient = LR * g[i] + momentum * m[i];
39  nm[i] = adjusted_gradient;
40  ng[i] = adjusted_gradient;
41  } else {
42  const float mi = m[i];
43  const float mi_new = momentum * mi + LR * g[i];
44  nm[i] = mi_new;
45  ng[i] = (1 + momentum) * mi_new - momentum * mi;
46  }
47 
48  if (param) {
49  param[i] -= ng[i];
50  }
51  }
52 }
53 
54 template <typename T, class Context>
55 class MomentumSGDOp final : public Operator<Context> {
56  public:
57  USE_OPERATOR_CONTEXT_FUNCTIONS;
58  MomentumSGDOp(const OperatorDef& operator_def, Workspace* ws)
59  : Operator<Context>(operator_def, ws),
60  momentum_(OperatorBase::GetSingleArgument<T>("momentum", 0.0)),
61  nesterov_(OperatorBase::GetSingleArgument<int>("nesterov", 0)) {}
62 
63  bool RunOnDevice() override {
64  // Iter live on the CPU
65  CAFFE_ENFORCE(OperatorBase::InputIsType<Tensor<Context>>(GRAD));
66  CAFFE_ENFORCE(OperatorBase::InputIsType<Tensor<Context>>(MOMENTUM));
67  CAFFE_ENFORCE(Input(LR).size() == 1);
68  CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENTUM).size());
69  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
70  Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
71 
72  momentum_sgd_update<Context>(
73  Input(GRAD).size(),
74  Input(GRAD).template data<T>(),
75  Input(MOMENTUM).template data<T>(),
76  Output(OUTPUT_GRAD)->template mutable_data<T>(),
77  Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
78  Input(LR).template data<T>(),
79  momentum_,
80  nesterov_,
81  NULL,
82  &context_);
83  return true;
84  }
85 
86  protected:
87  T momentum_{0.9};
88  bool nesterov_;
89  INPUT_TAGS(GRAD, MOMENTUM, LR);
90  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM);
91 };
92 
93 template <typename T, class Context>
94 class MomentumSGDUpdateOp final : public Operator<Context> {
95  public:
96  USE_OPERATOR_CONTEXT_FUNCTIONS;
97  MomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws)
98  : Operator<Context>(operator_def, ws),
99  momentum_(OperatorBase::GetSingleArgument<T>("momentum", 0.0)),
100  nesterov_(OperatorBase::GetSingleArgument<int>("nesterov", 0)) {}
101 
102  bool RunOnDevice() override {
103  // Iter live on the CPU
104  CAFFE_ENFORCE(OperatorBase::InputIsType<Tensor<Context>>(GRAD));
105  CAFFE_ENFORCE(OperatorBase::InputIsType<Tensor<Context>>(MOMENTUM));
106  CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
107  CAFFE_ENFORCE_EQ(Input(GRAD).size(), Input(MOMENTUM).size());
108  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
109  Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
110 
111  momentum_sgd_update<Context>(
112  Input(GRAD).size(),
113  Input(GRAD).template data<T>(),
114  Input(MOMENTUM).template data<T>(),
115  Output(OUTPUT_GRAD)->template mutable_data<T>(),
116  Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
117  Input(LR).template data<T>(),
118  momentum_,
119  nesterov_,
120  Output(OUTPUT_PARAM)->template mutable_data<T>(),
121  &context_);
122  return true;
123  }
124 
125  protected:
126  T momentum_{0.9};
127  bool nesterov_;
128  INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
129  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
130 };
131 
132 template <typename T, class Context>
133 class SparseMomentumSGDUpdateOp final : public Operator<Context> {
134  public:
135  USE_OPERATOR_CONTEXT_FUNCTIONS;
136  SparseMomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws)
137  : Operator<Context>(operator_def, ws),
138  momentum_(OperatorBase::GetSingleArgument<T>("momentum", 0.0)),
139  nesterov_(OperatorBase::GetSingleArgument<int>("nesterov", 0)) {}
140 
141  bool RunOnDevice() override {
142  // Resize [potentially] out-of-place blobs
143  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
144 
145  // Enforce shapes
146  CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
147  CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENTUM).size());
148  CAFFE_ENFORCE_EQ(Input(PARAM).size_from_dim(1),
149  Input(GRAD).size_from_dim(Input(INDICES).ndim()));
150 
152  this, Input(INDICES));
153  }
154 
155  template <typename SIndex>
156  bool DoRunWithType() {
157  auto block_size = Input(PARAM).size() / Input(PARAM).dim(0);
158  auto n = Input(GRAD).size() / block_size;
159 
160  const auto* gradIn = Input(GRAD).template data<T>();
161  const auto* momentumIn = Input(MOMENTUM).template data<T>();
162  const auto* lr = Input(LR).template data<T>();
163  const auto* paramIn = Input(PARAM).template data<T>();
164  const auto* indices = Input(INDICES).template data<SIndex>();
165 
166  auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
167  auto* momentumOut = Output(OUTPUT_MOMENTUM)->template mutable_data<T>();
168  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
169 
170  for (auto i = 0; i < n; ++i) {
171  auto idx = indices[i];
172  auto offsetI = i * block_size;
173  auto offsetIdx = idx * block_size;
174 
175  CAFFE_ENFORCE(offsetIdx + block_size <= Input(PARAM).size());
176  CAFFE_ENFORCE(offsetI + block_size <= Input(GRAD).size());
177 
178  momentum_sgd_update<Context>(
179  block_size,
180  gradIn + offsetI,
181  momentumIn + offsetIdx,
182  gradOut + offsetI,
183  momentumOut + offsetIdx,
184  lr,
185  momentum_,
186  nesterov_,
187  paramOut + offsetIdx,
188  &context_);
189  }
190  return true;
191  }
192 
193  protected:
194  T momentum_;
195  bool nesterov_;
196  INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM, INDICES);
197  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
198 };
199 }
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.