Caffe2 - C++ API
A deep learning, cross platform ML framework
adam.h
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/nn/module.h>
5 #include <torch/optim/optimizer.h>
6 #include <torch/optim/serialize.h>
7 
8 #include <utility>
9 #include <vector>
10 
11 namespace torch {
12 namespace serialize {
13 class OutputArchive;
14 class InputArchive;
15 } // namespace serialize
16 } // namespace torch
17 
18 namespace torch {
19 namespace optim {
20 
21 struct TORCH_API AdamOptions {
22  /* implicit */ AdamOptions(double learning_rate);
23  TORCH_ARG(double, learning_rate);
24  TORCH_ARG(double, beta1) = 0.9;
25  TORCH_ARG(double, beta2) = 0.999;
26  TORCH_ARG(double, weight_decay) = 0;
27  TORCH_ARG(double, eps) = 1e-8;
28  TORCH_ARG(bool, amsgrad) = false;
29 };
30 
31 class TORCH_API Adam : public Optimizer {
32  public:
33  template <typename ParameterContainer>
34  explicit Adam(ParameterContainer&& parameters, const AdamOptions& options)
35  : Optimizer(std::forward<ParameterContainer>(parameters)),
36  options(options) {}
37 
38  void step() override;
39 
40  void save(serialize::OutputArchive& archive) const override;
41  void load(serialize::InputArchive& archive) override;
42 
43  AdamOptions options;
44 
45  std::vector<int64_t> step_buffers;
46  std::vector<Tensor> exp_average_buffers;
47  std::vector<Tensor> exp_average_sq_buffers;
48  std::vector<Tensor> max_exp_average_sq_buffers;
49 
50  private:
51  Adam() : options(0) {}
52 
53  template <typename Self, typename Archive>
54  static void serialize(Self& self, Archive& archive) {
55  _TORCH_OPTIM_SERIALIZE(step_buffers);
56  _TORCH_OPTIM_SERIALIZE(exp_average_buffers);
57  _TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
58  _TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
59  }
60 };
61 } // namespace optim
62 } // namespace torch
Optimizer that defines a required step() method that takes no arguments and produces no values...
Definition: optimizer.h:100
Definition: jit_type.h:17
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32