1 #include "rmsprop_op.h" 3 #include "caffe2/utils/eigen_utils.h" 4 #include "caffe2/utils/math.h" 9 void rmsprop_update<CPUContext>(
22 ConstEigenVectorArrayMap<float> gVec(g, N);
23 ConstEigenVectorArrayMap<float> msVec(ms, N);
24 ConstEigenVectorArrayMap<float> momVec(mom, N);
26 EigenVectorArrayMap<float> nmsVec(nms, N);
27 nmsVec = msVec + (1.0f - decay) * (gVec * gVec - msVec);
29 EigenVectorArrayMap<float> nmomVec(nmom, N);
30 nmomVec = momVec * momentum + lr[0] * gVec / (epsilon + nmsVec).sqrt();
32 EigenVectorArrayMap<float>(ng, N) = nmomVec;
35 REGISTER_CPU_OPERATOR(RmsProp, RmsPropOp<float, CPUContext>);
36 OPERATOR_SCHEMA(RmsProp)
39 .AllowInplace({{0, 0}, {1, 1}, {2, 2}})
41 Computes the RMSProp update 42 (http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). 43 Concretely, given inputs (grad, mean_squares, mom, lr), computes: 45 mean_squares_o = mean_squares + (1 - decay) * (square(grad) - mean_squares) 46 mom_o = momentum * mom + lr * grad / sqrt(epsilon + mean_squares_o) 49 Returns (grad_o, mean_squares_o, mom_o). 51 SHOULD_NOT_DO_GRADIENT(RmsProp); A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...