Caffe2 - Python API
A deep learning, cross platform ML framework
lbfgs.py
1 import torch
2 from functools import reduce
3 from .optimizer import Optimizer
4 
5 
6 class LBFGS(Optimizer):
7  """Implements L-BFGS algorithm.
8 
9  .. warning::
10  This optimizer doesn't support per-parameter options and parameter
11  groups (there can be only one).
12 
13  .. warning::
14  Right now all parameters have to be on a single device. This will be
15  improved in the future.
16 
17  .. note::
18  This is a very memory intensive optimizer (it requires additional
19  ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
20  try reducing the history size, or use a different algorithm.
21 
22  Arguments:
23  lr (float): learning rate (default: 1)
24  max_iter (int): maximal number of iterations per optimization step
25  (default: 20)
26  max_eval (int): maximal number of function evaluations per optimization
27  step (default: max_iter * 1.25).
28  tolerance_grad (float): termination tolerance on first order optimality
29  (default: 1e-5).
30  tolerance_change (float): termination tolerance on function
31  value/parameter changes (default: 1e-9).
32  history_size (int): update history size (default: 100).
33  """
34 
35  def __init__(self, params, lr=1, max_iter=20, max_eval=None,
36  tolerance_grad=1e-5, tolerance_change=1e-9, history_size=100,
37  line_search_fn=None):
38  if max_eval is None:
39  max_eval = max_iter * 5 // 4
40  defaults = dict(lr=lr, max_iter=max_iter, max_eval=max_eval,
41  tolerance_grad=tolerance_grad, tolerance_change=tolerance_change,
42  history_size=history_size, line_search_fn=line_search_fn)
43  super(LBFGS, self).__init__(params, defaults)
44 
45  if len(self.param_groups) != 1:
46  raise ValueError("LBFGS doesn't support per-parameter options "
47  "(parameter groups)")
48 
49  self._params = self.param_groups[0]['params']
50  self._numel_cache = None
51 
52  def _numel(self):
53  if self._numel_cache is None:
54  self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
55  return self._numel_cache
56 
57  def _gather_flat_grad(self):
58  views = []
59  for p in self._params:
60  if p.grad is None:
61  view = p.data.new(p.data.numel()).zero_()
62  elif p.grad.data.is_sparse:
63  view = p.grad.data.to_dense().view(-1)
64  else:
65  view = p.grad.data.view(-1)
66  views.append(view)
67  return torch.cat(views, 0)
68 
69  def _add_grad(self, step_size, update):
70  offset = 0
71  for p in self._params:
72  numel = p.numel()
73  # view as to avoid deprecated pointwise semantics
74  p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))
75  offset += numel
76  assert offset == self._numel()
77 
78  def step(self, closure):
79  """Performs a single optimization step.
80 
81  Arguments:
82  closure (callable): A closure that reevaluates the model
83  and returns the loss.
84  """
85  assert len(self.param_groups) == 1
86 
87  group = self.param_groups[0]
88  lr = group['lr']
89  max_iter = group['max_iter']
90  max_eval = group['max_eval']
91  tolerance_grad = group['tolerance_grad']
92  tolerance_change = group['tolerance_change']
93  line_search_fn = group['line_search_fn']
94  history_size = group['history_size']
95 
96  # NOTE: LBFGS has only global state, but we register it as state for
97  # the first param, because this helps with casting in load_state_dict
98  state = self.state[self._params[0]]
99  state.setdefault('func_evals', 0)
100  state.setdefault('n_iter', 0)
101 
102  # evaluate initial f(x) and df/dx
103  orig_loss = closure()
104  loss = float(orig_loss)
105  current_evals = 1
106  state['func_evals'] += 1
107 
108  flat_grad = self._gather_flat_grad()
109  abs_grad_sum = flat_grad.abs().sum()
110 
111  if abs_grad_sum <= tolerance_grad:
112  return orig_loss
113 
114  # tensors cached in state (for tracing)
115  d = state.get('d')
116  t = state.get('t')
117  old_dirs = state.get('old_dirs')
118  old_stps = state.get('old_stps')
119  H_diag = state.get('H_diag')
120  prev_flat_grad = state.get('prev_flat_grad')
121  prev_loss = state.get('prev_loss')
122 
123  n_iter = 0
124  # optimize for a max of max_iter iterations
125  while n_iter < max_iter:
126  # keep track of nb of iterations
127  n_iter += 1
128  state['n_iter'] += 1
129 
130  ############################################################
131  # compute gradient descent direction
132  ############################################################
133  if state['n_iter'] == 1:
134  d = flat_grad.neg()
135  old_dirs = []
136  old_stps = []
137  H_diag = 1
138  else:
139  # do lbfgs update (update memory)
140  y = flat_grad.sub(prev_flat_grad)
141  s = d.mul(t)
142  ys = y.dot(s) # y*s
143  if ys > 1e-10:
144  # updating memory
145  if len(old_dirs) == history_size:
146  # shift history by one (limited-memory)
147  old_dirs.pop(0)
148  old_stps.pop(0)
149 
150  # store new direction/step
151  old_dirs.append(y)
152  old_stps.append(s)
153 
154  # update scale of initial Hessian approximation
155  H_diag = ys / y.dot(y) # (y*y)
156 
157  # compute the approximate (L-BFGS) inverse Hessian
158  # multiplied by the gradient
159  num_old = len(old_dirs)
160 
161  if 'ro' not in state:
162  state['ro'] = [None] * history_size
163  state['al'] = [None] * history_size
164  ro = state['ro']
165  al = state['al']
166 
167  for i in range(num_old):
168  ro[i] = 1. / old_dirs[i].dot(old_stps[i])
169 
170  # iteration in L-BFGS loop collapsed to use just one buffer
171  q = flat_grad.neg()
172  for i in range(num_old - 1, -1, -1):
173  al[i] = old_stps[i].dot(q) * ro[i]
174  q.add_(-al[i], old_dirs[i])
175 
176  # multiply by initial Hessian
177  # r/d is the final direction
178  d = r = torch.mul(q, H_diag)
179  for i in range(num_old):
180  be_i = old_dirs[i].dot(r) * ro[i]
181  r.add_(al[i] - be_i, old_stps[i])
182 
183  if prev_flat_grad is None:
184  prev_flat_grad = flat_grad.clone()
185  else:
186  prev_flat_grad.copy_(flat_grad)
187  prev_loss = loss
188 
189  ############################################################
190  # compute step length
191  ############################################################
192  # reset initial guess for step size
193  if state['n_iter'] == 1:
194  t = min(1., 1. / abs_grad_sum) * lr
195  else:
196  t = lr
197 
198  # directional derivative
199  gtd = flat_grad.dot(d) # g * d
200 
201  # optional line search: user function
202  ls_func_evals = 0
203  if line_search_fn is not None:
204  # perform line search, using user function
205  raise RuntimeError("line search function is not supported yet")
206  else:
207  # no line search, simply move with fixed-step
208  self._add_grad(t, d)
209  if n_iter != max_iter:
210  # re-evaluate function only if not in last iteration
211  # the reason we do this: in a stochastic setting,
212  # no use to re-evaluate that function here
213  loss = float(closure())
214  flat_grad = self._gather_flat_grad()
215  abs_grad_sum = flat_grad.abs().sum()
216  ls_func_evals = 1
217 
218  # update func eval
219  current_evals += ls_func_evals
220  state['func_evals'] += ls_func_evals
221 
222  ############################################################
223  # check conditions
224  ############################################################
225  if n_iter == max_iter:
226  break
227 
228  if current_evals >= max_eval:
229  break
230 
231  if abs_grad_sum <= tolerance_grad:
232  break
233 
234  if gtd > -tolerance_change:
235  break
236 
237  if d.mul(t).abs_().sum() <= tolerance_change:
238  break
239 
240  if abs(loss - prev_loss) < tolerance_change:
241  break
242 
243  state['d'] = d
244  state['t'] = t
245  state['old_dirs'] = old_dirs
246  state['old_stps'] = old_stps
247  state['H_diag'] = H_diag
248  state['prev_flat_grad'] = prev_flat_grad
249  state['prev_loss'] = prev_loss
250 
251  return orig_loss
def _numel(self)
Definition: lbfgs.py:52
def _add_grad(self, step_size, update)
Definition: lbfgs.py:69
def _gather_flat_grad(self)
Definition: lbfgs.py:57
def step(self, closure)
Definition: lbfgs.py:78