Caffe2 - Python API
A deep learning, cross platform ML framework
distributed_cpu.py
1 import torch
2 from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
3 import torch.distributed.deprecated as dist
4 from torch.nn.modules import Module
5 from collections import defaultdict
6 from torch.autograd import Variable
8 
9 
11  r"""Implements distributed data parallelism for CPU at the module level.
12 
13  This module support the ``mpi``, ``gloo``, ``tcp`` backends.
14 
15  This container parallelizes the application of the given module by
16  splitting the input across the specified devices by chunking in the batch
17  dimension. The module is replicated on each machine, and each such replica
18  handles a portion of the input. During the backwards pass, gradients from
19  each node are averaged.
20 
21  This module could be used in conjunction with the DistributedSampler,
22  (see :class `torch.utils.data.distributed.DistributedSampler`)
23  which will load a subset of the original dataset for each node with the same
24  batch size. So strong scaling should be configured like this:
25  n = 1, batch size = 128
26  n = 2, batch size = 64
27  n = 4, batch size = 32
28  n = 8, batch size = 16
29 
30  Creation of this class requires the distributed package to be already
31  initialized in the process group mode
32  (see :func:`torch.distributed.deprecated.init_process_group`).
33 
34  .. warning::
35  Constructor, forward method, and differentiation of the output (or a
36  function of the output of this module) is a distributed synchronization
37  point. Take that into account in case different node might be
38  executing different code.
39 
40  .. warning::
41  This module assumes all parameters are registered in the model by the
42  time it is created. No parameters should be added nor removed later.
43 
44  .. warning::
45  This module assumes all gradients are dense.
46 
47  .. warning::
48  This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
49  only work if gradients are to be accumulated in ``.grad`` attributes of
50  parameters).
51 
52  .. note::
53  Parameters are broadcast between nodes in the __init__() function. The
54  module performs an all-reduce step on gradients and assumes that they
55  will be modified by the optimizer in all nodes in the same way.
56 
57  .. warning::
58  Forward and backward hooks defined on :attr:`module` and its submodules
59  won't be invoked anymore, unless the hooks are initialized in the
60  :meth:`forward` method.
61 
62  Args:
63  module: module to be parallelized
64 
65  Example::
66 
67  >>> torch.distributed.deprecated.init_process_group(world_size=4, init_method='...')
68  >>> net = torch.nn.DistributedDataParallelCPU(model)
69  """
70 
71  def __init__(self, module):
72  super(DistributedDataParallelCPU, self).__init__()
73  self.module = module
74  self.sync_parameters()
75 
76  def allreduce_params():
77  if self.needs_reduction:
78  self.needs_reduction = False
79  buckets = defaultdict(list)
80  for param in self.module.parameters():
81  if param.requires_grad and param.grad is not None:
82  tp = type(param.data)
83  buckets[tp].append(param)
84 
85  for bucket in buckets.values():
86  grads = [param.grad.data for param in bucket]
87  coalesced = _flatten_dense_tensors(grads)
88  dist.all_reduce(coalesced)
89  coalesced /= dist.get_world_size()
90  for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
91  buf.copy_(synced)
92 
93  for param in list(self.module.parameters()):
94  @torch.utils.hooks.unserializable_hook
95  def allreduce_hook(*unused):
96  Variable._execution_engine.queue_callback(allreduce_params)
97 
98  if param.requires_grad:
99  param.register_hook(allreduce_hook)
100 
101  def sync_parameters(self):
102  for param in self.module.parameters():
103  dist.broadcast(param.data, 0)
104 
105  def forward(self, *inputs, **kwargs):
106  self.needs_reduction = True
107  return self.module(*inputs, **kwargs)