2 from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
5 from collections
import defaultdict
11 r"""Implements distributed data parallelism for CPU at the module level. 13 This module support the ``mpi``, ``gloo``, ``tcp`` backends. 15 This container parallelizes the application of the given module by 16 splitting the input across the specified devices by chunking in the batch 17 dimension. The module is replicated on each machine, and each such replica 18 handles a portion of the input. During the backwards pass, gradients from 19 each node are averaged. 21 This module could be used in conjunction with the DistributedSampler, 22 (see :class `torch.utils.data.distributed.DistributedSampler`) 23 which will load a subset of the original dataset for each node with the same 24 batch size. So strong scaling should be configured like this: 25 n = 1, batch size = 128 26 n = 2, batch size = 64 27 n = 4, batch size = 32 28 n = 8, batch size = 16 30 Creation of this class requires the distributed package to be already 31 initialized in the process group mode 32 (see :func:`torch.distributed.deprecated.init_process_group`). 35 Constructor, forward method, and differentiation of the output (or a 36 function of the output of this module) is a distributed synchronization 37 point. Take that into account in case different node might be 38 executing different code. 41 This module assumes all parameters are registered in the model by the 42 time it is created. No parameters should be added nor removed later. 45 This module assumes all gradients are dense. 48 This module doesn't work with :func:`torch.autograd.grad` (i.e. it will 49 only work if gradients are to be accumulated in ``.grad`` attributes of 53 Parameters are broadcast between nodes in the __init__() function. The 54 module performs an all-reduce step on gradients and assumes that they 55 will be modified by the optimizer in all nodes in the same way. 58 Forward and backward hooks defined on :attr:`module` and its submodules 59 won't be invoked anymore, unless the hooks are initialized in the 60 :meth:`forward` method. 63 module: module to be parallelized 67 >>> torch.distributed.deprecated.init_process_group(world_size=4, init_method='...') 68 >>> net = torch.nn.DistributedDataParallelCPU(model) 71 def __init__(self, module):
72 super(DistributedDataParallelCPU, self).__init__()
76 def allreduce_params():
79 buckets = defaultdict(list)
80 for param
in self.module.parameters():
81 if param.requires_grad
and param.grad
is not None:
83 buckets[tp].append(param)
85 for bucket
in buckets.values():
86 grads = [param.grad.data
for param
in bucket]
87 coalesced = _flatten_dense_tensors(grads)
88 dist.all_reduce(coalesced)
89 coalesced /= dist.get_world_size()
90 for buf, synced
in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
93 for param
in list(self.module.parameters()):
94 @torch.utils.hooks.unserializable_hook
95 def allreduce_hook(*unused):
96 Variable._execution_engine.queue_callback(allreduce_params)
98 if param.requires_grad:
99 param.register_hook(allreduce_hook)
101 def sync_parameters(self):
102 for param
in self.module.parameters():
103 dist.broadcast(param.data, 0)
105 def forward(self, *inputs, **kwargs):
107 return self.
module(*inputs, **kwargs)
def sync_parameters(self)