2 from .optimizer
import Optimizer, required
6 r"""Implements stochastic gradient descent (optionally with momentum). 8 Nesterov momentum is based on the formula from 9 `On the importance of initialization and momentum in deep learning`__. 12 params (iterable): iterable of parameters to optimize or dicts defining 14 lr (float): learning rate 15 momentum (float, optional): momentum factor (default: 0) 16 weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 17 dampening (float, optional): dampening for momentum (default: 0) 18 nesterov (bool, optional): enables Nesterov momentum (default: False) 21 >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 22 >>> optimizer.zero_grad() 23 >>> loss_fn(model(input), target).backward() 26 __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 29 The implementation of SGD with Momentum/Nesterov subtly differs from 30 Sutskever et. al. and implementations in some other frameworks. 32 Considering the specific case of Momentum, the update can be written as 38 where p, g, v and :math:`\rho` denote the parameters, gradient, 39 velocity, and momentum respectively. 41 This is in contrast to Sutskever et. al. and 42 other frameworks which employ an update of the form 45 v = \rho * v + lr * g \\ 48 The Nesterov version is analogously modified. 51 def __init__(self, params, lr=required, momentum=0, dampening=0,
52 weight_decay=0, nesterov=
False):
53 if lr
is not required
and lr < 0.0:
54 raise ValueError(
"Invalid learning rate: {}".format(lr))
56 raise ValueError(
"Invalid momentum value: {}".format(momentum))
57 if weight_decay < 0.0:
58 raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay))
60 defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
61 weight_decay=weight_decay, nesterov=nesterov)
62 if nesterov
and (momentum <= 0
or dampening != 0):
63 raise ValueError(
"Nesterov momentum requires a momentum and zero dampening")
64 super(SGD, self).__init__(params, defaults)
66 def __setstate__(self, state):
67 super(SGD, self).__setstate__(state)
68 for group
in self.param_groups:
69 group.setdefault(
'nesterov',
False)
71 def step(self, closure=None):
72 """Performs a single optimization step. 75 closure (callable, optional): A closure that reevaluates the model 79 if closure
is not None:
82 for group
in self.param_groups:
83 weight_decay = group[
'weight_decay']
84 momentum = group[
'momentum']
85 dampening = group[
'dampening']
86 nesterov = group[
'nesterov']
88 for p
in group[
'params']:
93 d_p.add_(weight_decay, p.data)
95 param_state = self.state[p]
96 if 'momentum_buffer' not in param_state:
97 buf = param_state[
'momentum_buffer'] = torch.clone(d_p).detach()
99 buf = param_state[
'momentum_buffer']
100 buf.mul_(momentum).add_(1 - dampening, d_p)
102 d_p = d_p.add(momentum, buf)
106 p.data.add_(-group[
'lr'], d_p)
def step(self, closure=None)