Caffe2 - C++ API
A deep learning, cross platform ML framework
batchnorm.h
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/pimpl.h>
5 #include <torch/types.h>
6 
7 #include <cstdint>
8 
9 namespace torch {
10 namespace nn {
11 
13 struct TORCH_API BatchNormOptions {
14  /* implicit */ BatchNormOptions(int64_t features);
17  TORCH_ARG(int64_t, features);
21  TORCH_ARG(bool, affine) = true;
26  TORCH_ARG(bool, stateful) = true;
29  TORCH_ARG(double, eps) = 1e-5;
32  TORCH_ARG(double, momentum) = 0.1;
33 };
34 
48 class TORCH_API BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {
49  public:
50  explicit BatchNormImpl(int64_t features)
51  : BatchNormImpl(BatchNormOptions(features)) {}
52  explicit BatchNormImpl(BatchNormOptions options);
53 
54  void reset() override;
55 
57  void pretty_print(std::ostream& stream) const override;
58 
65  Tensor forward(const Tensor& input);
66 
69  Tensor pure_forward(
70  const Tensor& input,
71  const Tensor& mean,
72  const Tensor& variance);
73 
76 
80 
84 
88 
92 };
93 
98 TORCH_MODULE(BatchNorm);
99 
100 } // namespace nn
101 } // namespace torch
Tensor bias
The learned bias.
Definition: batchnorm.h:83
Tensor running_var
The running variance.
Definition: batchnorm.h:91
Tensor running_mean
The running mean.
Definition: batchnorm.h:87
Applies Batch Normalization to an input.
Definition: batchnorm.h:48
BatchNormOptions options
The options with which this module was constructed.
Definition: batchnorm.h:75
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Definition: cloneable.h:23
Definition: jit_type.h:17
Options for the BatchNorm module.
Definition: batchnorm.h:13
Tensor weight
The learned weight.
Definition: batchnorm.h:79