1 from collections
import defaultdict
5 from copy
import deepcopy
6 from itertools
import chain
10 """Singleton class representing a required parameter for an Optimizer.""" 12 return "<required parameter>" 18 r"""Base class for all optimizers. 21 Parameters need to be specified as collections that have a deterministic 22 ordering that is consistent between runs. Examples of objects that don't 23 satisfy those properties are sets and iterators over values of dictionaries. 26 params (iterable): an iterable of :class:`torch.Tensor` s or 27 :class:`dict` s. Specifies what Tensors should be optimized. 28 defaults: (dict): a dict containing default values of optimization 29 options (used when a parameter group doesn't specify them). 32 def __init__(self, params, defaults):
35 if isinstance(params, torch.Tensor):
36 raise TypeError(
"params argument given to the optimizer should be " 37 "an iterable of Tensors or dicts, but got " +
40 self.
state = defaultdict(dict)
43 param_groups = list(params)
44 if len(param_groups) == 0:
45 raise ValueError(
"optimizer got an empty parameter list")
46 if not isinstance(param_groups[0], dict):
47 param_groups = [{
'params': param_groups}]
49 for param_group
in param_groups:
52 def __getstate__(self):
58 def __setstate__(self, state):
59 self.__dict__.update(state)
62 format_string = self.__class__.__name__ +
' (' 65 format_string +=
'Parameter Group {0}\n'.format(i)
66 for key
in sorted(group.keys()):
68 format_string +=
' {0}: {1}\n'.format(key, group[key])
73 r"""Returns the state of the optimizer as a :class:`dict`. 75 It contains two entries: 77 * state - a dict holding current optimization state. Its content 78 differs between optimizer classes. 79 * param_groups - a dict containing all parameter groups 82 def pack_group(group):
83 packed = {k: v
for k, v
in group.items()
if k !=
'params'}
84 packed[
'params'] = [id(p)
for p
in group[
'params']]
86 param_groups = [pack_group(g)
for g
in self.
param_groups]
88 packed_state = {(id(k)
if isinstance(k, torch.Tensor)
else k): v
89 for k, v
in self.state.items()}
91 'state': packed_state,
92 'param_groups': param_groups,
95 def load_state_dict(self, state_dict):
96 r"""Loads the optimizer state. 99 state_dict (dict): optimizer state. Should be an object returned 100 from a call to :meth:`state_dict`. 103 state_dict = deepcopy(state_dict)
106 saved_groups = state_dict[
'param_groups']
108 if len(groups) != len(saved_groups):
109 raise ValueError(
"loaded state dict has a different number of " 111 param_lens = (len(g[
'params'])
for g
in groups)
112 saved_lens = (len(g[
'params'])
for g
in saved_groups)
113 if any(p_len != s_len
for p_len, s_len
in zip(param_lens, saved_lens)):
114 raise ValueError(
"loaded state dict contains a parameter group " 115 "that doesn't match the size of optimizer's group")
118 id_map = {old_id: p
for old_id, p
in 119 zip(chain(*(g[
'params']
for g
in saved_groups)),
120 chain(*(g[
'params']
for g
in groups)))}
122 def cast(param, value):
123 r"""Make a deep copy of value, casting all tensors to device of param.""" 124 if isinstance(value, torch.Tensor):
127 if param.is_floating_point():
128 value = value.to(param.dtype)
129 value = value.to(param.device)
131 elif isinstance(value, dict):
132 return {k: cast(param, v)
for k, v
in value.items()}
133 elif isinstance(value, container_abcs.Iterable):
134 return type(value)(cast(param, v)
for v
in value)
141 state = defaultdict(dict)
142 for k, v
in state_dict[
'state'].items():
145 state[param] = cast(param, v)
150 def update_group(group, new_group):
151 new_group[
'params'] = group[
'params']
154 update_group(g, ng)
for g, ng
in zip(groups, saved_groups)]
155 self.
__setstate__({
'state': state,
'param_groups': param_groups})
158 r"""Clears the gradients of all optimized :class:`torch.Tensor` s.""" 160 for p
in group[
'params']:
161 if p.grad
is not None:
165 def step(self, closure):
166 r"""Performs a single optimization step (parameter update). 169 closure (callable): A closure that reevaluates the model and 170 returns the loss. Optional for most optimizers. 172 raise NotImplementedError
174 def add_param_group(self, param_group):
175 r"""Add a param group to the :class:`Optimizer` s `param_groups`. 177 This can be useful when fine tuning a pre-trained network as frozen layers can be made 178 trainable and added to the :class:`Optimizer` as training progresses. 181 param_group (dict): Specifies what Tensors should be optimized along with group 182 specific optimization options. 184 assert isinstance(param_group, dict),
"param group must be a dict" 186 params = param_group[
'params']
187 if isinstance(params, torch.Tensor):
188 param_group[
'params'] = [params]
189 elif isinstance(params, set):
190 raise TypeError(
'optimizer parameters need to be organized in ordered collections, but ' 191 'the ordering of tensors in sets will change between runs. Please use a list instead.')
193 param_group[
'params'] = list(params)
195 for param
in param_group[
'params']:
196 if not isinstance(param, torch.Tensor):
197 raise TypeError(
"optimizer can only optimize Tensors, " 199 if not param.is_leaf:
200 raise ValueError(
"can't optimize a non-leaf Tensor")
202 for name, default
in self.defaults.items():
203 if default
is required
and name
not in param_group:
204 raise ValueError(
"parameter group didn't specify a value of required optimization parameter " +
207 param_group.setdefault(name, default)
211 param_set.update(set(group[
'params']))
213 if not param_set.isdisjoint(set(param_group[
'params'])):
214 raise ValueError(
"some parameters appear in more than one parameter group")
216 self.param_groups.append(param_group)
def __setstate__(self, state)
def add_param_group(self, param_group)
def typename(o)
Define basic utilities.