Caffe2 - C++ API
A deep learning, cross platform ML framework
rmsprop_op.h
1 #pragma once
2 
3 #include "caffe2/core/common_omp.h"
4 #include "caffe2/core/operator.h"
5 
6 namespace caffe2 {
7 
8 template <typename Context>
9 void rmsprop_update(
10  int N,
11  const float* g,
12  const float* ms,
13  const float* mom,
14  float* ng,
15  float* nms,
16  float* nmom,
17  float decay,
18  float momentum,
19  float epsilon,
20  const float* lr,
21  Context* context);
22 
23 template <typename T, class Context>
24 class RmsPropOp final : public Operator<Context> {
25  public:
26  USE_OPERATOR_CONTEXT_FUNCTIONS;
27  RmsPropOp(const OperatorDef& operator_def, Workspace* ws)
28  : Operator<Context>(operator_def, ws),
29  decay_(this->template GetSingleArgument<float>("decay", 0.9f)),
30  momentum_(this->template GetSingleArgument<float>("momentum", 0.0f)),
31  epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {}
32  bool RunOnDevice() override {
33  CAFFE_ENFORCE(Input(LR).numel() == 1);
34  CAFFE_ENFORCE(Input(GRAD).numel() == Input(MEAN_SQUARES).numel());
35  CAFFE_ENFORCE(Input(GRAD).numel() == Input(OUTPUT_MOMENTUM).numel());
36  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
37  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
38  Output(OUTPUT_MEAN_SQUARES)->ResizeLike(Input(MEAN_SQUARES));
39  Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
40  rmsprop_update<Context>(
41  Input(GRAD).numel(),
42  Input(GRAD).template data<T>(),
43  Input(MEAN_SQUARES).template data<T>(),
44  Input(MOMENTUM).template data<T>(),
45  Output(OUTPUT_GRAD)->template mutable_data<T>(),
46  Output(OUTPUT_MEAN_SQUARES)->template mutable_data<T>(),
47  Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
48  decay_,
49  momentum_,
50  epsilon_,
51  Input(LR).template data<T>(),
52  &context_);
53  return true;
54  }
55 
56  protected:
57  T decay_{0.9};
58  T momentum_{0.0};
59  T epsilon_{1e-8};
60  INPUT_TAGS(GRAD, MEAN_SQUARES, MOMENTUM, LR);
61  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MEAN_SQUARES, OUTPUT_MOMENTUM);
62 };
63 }
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13