Caffe2 - C++ API
A deep learning, cross platform ML framework
adam.cpp
1 #include <torch/optim/adam.h>
2 
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/nn/module.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/utils.h>
7 
8 #include <ATen/ATen.h>
9 
10 #include <cmath>
11 #include <functional>
12 
13 namespace torch {
14 namespace optim {
15 AdamOptions::AdamOptions(double learning_rate)
16  : learning_rate_(learning_rate) {}
17 
18 void Adam::step() {
19  for (size_t i = 0; i < parameters_.size(); ++i) {
20  Tensor p = parameters_.at(i);
21  if (!p.grad().defined()) {
22  continue;
23  }
24 
25  if (options.weight_decay_ > 0) {
26  p.grad() = p.grad() + options.weight_decay_ * p;
27  }
28 
29  auto& exp_average = buffer_at(exp_average_buffers, i);
30  auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i);
31 
32  buffer_at(step_buffers, i) += 1;
33 
34  exp_average.mul_(options.beta1_).add_(p.grad(), 1 - options.beta1_);
35  exp_average_sq.mul_(options.beta2_)
36  .addcmul_(p.grad(), p.grad(), 1 - options.beta2_);
37 
38  Tensor denom = exp_average_sq;
39  if (options.amsgrad_) {
40  auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i);
41  max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq);
42  denom = max_exp_average_sq;
43  }
44 
45  const auto bias_correction1 =
46  1 - std::pow(options.beta1_, buffer_at(step_buffers, i));
47  const auto bias_correction2 =
48  1 - std::pow(options.beta2_, buffer_at(step_buffers, i));
49  const auto step_size =
50  options.learning_rate_ * std::sqrt(bias_correction2) / bias_correction1;
51 
52  NoGradGuard guard;
53  p.addcdiv_(exp_average, denom.sqrt() + options.eps_, -step_size);
54  }
55 }
56 
57 void Adam::save(serialize::OutputArchive& archive) const {
58  serialize(*this, archive);
59 }
60 
62  serialize(*this, archive);
63 }
64 } // namespace optim
65 } // namespace torch
void load(serialize::InputArchive &archive) override
Deserializes the optimizer state from the given archive.
Definition: adam.cpp:61
std::vector< Tensor > parameters_
The parameters this optimizer optimizes.
Definition: optimizer.h:83
void save(serialize::OutputArchive &archive) const override
Serializes the optimizer state into the given archive.
Definition: adam.cpp:57
T & buffer_at(std::vector< T > &buffers, size_t index)
Accesses a buffer at the given index.
Definition: optimizer.h:68
Definition: jit_type.h:17
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32