Caffe2 - C++ API
A deep learning, cross platform ML framework
Public Member Functions | Data Fields
torch::nn::LinearImpl Class Reference

Applies a linear transformation with optional bias. More...

#include <linear.h>

Inheritance diagram for torch::nn::LinearImpl:
torch::nn::Cloneable< LinearImpl > torch::nn::Module

Public Member Functions

 LinearImpl (int64_t in, int64_t out)
 
 LinearImpl (LinearOptions options)
 
void reset () override
 reset() must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules. More...
 
void pretty_print (std::ostream &stream) const override
 Pretty prints the Linear module into the given stream.
 
Tensor forward (const Tensor &input)
 Transforms the input tensor by multiplying with the weight and optionally adding the bias, if with_bias is true in the options. More...
 
- Public Member Functions inherited from torch::nn::Cloneable< LinearImpl >
std::shared_ptr< Moduleclone (const optional< Device > &device=nullopt) const override
 Performs a recursive "deep copy" of the Module, such that all parameters and submodules in the cloned module are different from those in the original module. More...
 
- Public Member Functions inherited from torch::nn::Module
 Module (std::string name)
 Tells the base Module about the name of the submodule.
 
 Module ()
 Constructs the module without immediate knowledge of the submodule's name. More...
 
const std::string & name () const noexcept
 Returns the name of the Module. More...
 
void apply (const ModuleApplyFunction &function)
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ConstModuleApplyFunction &function) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const NamedModuleApplyFunction &function, const std::string &name_prefix=std::string())
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ConstNamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ModulePointerApplyFunction &function) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const NamedModulePointerApplyFunction &function, const std::string &name_prefix=std::string()) const
 Applies the function to the Module and recursively to every submodule. More...
 
std::vector< Tensorparameters (bool recurse=true) const
 Returns the parameters of this Module and if recurse is true, also recursively of every submodule. More...
 
OrderedDict< std::string, Tensornamed_parameters (bool recurse=true) const
 Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule. More...
 
std::vector< Tensorbuffers (bool recurse=true) const
 Returns the buffers of this Module and if recurse is true, also recursively of every submodule. More...
 
OrderedDict< std::string, Tensornamed_buffers (bool recurse=true) const
 Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule. More...
 
std::vector< std::shared_ptr< Module > > modules (bool include_self=true) const
 Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position. More...
 
OrderedDict< std::string, std::shared_ptr< Module > > named_modules (const std::string &name_prefix=std::string(), bool include_self=true) const
 Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys, and if include_self is true, also inserts a shared_ptr to this module in the first position. More...
 
std::vector< std::shared_ptr< Module > > children () const
 Returns the direct submodules of this Module.
 
OrderedDict< std::string, std::shared_ptr< Module > > named_children () const
 Returns an OrderedDict of the direct submodules of this Module and their keys. More...
 
virtual void train (bool on=true)
 Enables "training" mode.
 
void eval ()
 Calls train(false) to enable "eval" mode. More...
 
virtual bool is_training () const noexcept
 True if the module is in training mode. More...
 
virtual void to (torch::Device device, torch::Dtype dtype, bool non_blocking=false)
 Recursively casts all parameters to the given dtype and device. More...
 
virtual void to (torch::Dtype dtype, bool non_blocking=false)
 Recursively casts all parameters to the given dtype. More...
 
virtual void to (torch::Device device, bool non_blocking=false)
 Recursively moves all parameters to the given device. More...
 
virtual void zero_grad ()
 Recursively zeros out the grad value of each registered parameter.
 
template<typename ModuleType >
ModuleType::ContainedType * as () noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType >
const ModuleType::ContainedType * as () const noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType * as () noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>>
const ModuleType * as () const noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
virtual void save (serialize::OutputArchive &archive) const
 Serializes the Module into the given OutputArchive.
 
virtual void load (serialize::InputArchive &archive)
 Deserializes the Module from the given InputArchive.
 

Data Fields

LinearOptions options
 The options used to configure this module.
 
Tensor weight
 The learned weight.
 
Tensor bias
 The learned bias. More...
 

Additional Inherited Members

- Public Types inherited from torch::nn::Module
using ModuleApplyFunction = std::function< void(Module &)>
 
using ConstModuleApplyFunction = std::function< void(const Module &)>
 
using NamedModuleApplyFunction = std::function< void(const std::string &, Module &)>
 
using ConstNamedModuleApplyFunction = std::function< void(const std::string &, const Module &)>
 
using ModulePointerApplyFunction = std::function< void(const std::shared_ptr< Module > &)>
 
using NamedModulePointerApplyFunction = std::function< void(const std::string &, const std::shared_ptr< Module > &)>
 
- Protected Member Functions inherited from torch::nn::Module
Tensorregister_parameter (std::string name, Tensor tensor, bool requires_grad=true)
 Registers a parameter with this Module. More...
 
Tensorregister_buffer (std::string name, Tensor tensor)
 Registers a buffer with this Module. More...
 
template<typename ModuleType >
std::shared_ptr< ModuleType > register_module (std::string name, std::shared_ptr< ModuleType > module)
 Registers a submodule with this Module. More...
 
template<typename ModuleType >
std::shared_ptr< ModuleType > register_module (std::string name, ModuleHolder< ModuleType > module_holder)
 Registers a submodule with this Module. More...
 

Detailed Description

Applies a linear transformation with optional bias.

Definition at line 25 of file linear.h.

Member Function Documentation

Tensor torch::nn::LinearImpl::forward ( const Tensor input)

Transforms the input tensor by multiplying with the weight and optionally adding the bias, if with_bias is true in the options.

Definition at line 37 of file linear.cpp.

void torch::nn::LinearImpl::reset ( )
overridevirtual

reset() must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.

Implements torch::nn::Cloneable< LinearImpl >.

Definition at line 17 of file linear.cpp.

Field Documentation

Tensor torch::nn::LinearImpl::bias

The learned bias.

If with_bias is false in the options, this tensor is undefined.

Definition at line 47 of file linear.h.


The documentation for this class was generated from the following files: