|
using | ModuleApplyFunction = std::function< void(Module &)> |
|
using | ConstModuleApplyFunction = std::function< void(const Module &)> |
|
using | NamedModuleApplyFunction = std::function< void(const std::string &, Module &)> |
|
using | ConstNamedModuleApplyFunction = std::function< void(const std::string &, const Module &)> |
|
using | ModulePointerApplyFunction = std::function< void(const std::shared_ptr< Module > &)> |
|
using | NamedModulePointerApplyFunction = std::function< void(const std::string &, const std::shared_ptr< Module > &)> |
|
| Module (std::string name) |
| Tells the base Module about the name of the submodule.
|
|
| Module () |
| Constructs the module without immediate knowledge of the submodule's name. More...
|
|
const std::string & | name () const noexcept |
| Returns the name of the Module . More...
|
|
virtual std::shared_ptr< Module > | clone (const optional< Device > &device=nullopt) const |
| Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules. More...
|
|
void | apply (const ModuleApplyFunction &function) |
| Applies the function to the Module and recursively to every submodule. More...
|
|
void | apply (const ConstModuleApplyFunction &function) const |
| Applies the function to the Module and recursively to every submodule. More...
|
|
void | apply (const NamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) |
| Applies the function to the Module and recursively to every submodule. More...
|
|
void | apply (const ConstNamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) const |
| Applies the function to the Module and recursively to every submodule. More...
|
|
void | apply (const ModulePointerApplyFunction &function) const |
| Applies the function to the Module and recursively to every submodule. More...
|
|
void | apply (const NamedModulePointerApplyFunction &function, const std::string &name_prefix=std::string()) const |
| Applies the function to the Module and recursively to every submodule. More...
|
|
std::vector< Tensor > | parameters (bool recurse=true) const |
| Returns the parameters of this Module and if recurse is true, also recursively of every submodule. More...
|
|
OrderedDict< std::string, Tensor > | named_parameters (bool recurse=true) const |
| Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule. More...
|
|
std::vector< Tensor > | buffers (bool recurse=true) const |
| Returns the buffers of this Module and if recurse is true, also recursively of every submodule. More...
|
|
OrderedDict< std::string, Tensor > | named_buffers (bool recurse=true) const |
| Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule. More...
|
|
std::vector< std::shared_ptr< Module > > | modules (bool include_self=true) const |
| Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position. More...
|
|
OrderedDict< std::string, std::shared_ptr< Module > > | named_modules (const std::string &name_prefix=std::string(), bool include_self=true) const |
| Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys, and if include_self is true, also inserts a shared_ptr to this module in the first position. More...
|
|
std::vector< std::shared_ptr< Module > > | children () const |
| Returns the direct submodules of this Module .
|
|
OrderedDict< std::string, std::shared_ptr< Module > > | named_children () const |
| Returns an OrderedDict of the direct submodules of this Module and their keys. More...
|
|
virtual void | train (bool on=true) |
| Enables "training" mode.
|
|
void | eval () |
| Calls train(false) to enable "eval" mode. More...
|
|
virtual bool | is_training () const noexcept |
| True if the module is in training mode. More...
|
|
virtual void | to (torch::Device device, torch::Dtype dtype, bool non_blocking=false) |
| Recursively casts all parameters to the given dtype and device . More...
|
|
virtual void | to (torch::Dtype dtype, bool non_blocking=false) |
| Recursively casts all parameters to the given dtype. More...
|
|
virtual void | to (torch::Device device, bool non_blocking=false) |
| Recursively moves all parameters to the given device. More...
|
|
virtual void | zero_grad () |
| Recursively zeros out the grad value of each registered parameter.
|
|
template<typename ModuleType > |
ModuleType::ContainedType * | as () noexcept |
| Attempts to cast this Module to the given ModuleType . More...
|
|
template<typename ModuleType > |
const ModuleType::ContainedType * | as () const noexcept |
| Attempts to cast this Module to the given ModuleType . More...
|
|
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>> |
ModuleType * | as () noexcept |
| Attempts to cast this Module to the given ModuleType . More...
|
|
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>> |
const ModuleType * | as () const noexcept |
| Attempts to cast this Module to the given ModuleType . More...
|
|
virtual void | save (serialize::OutputArchive &archive) const |
| Serializes the Module into the given OutputArchive .
|
|
virtual void | load (serialize::InputArchive &archive) |
| Deserializes the Module from the given InputArchive .
|
|
virtual void | pretty_print (std::ostream &stream) const |
| Streams a pretty representation of the Module into the given stream . More...
|
|
Tensor & | register_parameter (std::string name, Tensor tensor, bool requires_grad=true) |
| Registers a parameter with this Module . More...
|
|
Tensor & | register_buffer (std::string name, Tensor tensor) |
| Registers a buffer with this Module . More...
|
|
template<typename ModuleType > |
std::shared_ptr< ModuleType > | register_module (std::string name, std::shared_ptr< ModuleType > module) |
| Registers a submodule with this Module . More...
|
|
template<typename ModuleType > |
std::shared_ptr< ModuleType > | register_module (std::string name, ModuleHolder< ModuleType > module_holder) |
| Registers a submodule with this Module . More...
|
|
Definition at line 18 of file modules.cpp.