Caffe2 - C++ API
A deep learning, cross platform ML framework
momentum_sgd_op.cc
1 #include <caffe2/ideep/ideep_utils.h>
2 
3 namespace caffe2 {
4 
5 void momentum_sgd_update(
6  const int N,
7  const float* g,
8  const float* m,
9  float* ng,
10  float* nm,
11  const float* lr,
12  const float momentum,
13  const bool nesterov,
14  float* param) {
15  const float LR = lr[0];
16 #ifdef _OPENMP
17 #pragma omp parallel for schedule(static)
18 #endif
19  for (auto i = 0; i < N; ++i) {
20  if (!nesterov) {
21  const float adjusted_gradient = LR * g[i] + momentum * m[i];
22  nm[i] = adjusted_gradient;
23  ng[i] = adjusted_gradient;
24  } else {
25  const float mi = m[i];
26  const float mi_new = momentum * mi + LR * g[i];
27  nm[i] = mi_new;
28  ng[i] = (1 + momentum) * mi_new - momentum * mi;
29  }
30 
31  if (param) {
32  param[i] -= ng[i];
33  }
34  }
35 }
36 
37 class IDEEPMomentumSGDOp final : public IDEEPOperator {
38  public:
39  USE_IDEEP_DEF_ALIASES();
40  USE_IDEEP_OPERATOR_FUNCTIONS();
41 
42  IDEEPMomentumSGDOp(const OperatorDef& operator_def, Workspace* ws)
43  : IDEEPOperator(operator_def, ws),
44  momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.0)),
45  nesterov_(OperatorBase::GetSingleArgument<int>("nesterov", 0)) {}
46 
47  bool RunOnDevice() override {
48  CAFFE_ENFORCE(Input(GRAD).get_nelems() == Input(MOMENTUM).get_nelems());
49  if (Input(GRAD) != *Output(OUTPUT_GRAD)) {
50  Output(OUTPUT_GRAD)->reinit(Input(GRAD).get_descriptor());
51  }
52  if (Input(MOMENTUM) != *Output(OUTPUT_MOMENTUM)) {
53  Output(OUTPUT_MOMENTUM)->reinit(Input(MOMENTUM).get_descriptor());
54  }
55 
56  // TODO: Use itensor after 0-dim is supported. Now use CPU tensor.
57  const auto& lr = OperatorBase::Input<TensorCPU>(LR, CPU);
58  CAFFE_ENFORCE(lr.numel() == 1);
59 
60  momentum_sgd_update(
61  Input(GRAD).get_nelems(),
62  static_cast<float*>(Input(GRAD).get_data_handle()),
63  static_cast<float*>(Input(MOMENTUM).get_data_handle()),
64  static_cast<float*>(Output(OUTPUT_GRAD)->get_data_handle()),
65  static_cast<float*>(Output(OUTPUT_MOMENTUM)->get_data_handle()),
66  lr.template data<float>(),
67  momentum_,
68  nesterov_,
69  nullptr);
70  return true;
71  }
72 
73  protected:
74  float momentum_{0.9};
75  bool nesterov_;
76  INPUT_TAGS(GRAD, MOMENTUM, LR);
77  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM);
78 };
79 
81  public:
82  USE_IDEEP_DEF_ALIASES();
83  USE_IDEEP_OPERATOR_FUNCTIONS();
84  IDEEPMomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws)
85  : IDEEPOperator(operator_def, ws),
86  momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.0)),
87  nesterov_(OperatorBase::GetSingleArgument<int>("nesterov", 0)) {}
88 
89  bool RunOnDevice() override {
90  CAFFE_ENFORCE(Input(GRAD).get_nelems() == Input(MOMENTUM).get_nelems());
91  if (Input(GRAD) != *Output(OUTPUT_GRAD)) {
92  Output(OUTPUT_GRAD)->reinit(Input(GRAD).get_descriptor());
93  }
94  if (Input(MOMENTUM) != *Output(OUTPUT_MOMENTUM)) {
95  Output(OUTPUT_MOMENTUM)->reinit(Input(MOMENTUM).get_descriptor());
96  }
97 
98  // TODO: Use itensor after 0-dim is supported. Now use CPU tensor.
99  const auto& lr = OperatorBase::Input<TensorCPU>(LR, CPU);
100  CAFFE_ENFORCE(lr.numel() == 1);
101 
102  momentum_sgd_update(
103  Input(GRAD).get_nelems(),
104  static_cast<float*>(Input(GRAD).get_data_handle()),
105  static_cast<float*>(Input(MOMENTUM).get_data_handle()),
106  static_cast<float*>(Output(OUTPUT_GRAD)->get_data_handle()),
107  static_cast<float*>(Output(OUTPUT_MOMENTUM)->get_data_handle()),
108  lr.template data<float>(),
109  momentum_,
110  nesterov_,
111  static_cast<float*>(Output(OUTPUT_PARAM)->get_data_handle()));
112  return true;
113  }
114 
115  protected:
116  float momentum_{0.9};
117  bool nesterov_;
118  INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
119  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
120 };
121 
122 REGISTER_IDEEP_OPERATOR(MomentumSGD, IDEEPMomentumSGDOp);
123 REGISTER_IDEEP_OPERATOR(MomentumSGDUpdate, IDEEPMomentumSGDUpdateOp);
124 
125 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13