Caffe2 - C++ API
A deep learning, cross platform ML framework
rmsprop.cpp
1 #include <torch/optim/rmsprop.h>
2 
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/serialize/archive.h>
5 #include <torch/utils.h>
6 
7 #include <ATen/ATen.h>
8 
9 #include <functional>
10 
11 namespace torch {
12 namespace optim {
13 
14 RMSpropOptions::RMSpropOptions(double learning_rate)
15  : learning_rate_(learning_rate) {}
16 
19 void RMSprop::step() {
20  for (size_t i = 0; i < parameters_.size(); ++i) {
21  Tensor p = parameters_.at(i);
22  if (!p.grad().defined()) {
23  continue;
24  }
25 
26  if (options.weight_decay_ > 0) {
27  p.grad() = p.grad() + options.weight_decay_ * p;
28  }
29 
30  auto square_average = buffer_at(square_average_buffers, i);
31  square_average.mul_(options.alpha_)
32  .addcmul_(p.grad(), p.grad(), 1.0 - options.alpha_);
33 
34  Tensor average;
35  if (options.centered_ > 0) {
36  auto& grad_average = buffer_at(grad_average_buffers, i);
37  grad_average.mul_(options.alpha_).add_(p.grad(), 1.0 - options.alpha_);
38  average = square_average.addcmul(grad_average, grad_average, -1.0)
39  .sqrt()
40  .add_(options.eps_);
41  } else {
42  average = square_average.sqrt().add_(options.eps_);
43  }
44 
45  NoGradGuard guard;
46  if (options.momentum_ > 0) {
47  auto& momentum = buffer_at(momentum_buffers, i);
48  momentum.mul_(options.momentum_).addcdiv_(p.grad(), average);
49  p.add_(momentum, -options.learning_rate_);
50  } else {
51  p.addcdiv_(p.grad(), average, -options.learning_rate_);
52  }
53  }
54 }
55 
57  serialize(*this, archive);
58 }
59 
61  serialize(*this, archive);
62 }
63 } // namespace optim
64 } // namespace torch
std::vector< Tensor > parameters_
The parameters this optimizer optimizes.
Definition: optimizer.h:83
void load(serialize::InputArchive &archive) override
Deserializes the optimizer state from the given archive.
Definition: rmsprop.cpp:60
void step() override
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py.
Definition: rmsprop.cpp:19
void save(serialize::OutputArchive &archive) const override
Serializes the optimizer state into the given archive.
Definition: rmsprop.cpp:56
T & buffer_at(std::vector< T > &buffers, size_t index)
Accesses a buffer at the given index.
Definition: optimizer.h:68
Definition: jit_type.h:17
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32