Caffe2 - C++ API
A deep learning, cross platform ML framework
momentum_sgd_op.cc
1 
17 #include "momentum_sgd_op.h"
18 
19 namespace caffe2 {
20 
21 REGISTER_CPU_OPERATOR(MomentumSGD, MomentumSGDOp<float, CPUContext>);
22 OPERATOR_SCHEMA(MomentumSGD)
23  .NumInputs(3)
24  .NumOutputs(2)
25  .AllowInplace({{0, 0}, {1, 1}})
26  .TensorInferenceFunction(
27  [](const OperatorDef& /* unused */, const vector<TensorShape>& in) {
28  vector<TensorShape> out(2);
29  out[0] = in[0];
30  out[1] = in[1];
31  return out;
32  })
33  .SetDoc(R"DOC(
34 
35 Computes a momentum SGD update for an input gradient and momentum
36 parameters. Concretely, given inputs (grad, m, lr) and parameters
37 (momentum, nesterov), computes:
38 
39  if not nesterov:
40  adjusted_gradient = lr * grad + momentum * m
41  return (adjusted_gradient, adjusted_gradient)
42  else:
43  m_new = momentum * m + lr * grad
44  return ((1 + momentum) * m_new - momentum * m, m_new)
45 
46 Output is (grad, momentum)
47 
48 Note the difference to MomemtumSGDUpdate, which actually performs the
49 parameter update (and is thus faster).
50 )DOC");
51 SHOULD_NOT_DO_GRADIENT(MomentumSGD);
52 
53 REGISTER_CPU_OPERATOR(
54  MomentumSGDUpdate,
55  MomentumSGDUpdateOp<float, CPUContext>);
56 OPERATOR_SCHEMA(MomentumSGDUpdate)
57  .NumInputs(4)
58  .NumOutputs(3)
59  .AllowInplace({{0, 0}, {1, 1}, {3, 2}})
60  .TensorInferenceFunction(
61  [](const OperatorDef& /* unused */, const vector<TensorShape>& in) {
62  vector<TensorShape> out(3);
63  out[0] = in[0];
64  out[1] = in[1];
65  out[2] = in[3];
66  return out;
67  })
68  .SetDoc(R"DOC(
69 
70 Performs a momentum SGD update for an input gradient and momentum
71 parameters. Concretely, given inputs (grad, m, lr, param) and arguments
72 (momentum, nesterov), computes:
73 
74  if not nesterov:
75  adjusted_gradient = lr * grad + momentum * m
76  param = param - adjusted_gradient
77  return (adjusted_gradient, adjusted_gradient, param)
78  else:
79  m_new = momentum * m + lr * grad
80  param = param - ((1 + momentum) * m_new - momentum * m),
81  return ((1 + momentum) * m_new - momentum * m, m_new, param)
82 
83 Output is (grad, momentum, parameter).
84 
85 Note the difference to MomentumSGD, which returns a new gradient
86 but does not perform the parameter update.
87 
88 )DOC");
89 SHOULD_NOT_DO_GRADIENT(MomentumSGDUpdate);
90 
91 REGISTER_CPU_OPERATOR(
92  SparseMomentumSGDUpdate,
93  SparseMomentumSGDUpdateOp<float, CPUContext>);
94 OPERATOR_SCHEMA(SparseMomentumSGDUpdate)
95  .NumInputs(5)
96  .NumOutputs(3)
97  .AllowInplace({{0, 0}})
98  .EnforceInplace({{1, 1}, {3, 2}})
99  .TensorInferenceFunction(
100  [](const OperatorDef& /* unused */, const vector<TensorShape>& in) {
101  vector<TensorShape> out(3);
102  out[0] = in[0];
103  out[1] = in[1];
104  out[2] = in[3];
105  return out;
106  })
107  .SetDoc(R"DOC(
108 
109 Performs a momentum SGD update analogous to MomentumSGDUpdate, but using a
110 GradientSlice and indices into the full param and momentum tables. Both param
111 and momentum should be in-place (corresponding inputs and outputs should be the
112 same blobs).
113 
114 
115 
116 )DOC")
117  .Input(0, "grad", "GradientSlice with gradients for updated indices.")
118  .Input(1, "moment", "Momentum blob, same shape as param.")
119  .Input(2, "lr", "Learning rate.")
120  .Input(3, "param", "Full parameter blob.")
121  .Input(
122  4,
123  "indices",
124  "Indices (in first dimension of param) where updates are performed.")
125  .Output(0, "output_grad", "Adjusted gradient.")
126  .Output(1, "output_moment", "Updated momentum.")
127  .Output(2, "output_param", "Updated parameter")
128  .Arg("momentum", "Momentum hyperparameter.")
129  .Arg("nesterov", "(boolean) Whether to use Nesterov Accelerated Gradient.");
130 SHOULD_NOT_DO_GRADIENT(SparseMomentumSGDUpdate);
131 }
Copyright (c) 2016-present, Facebook, Inc.