1 #ifndef CAFFE2_SGD_LEARNING_RATE_FUNCTORS_H_ 2 #define CAFFE2_SGD_LEARNING_RATE_FUNCTORS_H_ 7 #include "caffe2/core/context.h" 8 #include "caffe2/core/operator.h" 18 virtual T operator()(
const int64_t iter)
const = 0;
25 T operator()(
const int64_t )
const override {
37 const int64_t active_period,
38 const int64_t inactive_period,
39 const bool active_first)
40 : active_period_(active_period),
41 inactive_period_(inactive_period),
42 active_first_(active_first) {}
43 T operator()(
const int64_t iter)
const override {
44 if (iter % (active_period_ + inactive_period_) <
45 (active_first_ ? active_period_ : inactive_period_)) {
46 return active_first_ ? 1. : 0.;
48 return active_first_ ? 0. : 1.;
52 int64_t active_period_;
53 int64_t inactive_period_;
62 : stepsize_(stepsize), gamma_(gamma) {}
63 T operator()(
const int64_t iter)
const override {
64 return std::pow(gamma_, static_cast<T>(iter / stepsize_));
76 T operator()(
const int64_t iter)
const override {
77 return std::pow(gamma_, static_cast<T>(iter));
88 : gamma_(gamma), power_(power) {}
89 T operator()(
const int64_t iter)
const override {
90 return std::pow(
T(1) + gamma_ * iter, -power_);
101 : power_(power), max_iter_(max_iter) {}
102 T operator()(
const int64_t iter)
const override {
103 return std::pow(1 -
T(iter) /
T(max_iter_), power_);
110 template <
typename T>
114 : start_multiplier_(start_multiplier), num_iter_(num_iter) {}
115 T operator()(
const int64_t iter)
const override {
116 if (iter >= num_iter_) {
119 return start_multiplier_ + (1. - start_multiplier_) *
T(iter) /
T(num_iter_);
126 template <
typename T>
130 : multiplier_(multiplier), num_iter_(num_iter) {}
131 T operator()(
const int64_t iter)
const override {
132 if (iter >= num_iter_) {
135 return T(multiplier_);
145 template <
typename T>
149 const int64_t num_iter,
150 const T start_multiplier,
153 const T end_multiplier)
154 : linear_warmup_lr_(start_multiplier, num_iter),
155 inv_lr_(gamma, power),
157 end_multiplier_(end_multiplier) {}
158 T operator()(
const int64_t iter)
const override {
159 if (iter < num_iter_) {
160 return linear_warmup_lr_(iter);
162 return std::max(end_multiplier_, inv_lr_(iter - num_iter_));
171 template <
typename T>
175 : num_iter_(num_iter), policy_(policy) {}
181 template <
typename T>
186 DCHECK_GT(sub_policies.size(), 0);
187 int64_t num_iter_start = 1;
188 for (
auto it = sub_policies.begin(); it != sub_policies.end(); ++it) {
189 DCHECK_GT(it->num_iter_, 0);
190 sub_policies_[num_iter_start].reset(it->policy_);
191 num_iter_start += it->num_iter_;
194 T operator()(
const int64_t iter)
const override {
195 auto sub_policy = sub_policies_.upper_bound(iter);
196 DCHECK(sub_policy != sub_policies_.begin());
198 return (*sub_policy->second)(iter);
202 std::map<int64_t, std::unique_ptr<LearningRateFunctor<T>>> sub_policies_;
207 #endif // CAFFE2_SGD_LEARNING_RATE_FUNCTORS_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...