 Caffe2 - Python API A deep learning, cross platform ML framework
sgd.py
1 import torch
2 from .optimizer import Optimizer, required
3
4
5 class SGD(Optimizer):
6  r"""Implements stochastic gradient descent (optionally with momentum).
7
8  Nesterov momentum is based on the formula from
9  On the importance of initialization and momentum in deep learning__.
10
11  Args:
12  params (iterable): iterable of parameters to optimize or dicts defining
13  parameter groups
14  lr (float): learning rate
15  momentum (float, optional): momentum factor (default: 0)
16  weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
17  dampening (float, optional): dampening for momentum (default: 0)
18  nesterov (bool, optional): enables Nesterov momentum (default: False)
19
20  Example:
21  >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
23  >>> loss_fn(model(input), target).backward()
24  >>> optimizer.step()
25
26  __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
27
28  .. note::
29  The implementation of SGD with Momentum/Nesterov subtly differs from
30  Sutskever et. al. and implementations in some other frameworks.
31
32  Considering the specific case of Momentum, the update can be written as
33
34  .. math::
35  v = \rho * v + g \\
36  p = p - lr * v
37
38  where p, g, v and :math:\rho denote the parameters, gradient,
39  velocity, and momentum respectively.
40
41  This is in contrast to Sutskever et. al. and
42  other frameworks which employ an update of the form
43
44  .. math::
45  v = \rho * v + lr * g \\
46  p = p - v
47
48  The Nesterov version is analogously modified.
49  """
50
51  def __init__(self, params, lr=required, momentum=0, dampening=0,
52  weight_decay=0, nesterov=False):
53  if lr is not required and lr < 0.0:
54  raise ValueError("Invalid learning rate: {}".format(lr))
55  if momentum < 0.0:
56  raise ValueError("Invalid momentum value: {}".format(momentum))
57  if weight_decay < 0.0:
58  raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
59
60  defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
61  weight_decay=weight_decay, nesterov=nesterov)
62  if nesterov and (momentum <= 0 or dampening != 0):
63  raise ValueError("Nesterov momentum requires a momentum and zero dampening")
64  super(SGD, self).__init__(params, defaults)
65
66  def __setstate__(self, state):
67  super(SGD, self).__setstate__(state)
68  for group in self.param_groups:
69  group.setdefault('nesterov', False)
70
71  def step(self, closure=None):
72  """Performs a single optimization step.
73
74  Arguments:
75  closure (callable, optional): A closure that reevaluates the model
76  and returns the loss.
77  """
78  loss = None
79  if closure is not None:
80  loss = closure()
81
82  for group in self.param_groups:
83  weight_decay = group['weight_decay']
84  momentum = group['momentum']
85  dampening = group['dampening']
86  nesterov = group['nesterov']
87
88  for p in group['params']:
89  if p.grad is None:
90  continue
91  d_p = p.grad.data
92  if weight_decay != 0:
94  if momentum != 0:
95  param_state = self.state[p]
96  if 'momentum_buffer' not in param_state:
97  buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
98  else:
99  buf = param_state['momentum_buffer']
100  buf.mul_(momentum).add_(1 - dampening, d_p)
101  if nesterov:
102  d_p = d_p.add(momentum, buf)
103  else:
104  d_p = buf
105