1 #include <torch/optim/sgd.h> 3 #include <torch/csrc/autograd/variable.h> 4 #include <torch/nn/pimpl.h> 5 #include <torch/optim/optimizer.h> 6 #include <torch/optim/serialize.h> 7 #include <torch/types.h> 8 #include <torch/utils.h> 10 #include <ATen/ATen.h> 16 SGDOptions::SGDOptions(
double learning_rate) : learning_rate_(learning_rate) {}
22 if (!p.grad().defined()) {
26 auto update = p.grad();
28 if (options.weight_decay_ > 0) {
29 update += options.weight_decay_ * p;
32 if (options.momentum_ != 0) {
33 const auto dampening = iteration_ == 0 ? 1 : 1 - options.dampening_;
34 auto& momentum =
buffer_at(momentum_buffers, i);
35 momentum = (options.momentum_ * momentum) + (dampening * update);
36 if (options.nesterov_) {
39 update += options.momentum_ * momentum;
46 p.add_(-options.learning_rate_ * update);
52 optim::serialize(archive,
"momentum_buffers", momentum_buffers);
56 optim::serialize(archive,
"momentum_buffers", momentum_buffers);
void load(serialize::InputArchive &archive) override
Deserializes the optimizer state from the given archive.
std::vector< Tensor > parameters_
The parameters this optimizer optimizes.
void save(serialize::OutputArchive &archive) const override
Serializes the optimizer state into the given archive.
T & buffer_at(std::vector< T > &buffers, size_t index)
Accesses a buffer at the given index.