Caffe2 - C++ API
A deep learning, cross platform ML framework
learning_rate_op.h
1 
17 #ifndef CAFFE2_SGD_LEARNING_RATE_OP_H_
18 #define CAFFE2_SGD_LEARNING_RATE_OP_H_
19 
20 #include <cfloat>
21 #include <cmath>
22 #include "caffe2/core/context.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/sgd/learning_rate_functors.h"
25 
26 namespace caffe2 {
27 
28 template <typename T, class Context>
29 class LearningRateOp final : public Operator<Context> {
30  public:
31  LearningRateOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws),
33  functor_(nullptr),
34  base_lr_(OperatorBase::template GetSingleArgument<float>(
35  "base_lr",
36  FLT_MAX)) {
37  CAFFE_ENFORCE_NE(base_lr_, FLT_MAX, "Base learning rate must be set.");
38  const string policy = OperatorBase::GetSingleArgument<string>("policy", "");
39  CAFFE_ENFORCE(policy.size(), "Must specify a learning rate policy.");
40  if (policy == "fixed") {
41  functor_.reset(new FixedLearningRate<T>());
42  } else if (policy == "alter") {
43  bool active_first =
44  OperatorBase::template GetSingleArgument<bool>("active_first", true);
45  int64_t active_period = OperatorBase::template GetSingleArgument<int64_t>(
46  "active_period", -1);
47  int64_t inactive_period =
48  OperatorBase::template GetSingleArgument<int64_t>(
49  "inactive_period", -1);
50  DCHECK_GE(active_period, 0);
51  DCHECK_GE(inactive_period, 0);
52  functor_.reset(new AlternateLearningRate<T>(
53  active_period, inactive_period, active_first));
54  } else if (policy == "hill") {
55  int64_t num_iter =
56  OperatorBase::template GetSingleArgument<int>("num_iter", 0);
57  DCHECK_GT(num_iter, 0);
58  T start_multiplier = OperatorBase::template GetSingleArgument<float>(
59  "start_multiplier", 0.);
60  DCHECK_GE(start_multiplier, 0); // start_multiplier in range [0, 1]
61  DCHECK_LE(start_multiplier, 1);
62  T gamma = OperatorBase::template GetSingleArgument<float>("gamma", 0);
63  DCHECK_GT(gamma, 0);
64  T power = OperatorBase::template GetSingleArgument<float>("power", 0);
65  DCHECK_GT(power, 0);
66  T end_multiplier =
67  OperatorBase::template GetSingleArgument<float>("end_multiplier", 0);
68  DCHECK_GE(end_multiplier, 0); // end_multiplier in range [0, 1]
69  DCHECK_LE(end_multiplier, 1);
70  functor_.reset(new HillLearningRate<T>(
71  num_iter, start_multiplier, gamma, power, end_multiplier));
72  } else if (policy == "step") {
73  int stepsize =
74  OperatorBase::template GetSingleArgument<int>("stepsize", 0);
75  T gamma = OperatorBase::template GetSingleArgument<float>("gamma", 0);
76  DCHECK_GT(stepsize, 0);
77  DCHECK_GT(gamma, 0);
78  functor_.reset(new StepLearningRate<T>(stepsize, gamma));
79  } else if (policy == "exp") {
80  T gamma = OperatorBase::template GetSingleArgument<float>("gamma", 0);
81  DCHECK_GT(gamma, 0);
82  functor_.reset(new ExpLearningRate<T>(gamma));
83  } else if (policy == "inv") {
84  T gamma = OperatorBase::template GetSingleArgument<float>("gamma", 0);
85  T power = OperatorBase::template GetSingleArgument<float>("power", 0);
86  DCHECK_GT(gamma, 0);
87  DCHECK_GT(power, 0);
88  functor_.reset(new InvLearningRate<T>(gamma, power));
89  } else if (policy == "poly") {
90  int max_iter = OperatorBase::template GetSingleArgument<int>("max_iter", -1);
91  T power = OperatorBase::template GetSingleArgument<float>("power", 0);
92  DCHECK_GT(power, 0);
93  functor_.reset(new PolyLearningRate<T>(power, max_iter));
94  } else if (policy == "linearWarmup") {
95  T start_multiplier = OperatorBase::template GetSingleArgument<float>(
96  "start_multiplier", 0.);
97  int num_iter =
98  OperatorBase::template GetSingleArgument<int>("num_iter", 0);
99  DCHECK_GT(start_multiplier, 0);
100  functor_.reset(
101  new LinearWarmupLearningRate<T>(start_multiplier, num_iter));
102  } else if (policy == "constantWarmup") {
103  T multiplier =
104  OperatorBase::template GetSingleArgument<float>("multiplier", 0.5);
105  int num_iter =
106  OperatorBase::template GetSingleArgument<int>("num_iter", 0);
107  DCHECK_GT(multiplier, 0);
108  functor_.reset(new ConstantWarmupLearningRate<T>(multiplier, num_iter));
109  } else {
110  LOG(FATAL) << "Unknown learning rate policy: " << policy;
111  }
112  }
113  USE_OPERATOR_CONTEXT_FUNCTIONS;
114 
115  bool RunOnDevice() override {
116  int64_t iter =
117  OperatorBase::Input<TensorCPU>(0).template data<int64_t>()[0];
118  T learning_rate = base_lr_ * (*functor_)(iter);
119  // Write to output.
120  auto* output = Output(0);
121  output->Resize(vector<TIndex>());
122  context_.template Copy<T, CPUContext, Context>(
123  1, &learning_rate, Output(0)->template mutable_data<T>());
124  return true;
125  }
126 
127  private:
128  unique_ptr<LearningRateFunctor<T> > functor_;
129  T base_lr_;
130 
131 };
132 
133 } // namespace caffe2
134 
135 #endif // CAFFE2_SGD_LEARNING_RATE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.