1 #include <torch/optim/lbfgs.h> 3 #include <torch/csrc/autograd/generated/variable_factories.h> 4 #include <torch/csrc/autograd/variable.h> 5 #include <torch/serialize/archive.h> 6 #include <torch/utils.h> 17 LBFGSOptions::LBFGSOptions(
double learning_rate)
18 : learning_rate_(learning_rate) {}
20 Tensor LBFGS::gather_flat_grad() {
21 std::vector<Tensor> views;
23 views.push_back(parameter.grad().view(-1));
25 return torch::cat(views);
31 for (
auto& parameter : parameters_) {
32 int64_t numel = parameter.numel();
34 update.slice(0, offset, offset + numel, 1).view_as(parameter),
35 step_size.item<
float>());
43 int64_t current_evals = 1;
46 Tensor flat_grad = gather_flat_grad();
47 Tensor abs_grad_sum = flat_grad.abs().sum();
49 if (abs_grad_sum.item<
float>() <= options.tolerance_grad_) {
56 while (n_iter < options.max_iter_) {
60 if (state_n_iter == 1) {
63 prev_flat_grad = flat_grad.clone();
65 Tensor y = flat_grad.sub(prev_flat_grad);
69 if (ys.item<
float>() > 1e-10) {
72 if (old_dirs.size() == options.history_size_) {
79 old_dirs.push_back(y);
80 old_stps.push_back(s);
83 H_diag = ys / y.dot(y);
86 int64_t num_old = old_dirs.size();
88 for (int64_t i = 0; i < num_old; i++) {
89 ro.at(i) = ONE / old_dirs.at(i).dot(old_stps.at(i));
92 Tensor q = flat_grad.neg();
93 for (int64_t i = num_old - 1; i >= 0; i--) {
94 al.at(i) = old_stps.at(i).dot(q) * ro.at(i);
95 q.add_(old_dirs.at(i), -al.at(i).item());
103 for (int64_t i = 0; i < num_old; i++) {
104 Tensor be_i = old_dirs.at(i).dot(r) * ro.at(i);
105 r.add_(old_stps.at(i), (al.at(i) - be_i).item());
107 prev_flat_grad.copy_(flat_grad);
116 t = torch::min(ONE, ONE / abs_grad_sum) * options.learning_rate_;
118 t = torch::tensor(options.learning_rate_, torch::kFloat32);
121 Tensor gtd = flat_grad.dot(d);
123 int64_t ls_func_evals = 0;
124 if (n_iter != options.max_iter_) {
129 flat_grad = gather_flat_grad();
130 abs_grad_sum = flat_grad.abs().sum();
134 current_evals += ls_func_evals;
140 if (n_iter == options.max_iter_) {
142 }
else if (current_evals >= options.max_eval_) {
144 }
else if (abs_grad_sum.item<
float>() <= options.tolerance_grad_) {
146 }
else if (gtd.item<
float>() > -options.tolerance_grad_) {
149 d.mul(t).abs_().sum().item<
float>() <= options.tolerance_change_) {
152 std::abs(loss.item<
float>() - prev_loss.item<
float>()) <
153 options.tolerance_change_) {
161 serialize(*
this, archive);
165 serialize(*
this, archive);
std::vector< Tensor > parameters_
The parameters this optimizer optimizes.
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
torch::Tensor step(LossClosure closure) override
void load(serialize::InputArchive &archive) override
Deserializes the optimizer state from the given archive.
void save(serialize::OutputArchive &archive) const override
Serializes the optimizer state into the given archive.
std::function< Tensor()> LossClosure
A loss function closure, which is expected to return the loss value.