Caffe2 - C++ API
A deep learning, cross platform ML framework
adagrad.h
1 #pragma once
2 
3 #include <torch/nn/pimpl.h>
4 #include <torch/optim/optimizer.h>
5 #include <torch/optim/serialize.h>
6 #include <torch/types.h>
7 
8 #include <utility>
9 #include <vector>
10 
11 namespace torch {
12 namespace serialize {
13 class OutputArchive;
14 class InputArchive;
15 } // namespace serialize
16 } // namespace torch
17 
18 namespace torch {
19 namespace optim {
20 
21 struct TORCH_API AdagradOptions {
22  AdagradOptions(double learning_rate);
23  TORCH_ARG(double, learning_rate);
24  TORCH_ARG(double, lr_decay) = 0;
25  TORCH_ARG(double, weight_decay) = 0;
26 };
27 
28 class TORCH_API Adagrad : public Optimizer {
29  public:
30  template <typename ParameterContainer>
31  explicit Adagrad(
32  ParameterContainer&& parameters,
33  const AdagradOptions& options)
34  : Optimizer(std::forward<ParameterContainer>(parameters)),
35  options(options) {}
36 
37  void step() override;
38 
39  AdagradOptions options;
40 
41  void save(serialize::OutputArchive& archive) const override;
42  void load(serialize::InputArchive& archive) override;
43 
44  std::vector<Tensor> sum_buffers;
45  std::vector<int64_t> step_buffers;
46 
47  private:
48  Adagrad() : options(0) {}
49 
50  template <typename Self, typename Archive>
51  static void serialize(Self& self, Archive& archive) {
52  _TORCH_OPTIM_SERIALIZE(sum_buffers);
53  _TORCH_OPTIM_SERIALIZE(step_buffers);
54  }
55 };
56 } // namespace optim
57 } // namespace torch
Optimizer that defines a required step() method that takes no arguments and produces no values...
Definition: optimizer.h:100
Definition: jit_type.h:17
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32