3 #include "caffe2/core/common_omp.h" 4 #include "caffe2/core/operator.h" 8 template <
typename Context>
23 template <
typename T,
class Context>
26 USE_OPERATOR_CONTEXT_FUNCTIONS;
29 decay_(this->
template GetSingleArgument<float>(
"decay", 0.9f)),
30 momentum_(this->
template GetSingleArgument<float>(
"momentum", 0.0f)),
31 epsilon_(this->
template GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
32 bool RunOnDevice()
override {
33 CAFFE_ENFORCE(
Input(LR).numel() == 1);
34 CAFFE_ENFORCE(
Input(GRAD).numel() ==
Input(MEAN_SQUARES).numel());
35 CAFFE_ENFORCE(
Input(GRAD).numel() ==
Input(OUTPUT_MOMENTUM).numel());
36 Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
37 Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
38 Output(OUTPUT_MEAN_SQUARES)->ResizeLike(
Input(MEAN_SQUARES));
39 Output(OUTPUT_MOMENTUM)->ResizeLike(
Input(MOMENTUM));
40 rmsprop_update<Context>(
42 Input(GRAD).template data<T>(),
43 Input(MEAN_SQUARES).template data<T>(),
44 Input(MOMENTUM).template data<T>(),
45 Output(OUTPUT_GRAD)->template mutable_data<T>(),
46 Output(OUTPUT_MEAN_SQUARES)->template mutable_data<T>(),
47 Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
51 Input(LR).template data<T>(),
60 INPUT_TAGS(GRAD, MEAN_SQUARES, MOMENTUM, LR);
61 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MEAN_SQUARES, OUTPUT_MOMENTUM);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...