1 from __future__
import division
4 from ._functions
import SyncBatchNorm
as sync_batch_norm
5 from .module
import Module
7 from ..
import functional
as F
9 from ..._jit_internal
import weak_module, weak_script_method
17 __constants__ = [
'track_running_stats',
'momentum',
'eps',
'weight',
'bias',
18 'running_mean',
'running_var',
'num_batches_tracked']
20 def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
21 track_running_stats=
True):
22 super(_BatchNorm, self).__init__()
32 self.register_parameter(
'weight',
None)
33 self.register_parameter(
'bias',
None)
35 self.register_buffer(
'running_mean', torch.zeros(num_features))
36 self.register_buffer(
'running_var', torch.ones(num_features))
37 self.register_buffer(
'num_batches_tracked',
torch.tensor(0, dtype=torch.long))
39 self.register_parameter(
'running_mean',
None)
40 self.register_parameter(
'running_var',
None)
41 self.register_parameter(
'num_batches_tracked',
None)
44 def reset_running_stats(self):
46 self.running_mean.zero_()
47 self.running_var.fill_(1)
48 self.num_batches_tracked.zero_()
50 def reset_parameters(self):
54 init.zeros_(self.
bias)
56 def _check_input_dim(self, input):
57 raise NotImplementedError
60 def forward(self, input):
63 exponential_average_factor = 0.0
67 if self.num_batches_tracked
is not None:
68 self.num_batches_tracked += 1
70 exponential_average_factor = 1.0 / float(self.num_batches_tracked)
72 exponential_average_factor = self.
momentum 75 input, self.running_mean, self.running_var, self.
weight, self.
bias,
77 exponential_average_factor, self.
eps)
80 return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
81 'track_running_stats={track_running_stats}'.format(**self.__dict__)
83 def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
84 missing_keys, unexpected_keys, error_msgs):
85 version = local_metadata.get(
'version',
None)
90 num_batches_tracked_key = prefix +
'num_batches_tracked' 91 if num_batches_tracked_key
not in state_dict:
92 state_dict[num_batches_tracked_key] =
torch.tensor(0, dtype=torch.long)
94 super(_BatchNorm, self)._load_from_state_dict(
95 state_dict, prefix, local_metadata, strict,
96 missing_keys, unexpected_keys, error_msgs)
101 r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D 102 inputs with optional additional channel dimension) as described in the paper 103 `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 107 y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 109 The mean and standard-deviation are calculated per-dimension over 110 the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 111 of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled 112 from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. 114 Also by default, during training this layer keeps running estimates of its 115 computed mean and variance, which are then used for normalization during 116 evaluation. The running estimates are kept with a default :attr:`momentum` 119 If :attr:`track_running_stats` is set to ``False``, this layer then does not 120 keep running estimates, and batch statistics are instead used during 121 evaluation time as well. 124 This :attr:`momentum` argument is different from one used in optimizer 125 classes and the conventional notion of momentum. Mathematically, the 126 update rule for running statistics here is 127 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 128 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 131 Because the Batch Normalization is done over the `C` dimension, computing statistics 132 on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. 135 num_features: :math:`C` from an expected input of size 136 :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` 137 eps: a value added to the denominator for numerical stability. 139 momentum: the value used for the running_mean and running_var 140 computation. Can be set to ``None`` for cumulative moving average 141 (i.e. simple average). Default: 0.1 142 affine: a boolean value that when set to ``True``, this module has 143 learnable affine parameters. Default: ``True`` 144 track_running_stats: a boolean value that when set to ``True``, this 145 module tracks the running mean and variance, and when set to ``False``, 146 this module does not track such statistics and always uses batch 147 statistics in both training and eval modes. Default: ``True`` 150 - Input: :math:`(N, C)` or :math:`(N, C, L)` 151 - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 155 >>> # With Learnable Parameters 156 >>> m = nn.BatchNorm1d(100) 157 >>> # Without Learnable Parameters 158 >>> m = nn.BatchNorm1d(100, affine=False) 159 >>> input = torch.randn(20, 100) 160 >>> output = m(input) 162 .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 163 https://arxiv.org/abs/1502.03167 167 def _check_input_dim(self, input):
168 if input.dim() != 2
and input.dim() != 3:
169 raise ValueError(
'expected 2D or 3D input (got {}D input)' 170 .format(input.dim()))
175 r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs 176 with additional channel dimension) as described in the paper 177 `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 181 y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 183 The mean and standard-deviation are calculated per-dimension over 184 the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 185 of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled 186 from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. 188 Also by default, during training this layer keeps running estimates of its 189 computed mean and variance, which are then used for normalization during 190 evaluation. The running estimates are kept with a default :attr:`momentum` 193 If :attr:`track_running_stats` is set to ``False``, this layer then does not 194 keep running estimates, and batch statistics are instead used during 195 evaluation time as well. 198 This :attr:`momentum` argument is different from one used in optimizer 199 classes and the conventional notion of momentum. Mathematically, the 200 update rule for running statistics here is 201 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 202 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 205 Because the Batch Normalization is done over the `C` dimension, computing statistics 206 on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. 209 num_features: :math:`C` from an expected input of size 211 eps: a value added to the denominator for numerical stability. 213 momentum: the value used for the running_mean and running_var 214 computation. Can be set to ``None`` for cumulative moving average 215 (i.e. simple average). Default: 0.1 216 affine: a boolean value that when set to ``True``, this module has 217 learnable affine parameters. Default: ``True`` 218 track_running_stats: a boolean value that when set to ``True``, this 219 module tracks the running mean and variance, and when set to ``False``, 220 this module does not track such statistics and always uses batch 221 statistics in both training and eval modes. Default: ``True`` 224 - Input: :math:`(N, C, H, W)` 225 - Output: :math:`(N, C, H, W)` (same shape as input) 229 >>> # With Learnable Parameters 230 >>> m = nn.BatchNorm2d(100) 231 >>> # Without Learnable Parameters 232 >>> m = nn.BatchNorm2d(100, affine=False) 233 >>> input = torch.randn(20, 100, 35, 45) 234 >>> output = m(input) 236 .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 237 https://arxiv.org/abs/1502.03167 241 def _check_input_dim(self, input):
243 raise ValueError(
'expected 4D input (got {}D input)' 244 .format(input.dim()))
249 r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs 250 with additional channel dimension) as described in the paper 251 `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 255 y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 257 The mean and standard-deviation are calculated per-dimension over 258 the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 259 of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled 260 from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. 262 Also by default, during training this layer keeps running estimates of its 263 computed mean and variance, which are then used for normalization during 264 evaluation. The running estimates are kept with a default :attr:`momentum` 267 If :attr:`track_running_stats` is set to ``False``, this layer then does not 268 keep running estimates, and batch statistics are instead used during 269 evaluation time as well. 272 This :attr:`momentum` argument is different from one used in optimizer 273 classes and the conventional notion of momentum. Mathematically, the 274 update rule for running statistics here is 275 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 276 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 279 Because the Batch Normalization is done over the `C` dimension, computing statistics 280 on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization 281 or Spatio-temporal Batch Normalization. 284 num_features: :math:`C` from an expected input of size 285 :math:`(N, C, D, H, W)` 286 eps: a value added to the denominator for numerical stability. 288 momentum: the value used for the running_mean and running_var 289 computation. Can be set to ``None`` for cumulative moving average 290 (i.e. simple average). Default: 0.1 291 affine: a boolean value that when set to ``True``, this module has 292 learnable affine parameters. Default: ``True`` 293 track_running_stats: a boolean value that when set to ``True``, this 294 module tracks the running mean and variance, and when set to ``False``, 295 this module does not track such statistics and always uses batch 296 statistics in both training and eval modes. Default: ``True`` 299 - Input: :math:`(N, C, D, H, W)` 300 - Output: :math:`(N, C, D, H, W)` (same shape as input) 304 >>> # With Learnable Parameters 305 >>> m = nn.BatchNorm3d(100) 306 >>> # Without Learnable Parameters 307 >>> m = nn.BatchNorm3d(100, affine=False) 308 >>> input = torch.randn(20, 100, 35, 45, 10) 309 >>> output = m(input) 311 .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 312 https://arxiv.org/abs/1502.03167 316 def _check_input_dim(self, input):
318 raise ValueError(
'expected 5D input (got {}D input)' 319 .format(input.dim()))
323 r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs 324 with additional channel dimension) as described in the paper 325 `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 329 y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 331 The mean and standard-deviation are calculated per-dimension over all 332 mini-batches of the same process groups. :math:`\gamma` and :math:`\beta` 333 are learnable parameter vectors of size `C` (where `C` is the input size). 334 By default, the elements of :math:`\gamma` are sampled from 335 :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. 337 Also by default, during training this layer keeps running estimates of its 338 computed mean and variance, which are then used for normalization during 339 evaluation. The running estimates are kept with a default :attr:`momentum` 342 If :attr:`track_running_stats` is set to ``False``, this layer then does not 343 keep running estimates, and batch statistics are instead used during 344 evaluation time as well. 347 This :attr:`momentum` argument is different from one used in optimizer 348 classes and the conventional notion of momentum. Mathematically, the 349 update rule for running statistics here is 350 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, 351 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 354 Because the Batch Normalization is done over the `C` dimension, computing statistics 355 on `(N, +)` slices, it's common terminology to call this Volumetric Batch Normalization 356 or Spatio-temporal Batch Normalization. 358 Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use 359 torch.nn.utils.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping 363 num_features: :math:`C` from an expected input of size 365 eps: a value added to the denominator for numerical stability. 367 momentum: the value used for the running_mean and running_var 368 computation. Can be set to ``None`` for cumulative moving average 369 (i.e. simple average). Default: 0.1 370 affine: a boolean value that when set to ``True``, this module has 371 learnable affine parameters. Default: ``True`` 372 track_running_stats: a boolean value that when set to ``True``, this 373 module tracks the running mean and variance, and when set to ``False``, 374 this module does not track such statistics and always uses batch 375 statistics in both training and eval modes. Default: ``True`` 376 process_group: synchronization of stats happen within each process group 377 individually. Default behavior is synchronization across the whole 381 - Input: :math:`(N, C, +)` 382 - Output: :math:`(N, C, +)` (same shape as input) 386 >>> # With Learnable Parameters 387 >>> m = nn.SyncBatchNorm(100) 388 >>> # creating process group (optional) 389 >>> # process_ids is a list of int identifying rank ids. 390 >>> process_group = torch.distributed.new_group(process_ids) 391 >>> # Without Learnable Parameters 392 >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) 393 >>> input = torch.randn(20, 100, 35, 45, 10) 394 >>> output = m(input) 396 >>> # network is nn.BatchNorm layer 397 >>> sync_bn_network = torch.nn.utils.convert_sync_batchnorm(network, process_group) 398 >>> # only single gpu per process is currently supported 399 >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( 401 >>> device_ids=[args.local_rank], 402 >>> output_device=args.local_rank) 404 .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 405 https://arxiv.org/abs/1502.03167 408 def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
409 track_running_stats=
True, process_group=
None):
410 super(SyncBatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats)
416 def _check_input_dim(self, input):
418 raise ValueError(
'expected at least 3D input (got {}D input)' 419 .format(input.dim()))
421 def _specify_ddp_gpu_num(self, gpu_size):
423 raise ValueError(
'SyncBatchNorm is only supported for DDP with single GPU per process')
426 def forward(self, input):
428 if not input.is_cuda:
429 raise ValueError(
'expected input tensor to be on GPU')
432 raise AttributeError(
'SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel')
436 exponential_average_factor = 0.0
439 self.num_batches_tracked += 1
441 exponential_average_factor = 1.0 / self.num_batches_tracked.item()
443 exponential_average_factor = self.
momentum 446 process_group = torch.distributed.group.WORLD
449 world_size = torch.distributed.get_world_size(process_group)
454 input, self.running_mean, self.running_var, self.
weight, self.
bias,
456 exponential_average_factor, self.
eps)
458 return sync_batch_norm.apply(
459 input, self.
weight, self.
bias, self.running_mean, self.running_var,
460 self.
eps, exponential_average_factor, process_group, world_size)
def reset_parameters(self)
def reset_running_stats(self)
def _check_input_dim(self, input)