Caffe2 - C++ API
A deep learning, cross platform ML framework
optimizer.h
1 #pragma once
2 
3 #include <torch/csrc/WindowsTorchApiMacro.h>
4 
5 #include <algorithm>
6 #include <functional>
7 #include <iterator>
8 #include <memory>
9 #include <string>
10 #include <vector>
11 
12 // Forward declarations confuse Doxygen
13 #ifndef DOXYGEN_SHOULD_SKIP_THIS
14 namespace at {
15 class Tensor;
16 } // namespace at
17 
18 namespace torch {
19 using at::Tensor;
20 namespace serialize {
21 class OutputArchive;
22 class InputArchive;
23 } // namespace serialize
24 } // namespace torch
25 #endif // DOXYGEN_SHOULD_SKIP_THIS
26 
27 namespace torch {
28 namespace optim {
29 namespace detail {
34 class TORCH_API OptimizerBase {
35  public:
37  explicit OptimizerBase(std::vector<Tensor> parameters);
38 
39  virtual ~OptimizerBase() = default;
40 
42  void add_parameters(const std::vector<Tensor>& parameters);
43 
45  virtual void zero_grad();
46 
48  const std::vector<Tensor>& parameters() const noexcept;
49 
51  std::vector<Tensor>& parameters() noexcept;
52 
54  size_t size() const noexcept;
55 
57  virtual void save(serialize::OutputArchive& archive) const;
58 
60  virtual void load(serialize::InputArchive& archive);
61 
62  protected:
63  OptimizerBase() = default;
64 
67  template <typename T>
68  T& buffer_at(std::vector<T>& buffers, size_t index) {
69  if (buffers.size() <= index) {
70  const auto old_size = buffers.size();
71  buffers.resize(index + 1);
72  std::fill(buffers.begin() + old_size, buffers.end(), T{0});
73  }
74  return buffers[index];
75  }
76 
80  Tensor& buffer_at(std::vector<Tensor>& buffers, size_t index);
81 
83  std::vector<Tensor> parameters_;
84 };
85 
87 TORCH_API serialize::OutputArchive& operator<<(
88  serialize::OutputArchive& archive,
89  const OptimizerBase& optimizer);
90 
92 TORCH_API serialize::InputArchive& operator>>(
93  serialize::InputArchive& archive,
94  OptimizerBase& optimizer);
95 } // namespace detail
96 
101  public:
102  using detail::OptimizerBase::OptimizerBase;
103  virtual void step() = 0;
104 };
105 
111  public:
113  using LossClosure = std::function<Tensor()>;
114  using detail::OptimizerBase::OptimizerBase;
115  virtual Tensor step(LossClosure closure) = 0;
116 };
117 
118 } // namespace optim
119 } // namespace torch
std::vector< Tensor > parameters_
The parameters this optimizer optimizes.
Definition: optimizer.h:83
Optimizer that defines a required step() method that takes no arguments and produces no values...
Definition: optimizer.h:100
T & buffer_at(std::vector< T > &buffers, size_t index)
Accesses a buffer at the given index.
Definition: optimizer.h:68
Definition: jit_type.h:17
Optimizer that requires the loss function to be supplied to the step() function, as it may evaluate t...
Definition: optimizer.h:110
Flush-To-Zero and Denormals-Are-Zero mode.
Base class for all optimizers, that does not yet define a step() mechanism.
Definition: optimizer.h:34
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32
std::function< Tensor()> LossClosure
A loss function closure, which is expected to return the loss value.
Definition: optimizer.h:113