Applies Batch Normalization to an input. More...
#include <batchnorm.h>
Public Member Functions | |
BatchNormImpl (int64_t features) | |
BatchNormImpl (BatchNormOptions options) | |
void | reset () override |
reset() must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules. More... | |
void | pretty_print (std::ostream &stream) const override |
Pretty prints the BatchNorm module into the given stream . | |
Tensor | forward (const Tensor &input) |
Applies batch normalization on the input using the stored mean and variance. More... | |
Tensor | pure_forward (const Tensor &input, const Tensor &mean, const Tensor &variance) |
Applies batch normalization on the input using the given mean and variance statistics. More... | |
Public Member Functions inherited from torch::nn::Cloneable< BatchNormImpl > | |
std::shared_ptr< Module > | clone (const optional< Device > &device=nullopt) const override |
Performs a recursive "deep copy" of the Module , such that all parameters and submodules in the cloned module are different from those in the original module. More... | |
Public Member Functions inherited from torch::nn::Module | |
Module (std::string name) | |
Tells the base Module about the name of the submodule. | |
Module () | |
Constructs the module without immediate knowledge of the submodule's name. More... | |
const std::string & | name () const noexcept |
Returns the name of the Module . More... | |
void | apply (const ModuleApplyFunction &function) |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const ConstModuleApplyFunction &function) const |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const NamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const ConstNamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) const |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const ModulePointerApplyFunction &function) const |
Applies the function to the Module and recursively to every submodule. More... | |
void | apply (const NamedModulePointerApplyFunction &function, const std::string &name_prefix=std::string()) const |
Applies the function to the Module and recursively to every submodule. More... | |
std::vector< Tensor > | parameters (bool recurse=true) const |
Returns the parameters of this Module and if recurse is true, also recursively of every submodule. More... | |
OrderedDict< std::string, Tensor > | named_parameters (bool recurse=true) const |
Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule. More... | |
std::vector< Tensor > | buffers (bool recurse=true) const |
Returns the buffers of this Module and if recurse is true, also recursively of every submodule. More... | |
OrderedDict< std::string, Tensor > | named_buffers (bool recurse=true) const |
Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule. More... | |
std::vector< std::shared_ptr< Module > > | modules (bool include_self=true) const |
Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position. More... | |
OrderedDict< std::string, std::shared_ptr< Module > > | named_modules (const std::string &name_prefix=std::string(), bool include_self=true) const |
Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys, and if include_self is true, also inserts a shared_ptr to this module in the first position. More... | |
std::vector< std::shared_ptr< Module > > | children () const |
Returns the direct submodules of this Module . | |
OrderedDict< std::string, std::shared_ptr< Module > > | named_children () const |
Returns an OrderedDict of the direct submodules of this Module and their keys. More... | |
virtual void | train (bool on=true) |
Enables "training" mode. | |
void | eval () |
Calls train(false) to enable "eval" mode. More... | |
virtual bool | is_training () const noexcept |
True if the module is in training mode. More... | |
virtual void | to (torch::Device device, torch::Dtype dtype, bool non_blocking=false) |
Recursively casts all parameters to the given dtype and device . More... | |
virtual void | to (torch::Dtype dtype, bool non_blocking=false) |
Recursively casts all parameters to the given dtype. More... | |
virtual void | to (torch::Device device, bool non_blocking=false) |
Recursively moves all parameters to the given device. More... | |
virtual void | zero_grad () |
Recursively zeros out the grad value of each registered parameter. | |
template<typename ModuleType > | |
ModuleType::ContainedType * | as () noexcept |
Attempts to cast this Module to the given ModuleType . More... | |
template<typename ModuleType > | |
const ModuleType::ContainedType * | as () const noexcept |
Attempts to cast this Module to the given ModuleType . More... | |
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>> | |
ModuleType * | as () noexcept |
Attempts to cast this Module to the given ModuleType . More... | |
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>> | |
const ModuleType * | as () const noexcept |
Attempts to cast this Module to the given ModuleType . More... | |
virtual void | save (serialize::OutputArchive &archive) const |
Serializes the Module into the given OutputArchive . | |
virtual void | load (serialize::InputArchive &archive) |
Deserializes the Module from the given InputArchive . | |
Data Fields | |
BatchNormOptions | options |
The options with which this module was constructed. | |
Tensor | weight |
The learned weight. More... | |
Tensor | bias |
The learned bias. More... | |
Tensor | running_mean |
The running mean. More... | |
Tensor | running_var |
The running variance. More... | |
Additional Inherited Members | |
Public Types inherited from torch::nn::Module | |
using | ModuleApplyFunction = std::function< void(Module &)> |
using | ConstModuleApplyFunction = std::function< void(const Module &)> |
using | NamedModuleApplyFunction = std::function< void(const std::string &, Module &)> |
using | ConstNamedModuleApplyFunction = std::function< void(const std::string &, const Module &)> |
using | ModulePointerApplyFunction = std::function< void(const std::shared_ptr< Module > &)> |
using | NamedModulePointerApplyFunction = std::function< void(const std::string &, const std::shared_ptr< Module > &)> |
Protected Member Functions inherited from torch::nn::Module | |
Tensor & | register_parameter (std::string name, Tensor tensor, bool requires_grad=true) |
Registers a parameter with this Module . More... | |
Tensor & | register_buffer (std::string name, Tensor tensor) |
Registers a buffer with this Module . More... | |
template<typename ModuleType > | |
std::shared_ptr< ModuleType > | register_module (std::string name, std::shared_ptr< ModuleType > module) |
Registers a submodule with this Module . More... | |
template<typename ModuleType > | |
std::shared_ptr< ModuleType > | register_module (std::string name, ModuleHolder< ModuleType > module_holder) |
Registers a submodule with this Module . More... | |
Applies Batch Normalization to an input.
Refer to the documentation for BatchNorm1d
in PyTorch to learn more about the exact semantics of this module, but see the note below regarding differences between the Python and C++ API.
.. attention:: In the Python API, there are separate implementations for 1-D, 2-D and 3-D BatchNorm. In C++, there is only one BatchNorm
module, which works for any of these dimensions.
Definition at line 48 of file batchnorm.h.
Applies batch normalization on the input
using the stored mean and variance.
The module must be constructed with stateful = true
when calling this method, as the module will otherwise not store running statistics. If you want to supply the mean and variance yourself, use pure_forward
.
Definition at line 44 of file batchnorm.cpp.
Tensor torch::nn::BatchNormImpl::pure_forward | ( | const Tensor & | input, |
const Tensor & | mean, | ||
const Tensor & | variance | ||
) |
Applies batch normalization on the input
using the given mean
and variance
statistics.
Definition at line 53 of file batchnorm.cpp.
|
overridevirtual |
reset()
must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.
Implements torch::nn::Cloneable< BatchNormImpl >.
Definition at line 21 of file batchnorm.cpp.
Tensor torch::nn::BatchNormImpl::bias |
The learned bias.
Only defined if the affine
option was true
upon construction.
Definition at line 83 of file batchnorm.h.
Tensor torch::nn::BatchNormImpl::running_mean |
The running mean.
Only defined if the stateful
option was true
upon construction.
Definition at line 87 of file batchnorm.h.
Tensor torch::nn::BatchNormImpl::running_var |
The running variance.
Only defined if the stateful
option was true
upon construction.
Definition at line 91 of file batchnorm.h.
Tensor torch::nn::BatchNormImpl::weight |
The learned weight.
Only defined if the affine
option was true
upon construction.
Definition at line 79 of file batchnorm.h.