Caffe2 - C++ API
A deep learning, cross platform ML framework
lbfgs.h
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/nn/module.h>
5 #include <torch/optim/optimizer.h>
6 #include <torch/optim/serialize.h>
7 #include <torch/serialize/archive.h>
8 
9 #include <deque>
10 #include <functional>
11 #include <memory>
12 #include <vector>
13 
14 namespace torch {
15 namespace optim {
16 
17 struct TORCH_API LBFGSOptions {
18  LBFGSOptions(double learning_rate);
19  TORCH_ARG(double, learning_rate);
20  TORCH_ARG(int64_t, max_iter) = 20;
21  TORCH_ARG(int64_t, max_eval) = 25;
22  TORCH_ARG(float, tolerance_grad) = 1e-5;
23  TORCH_ARG(float, tolerance_change) = 1e-9;
24  TORCH_ARG(size_t, history_size) = 100;
25 };
26 
27 class TORCH_API LBFGS : public LossClosureOptimizer {
28  public:
29  template <typename ParameterContainer>
30  explicit LBFGS(ParameterContainer&& parameters, const LBFGSOptions& options)
31  : LossClosureOptimizer(std::forward<ParameterContainer>(parameters)),
32  options(options),
33  ro(options.history_size_),
34  al(options.history_size_) {}
35 
36  torch::Tensor step(LossClosure closure) override;
37 
38  LBFGSOptions options;
39 
40  void save(serialize::OutputArchive& archive) const override;
41  void load(serialize::InputArchive& archive) override;
42 
43  Tensor d{torch::empty({0})};
44  Tensor H_diag{torch::empty({0})};
45  Tensor prev_flat_grad{torch::empty({0})};
46  Tensor t{torch::zeros(1)};
47  Tensor prev_loss{torch::zeros(1)};
48  std::vector<Tensor> ro;
49  std::vector<Tensor> al;
50  std::deque<Tensor> old_dirs;
51  std::deque<Tensor> old_stps;
52  int64_t func_evals{0};
53  int64_t state_n_iter{0};
54 
55  private:
56  LBFGS() : options(0) {}
57 
58  Tensor gather_flat_grad();
59  void add_grad(const torch::Tensor& step_size, const Tensor& update);
60 
61  template <typename Self, typename Archive>
62  static void serialize(Self& self, Archive& archive) {
63  archive("d", self.d, /*is_buffer=*/true);
64  archive("t", self.t, /*is_buffer=*/true);
65  archive("H_diag", self.H_diag, /*is_buffer=*/true);
66  archive("prev_flat_grad", self.prev_flat_grad, /*is_buffer=*/true);
67  archive("prev_loss", self.prev_loss, /*is_buffer=*/true);
68  optim::serialize(archive, "old_dirs", self.old_dirs);
69  optim::serialize(archive, "old_stps", self.old_stps);
70  }
71 };
72 } // namespace optim
73 } // namespace torch
Definition: jit_type.h:17
Optimizer that requires the loss function to be supplied to the step() function, as it may evaluate t...
Definition: optimizer.h:110
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32
std::function< Tensor()> LossClosure
A loss function closure, which is expected to return the loss value.
Definition: optimizer.h:113