Caffe2 - C++ API
A deep learning, cross platform ML framework
sgd.cpp
1 #include <torch/optim/sgd.h>
2 
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/nn/pimpl.h>
5 #include <torch/optim/optimizer.h>
6 #include <torch/optim/serialize.h>
7 #include <torch/types.h>
8 #include <torch/utils.h>
9 
10 #include <ATen/ATen.h>
11 
12 #include <functional>
13 
14 namespace torch {
15 namespace optim {
16 SGDOptions::SGDOptions(double learning_rate) : learning_rate_(learning_rate) {}
17 
18 void SGD::step() {
19  for (size_t i = 0; i < parameters_.size(); ++i) {
20  Tensor p = parameters_.at(i);
21 
22  if (!p.grad().defined()) {
23  continue;
24  }
25 
26  auto update = p.grad();
27 
28  if (options.weight_decay_ > 0) {
29  update += options.weight_decay_ * p;
30  }
31 
32  if (options.momentum_ != 0) {
33  const auto dampening = iteration_ == 0 ? 1 : 1 - options.dampening_;
34  auto& momentum = buffer_at(momentum_buffers, i);
35  momentum = (options.momentum_ * momentum) + (dampening * update);
36  if (options.nesterov_) {
37  // See github.com/lisa-lab/pylearn2/pull/136#issuecomment-10381617
38  // for notes on this implementation of nesterov momentum.
39  update += options.momentum_ * momentum;
40  } else {
41  update = momentum;
42  }
43  }
44 
45  NoGradGuard guard;
46  p.add_(-options.learning_rate_ * update);
47  }
48  iteration_ += 1;
49 }
50 
51 void SGD::save(serialize::OutputArchive& archive) const {
52  optim::serialize(archive, "momentum_buffers", momentum_buffers);
53 }
54 
56  optim::serialize(archive, "momentum_buffers", momentum_buffers);
57 }
58 } // namespace optim
59 } // namespace torch
void load(serialize::InputArchive &archive) override
Deserializes the optimizer state from the given archive.
Definition: sgd.cpp:55
std::vector< Tensor > parameters_
The parameters this optimizer optimizes.
Definition: optimizer.h:83
void save(serialize::OutputArchive &archive) const override
Serializes the optimizer state into the given archive.
Definition: sgd.cpp:51
T & buffer_at(std::vector< T > &buffers, size_t index)
Accesses a buffer at the given index.
Definition: optimizer.h:68
Definition: jit_type.h:17
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32