Caffe2 - C++ API
A deep learning, cross platform ML framework
optimizer.cpp
1 #include <torch/optim/optimizer.h>
2 
3 #include <torch/csrc/autograd/generated/variable_factories.h>
4 #include <torch/ordered_dict.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/types.h>
7 
8 #include <string>
9 #include <utility>
10 #include <vector>
11 
12 namespace torch {
13 namespace optim {
14 namespace detail {
15 OptimizerBase::OptimizerBase(std::vector<Tensor> parameters)
16  : parameters_(std::move(parameters)) {}
17 
18 void OptimizerBase::add_parameters(const std::vector<Tensor>& parameters) {
19  parameters_.insert(parameters_.end(), parameters.begin(), parameters.end());
20 }
21 
23  for (auto& parameter : parameters_) {
24  if (parameter.grad().defined()) {
25  parameter.grad().detach_();
26  parameter.grad().zero_();
27  }
28  }
29 }
30 
31 const std::vector<Tensor>& OptimizerBase::parameters() const noexcept {
32  return parameters_;
33 }
34 
35 std::vector<Tensor>& OptimizerBase::parameters() noexcept {
36  return parameters_;
37 }
38 
39 size_t OptimizerBase::size() const noexcept {
40  return parameters_.size();
41 }
42 
43 Tensor& OptimizerBase::buffer_at(std::vector<Tensor>& buffers, size_t index) {
44  if (buffers.size() <= index) {
45  buffers.reserve(index);
46  for (auto i = buffers.size(); i <= index; ++i) {
47  buffers.push_back(torch::zeros_like(parameters_.at(i)));
48  }
49  }
50  // Copy the buffer to the device and dtype of the parameter.
51  const auto& parameter = parameters_.at(index);
52  const auto& buffer = buffers.at(index);
53  if (buffer.device() != parameter.device() ||
54  buffer.dtype() != parameter.dtype()) {
55  buffers[index] = buffer.to(parameter.device(), parameter.scalar_type());
56  }
57  return buffers[index];
58 }
59 
62 
64 serialize::OutputArchive& operator<<(
65  serialize::OutputArchive& archive,
66  const OptimizerBase& optimizer) {
67  optimizer.save(archive);
68  return archive;
69 }
70 
72 serialize::InputArchive& operator>>(
73  serialize::InputArchive& archive,
74  OptimizerBase& optimizer) {
75  optimizer.load(archive);
76  return archive;
77 }
78 } // namespace detail
79 } // namespace optim
80 } // namespace torch
std::vector< Tensor > parameters_
The parameters this optimizer optimizes.
Definition: optimizer.h:83
virtual void zero_grad()
Zeros out the gradients of all parameters.
Definition: optimizer.cpp:22
T & buffer_at(std::vector< T > &buffers, size_t index)
Accesses a buffer at the given index.
Definition: optimizer.h:68
Definition: jit_type.h:17
virtual void load(serialize::InputArchive &archive)
Deserializes the optimizer state from the given archive.
Definition: optimizer.cpp:61
size_t size() const noexcept
Returns the number of parameters referenced by the optimizer.
Definition: optimizer.cpp:39
void add_parameters(const std::vector< Tensor > &parameters)
Adds the given vector of parameters to the optimizer&#39;s parameter list.
Definition: optimizer.cpp:18
Base class for all optimizers, that does not yet define a step() mechanism.
Definition: optimizer.h:34
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32
virtual void save(serialize::OutputArchive &archive) const
Serializes the optimizer state into the given archive.
Definition: optimizer.cpp:60
const std::vector< Tensor > & parameters() const noexcept
Provides a const reference to the parameters this optimizer holds.
Definition: optimizer.cpp:31