2 from functools
import reduce
3 from .optimizer
import Optimizer
7 """Implements L-BFGS algorithm. 10 This optimizer doesn't support per-parameter options and parameter 11 groups (there can be only one). 14 Right now all parameters have to be on a single device. This will be 15 improved in the future. 18 This is a very memory intensive optimizer (it requires additional 19 ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory 20 try reducing the history size, or use a different algorithm. 23 lr (float): learning rate (default: 1) 24 max_iter (int): maximal number of iterations per optimization step 26 max_eval (int): maximal number of function evaluations per optimization 27 step (default: max_iter * 1.25). 28 tolerance_grad (float): termination tolerance on first order optimality 30 tolerance_change (float): termination tolerance on function 31 value/parameter changes (default: 1e-9). 32 history_size (int): update history size (default: 100). 35 def __init__(self, params, lr=1, max_iter=20, max_eval=None,
36 tolerance_grad=1e-5, tolerance_change=1e-9, history_size=100,
39 max_eval = max_iter * 5 // 4
40 defaults = dict(lr=lr, max_iter=max_iter, max_eval=max_eval,
41 tolerance_grad=tolerance_grad, tolerance_change=tolerance_change,
42 history_size=history_size, line_search_fn=line_search_fn)
43 super(LBFGS, self).__init__(params, defaults)
45 if len(self.param_groups) != 1:
46 raise ValueError(
"LBFGS doesn't support per-parameter options " 49 self.
_params = self.param_groups[0][
'params']
57 def _gather_flat_grad(self):
61 view = p.data.new(p.data.numel()).zero_()
62 elif p.grad.data.is_sparse:
63 view = p.grad.data.to_dense().view(-1)
65 view = p.grad.data.view(-1)
67 return torch.cat(views, 0)
69 def _add_grad(self, step_size, update):
74 p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))
76 assert offset == self.
_numel()
79 """Performs a single optimization step. 82 closure (callable): A closure that reevaluates the model 85 assert len(self.param_groups) == 1
87 group = self.param_groups[0]
89 max_iter = group[
'max_iter']
90 max_eval = group[
'max_eval']
91 tolerance_grad = group[
'tolerance_grad']
92 tolerance_change = group[
'tolerance_change']
93 line_search_fn = group[
'line_search_fn']
94 history_size = group[
'history_size']
98 state = self.state[self.
_params[0]]
99 state.setdefault(
'func_evals', 0)
100 state.setdefault(
'n_iter', 0)
103 orig_loss = closure()
104 loss = float(orig_loss)
106 state[
'func_evals'] += 1
109 abs_grad_sum = flat_grad.abs().sum()
111 if abs_grad_sum <= tolerance_grad:
117 old_dirs = state.get(
'old_dirs')
118 old_stps = state.get(
'old_stps')
119 H_diag = state.get(
'H_diag')
120 prev_flat_grad = state.get(
'prev_flat_grad')
121 prev_loss = state.get(
'prev_loss')
125 while n_iter < max_iter:
133 if state[
'n_iter'] == 1:
140 y = flat_grad.sub(prev_flat_grad)
145 if len(old_dirs) == history_size:
155 H_diag = ys / y.dot(y)
159 num_old = len(old_dirs)
161 if 'ro' not in state:
162 state[
'ro'] = [
None] * history_size
163 state[
'al'] = [
None] * history_size
167 for i
in range(num_old):
168 ro[i] = 1. / old_dirs[i].dot(old_stps[i])
172 for i
in range(num_old - 1, -1, -1):
173 al[i] = old_stps[i].dot(q) * ro[i]
174 q.add_(-al[i], old_dirs[i])
178 d = r = torch.mul(q, H_diag)
179 for i
in range(num_old):
180 be_i = old_dirs[i].dot(r) * ro[i]
181 r.add_(al[i] - be_i, old_stps[i])
183 if prev_flat_grad
is None:
184 prev_flat_grad = flat_grad.clone()
186 prev_flat_grad.copy_(flat_grad)
193 if state[
'n_iter'] == 1:
194 t = min(1., 1. / abs_grad_sum) * lr
199 gtd = flat_grad.dot(d)
203 if line_search_fn
is not None:
205 raise RuntimeError(
"line search function is not supported yet")
209 if n_iter != max_iter:
213 loss = float(closure())
215 abs_grad_sum = flat_grad.abs().sum()
219 current_evals += ls_func_evals
220 state[
'func_evals'] += ls_func_evals
225 if n_iter == max_iter:
228 if current_evals >= max_eval:
231 if abs_grad_sum <= tolerance_grad:
234 if gtd > -tolerance_change:
237 if d.mul(t).abs_().sum() <= tolerance_change:
240 if abs(loss - prev_loss) < tolerance_change:
245 state[
'old_dirs'] = old_dirs
246 state[
'old_stps'] = old_stps
247 state[
'H_diag'] = H_diag
248 state[
'prev_flat_grad'] = prev_flat_grad
249 state[
'prev_loss'] = prev_loss
def _add_grad(self, step_size, update)
def _gather_flat_grad(self)