Caffe2 - C++ API
A deep learning, cross platform ML framework
Public Member Functions | Data Fields
torch::optim::SGD Class Reference
Inheritance diagram for torch::optim::SGD:
torch::optim::Optimizer torch::optim::detail::OptimizerBase

Public Member Functions

template<typename ParameterContainer >
 SGD (ParameterContainer &&parameters, const SGDOptions &options)
 
void step () override
 
void save (serialize::OutputArchive &archive) const override
 Serializes the optimizer state into the given archive.
 
void load (serialize::InputArchive &archive) override
 Deserializes the optimizer state from the given archive.
 
- Public Member Functions inherited from torch::optim::detail::OptimizerBase
 OptimizerBase (std::vector< Tensor > parameters)
 Constructs the Optimizer from a vector of parameters.
 
void add_parameters (const std::vector< Tensor > &parameters)
 Adds the given vector of parameters to the optimizer's parameter list.
 
virtual void zero_grad ()
 Zeros out the gradients of all parameters.
 
const std::vector< Tensor > & parameters () const noexcept
 Provides a const reference to the parameters this optimizer holds.
 
std::vector< Tensor > & parameters () noexcept
 Provides a reference to the parameters this optimizer holds.
 
size_t size () const noexcept
 Returns the number of parameters referenced by the optimizer.
 

Data Fields

SGDOptions options
 
std::vector< Tensormomentum_buffers
 

Additional Inherited Members

- Protected Member Functions inherited from torch::optim::detail::OptimizerBase
template<typename T >
Tbuffer_at (std::vector< T > &buffers, size_t index)
 Accesses a buffer at the given index. More...
 
Tensorbuffer_at (std::vector< Tensor > &buffers, size_t index)
 Accesses a buffer at the given index, converts it to the type of the parameter at the corresponding index (a no-op if they match). More...
 
- Protected Attributes inherited from torch::optim::detail::OptimizerBase
std::vector< Tensorparameters_
 The parameters this optimizer optimizes.
 

Detailed Description

Definition at line 31 of file sgd.h.


The documentation for this class was generated from the following files: