Caffe2 - C++ API
A deep learning, cross platform ML framework
lbfgs.cpp
1 #include <torch/optim/lbfgs.h>
2 
3 #include <torch/csrc/autograd/generated/variable_factories.h>
4 #include <torch/csrc/autograd/variable.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/utils.h>
7 
8 #include <ATen/ATen.h>
9 
10 #include <cmath>
11 #include <functional>
12 #include <vector>
13 
14 namespace torch {
15 namespace optim {
16 
17 LBFGSOptions::LBFGSOptions(double learning_rate)
18  : learning_rate_(learning_rate) {}
19 
20 Tensor LBFGS::gather_flat_grad() {
21  std::vector<Tensor> views;
22  for (auto& parameter : parameters_) {
23  views.push_back(parameter.grad().view(-1));
24  }
25  return torch::cat(views);
26 }
27 
28 void LBFGS::add_grad(const torch::Tensor& step_size, const Tensor& update) {
29  NoGradGuard guard;
30  int64_t offset = 0;
31  for (auto& parameter : parameters_) {
32  int64_t numel = parameter.numel();
33  parameter.add_(
34  update.slice(0, offset, offset + numel, 1).view_as(parameter),
35  step_size.item<float>());
36  offset += numel;
37  }
38 }
39 
41  torch::Tensor orig_loss = closure();
42  torch::Tensor loss = orig_loss.clone();
43  int64_t current_evals = 1;
44  func_evals += 1;
45 
46  Tensor flat_grad = gather_flat_grad();
47  Tensor abs_grad_sum = flat_grad.abs().sum();
48 
49  if (abs_grad_sum.item<float>() <= options.tolerance_grad_) {
50  return loss;
51  }
52 
53  Tensor ONE = torch::tensor(1, flat_grad.options());
54 
55  int64_t n_iter = 0;
56  while (n_iter < options.max_iter_) {
57  n_iter++;
58  state_n_iter++;
59 
60  if (state_n_iter == 1) {
61  d = flat_grad.neg();
62  H_diag = ONE;
63  prev_flat_grad = flat_grad.clone();
64  } else {
65  Tensor y = flat_grad.sub(prev_flat_grad);
66  Tensor s = d.mul(t);
67  Tensor ys = y.dot(s);
68 
69  if (ys.item<float>() > 1e-10) {
70  // updating memory
71 
72  if (old_dirs.size() == options.history_size_) {
73  // shift history by one (limited memory)
74  old_dirs.pop_front();
75  old_stps.pop_front();
76  }
77 
78  // store new direction/step
79  old_dirs.push_back(y);
80  old_stps.push_back(s);
81 
82  // update scale of initial Hessian approximation
83  H_diag = ys / y.dot(y);
84  }
85 
86  int64_t num_old = old_dirs.size();
87 
88  for (int64_t i = 0; i < num_old; i++) {
89  ro.at(i) = ONE / old_dirs.at(i).dot(old_stps.at(i));
90  }
91 
92  Tensor q = flat_grad.neg();
93  for (int64_t i = num_old - 1; i >= 0; i--) {
94  al.at(i) = old_stps.at(i).dot(q) * ro.at(i);
95  q.add_(old_dirs.at(i), -al.at(i).item());
96  }
97 
98  // Multiply by initial Hessian
99  // r/d is the final direction
100  Tensor r = q.mul(H_diag);
101  d = r;
102 
103  for (int64_t i = 0; i < num_old; i++) {
104  Tensor be_i = old_dirs.at(i).dot(r) * ro.at(i);
105  r.add_(old_stps.at(i), (al.at(i) - be_i).item());
106  }
107  prev_flat_grad.copy_(flat_grad);
108  }
109 
114  // reset initial guess for step size
115  if (n_iter == 1) {
116  t = torch::min(ONE, ONE / abs_grad_sum) * options.learning_rate_;
117  } else {
118  t = torch::tensor(options.learning_rate_, torch::kFloat32);
119  }
120 
121  Tensor gtd = flat_grad.dot(d);
122  add_grad(t, d);
123  int64_t ls_func_evals = 0;
124  if (n_iter != options.max_iter_) {
125  // re-evaluate function only if not in last iteration
126  // the reason we do this: in a stochastic setting,
127  // no use to re-evaluate that function here
128  loss = closure();
129  flat_grad = gather_flat_grad();
130  abs_grad_sum = flat_grad.abs().sum();
131  ls_func_evals = 1;
132  }
133 
134  current_evals += ls_func_evals;
135 
140  if (n_iter == options.max_iter_) {
141  break;
142  } else if (current_evals >= options.max_eval_) {
143  break;
144  } else if (abs_grad_sum.item<float>() <= options.tolerance_grad_) {
145  break;
146  } else if (gtd.item<float>() > -options.tolerance_grad_) {
147  break;
148  } else if (
149  d.mul(t).abs_().sum().item<float>() <= options.tolerance_change_) {
150  break;
151  } else if (
152  std::abs(loss.item<float>() - prev_loss.item<float>()) <
153  options.tolerance_change_) {
154  break;
155  }
156  }
157  return orig_loss;
158 }
159 
160 void LBFGS::save(serialize::OutputArchive& archive) const {
161  serialize(*this, archive);
162 }
163 
165  serialize(*this, archive);
166 }
167 } // namespace optim
168 } // namespace torch
std::vector< Tensor > parameters_
The parameters this optimizer optimizes.
Definition: optimizer.h:83
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
Definition: TensorMethods.h:42
torch::Tensor step(LossClosure closure) override
Definition: lbfgs.cpp:40
void load(serialize::InputArchive &archive) override
Deserializes the optimizer state from the given archive.
Definition: lbfgs.cpp:164
Definition: jit_type.h:17
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32
void save(serialize::OutputArchive &archive) const override
Serializes the optimizer state into the given archive.
Definition: lbfgs.cpp:160
std::function< Tensor()> LossClosure
A loss function closure, which is expected to return the loss value.
Definition: optimizer.h:113