Caffe2 - Python API
A deep learning, cross platform ML framework
lr_scheduler.py
1 import types
2 import math
3 import torch
4 from torch._six import inf
5 from collections import Counter
6 from functools import partial
7 from .optimizer import Optimizer
8 
9 
10 class _LRScheduler(object):
11  def __init__(self, optimizer, last_epoch=-1):
12  if not isinstance(optimizer, Optimizer):
13  raise TypeError('{} is not an Optimizer'.format(
14  type(optimizer).__name__))
15  self.optimizer = optimizer
16  if last_epoch == -1:
17  for group in optimizer.param_groups:
18  group.setdefault('initial_lr', group['lr'])
19  else:
20  for i, group in enumerate(optimizer.param_groups):
21  if 'initial_lr' not in group:
22  raise KeyError("param 'initial_lr' is not specified "
23  "in param_groups[{}] when resuming an optimizer".format(i))
24  self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
25  self.step(last_epoch + 1)
26  self.last_epoch = last_epoch
27 
28  def state_dict(self):
29  """Returns the state of the scheduler as a :class:`dict`.
30 
31  It contains an entry for every variable in self.__dict__ which
32  is not the optimizer.
33  """
34  return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
35 
36  def load_state_dict(self, state_dict):
37  """Loads the schedulers state.
38 
39  Arguments:
40  state_dict (dict): scheduler state. Should be an object returned
41  from a call to :meth:`state_dict`.
42  """
43  self.__dict__.update(state_dict)
44 
45  def get_lr(self):
46  raise NotImplementedError
47 
48  def step(self, epoch=None):
49  if epoch is None:
50  epoch = self.last_epoch + 1
51  self.last_epoch = epoch
52  for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
53  param_group['lr'] = lr
54 
55 
57  """Sets the learning rate of each parameter group to the initial lr
58  times a given function. When last_epoch=-1, sets initial lr as lr.
59 
60  Args:
61  optimizer (Optimizer): Wrapped optimizer.
62  lr_lambda (function or list): A function which computes a multiplicative
63  factor given an integer parameter epoch, or a list of such
64  functions, one for each group in optimizer.param_groups.
65  last_epoch (int): The index of last epoch. Default: -1.
66 
67  Example:
68  >>> # Assuming optimizer has two groups.
69  >>> lambda1 = lambda epoch: epoch // 30
70  >>> lambda2 = lambda epoch: 0.95 ** epoch
71  >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
72  >>> for epoch in range(100):
73  >>> scheduler.step()
74  >>> train(...)
75  >>> validate(...)
76  """
77 
78  def __init__(self, optimizer, lr_lambda, last_epoch=-1):
79  self.optimizer = optimizer
80  if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
81  self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
82  else:
83  if len(lr_lambda) != len(optimizer.param_groups):
84  raise ValueError("Expected {} lr_lambdas, but got {}".format(
85  len(optimizer.param_groups), len(lr_lambda)))
86  self.lr_lambdas = list(lr_lambda)
87  self.last_epoch = last_epoch
88  super(LambdaLR, self).__init__(optimizer, last_epoch)
89 
90  def state_dict(self):
91  """Returns the state of the scheduler as a :class:`dict`.
92 
93  It contains an entry for every variable in self.__dict__ which
94  is not the optimizer.
95  The learning rate lambda functions will only be saved if they are callable objects
96  and not if they are functions or lambdas.
97  """
98  state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
99  state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
100 
101  for idx, fn in enumerate(self.lr_lambdas):
102  if not isinstance(fn, types.FunctionType):
103  state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
104 
105  return state_dict
106 
107  def load_state_dict(self, state_dict):
108  """Loads the schedulers state.
109 
110  Arguments:
111  state_dict (dict): scheduler state. Should be an object returned
112  from a call to :meth:`state_dict`.
113  """
114  lr_lambdas = state_dict.pop('lr_lambdas')
115  self.__dict__.update(state_dict)
116 
117  for idx, fn in enumerate(lr_lambdas):
118  if fn is not None:
119  self.lr_lambdas[idx].__dict__.update(fn)
120 
121  def get_lr(self):
122  return [base_lr * lmbda(self.last_epoch)
123  for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
124 
125 
127  """Decays the learning rate of each parameter group by gamma every
128  step_size epochs. Notice that such decay can happen simultaneously with
129  other changes to the learning rate from outside this scheduler. When
130  last_epoch=-1, sets initial lr as lr.
131 
132  Args:
133  optimizer (Optimizer): Wrapped optimizer.
134  step_size (int): Period of learning rate decay.
135  gamma (float): Multiplicative factor of learning rate decay.
136  Default: 0.1.
137  last_epoch (int): The index of last epoch. Default: -1.
138 
139  Example:
140  >>> # Assuming optimizer uses lr = 0.05 for all groups
141  >>> # lr = 0.05 if epoch < 30
142  >>> # lr = 0.005 if 30 <= epoch < 60
143  >>> # lr = 0.0005 if 60 <= epoch < 90
144  >>> # ...
145  >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
146  >>> for epoch in range(100):
147  >>> scheduler.step()
148  >>> train(...)
149  >>> validate(...)
150  """
151 
152  def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
153  self.step_size = step_size
154  self.gamma = gamma
155  super(StepLR, self).__init__(optimizer, last_epoch)
156 
157  def get_lr(self):
158  if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
159  return [group['lr'] for group in self.optimizer.param_groups]
160  return [group['lr'] * self.gamma
161  for group in self.optimizer.param_groups]
162 
163 
165  """Decays the learning rate of each parameter group by gamma once the
166  number of epoch reaches one of the milestones. Notice that such decay can
167  happen simultaneously with other changes to the learning rate from outside
168  this scheduler. When last_epoch=-1, sets initial lr as lr.
169 
170  Args:
171  optimizer (Optimizer): Wrapped optimizer.
172  milestones (list): List of epoch indices. Must be increasing.
173  gamma (float): Multiplicative factor of learning rate decay.
174  Default: 0.1.
175  last_epoch (int): The index of last epoch. Default: -1.
176 
177  Example:
178  >>> # Assuming optimizer uses lr = 0.05 for all groups
179  >>> # lr = 0.05 if epoch < 30
180  >>> # lr = 0.005 if 30 <= epoch < 80
181  >>> # lr = 0.0005 if epoch >= 80
182  >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
183  >>> for epoch in range(100):
184  >>> scheduler.step()
185  >>> train(...)
186  >>> validate(...)
187  """
188 
189  def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
190  self.milestones = Counter(milestones)
191  self.gamma = gamma
192  super(MultiStepLR, self).__init__(optimizer, last_epoch)
193 
194  def get_lr(self):
195  if self.last_epoch not in self.milestones:
196  return [group['lr'] for group in self.optimizer.param_groups]
197  return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
198  for group in self.optimizer.param_groups]
199 
200 
202  """Decays the learning rate of each parameter group by gamma every epoch.
203  When last_epoch=-1, sets initial lr as lr.
204 
205  Args:
206  optimizer (Optimizer): Wrapped optimizer.
207  gamma (float): Multiplicative factor of learning rate decay.
208  last_epoch (int): The index of last epoch. Default: -1.
209  """
210 
211  def __init__(self, optimizer, gamma, last_epoch=-1):
212  self.gamma = gamma
213  super(ExponentialLR, self).__init__(optimizer, last_epoch)
214 
215  def get_lr(self):
216  if self.last_epoch == 0:
217  return self.base_lrs
218  return [group['lr'] * self.gamma
219  for group in self.optimizer.param_groups]
220 
221 
223  r"""Set the learning rate of each parameter group using a cosine annealing
224  schedule, where :math:`\eta_{max}` is set to the initial lr and
225  :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
226 
227  .. math::
228  \eta_{t+1} = \eta_{min} + (\eta_t - \eta_{min})\frac{1 +
229  \cos(\frac{T_{cur+1}}{T_{max}}\pi)}{1 + \cos(\frac{T_{cur}}{T_{max}}\pi)}
230 
231  When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
232  is defined recursively, the learning rate can be simultaneously modified
233  outside this scheduler by other operators. If the learning rate is set
234  solely by this scheduler, the learning rate at each step becomes:
235 
236  .. math::
237  \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
238  \cos(\frac{T_{cur}}{T_{max}}\pi))
239 
240  It has been proposed in
241  `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
242  implements the cosine annealing part of SGDR, and not the restarts.
243 
244  Args:
245  optimizer (Optimizer): Wrapped optimizer.
246  T_max (int): Maximum number of iterations.
247  eta_min (float): Minimum learning rate. Default: 0.
248  last_epoch (int): The index of last epoch. Default: -1.
249 
250  .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
251  https://arxiv.org/abs/1608.03983
252  """
253 
254  def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
255  self.T_max = T_max
256  self.eta_min = eta_min
257  super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
258 
259  def get_lr(self):
260  if self.last_epoch == 0:
261  return self.base_lrs
262  return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
263  (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
264  (group['lr'] - self.eta_min) + self.eta_min
265  for group in self.optimizer.param_groups]
266 
267 
268 class ReduceLROnPlateau(object):
269  """Reduce learning rate when a metric has stopped improving.
270  Models often benefit from reducing the learning rate by a factor
271  of 2-10 once learning stagnates. This scheduler reads a metrics
272  quantity and if no improvement is seen for a 'patience' number
273  of epochs, the learning rate is reduced.
274 
275  Args:
276  optimizer (Optimizer): Wrapped optimizer.
277  mode (str): One of `min`, `max`. In `min` mode, lr will
278  be reduced when the quantity monitored has stopped
279  decreasing; in `max` mode it will be reduced when the
280  quantity monitored has stopped increasing. Default: 'min'.
281  factor (float): Factor by which the learning rate will be
282  reduced. new_lr = lr * factor. Default: 0.1.
283  patience (int): Number of epochs with no improvement after
284  which learning rate will be reduced. For example, if
285  `patience = 2`, then we will ignore the first 2 epochs
286  with no improvement, and will only decrease the LR after the
287  3rd epoch if the loss still hasn't improved then.
288  Default: 10.
289  verbose (bool): If ``True``, prints a message to stdout for
290  each update. Default: ``False``.
291  threshold (float): Threshold for measuring the new optimum,
292  to only focus on significant changes. Default: 1e-4.
293  threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
294  dynamic_threshold = best * ( 1 + threshold ) in 'max'
295  mode or best * ( 1 - threshold ) in `min` mode.
296  In `abs` mode, dynamic_threshold = best + threshold in
297  `max` mode or best - threshold in `min` mode. Default: 'rel'.
298  cooldown (int): Number of epochs to wait before resuming
299  normal operation after lr has been reduced. Default: 0.
300  min_lr (float or list): A scalar or a list of scalars. A
301  lower bound on the learning rate of all param groups
302  or each group respectively. Default: 0.
303  eps (float): Minimal decay applied to lr. If the difference
304  between new and old lr is smaller than eps, the update is
305  ignored. Default: 1e-8.
306 
307  Example:
308  >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
309  >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
310  >>> for epoch in range(10):
311  >>> train(...)
312  >>> val_loss = validate(...)
313  >>> # Note that step should be called after validate()
314  >>> scheduler.step(val_loss)
315  """
316 
317  def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
318  verbose=False, threshold=1e-4, threshold_mode='rel',
319  cooldown=0, min_lr=0, eps=1e-8):
320 
321  if factor >= 1.0:
322  raise ValueError('Factor should be < 1.0.')
323  self.factor = factor
324 
325  if not isinstance(optimizer, Optimizer):
326  raise TypeError('{} is not an Optimizer'.format(
327  type(optimizer).__name__))
328  self.optimizer = optimizer
329 
330  if isinstance(min_lr, list) or isinstance(min_lr, tuple):
331  if len(min_lr) != len(optimizer.param_groups):
332  raise ValueError("expected {} min_lrs, got {}".format(
333  len(optimizer.param_groups), len(min_lr)))
334  self.min_lrs = list(min_lr)
335  else:
336  self.min_lrs = [min_lr] * len(optimizer.param_groups)
337 
338  self.patience = patience
339  self.verbose = verbose
340  self.cooldown = cooldown
341  self.cooldown_counter = 0
342  self.mode = mode
343  self.threshold = threshold
344  self.threshold_mode = threshold_mode
345  self.best = None
346  self.num_bad_epochs = None
347  self.mode_worse = None # the worse value for the chosen mode
348  self.is_better = None
349  self.eps = eps
350  self.last_epoch = -1
351  self._init_is_better(mode=mode, threshold=threshold,
352  threshold_mode=threshold_mode)
353  self._reset()
354 
355  def _reset(self):
356  """Resets num_bad_epochs counter and cooldown counter."""
357  self.best = self.mode_worse
358  self.cooldown_counter = 0
359  self.num_bad_epochs = 0
360 
361  def step(self, metrics, epoch=None):
362  current = metrics
363  if epoch is None:
364  epoch = self.last_epoch = self.last_epoch + 1
365  self.last_epoch = epoch
366 
367  if self.is_better(current, self.best):
368  self.best = current
369  self.num_bad_epochs = 0
370  else:
371  self.num_bad_epochs += 1
372 
373  if self.in_cooldown:
374  self.cooldown_counter -= 1
375  self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
376 
377  if self.num_bad_epochs > self.patience:
378  self._reduce_lr(epoch)
379  self.cooldown_counter = self.cooldown
380  self.num_bad_epochs = 0
381 
382  def _reduce_lr(self, epoch):
383  for i, param_group in enumerate(self.optimizer.param_groups):
384  old_lr = float(param_group['lr'])
385  new_lr = max(old_lr * self.factor, self.min_lrs[i])
386  if old_lr - new_lr > self.eps:
387  param_group['lr'] = new_lr
388  if self.verbose:
389  print('Epoch {:5d}: reducing learning rate'
390  ' of group {} to {:.4e}.'.format(epoch, i, new_lr))
391 
392  @property
393  def in_cooldown(self):
394  return self.cooldown_counter > 0
395 
396  def _cmp(self, mode, threshold_mode, threshold, a, best):
397  if mode == 'min' and threshold_mode == 'rel':
398  rel_epsilon = 1. - threshold
399  return a < best * rel_epsilon
400 
401  elif mode == 'min' and threshold_mode == 'abs':
402  return a < best - threshold
403 
404  elif mode == 'max' and threshold_mode == 'rel':
405  rel_epsilon = threshold + 1.
406  return a > best * rel_epsilon
407 
408  else: # mode == 'max' and epsilon_mode == 'abs':
409  return a > best + threshold
410 
411  def _init_is_better(self, mode, threshold, threshold_mode):
412  if mode not in {'min', 'max'}:
413  raise ValueError('mode ' + mode + ' is unknown!')
414  if threshold_mode not in {'rel', 'abs'}:
415  raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
416 
417  if mode == 'min':
418  self.mode_worse = inf
419  else: # mode == 'max':
420  self.mode_worse = -inf
421 
422  self.is_better = partial(self._cmp, mode, threshold_mode, threshold)
423 
424  def state_dict(self):
425  return {key: value for key, value in self.__dict__.items() if key not in {'optimizer', 'is_better'}}
426 
427  def load_state_dict(self, state_dict):
428  self.__dict__.update(state_dict)
429  self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
def _init_is_better(self, mode, threshold, threshold_mode)
def load_state_dict(self, state_dict)
Definition: lr_scheduler.py:36
def _cmp(self, mode, threshold_mode, threshold, a, best)
def load_state_dict(self, state_dict)