Caffe2 - Python API
A deep learning, cross platform ML framework
batchnorm.py
1 from __future__ import division
2 
3 import torch
4 from ._functions import SyncBatchNorm as sync_batch_norm
5 from .module import Module
6 from torch.nn.parameter import Parameter
7 from .. import functional as F
8 from .. import init
9 from ..._jit_internal import weak_module, weak_script_method
10 
11 
12 # TODO: check contiguous in THNN
13 # TODO: use separate backend functions?
14 @weak_module
15 class _BatchNorm(Module):
16  _version = 2
17  __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
18  'running_mean', 'running_var', 'num_batches_tracked']
19 
20  def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
21  track_running_stats=True):
22  super(_BatchNorm, self).__init__()
23  self.num_features = num_features
24  self.eps = eps
25  self.momentum = momentum
26  self.affine = affine
27  self.track_running_stats = track_running_stats
28  if self.affine:
29  self.weight = Parameter(torch.Tensor(num_features))
30  self.bias = Parameter(torch.Tensor(num_features))
31  else:
32  self.register_parameter('weight', None)
33  self.register_parameter('bias', None)
34  if self.track_running_stats:
35  self.register_buffer('running_mean', torch.zeros(num_features))
36  self.register_buffer('running_var', torch.ones(num_features))
37  self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
38  else:
39  self.register_parameter('running_mean', None)
40  self.register_parameter('running_var', None)
41  self.register_parameter('num_batches_tracked', None)
42  self.reset_parameters()
43 
44  def reset_running_stats(self):
45  if self.track_running_stats:
46  self.running_mean.zero_()
47  self.running_var.fill_(1)
48  self.num_batches_tracked.zero_()
49 
50  def reset_parameters(self):
51  self.reset_running_stats()
52  if self.affine:
53  init.uniform_(self.weight)
54  init.zeros_(self.bias)
55 
56  def _check_input_dim(self, input):
57  raise NotImplementedError
58 
59  @weak_script_method
60  def forward(self, input):
61  self._check_input_dim(input)
62 
63  exponential_average_factor = 0.0
64 
65  if self.training and self.track_running_stats:
66  # TODO: if statement only here to tell the jit to skip emitting this when it is None
67  if self.num_batches_tracked is not None:
68  self.num_batches_tracked += 1
69  if self.momentum is None: # use cumulative moving average
70  exponential_average_factor = 1.0 / float(self.num_batches_tracked)
71  else: # use exponential moving average
72  exponential_average_factor = self.momentum
73 
74  return F.batch_norm(
75  input, self.running_mean, self.running_var, self.weight, self.bias,
76  self.training or not self.track_running_stats,
77  exponential_average_factor, self.eps)
78 
79  def extra_repr(self):
80  return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
81  'track_running_stats={track_running_stats}'.format(**self.__dict__)
82 
83  def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
84  missing_keys, unexpected_keys, error_msgs):
85  version = local_metadata.get('version', None)
86 
87  if (version is None or version < 2) and self.track_running_stats:
88  # at version 2: added num_batches_tracked buffer
89  # this should have a default value of 0
90  num_batches_tracked_key = prefix + 'num_batches_tracked'
91  if num_batches_tracked_key not in state_dict:
92  state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)
93 
94  super(_BatchNorm, self)._load_from_state_dict(
95  state_dict, prefix, local_metadata, strict,
96  missing_keys, unexpected_keys, error_msgs)
97 
98 
99 @weak_module
101  r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
102  inputs with optional additional channel dimension) as described in the paper
103  `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
104 
105  .. math::
106 
107  y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
108 
109  The mean and standard-deviation are calculated per-dimension over
110  the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
111  of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
112  from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
113 
114  Also by default, during training this layer keeps running estimates of its
115  computed mean and variance, which are then used for normalization during
116  evaluation. The running estimates are kept with a default :attr:`momentum`
117  of 0.1.
118 
119  If :attr:`track_running_stats` is set to ``False``, this layer then does not
120  keep running estimates, and batch statistics are instead used during
121  evaluation time as well.
122 
123  .. note::
124  This :attr:`momentum` argument is different from one used in optimizer
125  classes and the conventional notion of momentum. Mathematically, the
126  update rule for running statistics here is
127  :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
128  where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
129  new observed value.
130 
131  Because the Batch Normalization is done over the `C` dimension, computing statistics
132  on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
133 
134  Args:
135  num_features: :math:`C` from an expected input of size
136  :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
137  eps: a value added to the denominator for numerical stability.
138  Default: 1e-5
139  momentum: the value used for the running_mean and running_var
140  computation. Can be set to ``None`` for cumulative moving average
141  (i.e. simple average). Default: 0.1
142  affine: a boolean value that when set to ``True``, this module has
143  learnable affine parameters. Default: ``True``
144  track_running_stats: a boolean value that when set to ``True``, this
145  module tracks the running mean and variance, and when set to ``False``,
146  this module does not track such statistics and always uses batch
147  statistics in both training and eval modes. Default: ``True``
148 
149  Shape:
150  - Input: :math:`(N, C)` or :math:`(N, C, L)`
151  - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
152 
153  Examples::
154 
155  >>> # With Learnable Parameters
156  >>> m = nn.BatchNorm1d(100)
157  >>> # Without Learnable Parameters
158  >>> m = nn.BatchNorm1d(100, affine=False)
159  >>> input = torch.randn(20, 100)
160  >>> output = m(input)
161 
162  .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
163  https://arxiv.org/abs/1502.03167
164  """
165 
166  @weak_script_method
167  def _check_input_dim(self, input):
168  if input.dim() != 2 and input.dim() != 3:
169  raise ValueError('expected 2D or 3D input (got {}D input)'
170  .format(input.dim()))
171 
172 
173 @weak_module
175  r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
176  with additional channel dimension) as described in the paper
177  `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
178 
179  .. math::
180 
181  y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
182 
183  The mean and standard-deviation are calculated per-dimension over
184  the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
185  of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
186  from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
187 
188  Also by default, during training this layer keeps running estimates of its
189  computed mean and variance, which are then used for normalization during
190  evaluation. The running estimates are kept with a default :attr:`momentum`
191  of 0.1.
192 
193  If :attr:`track_running_stats` is set to ``False``, this layer then does not
194  keep running estimates, and batch statistics are instead used during
195  evaluation time as well.
196 
197  .. note::
198  This :attr:`momentum` argument is different from one used in optimizer
199  classes and the conventional notion of momentum. Mathematically, the
200  update rule for running statistics here is
201  :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
202  where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
203  new observed value.
204 
205  Because the Batch Normalization is done over the `C` dimension, computing statistics
206  on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
207 
208  Args:
209  num_features: :math:`C` from an expected input of size
210  :math:`(N, C, H, W)`
211  eps: a value added to the denominator for numerical stability.
212  Default: 1e-5
213  momentum: the value used for the running_mean and running_var
214  computation. Can be set to ``None`` for cumulative moving average
215  (i.e. simple average). Default: 0.1
216  affine: a boolean value that when set to ``True``, this module has
217  learnable affine parameters. Default: ``True``
218  track_running_stats: a boolean value that when set to ``True``, this
219  module tracks the running mean and variance, and when set to ``False``,
220  this module does not track such statistics and always uses batch
221  statistics in both training and eval modes. Default: ``True``
222 
223  Shape:
224  - Input: :math:`(N, C, H, W)`
225  - Output: :math:`(N, C, H, W)` (same shape as input)
226 
227  Examples::
228 
229  >>> # With Learnable Parameters
230  >>> m = nn.BatchNorm2d(100)
231  >>> # Without Learnable Parameters
232  >>> m = nn.BatchNorm2d(100, affine=False)
233  >>> input = torch.randn(20, 100, 35, 45)
234  >>> output = m(input)
235 
236  .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
237  https://arxiv.org/abs/1502.03167
238  """
239 
240  @weak_script_method
241  def _check_input_dim(self, input):
242  if input.dim() != 4:
243  raise ValueError('expected 4D input (got {}D input)'
244  .format(input.dim()))
245 
246 
247 @weak_module
249  r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
250  with additional channel dimension) as described in the paper
251  `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
252 
253  .. math::
254 
255  y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
256 
257  The mean and standard-deviation are calculated per-dimension over
258  the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
259  of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
260  from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
261 
262  Also by default, during training this layer keeps running estimates of its
263  computed mean and variance, which are then used for normalization during
264  evaluation. The running estimates are kept with a default :attr:`momentum`
265  of 0.1.
266 
267  If :attr:`track_running_stats` is set to ``False``, this layer then does not
268  keep running estimates, and batch statistics are instead used during
269  evaluation time as well.
270 
271  .. note::
272  This :attr:`momentum` argument is different from one used in optimizer
273  classes and the conventional notion of momentum. Mathematically, the
274  update rule for running statistics here is
275  :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
276  where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
277  new observed value.
278 
279  Because the Batch Normalization is done over the `C` dimension, computing statistics
280  on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
281  or Spatio-temporal Batch Normalization.
282 
283  Args:
284  num_features: :math:`C` from an expected input of size
285  :math:`(N, C, D, H, W)`
286  eps: a value added to the denominator for numerical stability.
287  Default: 1e-5
288  momentum: the value used for the running_mean and running_var
289  computation. Can be set to ``None`` for cumulative moving average
290  (i.e. simple average). Default: 0.1
291  affine: a boolean value that when set to ``True``, this module has
292  learnable affine parameters. Default: ``True``
293  track_running_stats: a boolean value that when set to ``True``, this
294  module tracks the running mean and variance, and when set to ``False``,
295  this module does not track such statistics and always uses batch
296  statistics in both training and eval modes. Default: ``True``
297 
298  Shape:
299  - Input: :math:`(N, C, D, H, W)`
300  - Output: :math:`(N, C, D, H, W)` (same shape as input)
301 
302  Examples::
303 
304  >>> # With Learnable Parameters
305  >>> m = nn.BatchNorm3d(100)
306  >>> # Without Learnable Parameters
307  >>> m = nn.BatchNorm3d(100, affine=False)
308  >>> input = torch.randn(20, 100, 35, 45, 10)
309  >>> output = m(input)
310 
311  .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
312  https://arxiv.org/abs/1502.03167
313  """
314 
315  @weak_script_method
316  def _check_input_dim(self, input):
317  if input.dim() != 5:
318  raise ValueError('expected 5D input (got {}D input)'
319  .format(input.dim()))
320 
321 
323  r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
324  with additional channel dimension) as described in the paper
325  `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
326 
327  .. math::
328 
329  y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
330 
331  The mean and standard-deviation are calculated per-dimension over all
332  mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
333  are learnable parameter vectors of size `C` (where `C` is the input size).
334  By default, the elements of :math:`\gamma` are sampled from
335  :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
336 
337  Also by default, during training this layer keeps running estimates of its
338  computed mean and variance, which are then used for normalization during
339  evaluation. The running estimates are kept with a default :attr:`momentum`
340  of 0.1.
341 
342  If :attr:`track_running_stats` is set to ``False``, this layer then does not
343  keep running estimates, and batch statistics are instead used during
344  evaluation time as well.
345 
346  .. note::
347  This :attr:`momentum` argument is different from one used in optimizer
348  classes and the conventional notion of momentum. Mathematically, the
349  update rule for running statistics here is
350  :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
351  where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
352  new observed value.
353 
354  Because the Batch Normalization is done over the `C` dimension, computing statistics
355  on `(N, +)` slices, it's common terminology to call this Volumetric Batch Normalization
356  or Spatio-temporal Batch Normalization.
357 
358  Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use
359  torch.nn.utils.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping
360  Network with DDP.
361 
362  Args:
363  num_features: :math:`C` from an expected input of size
364  :math:`(N, C, +)`
365  eps: a value added to the denominator for numerical stability.
366  Default: 1e-5
367  momentum: the value used for the running_mean and running_var
368  computation. Can be set to ``None`` for cumulative moving average
369  (i.e. simple average). Default: 0.1
370  affine: a boolean value that when set to ``True``, this module has
371  learnable affine parameters. Default: ``True``
372  track_running_stats: a boolean value that when set to ``True``, this
373  module tracks the running mean and variance, and when set to ``False``,
374  this module does not track such statistics and always uses batch
375  statistics in both training and eval modes. Default: ``True``
376  process_group: synchronization of stats happen within each process group
377  individually. Default behavior is synchronization across the whole
378  world
379 
380  Shape:
381  - Input: :math:`(N, C, +)`
382  - Output: :math:`(N, C, +)` (same shape as input)
383 
384  Examples::
385 
386  >>> # With Learnable Parameters
387  >>> m = nn.SyncBatchNorm(100)
388  >>> # creating process group (optional)
389  >>> # process_ids is a list of int identifying rank ids.
390  >>> process_group = torch.distributed.new_group(process_ids)
391  >>> # Without Learnable Parameters
392  >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
393  >>> input = torch.randn(20, 100, 35, 45, 10)
394  >>> output = m(input)
395 
396  >>> # network is nn.BatchNorm layer
397  >>> sync_bn_network = torch.nn.utils.convert_sync_batchnorm(network, process_group)
398  >>> # only single gpu per process is currently supported
399  >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
400  >>> sync_bn_network,
401  >>> device_ids=[args.local_rank],
402  >>> output_device=args.local_rank)
403 
404  .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
405  https://arxiv.org/abs/1502.03167
406  """
407 
408  def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
409  track_running_stats=True, process_group=None):
410  super(SyncBatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats)
411  self.process_group = process_group
412  # gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used
413  # under supported condition (single GPU per process)
414  self.ddp_gpu_size = None
415 
416  def _check_input_dim(self, input):
417  if input.dim() <= 2:
418  raise ValueError('expected at least 3D input (got {}D input)'
419  .format(input.dim()))
420 
421  def _specify_ddp_gpu_num(self, gpu_size):
422  if gpu_size > 1:
423  raise ValueError('SyncBatchNorm is only supported for DDP with single GPU per process')
424  self.ddp_gpu_size = gpu_size
425 
426  def forward(self, input):
427  # currently only GPU input is supported
428  if not input.is_cuda:
429  raise ValueError('expected input tensor to be on GPU')
430 
431  if not self.ddp_gpu_size:
432  raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel')
433 
434  self._check_input_dim(input)
435 
436  exponential_average_factor = 0.0
437 
438  if self.training and self.track_running_stats:
439  self.num_batches_tracked += 1
440  if self.momentum is None: # use cumulative moving average
441  exponential_average_factor = 1.0 / self.num_batches_tracked.item()
442  else: # use exponential moving average
443  exponential_average_factor = self.momentum
444 
445  world_size = 1
446  process_group = torch.distributed.group.WORLD
447  if self.process_group:
448  process_group = self.process_group
449  world_size = torch.distributed.get_world_size(process_group)
450 
451  # fallback to framework BN when synchronization is not necessary
452  if world_size == 1 or (not self.training and self.track_running_stats):
453  return F.batch_norm(
454  input, self.running_mean, self.running_var, self.weight, self.bias,
455  self.training or not self.track_running_stats,
456  exponential_average_factor, self.eps)
457  else:
458  return sync_batch_norm.apply(
459  input, self.weight, self.bias, self.running_mean, self.running_var,
460  self.eps, exponential_average_factor, process_group, world_size)