1 #ifndef CAFFE2_SGD_LEARNING_RATE_OP_H_ 2 #define CAFFE2_SGD_LEARNING_RATE_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/sgd/learning_rate_functors.h" 12 template <
typename T,
class Context>
18 base_lr_(this->
template GetSingleArgument<float>(
21 CAFFE_ENFORCE_NE(base_lr_, FLT_MAX,
"Base learning rate must be set.");
22 const string policy = this->
template GetSingleArgument<string>(
"policy",
"");
23 CAFFE_ENFORCE(policy.size(),
"Must specify a learning rate policy.");
24 functor_.reset(createLearningRateFunctor(policy));
26 USE_OPERATOR_CONTEXT_FUNCTIONS;
28 bool RunOnDevice()
override {
30 OperatorBase::Input<Tensor>(0, CPU).
template data<int64_t>()[0];
31 T learning_rate = cur_base_lr_ * (*functor_)(iter);
33 auto* output = Output(0);
34 output->Resize(vector<int64_t>());
35 context_.template CopyFromCPU<T>(
36 1, &learning_rate, Output(0)->template mutable_data<T>());
41 unique_ptr<LearningRateFunctor<T>> functor_;
48 const string& arg_prefix =
"") {
49 if (policy !=
"composite") {
51 this->
template GetSingleArgument<float>(arg_prefix +
"lr_scale", 1.0);
52 cur_base_lr_ = base_lr_scale_ * base_lr_;
54 if (policy ==
"fixed") {
56 }
else if (policy ==
"alter") {
57 bool active_first = this->
template GetSingleArgument<bool>(
58 arg_prefix +
"active_first",
true);
59 int64_t active_period = this->
template GetSingleArgument<int64_t>(
60 arg_prefix +
"active_period", -1);
61 int64_t inactive_period =
62 this->
template GetSingleArgument<int64_t>(
63 arg_prefix +
"inactive_period", -1);
64 DCHECK_GE(active_period, 0);
65 DCHECK_GE(inactive_period, 0);
67 active_period, inactive_period, active_first);
68 }
else if (policy ==
"hill") {
69 int64_t num_iter = this->
template GetSingleArgument<int>(
70 arg_prefix +
"num_iter", 0);
71 DCHECK_GT(num_iter, 0);
72 T start_multiplier = this->
template GetSingleArgument<float>(
73 arg_prefix +
"start_multiplier", 0.);
74 DCHECK_GE(start_multiplier, 0);
75 DCHECK_LE(start_multiplier, 1);
76 T gamma = this->
template GetSingleArgument<float>(
77 arg_prefix +
"gamma", 0);
79 T power = this->
template GetSingleArgument<float>(
80 arg_prefix +
"power", 0);
82 T end_multiplier = this->
template GetSingleArgument<float>(
83 arg_prefix +
"end_multiplier", 0);
84 DCHECK_GE(end_multiplier, 0);
85 DCHECK_LE(end_multiplier, 1);
87 num_iter, start_multiplier, gamma, power, end_multiplier);
88 }
else if (policy ==
"step") {
89 int stepsize = this->
template GetSingleArgument<int>(
90 arg_prefix +
"stepsize", 0);
91 T gamma = this->
template GetSingleArgument<float>(
92 arg_prefix +
"gamma", 0);
93 DCHECK_GT(stepsize, 0);
96 }
else if (policy ==
"exp") {
97 T gamma = this->
template GetSingleArgument<float>(
98 arg_prefix +
"gamma", 0);
101 }
else if (policy ==
"inv") {
102 T gamma = this->
template GetSingleArgument<float>(
103 arg_prefix +
"gamma", 0);
104 T power = this->
template GetSingleArgument<float>(
105 arg_prefix +
"power", 0);
109 }
else if (policy ==
"poly") {
110 int max_iter = this->
template GetSingleArgument<int>(
111 arg_prefix +
"max_iter", -1);
112 T power = this->
template GetSingleArgument<float>(
113 arg_prefix +
"power", 0);
116 }
else if (policy ==
"linearWarmup") {
117 T start_multiplier = this->
template GetSingleArgument<float>(
118 arg_prefix +
"start_multiplier", 0.);
119 int num_iter = this->
template GetSingleArgument<int>(
120 arg_prefix +
"num_iter", 0);
121 DCHECK_GE(start_multiplier, 0);
123 }
else if (policy ==
"constantWarmup") {
124 T multiplier = this->
template GetSingleArgument<float>(
125 arg_prefix +
"multiplier", 0.5);
126 int num_iter = this->
template GetSingleArgument<int>(
127 arg_prefix +
"num_iter", 0);
128 DCHECK_GT(multiplier, 0);
130 }
else if (policy ==
"composite") {
131 std::vector<int> sub_policy_num_iters =
132 this->
template GetRepeatedArgument<int>(
133 "sub_policy_num_iters");
134 std::list<CompositeLearningRateItem<T>> sub_policies;
136 sub_policy_num_iters.size(),
138 "Must specify at least one sub learning rate policy.");
139 for (
int i = 0; i < sub_policy_num_iters.size(); ++i) {
141 sub_policy_num_iters[i],
143 "The number of iterations for sub learning rate policy should be positive.");
144 std::stringstream sub_policy_arg_prefix;
145 sub_policy_arg_prefix <<
"sub_policy_" << i <<
"_";
146 const string sub_policy_arg_prefix_str = sub_policy_arg_prefix.str();
147 const string sub_policy = this->
template GetSingleArgument<string>(
148 sub_policy_arg_prefix_str +
"policy",
"");
149 if (sub_policy ==
"composite") {
151 "Defining composite LR policy as a subpolicy of composite LR " 152 "policy is not allowed.");
155 sub_policy_num_iters[i],
156 createLearningRateFunctor(sub_policy, sub_policy_arg_prefix_str)));
160 CAFFE_THROW(
"Unknown learning rate policy: ", policy);
168 #endif // CAFFE2_SGD_LEARNING_RATE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...