Caffe2 - C++ API
A deep learning, cross platform ML framework
rmsprop.h
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/nn/module.h>
5 #include <torch/optim/optimizer.h>
6 #include <torch/optim/serialize.h>
7 #include <torch/types.h>
8 
9 #include <functional>
10 #include <memory>
11 #include <string>
12 #include <vector>
13 
14 namespace torch {
15 namespace serialize {
16 class OutputArchive;
17 class InputArchive;
18 } // namespace serialize
19 } // namespace torch
20 
21 namespace torch {
22 namespace optim {
23 
24 struct TORCH_API RMSpropOptions {
25  RMSpropOptions(double learning_rate);
26  TORCH_ARG(double, learning_rate);
27  TORCH_ARG(double, alpha) = 0.99;
28  TORCH_ARG(double, eps) = 1e-8;
29  TORCH_ARG(double, weight_decay) = 0;
30  TORCH_ARG(double, momentum) = 0;
31  TORCH_ARG(bool, centered) = false;
32 };
33 
34 class TORCH_API RMSprop : public Optimizer {
35  public:
36  template <typename ParameterContainer>
37  explicit RMSprop(
38  ParameterContainer&& parameters,
39  const RMSpropOptions& options)
40  : Optimizer(std::forward<ParameterContainer>(parameters)),
41  options(options) {}
42 
43  void step() override;
44 
45  RMSpropOptions options;
46 
47  void save(serialize::OutputArchive& archive) const override;
48  void load(serialize::InputArchive& archive) override;
49 
50  std::vector<Tensor> square_average_buffers;
51  std::vector<Tensor> momentum_buffers;
52  std::vector<Tensor> grad_average_buffers;
53 
54  private:
55  RMSprop() : options(0) {}
56 
57  template <typename Self, typename Archive>
58  static void serialize(Self& self, Archive& archive) {
59  _TORCH_OPTIM_SERIALIZE(square_average_buffers);
60  _TORCH_OPTIM_SERIALIZE(momentum_buffers);
61  _TORCH_OPTIM_SERIALIZE(grad_average_buffers);
62  }
63 };
64 } // namespace optim
65 } // namespace torch
Optimizer that defines a required step() method that takes no arguments and produces no values...
Definition: optimizer.h:100
Definition: jit_type.h:17
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32