Caffe2 - C++ API
A deep learning, cross platform ML framework
rmsprop_op.cc
1 
17 #include "rmsprop_op.h"
18 
19 #include "caffe2/utils/math.h"
20 
21 namespace caffe2 {
22 
23 template <>
24 void rmsprop_update<CPUContext>(
25  int N,
26  const float* g,
27  const float* ms,
28  const float* mom,
29  float* ng,
30  float* nms,
31  float* nmom,
32  float decay,
33  float momentum,
34  float epsilon,
35  const float* lr,
36  CPUContext* /*context*/) {
37  ConstEigenVectorArrayMap<float> gVec(g, N);
38  ConstEigenVectorArrayMap<float> msVec(ms, N);
39  ConstEigenVectorArrayMap<float> momVec(mom, N);
40  // Update new mean square estimate
41  EigenVectorArrayMap<float> nmsVec(nms, N);
42  nmsVec = msVec + (1.0f - decay) * (gVec * gVec - msVec);
43  // Update momentum estimate
44  EigenVectorArrayMap<float> nmomVec(nmom, N);
45  nmomVec = momVec * momentum + lr[0] * gVec / (epsilon + nmsVec).sqrt();
46  // New gradient is the momentum
47  EigenVectorArrayMap<float>(ng, N) = nmomVec;
48 }
49 
50 REGISTER_CPU_OPERATOR(RmsProp, RmsPropOp<float, CPUContext>);
51 OPERATOR_SCHEMA(RmsProp)
52  .NumInputs(4)
53  .NumOutputs(3)
54  .AllowInplace({{0, 0}, {1, 1}, {2, 2}})
55  .SetDoc(R"DOC(
56 Computes the RMSProp update
57 (http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf).
58 Concretely, given inputs (grad, mean_squares, mom, lr), computes:
59 
60  mean_squares_o = mean_squares + (1 - decay) * (square(grad) - mean_squares)
61  mom_o = momentum * mom + lr * grad / sqrt(epsilon + mean_squares_o)
62  grad_o = mom_o
63 
64 Returns (grad_o, mean_squares_o, mom_o).
65 )DOC");
66 SHOULD_NOT_DO_GRADIENT(RmsProp);
67 
68 }
Copyright (c) 2016-present, Facebook, Inc.