Caffe2 - C++ API
A deep learning, cross platform ML framework
rmsprop_op.cc
1 #include "rmsprop_op.h"
2 
3 #include "caffe2/utils/eigen_utils.h"
4 #include "caffe2/utils/math.h"
5 
6 namespace caffe2 {
7 
8 template <>
9 void rmsprop_update<CPUContext>(
10  int N,
11  const float* g,
12  const float* ms,
13  const float* mom,
14  float* ng,
15  float* nms,
16  float* nmom,
17  float decay,
18  float momentum,
19  float epsilon,
20  const float* lr,
21  CPUContext* /*context*/) {
22  ConstEigenVectorArrayMap<float> gVec(g, N);
23  ConstEigenVectorArrayMap<float> msVec(ms, N);
24  ConstEigenVectorArrayMap<float> momVec(mom, N);
25  // Update new mean square estimate
26  EigenVectorArrayMap<float> nmsVec(nms, N);
27  nmsVec = msVec + (1.0f - decay) * (gVec * gVec - msVec);
28  // Update momentum estimate
29  EigenVectorArrayMap<float> nmomVec(nmom, N);
30  nmomVec = momVec * momentum + lr[0] * gVec / (epsilon + nmsVec).sqrt();
31  // New gradient is the momentum
32  EigenVectorArrayMap<float>(ng, N) = nmomVec;
33 }
34 
35 REGISTER_CPU_OPERATOR(RmsProp, RmsPropOp<float, CPUContext>);
36 OPERATOR_SCHEMA(RmsProp)
37  .NumInputs(4)
38  .NumOutputs(3)
39  .AllowInplace({{0, 0}, {1, 1}, {2, 2}})
40  .SetDoc(R"DOC(
41 Computes the RMSProp update
42 (http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf).
43 Concretely, given inputs (grad, mean_squares, mom, lr), computes:
44 
45  mean_squares_o = mean_squares + (1 - decay) * (square(grad) - mean_squares)
46  mom_o = momentum * mom + lr * grad / sqrt(epsilon + mean_squares_o)
47  grad_o = mom_o
48 
49 Returns (grad_o, mean_squares_o, mom_o).
50 )DOC");
51 SHOULD_NOT_DO_GRADIENT(RmsProp);
52 
53 }
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13