1 #include <torch/nn/modules/batchnorm.h> 3 #include <torch/cuda.h> 4 #include <torch/types.h> 6 #include <c10/util/Exception.h> 15 BatchNormOptions::BatchNormOptions(int64_t features) : features_(features) {}
17 BatchNormImpl::BatchNormImpl(BatchNormOptions options) : options(options) {
21 void BatchNormImpl::reset() {
22 if (options.affine_) {
23 weight = register_parameter(
24 "weight", torch::empty({options.features_}).uniform_());
25 bias = register_parameter(
"bias", torch::zeros({options.features_}));
28 if (options.stateful_) {
30 register_buffer(
"running_mean", torch::zeros({options.features_}));
32 register_buffer(
"running_var", torch::ones({options.features_}));
36 void BatchNormImpl::pretty_print(std::ostream& stream)
const {
37 stream << std::boolalpha
38 <<
"torch::nn::BatchNorm(features=" << options.features_
39 <<
", eps=" << options.eps_ <<
", momentum=" << options.momentum_
40 <<
", affine=" << options.affine_ <<
", stateful=" << options.stateful_
47 "Calling BatchNorm::forward is only permitted when " 48 "the 'stateful' option is true (was false). " 49 "Use BatchNorm::pure_forward instead.");
50 return pure_forward(input, running_mean, running_var);
58 const auto num_channels = input.dim() > 1 ? input.size(1) : 1;
60 input.numel() / num_channels > 1,
61 "BatchNorm expected more than 1 value per channel when training!");
64 return torch::batch_norm(
73 torch::cuda::cudnn_is_available());
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the internal state of the sampler.