1 from collections 
import defaultdict
     5 from copy 
import deepcopy
     6 from itertools 
import chain
    10     """Singleton class representing a required parameter for an Optimizer."""    12         return "<required parameter>"    18     r"""Base class for all optimizers.    21         Parameters need to be specified as collections that have a deterministic    22         ordering that is consistent between runs. Examples of objects that don't    23         satisfy those properties are sets and iterators over values of dictionaries.    26         params (iterable): an iterable of :class:`torch.Tensor` s or    27             :class:`dict` s. Specifies what Tensors should be optimized.    28         defaults: (dict): a dict containing default values of optimization    29             options (used when a parameter group doesn't specify them).    32     def __init__(self, params, defaults):
    35         if isinstance(params, torch.Tensor):
    36             raise TypeError(
"params argument given to the optimizer should be "    37                             "an iterable of Tensors or dicts, but got " +
    40         self.
state = defaultdict(dict)
    43         param_groups = list(params)
    44         if len(param_groups) == 0:
    45             raise ValueError(
"optimizer got an empty parameter list")
    46         if not isinstance(param_groups[0], dict):
    47             param_groups = [{
'params': param_groups}]
    49         for param_group 
in param_groups:
    52     def __getstate__(self):
    58     def __setstate__(self, state):
    59         self.__dict__.update(state)
    62         format_string = self.__class__.__name__ + 
' ('    65             format_string += 
'Parameter Group {0}\n'.format(i)
    66             for key 
in sorted(group.keys()):
    68                     format_string += 
'    {0}: {1}\n'.format(key, group[key])
    73         r"""Returns the state of the optimizer as a :class:`dict`.    75         It contains two entries:    77         * state - a dict holding current optimization state. Its content    78             differs between optimizer classes.    79         * param_groups - a dict containing all parameter groups    82         def pack_group(group):
    83             packed = {k: v 
for k, v 
in group.items() 
if k != 
'params'}
    84             packed[
'params'] = [id(p) 
for p 
in group[
'params']]
    86         param_groups = [pack_group(g) 
for g 
in self.
param_groups]
    88         packed_state = {(id(k) 
if isinstance(k, torch.Tensor) 
else k): v
    89                         for k, v 
in self.state.items()}
    91             'state': packed_state,
    92             'param_groups': param_groups,
    95     def load_state_dict(self, state_dict):
    96         r"""Loads the optimizer state.    99             state_dict (dict): optimizer state. Should be an object returned   100                 from a call to :meth:`state_dict`.   103         state_dict = deepcopy(state_dict)
   106         saved_groups = state_dict[
'param_groups']
   108         if len(groups) != len(saved_groups):
   109             raise ValueError(
"loaded state dict has a different number of "   111         param_lens = (len(g[
'params']) 
for g 
in groups)
   112         saved_lens = (len(g[
'params']) 
for g 
in saved_groups)
   113         if any(p_len != s_len 
for p_len, s_len 
in zip(param_lens, saved_lens)):
   114             raise ValueError(
"loaded state dict contains a parameter group "   115                              "that doesn't match the size of optimizer's group")
   118         id_map = {old_id: p 
for old_id, p 
in   119                   zip(chain(*(g[
'params'] 
for g 
in saved_groups)),
   120                       chain(*(g[
'params'] 
for g 
in groups)))}
   122         def cast(param, value):
   123             r"""Make a deep copy of value, casting all tensors to device of param."""   124             if isinstance(value, torch.Tensor):
   127                 if param.is_floating_point():
   128                     value = value.to(param.dtype)
   129                 value = value.to(param.device)
   131             elif isinstance(value, dict):
   132                 return {k: cast(param, v) 
for k, v 
in value.items()}
   133             elif isinstance(value, container_abcs.Iterable):
   134                 return type(value)(cast(param, v) 
for v 
in value)
   141         state = defaultdict(dict)
   142         for k, v 
in state_dict[
'state'].items():
   145                 state[param] = cast(param, v)
   150         def update_group(group, new_group):
   151             new_group[
'params'] = group[
'params']
   154             update_group(g, ng) 
for g, ng 
in zip(groups, saved_groups)]
   155         self.
__setstate__({
'state': state, 
'param_groups': param_groups})
   158         r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""   160             for p 
in group[
'params']:
   161                 if p.grad 
is not None:
   165     def step(self, closure):
   166         r"""Performs a single optimization step (parameter update).   169             closure (callable): A closure that reevaluates the model and   170                 returns the loss. Optional for most optimizers.   172         raise NotImplementedError
   174     def add_param_group(self, param_group):
   175         r"""Add a param group to the :class:`Optimizer` s `param_groups`.   177         This can be useful when fine tuning a pre-trained network as frozen layers can be made   178         trainable and added to the :class:`Optimizer` as training progresses.   181             param_group (dict): Specifies what Tensors should be optimized along with group   182             specific optimization options.   184         assert isinstance(param_group, dict), 
"param group must be a dict"   186         params = param_group[
'params']
   187         if isinstance(params, torch.Tensor):
   188             param_group[
'params'] = [params]
   189         elif isinstance(params, set):
   190             raise TypeError(
'optimizer parameters need to be organized in ordered collections, but '   191                             'the ordering of tensors in sets will change between runs. Please use a list instead.')
   193             param_group[
'params'] = list(params)
   195         for param 
in param_group[
'params']:
   196             if not isinstance(param, torch.Tensor):
   197                 raise TypeError(
"optimizer can only optimize Tensors, "   199             if not param.is_leaf:
   200                 raise ValueError(
"can't optimize a non-leaf Tensor")
   202         for name, default 
in self.defaults.items():
   203             if default 
is required 
and name 
not in param_group:
   204                 raise ValueError(
"parameter group didn't specify a value of required optimization parameter " +
   207                 param_group.setdefault(name, default)
   211             param_set.update(set(group[
'params']))
   213         if not param_set.isdisjoint(set(param_group[
'params'])):
   214             raise ValueError(
"some parameters appear in more than one parameter group")
   216         self.param_groups.append(param_group)
 
def __setstate__(self, state)
def add_param_group(self, param_group)
def typename(o)
Define basic utilities.