Caffe2 - C++ API
A deep learning, cross platform ML framework
sgd.h
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/nn/module.h>
5 #include <torch/optim/optimizer.h>
6 #include <torch/types.h>
7 
8 #include <cstddef>
9 #include <utility>
10 #include <vector>
11 
12 namespace torch {
13 namespace serialize {
14 class OutputArchive;
15 class InputArchive;
16 } // namespace serialize
17 } // namespace torch
18 
19 namespace torch {
20 namespace optim {
21 
22 struct TORCH_API SGDOptions {
23  /* implicit */ SGDOptions(double learning_rate);
24  TORCH_ARG(double, learning_rate);
25  TORCH_ARG(double, momentum) = 0;
26  TORCH_ARG(double, dampening) = 0;
27  TORCH_ARG(double, weight_decay) = 0;
28  TORCH_ARG(bool, nesterov) = false;
29 };
30 
31 class TORCH_API SGD : public Optimizer {
32  public:
33  template <typename ParameterContainer>
34  explicit SGD(ParameterContainer&& parameters, const SGDOptions& options)
35  : Optimizer(std::forward<ParameterContainer>(parameters)),
36  options(options) {}
37 
38  void step() override;
39 
40  void save(serialize::OutputArchive& archive) const override;
41  void load(serialize::InputArchive& archive) override;
42 
43  SGDOptions options;
44 
45  std::vector<Tensor> momentum_buffers;
46 
47  private:
48  SGD() : options(0) {}
49 
51  size_t iteration_{0};
52 };
53 } // namespace optim
54 } // namespace torch
Optimizer that defines a required step() method that takes no arguments and produces no values...
Definition: optimizer.h:100
Definition: jit_type.h:17
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32