1 #include "momentum_sgd_op.h" 5 REGISTER_CPU_OPERATOR(MomentumSGD, MomentumSGDOp<float, CPUContext>);
6 OPERATOR_SCHEMA(MomentumSGD)
9 .AllowInplace({{0, 0}, {1, 1}})
10 .TensorInferenceFunction(
11 [](
const OperatorDef& ,
const vector<TensorShape>& in) {
12 vector<TensorShape> out(2);
19 Computes a momentum SGD update for an input gradient and momentum 20 parameters. Concretely, given inputs (grad, m, lr) and parameters 21 (momentum, nesterov), computes: 24 adjusted_gradient = lr * grad + momentum * m 25 return (adjusted_gradient, adjusted_gradient) 27 m_new = momentum * m + lr * grad 28 return ((1 + momentum) * m_new - momentum * m, m_new) 30 Output is (grad, momentum) 32 Note the difference to MomemtumSGDUpdate, which actually performs the 33 parameter update (and is thus faster). 35 SHOULD_NOT_DO_GRADIENT(MomentumSGD); 37 REGISTER_CPU_OPERATOR( 39 MomentumSGDUpdateOp<float, CPUContext>); 40 OPERATOR_SCHEMA(MomentumSGDUpdate) 43 .AllowInplace({{0, 0}, {1, 1}, {3, 2}}) 44 .TensorInferenceFunction( 45 [](const OperatorDef& ,
const vector<TensorShape>& in) {
46 vector<TensorShape> out(3);
54 Performs a momentum SGD update for an input gradient and momentum 55 parameters. Concretely, given inputs (grad, m, lr, param) and arguments 56 (momentum, nesterov), computes: 59 adjusted_gradient = lr * grad + momentum * m 60 param = param - adjusted_gradient 61 return (adjusted_gradient, adjusted_gradient, param) 63 m_new = momentum * m + lr * grad 64 param = param - ((1 + momentum) * m_new - momentum * m), 65 return ((1 + momentum) * m_new - momentum * m, m_new, param) 67 Output is (grad, momentum, parameter). 69 Note the difference to MomentumSGD, which returns a new gradient 70 but does not perform the parameter update. 73 SHOULD_NOT_DO_GRADIENT(MomentumSGDUpdate); 75 REGISTER_CPU_OPERATOR( 76 SparseMomentumSGDUpdate, 77 SparseMomentumSGDUpdateOp<float, CPUContext>); 78 OPERATOR_SCHEMA(SparseMomentumSGDUpdate) 81 .AllowInplace({{0, 0}}) 82 .EnforceInplace({{1, 1}, {3, 2}}) 83 .TensorInferenceFunction( 84 [](const OperatorDef& ,
const vector<TensorShape>& in) {
85 vector<TensorShape> out(3);
93 Performs a momentum SGD update analogous to MomentumSGDUpdate, but using a 94 GradientSlice and indices into the full param and momentum tables. Both param 95 and momentum should be in-place (corresponding inputs and outputs should be the 101 .Input(0, "grad",
"GradientSlice with gradients for updated indices.")
102 .Input(1,
"moment",
"Momentum blob, same shape as param.")
103 .Input(2,
"lr",
"Learning rate.")
104 .Input(3,
"param",
"Full parameter blob.")
108 "Indices (in first dimension of param) where updates are performed.")
109 .Output(0,
"output_grad",
"Adjusted gradient.")
110 .Output(1,
"output_moment",
"Updated momentum.")
111 .Output(2,
"output_param",
"Updated parameter")
112 .Arg(
"momentum",
"Momentum hyperparameter.")
113 .Arg(
"nesterov",
"(boolean) Whether to use Nesterov Accelerated Gradient.");
114 SHOULD_NOT_DO_GRADIENT(SparseMomentumSGDUpdate);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...