Caffe2 - C++ API
A deep learning, cross platform ML framework
rmsprop_op.h
1 
17 #pragma once
18 
19 #include "caffe2/core/common_omp.h"
20 #include "caffe2/core/operator.h"
21 
22 namespace caffe2 {
23 
24 template <typename Context>
25 void rmsprop_update(
26  int N,
27  const float* g,
28  const float* ms,
29  const float* mom,
30  float* ng,
31  float* nms,
32  float* nmom,
33  float decay,
34  float momentum,
35  float epsilon,
36  const float* lr,
37  Context* context);
38 
39 template <typename T, class Context>
40 class RmsPropOp final : public Operator<Context> {
41  public:
42  USE_OPERATOR_CONTEXT_FUNCTIONS;
43  RmsPropOp(const OperatorDef& operator_def, Workspace* ws)
44  : Operator<Context>(operator_def, ws),
45  decay_(OperatorBase::GetSingleArgument<float>("decay", 0.9f)),
46  momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.0f)),
47  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)) {}
48  bool RunOnDevice() override {
49  CAFFE_ENFORCE(Input(LR).size() == 1);
50  CAFFE_ENFORCE(Input(GRAD).size() == Input(MEAN_SQUARES).size());
51  CAFFE_ENFORCE(Input(GRAD).size() == Input(OUTPUT_MOMENTUM).size());
52  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
53  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
54  Output(OUTPUT_MEAN_SQUARES)->ResizeLike(Input(MEAN_SQUARES));
55  Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
56  rmsprop_update<Context>(
57  Input(GRAD).size(),
58  Input(GRAD).template data<T>(),
59  Input(MEAN_SQUARES).template data<T>(),
60  Input(MOMENTUM).template data<T>(),
61  Output(OUTPUT_GRAD)->template mutable_data<T>(),
62  Output(OUTPUT_MEAN_SQUARES)->template mutable_data<T>(),
63  Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
64  decay_,
65  momentum_,
66  epsilon_,
67  Input(LR).template data<T>(),
68  &context_);
69  return true;
70  }
71 
72  protected:
73  T decay_{0.9};
74  T momentum_{0.0};
75  T epsilon_{1e-8};
76  INPUT_TAGS(GRAD, MEAN_SQUARES, MOMENTUM, LR);
77  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MEAN_SQUARES, OUTPUT_MOMENTUM);
78 };
79 }
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.