1 #include <torch/optim/optimizer.h> 3 #include <torch/csrc/autograd/generated/variable_factories.h> 4 #include <torch/ordered_dict.h> 5 #include <torch/serialize/archive.h> 6 #include <torch/types.h> 15 OptimizerBase::OptimizerBase(std::vector<Tensor> parameters)
16 : parameters_(
std::move(parameters)) {}
24 if (parameter.grad().defined()) {
25 parameter.grad().detach_();
26 parameter.grad().zero_();
44 if (buffers.size() <= index) {
45 buffers.reserve(index);
46 for (
auto i = buffers.size(); i <= index; ++i) {
47 buffers.push_back(torch::zeros_like(
parameters_.at(i)));
52 const auto& buffer = buffers.at(index);
53 if (buffer.device() != parameter.device() ||
54 buffer.dtype() != parameter.dtype()) {
55 buffers[index] = buffer.to(parameter.device(), parameter.scalar_type());
57 return buffers[index];
67 optimizer.
save(archive);
75 optimizer.
load(archive);
std::vector< Tensor > parameters_
The parameters this optimizer optimizes.
virtual void zero_grad()
Zeros out the gradients of all parameters.
T & buffer_at(std::vector< T > &buffers, size_t index)
Accesses a buffer at the given index.
virtual void load(serialize::InputArchive &archive)
Deserializes the optimizer state from the given archive.
size_t size() const noexcept
Returns the number of parameters referenced by the optimizer.
void add_parameters(const std::vector< Tensor > ¶meters)
Adds the given vector of parameters to the optimizer's parameter list.
Base class for all optimizers, that does not yet define a step() mechanism.
virtual void save(serialize::OutputArchive &archive) const
Serializes the optimizer state into the given archive.
const std::vector< Tensor > & parameters() const noexcept
Provides a const reference to the parameters this optimizer holds.