Caffe2 - C++ API
A deep learning, cross platform ML framework
batchnorm.cpp
1 #include <torch/nn/modules/batchnorm.h>
2 
3 #include <torch/cuda.h>
4 #include <torch/types.h>
5 
6 #include <c10/util/Exception.h>
7 
8 #include <cstddef>
9 #include <ostream>
10 #include <utility>
11 #include <vector>
12 
13 namespace torch {
14 namespace nn {
15 BatchNormOptions::BatchNormOptions(int64_t features) : features_(features) {}
16 
17 BatchNormImpl::BatchNormImpl(BatchNormOptions options) : options(options) {
18  reset();
19 }
20 
21 void BatchNormImpl::reset() {
22  if (options.affine_) {
23  weight = register_parameter(
24  "weight", torch::empty({options.features_}).uniform_());
25  bias = register_parameter("bias", torch::zeros({options.features_}));
26  }
27 
28  if (options.stateful_) {
29  running_mean =
30  register_buffer("running_mean", torch::zeros({options.features_}));
31  running_var =
32  register_buffer("running_var", torch::ones({options.features_}));
33  }
34 }
35 
36 void BatchNormImpl::pretty_print(std::ostream& stream) const {
37  stream << std::boolalpha
38  << "torch::nn::BatchNorm(features=" << options.features_
39  << ", eps=" << options.eps_ << ", momentum=" << options.momentum_
40  << ", affine=" << options.affine_ << ", stateful=" << options.stateful_
41  << ")";
42 }
43 
44 Tensor BatchNormImpl::forward(const Tensor& input) {
45  AT_CHECK(
46  options.stateful_,
47  "Calling BatchNorm::forward is only permitted when "
48  "the 'stateful' option is true (was false). "
49  "Use BatchNorm::pure_forward instead.");
50  return pure_forward(input, running_mean, running_var);
51 }
52 
53 Tensor BatchNormImpl::pure_forward(
54  const Tensor& input,
55  const Tensor& mean,
56  const Tensor& variance) {
57  if (is_training()) {
58  const auto num_channels = input.dim() > 1 ? input.size(1) : 1;
59  AT_CHECK(
60  input.numel() / num_channels > 1,
61  "BatchNorm expected more than 1 value per channel when training!");
62  }
63 
64  return torch::batch_norm(
65  input,
66  weight,
67  bias,
68  mean,
69  variance,
70  is_training(),
71  options.momentum_,
72  options.eps_,
73  torch::cuda::cudnn_is_available());
74 }
75 
76 } // namespace nn
77 } // namespace torch
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the internal state of the sampler.
Definition: stream.cpp:23
Definition: jit_type.h:17