Caffe2 - C++ API
A deep learning, cross platform ML framework
Public Member Functions | Data Fields
torch::nn::BatchNormImpl Class Reference

Applies Batch Normalization to an input. More...

#include <batchnorm.h>

Inheritance diagram for torch::nn::BatchNormImpl:
torch::nn::Cloneable< BatchNormImpl > torch::nn::Module

Public Member Functions

 BatchNormImpl (int64_t features)
 
 BatchNormImpl (BatchNormOptions options)
 
void reset () override
 reset() must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules. More...
 
void pretty_print (std::ostream &stream) const override
 Pretty prints the BatchNorm module into the given stream.
 
Tensor forward (const Tensor &input)
 Applies batch normalization on the input using the stored mean and variance. More...
 
Tensor pure_forward (const Tensor &input, const Tensor &mean, const Tensor &variance)
 Applies batch normalization on the input using the given mean and variance statistics. More...
 
- Public Member Functions inherited from torch::nn::Cloneable< BatchNormImpl >
std::shared_ptr< Moduleclone (const optional< Device > &device=nullopt) const override
 Performs a recursive "deep copy" of the Module, such that all parameters and submodules in the cloned module are different from those in the original module. More...
 
- Public Member Functions inherited from torch::nn::Module
 Module (std::string name)
 Tells the base Module about the name of the submodule.
 
 Module ()
 Constructs the module without immediate knowledge of the submodule's name. More...
 
const std::string & name () const noexcept
 Returns the name of the Module. More...
 
void apply (const ModuleApplyFunction &function)
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ConstModuleApplyFunction &function) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const NamedModuleApplyFunction &function, const std::string &name_prefix=std::string())
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ConstNamedModuleApplyFunction &function, const std::string &name_prefix=std::string()) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const ModulePointerApplyFunction &function) const
 Applies the function to the Module and recursively to every submodule. More...
 
void apply (const NamedModulePointerApplyFunction &function, const std::string &name_prefix=std::string()) const
 Applies the function to the Module and recursively to every submodule. More...
 
std::vector< Tensorparameters (bool recurse=true) const
 Returns the parameters of this Module and if recurse is true, also recursively of every submodule. More...
 
OrderedDict< std::string, Tensornamed_parameters (bool recurse=true) const
 Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule. More...
 
std::vector< Tensorbuffers (bool recurse=true) const
 Returns the buffers of this Module and if recurse is true, also recursively of every submodule. More...
 
OrderedDict< std::string, Tensornamed_buffers (bool recurse=true) const
 Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule. More...
 
std::vector< std::shared_ptr< Module > > modules (bool include_self=true) const
 Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position. More...
 
OrderedDict< std::string, std::shared_ptr< Module > > named_modules (const std::string &name_prefix=std::string(), bool include_self=true) const
 Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys, and if include_self is true, also inserts a shared_ptr to this module in the first position. More...
 
std::vector< std::shared_ptr< Module > > children () const
 Returns the direct submodules of this Module.
 
OrderedDict< std::string, std::shared_ptr< Module > > named_children () const
 Returns an OrderedDict of the direct submodules of this Module and their keys. More...
 
virtual void train (bool on=true)
 Enables "training" mode.
 
void eval ()
 Calls train(false) to enable "eval" mode. More...
 
virtual bool is_training () const noexcept
 True if the module is in training mode. More...
 
virtual void to (torch::Device device, torch::Dtype dtype, bool non_blocking=false)
 Recursively casts all parameters to the given dtype and device. More...
 
virtual void to (torch::Dtype dtype, bool non_blocking=false)
 Recursively casts all parameters to the given dtype. More...
 
virtual void to (torch::Device device, bool non_blocking=false)
 Recursively moves all parameters to the given device. More...
 
virtual void zero_grad ()
 Recursively zeros out the grad value of each registered parameter.
 
template<typename ModuleType >
ModuleType::ContainedType * as () noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType >
const ModuleType::ContainedType * as () const noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType * as () noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
template<typename ModuleType , typename = torch::detail::disable_if_module_holder_t<ModuleType>>
const ModuleType * as () const noexcept
 Attempts to cast this Module to the given ModuleType. More...
 
virtual void save (serialize::OutputArchive &archive) const
 Serializes the Module into the given OutputArchive.
 
virtual void load (serialize::InputArchive &archive)
 Deserializes the Module from the given InputArchive.
 

Data Fields

BatchNormOptions options
 The options with which this module was constructed.
 
Tensor weight
 The learned weight. More...
 
Tensor bias
 The learned bias. More...
 
Tensor running_mean
 The running mean. More...
 
Tensor running_var
 The running variance. More...
 

Additional Inherited Members

- Public Types inherited from torch::nn::Module
using ModuleApplyFunction = std::function< void(Module &)>
 
using ConstModuleApplyFunction = std::function< void(const Module &)>
 
using NamedModuleApplyFunction = std::function< void(const std::string &, Module &)>
 
using ConstNamedModuleApplyFunction = std::function< void(const std::string &, const Module &)>
 
using ModulePointerApplyFunction = std::function< void(const std::shared_ptr< Module > &)>
 
using NamedModulePointerApplyFunction = std::function< void(const std::string &, const std::shared_ptr< Module > &)>
 
- Protected Member Functions inherited from torch::nn::Module
Tensorregister_parameter (std::string name, Tensor tensor, bool requires_grad=true)
 Registers a parameter with this Module. More...
 
Tensorregister_buffer (std::string name, Tensor tensor)
 Registers a buffer with this Module. More...
 
template<typename ModuleType >
std::shared_ptr< ModuleType > register_module (std::string name, std::shared_ptr< ModuleType > module)
 Registers a submodule with this Module. More...
 
template<typename ModuleType >
std::shared_ptr< ModuleType > register_module (std::string name, ModuleHolder< ModuleType > module_holder)
 Registers a submodule with this Module. More...
 

Detailed Description

Applies Batch Normalization to an input.

Refer to the documentation for BatchNorm1d in PyTorch to learn more about the exact semantics of this module, but see the note below regarding differences between the Python and C++ API.

.. attention:: In the Python API, there are separate implementations for 1-D, 2-D and 3-D BatchNorm. In C++, there is only one BatchNorm module, which works for any of these dimensions.

Definition at line 48 of file batchnorm.h.

Member Function Documentation

Tensor torch::nn::BatchNormImpl::forward ( const Tensor input)

Applies batch normalization on the input using the stored mean and variance.

The module must be constructed with stateful = true when calling this method, as the module will otherwise not store running statistics. If you want to supply the mean and variance yourself, use pure_forward.

Definition at line 44 of file batchnorm.cpp.

Tensor torch::nn::BatchNormImpl::pure_forward ( const Tensor input,
const Tensor mean,
const Tensor variance 
)

Applies batch normalization on the input using the given mean and variance statistics.

Definition at line 53 of file batchnorm.cpp.

void torch::nn::BatchNormImpl::reset ( )
overridevirtual

reset() must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.

Implements torch::nn::Cloneable< BatchNormImpl >.

Definition at line 21 of file batchnorm.cpp.

Field Documentation

Tensor torch::nn::BatchNormImpl::bias

The learned bias.

Only defined if the affine option was true upon construction.

Definition at line 83 of file batchnorm.h.

Tensor torch::nn::BatchNormImpl::running_mean

The running mean.

Only defined if the stateful option was true upon construction.

Definition at line 87 of file batchnorm.h.

Tensor torch::nn::BatchNormImpl::running_var

The running variance.

Only defined if the stateful option was true upon construction.

Definition at line 91 of file batchnorm.h.

Tensor torch::nn::BatchNormImpl::weight

The learned weight.

Only defined if the affine option was true upon construction.

Definition at line 79 of file batchnorm.h.


The documentation for this class was generated from the following files: