4 def convert_sync_batchnorm(module, process_group=None):
5 r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to 6 `torch.nn.SyncBatchNorm` layer. 9 module (nn.Module): containing module 10 process_group (optional): process group to scope synchronization, 11 default is the whole world 14 The original module with the converted `torch.nn.SyncBatchNorm` layer 18 >>> # Network with nn.BatchNorm layer 19 >>> module = torch.nn.Sequential( 20 >>> torch.nn.Linear(20, 100), 21 >>> torch.nn.BatchNorm1d(100) 23 >>> # creating process group (optional) 24 >>> # process_ids is a list of int identifying rank ids. 25 >>> process_group = torch.distributed.new_group(process_ids) 26 >>> sync_bn_module = convert_sync_batchnorm(module, process_group) 29 module_output = module
31 module_output = torch.nn.SyncBatchNorm(module.num_features,
32 module.eps, module.momentum,
34 module.track_running_stats,
37 module_output.weight.data = module.weight.data.clone().detach()
38 module_output.bias.data = module.bias.data.clone().detach()
39 module_output.running_mean = module.running_mean
40 module_output.running_var = module.running_var
41 module_output.num_batches_tracked = module.num_batches_tracked
42 for name, child
in module.named_children():
43 module_output.add_module(name, convert_sync_batchnorm(child))