Caffe2 - Python API
A deep learning, cross platform ML framework
sync_batch_norm.py
1 import torch
2 
3 
4 def convert_sync_batchnorm(module, process_group=None):
5  r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to
6  `torch.nn.SyncBatchNorm` layer.
7 
8  Args:
9  module (nn.Module): containing module
10  process_group (optional): process group to scope synchronization,
11  default is the whole world
12 
13  Returns:
14  The original module with the converted `torch.nn.SyncBatchNorm` layer
15 
16  Example::
17 
18  >>> # Network with nn.BatchNorm layer
19  >>> module = torch.nn.Sequential(
20  >>> torch.nn.Linear(20, 100),
21  >>> torch.nn.BatchNorm1d(100)
22  >>> ).cuda()
23  >>> # creating process group (optional)
24  >>> # process_ids is a list of int identifying rank ids.
25  >>> process_group = torch.distributed.new_group(process_ids)
26  >>> sync_bn_module = convert_sync_batchnorm(module, process_group)
27 
28  """
29  module_output = module
30  if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
31  module_output = torch.nn.SyncBatchNorm(module.num_features,
32  module.eps, module.momentum,
33  module.affine,
34  module.track_running_stats,
35  process_group)
36  if module.affine:
37  module_output.weight.data = module.weight.data.clone().detach()
38  module_output.bias.data = module.bias.data.clone().detach()
39  module_output.running_mean = module.running_mean
40  module_output.running_var = module.running_var
41  module_output.num_batches_tracked = module.num_batches_tracked
42  for name, child in module.named_children():
43  module_output.add_module(name, convert_sync_batchnorm(child))
44  del module
45  return module_output