Caffe2 - C++ API
A deep learning, cross platform ML framework
learning_rate_op.h
1 #ifndef CAFFE2_SGD_LEARNING_RATE_OP_H_
2 #define CAFFE2_SGD_LEARNING_RATE_OP_H_
3 
4 #include <cfloat>
5 #include <cmath>
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/sgd/learning_rate_functors.h"
9 
10 namespace caffe2 {
11 
12 template <typename T, class Context>
13 class LearningRateOp final : public Operator<Context> {
14  public:
15  LearningRateOp(const OperatorDef& operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws),
17  functor_(nullptr),
18  base_lr_(this->template GetSingleArgument<float>(
19  "base_lr",
20  FLT_MAX)) {
21  CAFFE_ENFORCE_NE(base_lr_, FLT_MAX, "Base learning rate must be set.");
22  const string policy = this->template GetSingleArgument<string>("policy", "");
23  CAFFE_ENFORCE(policy.size(), "Must specify a learning rate policy.");
24  functor_.reset(createLearningRateFunctor(policy));
25  }
26  USE_OPERATOR_CONTEXT_FUNCTIONS;
27 
28  bool RunOnDevice() override {
29  int64_t iter =
30  OperatorBase::Input<Tensor>(0, CPU).template data<int64_t>()[0];
31  T learning_rate = cur_base_lr_ * (*functor_)(iter);
32  // Write to output.
33  auto* output = Output(0);
34  output->Resize(vector<int64_t>());
35  context_.template CopyFromCPU<T>(
36  1, &learning_rate, Output(0)->template mutable_data<T>());
37  return true;
38  }
39 
40  private:
41  unique_ptr<LearningRateFunctor<T>> functor_;
42  T base_lr_;
43  T base_lr_scale_;
44  T cur_base_lr_;
45 
46  LearningRateFunctor<T>* createLearningRateFunctor(
47  const string& policy,
48  const string& arg_prefix = "") {
49  if (policy != "composite") {
50  base_lr_scale_ =
51  this->template GetSingleArgument<float>(arg_prefix + "lr_scale", 1.0);
52  cur_base_lr_ = base_lr_scale_ * base_lr_;
53  }
54  if (policy == "fixed") {
55  return new FixedLearningRate<T>();
56  } else if (policy == "alter") {
57  bool active_first = this->template GetSingleArgument<bool>(
58  arg_prefix + "active_first", true);
59  int64_t active_period = this->template GetSingleArgument<int64_t>(
60  arg_prefix + "active_period", -1);
61  int64_t inactive_period =
62  this->template GetSingleArgument<int64_t>(
63  arg_prefix + "inactive_period", -1);
64  DCHECK_GE(active_period, 0);
65  DCHECK_GE(inactive_period, 0);
66  return new AlternateLearningRate<T>(
67  active_period, inactive_period, active_first);
68  } else if (policy == "hill") {
69  int64_t num_iter = this->template GetSingleArgument<int>(
70  arg_prefix + "num_iter", 0);
71  DCHECK_GT(num_iter, 0);
72  T start_multiplier = this->template GetSingleArgument<float>(
73  arg_prefix + "start_multiplier", 0.);
74  DCHECK_GE(start_multiplier, 0); // start_multiplier in range [0, 1]
75  DCHECK_LE(start_multiplier, 1);
76  T gamma = this->template GetSingleArgument<float>(
77  arg_prefix + "gamma", 0);
78  DCHECK_GT(gamma, 0);
79  T power = this->template GetSingleArgument<float>(
80  arg_prefix + "power", 0);
81  DCHECK_GT(power, 0);
82  T end_multiplier = this->template GetSingleArgument<float>(
83  arg_prefix + "end_multiplier", 0);
84  DCHECK_GE(end_multiplier, 0); // end_multiplier in range [0, 1]
85  DCHECK_LE(end_multiplier, 1);
86  return new HillLearningRate<T>(
87  num_iter, start_multiplier, gamma, power, end_multiplier);
88  } else if (policy == "step") {
89  int stepsize = this->template GetSingleArgument<int>(
90  arg_prefix + "stepsize", 0);
91  T gamma = this->template GetSingleArgument<float>(
92  arg_prefix + "gamma", 0);
93  DCHECK_GT(stepsize, 0);
94  DCHECK_GT(gamma, 0);
95  return new StepLearningRate<T>(stepsize, gamma);
96  } else if (policy == "exp") {
97  T gamma = this->template GetSingleArgument<float>(
98  arg_prefix + "gamma", 0);
99  DCHECK_GT(gamma, 0);
100  return new ExpLearningRate<T>(gamma);
101  } else if (policy == "inv") {
102  T gamma = this->template GetSingleArgument<float>(
103  arg_prefix + "gamma", 0);
104  T power = this->template GetSingleArgument<float>(
105  arg_prefix + "power", 0);
106  DCHECK_GT(gamma, 0);
107  DCHECK_GT(power, 0);
108  return new InvLearningRate<T>(gamma, power);
109  } else if (policy == "poly") {
110  int max_iter = this->template GetSingleArgument<int>(
111  arg_prefix + "max_iter", -1);
112  T power = this->template GetSingleArgument<float>(
113  arg_prefix + "power", 0);
114  DCHECK_GT(power, 0);
115  return new PolyLearningRate<T>(power, max_iter);
116  } else if (policy == "linearWarmup") {
117  T start_multiplier = this->template GetSingleArgument<float>(
118  arg_prefix + "start_multiplier", 0.);
119  int num_iter = this->template GetSingleArgument<int>(
120  arg_prefix + "num_iter", 0);
121  DCHECK_GE(start_multiplier, 0);
122  return new LinearWarmupLearningRate<T>(start_multiplier, num_iter);
123  } else if (policy == "constantWarmup") {
124  T multiplier = this->template GetSingleArgument<float>(
125  arg_prefix + "multiplier", 0.5);
126  int num_iter = this->template GetSingleArgument<int>(
127  arg_prefix + "num_iter", 0);
128  DCHECK_GT(multiplier, 0);
129  return new ConstantWarmupLearningRate<T>(multiplier, num_iter);
130  } else if (policy == "composite") {
131  std::vector<int> sub_policy_num_iters =
132  this->template GetRepeatedArgument<int>(
133  "sub_policy_num_iters");
134  std::list<CompositeLearningRateItem<T>> sub_policies;
135  CAFFE_ENFORCE_GT(
136  sub_policy_num_iters.size(),
137  0,
138  "Must specify at least one sub learning rate policy.");
139  for (int i = 0; i < sub_policy_num_iters.size(); ++i) {
140  CAFFE_ENFORCE_GT(
141  sub_policy_num_iters[i],
142  0,
143  "The number of iterations for sub learning rate policy should be positive.");
144  std::stringstream sub_policy_arg_prefix;
145  sub_policy_arg_prefix << "sub_policy_" << i << "_";
146  const string sub_policy_arg_prefix_str = sub_policy_arg_prefix.str();
147  const string sub_policy = this->template GetSingleArgument<string>(
148  sub_policy_arg_prefix_str + "policy", "");
149  if (sub_policy == "composite") {
150  CAFFE_THROW(
151  "Defining composite LR policy as a subpolicy of composite LR "
152  "policy is not allowed.");
153  }
154  sub_policies.push_back(CompositeLearningRateItem<T>(
155  sub_policy_num_iters[i],
156  createLearningRateFunctor(sub_policy, sub_policy_arg_prefix_str)));
157  }
158  return new CompositeLearningRate<T>(sub_policies);
159  } else {
160  CAFFE_THROW("Unknown learning rate policy: ", policy);
161  return NULL;
162  }
163  }
164 };
165 
166 } // namespace caffe2
167 
168 #endif // CAFFE2_SGD_LEARNING_RATE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13