Caffe2 - Python API
A deep learning, cross platform ML framework
optimizer.py
1 from collections import defaultdict
2 from torch._six import container_abcs
3 
4 import torch
5 from copy import deepcopy
6 from itertools import chain
7 
8 
9 class _RequiredParameter(object):
10  """Singleton class representing a required parameter for an Optimizer."""
11  def __repr__(self):
12  return "<required parameter>"
13 
14 required = _RequiredParameter()
15 
16 
17 class Optimizer(object):
18  r"""Base class for all optimizers.
19 
20  .. warning::
21  Parameters need to be specified as collections that have a deterministic
22  ordering that is consistent between runs. Examples of objects that don't
23  satisfy those properties are sets and iterators over values of dictionaries.
24 
25  Arguments:
26  params (iterable): an iterable of :class:`torch.Tensor` s or
27  :class:`dict` s. Specifies what Tensors should be optimized.
28  defaults: (dict): a dict containing default values of optimization
29  options (used when a parameter group doesn't specify them).
30  """
31 
32  def __init__(self, params, defaults):
33  self.defaults = defaults
34 
35  if isinstance(params, torch.Tensor):
36  raise TypeError("params argument given to the optimizer should be "
37  "an iterable of Tensors or dicts, but got " +
38  torch.typename(params))
39 
40  self.state = defaultdict(dict)
41  self.param_groups = []
42 
43  param_groups = list(params)
44  if len(param_groups) == 0:
45  raise ValueError("optimizer got an empty parameter list")
46  if not isinstance(param_groups[0], dict):
47  param_groups = [{'params': param_groups}]
48 
49  for param_group in param_groups:
50  self.add_param_group(param_group)
51 
52  def __getstate__(self):
53  return {
54  'state': self.state,
55  'param_groups': self.param_groups,
56  }
57 
58  def __setstate__(self, state):
59  self.__dict__.update(state)
60 
61  def __repr__(self):
62  format_string = self.__class__.__name__ + ' ('
63  for i, group in enumerate(self.param_groups):
64  format_string += '\n'
65  format_string += 'Parameter Group {0}\n'.format(i)
66  for key in sorted(group.keys()):
67  if key != 'params':
68  format_string += ' {0}: {1}\n'.format(key, group[key])
69  format_string += ')'
70  return format_string
71 
72  def state_dict(self):
73  r"""Returns the state of the optimizer as a :class:`dict`.
74 
75  It contains two entries:
76 
77  * state - a dict holding current optimization state. Its content
78  differs between optimizer classes.
79  * param_groups - a dict containing all parameter groups
80  """
81  # Save ids instead of Tensors
82  def pack_group(group):
83  packed = {k: v for k, v in group.items() if k != 'params'}
84  packed['params'] = [id(p) for p in group['params']]
85  return packed
86  param_groups = [pack_group(g) for g in self.param_groups]
87  # Remap state to use ids as keys
88  packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
89  for k, v in self.state.items()}
90  return {
91  'state': packed_state,
92  'param_groups': param_groups,
93  }
94 
95  def load_state_dict(self, state_dict):
96  r"""Loads the optimizer state.
97 
98  Arguments:
99  state_dict (dict): optimizer state. Should be an object returned
100  from a call to :meth:`state_dict`.
101  """
102  # deepcopy, to be consistent with module API
103  state_dict = deepcopy(state_dict)
104  # Validate the state_dict
105  groups = self.param_groups
106  saved_groups = state_dict['param_groups']
107 
108  if len(groups) != len(saved_groups):
109  raise ValueError("loaded state dict has a different number of "
110  "parameter groups")
111  param_lens = (len(g['params']) for g in groups)
112  saved_lens = (len(g['params']) for g in saved_groups)
113  if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
114  raise ValueError("loaded state dict contains a parameter group "
115  "that doesn't match the size of optimizer's group")
116 
117  # Update the state
118  id_map = {old_id: p for old_id, p in
119  zip(chain(*(g['params'] for g in saved_groups)),
120  chain(*(g['params'] for g in groups)))}
121 
122  def cast(param, value):
123  r"""Make a deep copy of value, casting all tensors to device of param."""
124  if isinstance(value, torch.Tensor):
125  # Floating-point types are a bit special here. They are the only ones
126  # that are assumed to always match the type of params.
127  if param.is_floating_point():
128  value = value.to(param.dtype)
129  value = value.to(param.device)
130  return value
131  elif isinstance(value, dict):
132  return {k: cast(param, v) for k, v in value.items()}
133  elif isinstance(value, container_abcs.Iterable):
134  return type(value)(cast(param, v) for v in value)
135  else:
136  return value
137 
138  # Copy state assigned to params (and cast tensors to appropriate types).
139  # State that is not assigned to params is copied as is (needed for
140  # backward compatibility).
141  state = defaultdict(dict)
142  for k, v in state_dict['state'].items():
143  if k in id_map:
144  param = id_map[k]
145  state[param] = cast(param, v)
146  else:
147  state[k] = v
148 
149  # Update parameter groups, setting their 'params' value
150  def update_group(group, new_group):
151  new_group['params'] = group['params']
152  return new_group
153  param_groups = [
154  update_group(g, ng) for g, ng in zip(groups, saved_groups)]
155  self.__setstate__({'state': state, 'param_groups': param_groups})
156 
157  def zero_grad(self):
158  r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
159  for group in self.param_groups:
160  for p in group['params']:
161  if p.grad is not None:
162  p.grad.detach_()
163  p.grad.zero_()
164 
165  def step(self, closure):
166  r"""Performs a single optimization step (parameter update).
167 
168  Arguments:
169  closure (callable): A closure that reevaluates the model and
170  returns the loss. Optional for most optimizers.
171  """
172  raise NotImplementedError
173 
174  def add_param_group(self, param_group):
175  r"""Add a param group to the :class:`Optimizer` s `param_groups`.
176 
177  This can be useful when fine tuning a pre-trained network as frozen layers can be made
178  trainable and added to the :class:`Optimizer` as training progresses.
179 
180  Arguments:
181  param_group (dict): Specifies what Tensors should be optimized along with group
182  specific optimization options.
183  """
184  assert isinstance(param_group, dict), "param group must be a dict"
185 
186  params = param_group['params']
187  if isinstance(params, torch.Tensor):
188  param_group['params'] = [params]
189  elif isinstance(params, set):
190  raise TypeError('optimizer parameters need to be organized in ordered collections, but '
191  'the ordering of tensors in sets will change between runs. Please use a list instead.')
192  else:
193  param_group['params'] = list(params)
194 
195  for param in param_group['params']:
196  if not isinstance(param, torch.Tensor):
197  raise TypeError("optimizer can only optimize Tensors, "
198  "but one of the params is " + torch.typename(param))
199  if not param.is_leaf:
200  raise ValueError("can't optimize a non-leaf Tensor")
201 
202  for name, default in self.defaults.items():
203  if default is required and name not in param_group:
204  raise ValueError("parameter group didn't specify a value of required optimization parameter " +
205  name)
206  else:
207  param_group.setdefault(name, default)
208 
209  param_set = set()
210  for group in self.param_groups:
211  param_set.update(set(group['params']))
212 
213  if not param_set.isdisjoint(set(param_group['params'])):
214  raise ValueError("some parameters appear in more than one parameter group")
215 
216  self.param_groups.append(param_group)
def __setstate__(self, state)
Definition: optimizer.py:58
def add_param_group(self, param_group)
Definition: optimizer.py:174
def typename(o)
Define basic utilities.
Definition: __init__.py:94