Caffe2 - Python API
A deep learning, cross platform ML framework
data_parallel.py
1 import operator
2 import torch
3 import warnings
4 from itertools import chain
5 from ..modules import Module
6 from .scatter_gather import scatter_kwargs, gather
7 from .replicate import replicate
8 from .parallel_apply import parallel_apply
9 from torch.cuda._utils import _get_device_index
10 
11 
12 def _check_balance(device_ids):
13  imbalance_warn = """
14  There is an imbalance between your GPUs. You may want to exclude GPU {} which
15  has less than 75% of the memory or cores of GPU {}. You can do so by setting
16  the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
17  environment variable."""
18  device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
19  dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
20 
21  def warn_imbalance(get_prop):
22  values = [get_prop(props) for props in dev_props]
23  min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
24  max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
25  if min_val / max_val < 0.75:
26  warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
27  return True
28  return False
29 
30  if warn_imbalance(lambda props: props.total_memory):
31  return
32  if warn_imbalance(lambda props: props.multi_processor_count):
33  return
34 
35 
36 class DataParallel(Module):
37  r"""Implements data parallelism at the module level.
38 
39  This container parallelizes the application of the given :attr:`module` by
40  splitting the input across the specified devices by chunking in the batch
41  dimension (other objects will be copied once per device). In the forward
42  pass, the module is replicated on each device, and each replica handles a
43  portion of the input. During the backwards pass, gradients from each replica
44  are summed into the original module.
45 
46  The batch size should be larger than the number of GPUs used.
47 
48  See also: :ref:`cuda-nn-dataparallel-instead`
49 
50  Arbitrary positional and keyword inputs are allowed to be passed into
51  DataParallel but some types are specially handled. tensors will be
52  **scattered** on dim specified (default 0). tuple, list and dict types will
53  be shallow copied. The other types will be shared among different threads
54  and can be corrupted if written to in the model's forward pass.
55 
56  The parallelized :attr:`module` must have its parameters and buffers on
57  ``device_ids[0]`` before running this :class:`~torch.nn.DataParallel`
58  module.
59 
60  .. warning::
61  In each forward, :attr:`module` is **replicated** on each device, so any
62  updates to the running module in ``forward`` will be lost. For example,
63  if :attr:`module` has a counter attribute that is incremented in each
64  ``forward``, it will always stay at the initial value because the update
65  is done on the replicas which are destroyed after ``forward``. However,
66  :class:`~torch.nn.DataParallel` guarantees that the replica on
67  ``device[0]`` will have its parameters and buffers sharing storage with
68  the base parallelized :attr:`module`. So **in-place** updates to the
69  parameters or buffers on ``device[0]`` will be recorded. E.g.,
70  :class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm`
71  rely on this behavior to update the buffers.
72 
73  .. warning::
74  Forward and backward hooks defined on :attr:`module` and its submodules
75  will be invoked ``len(device_ids)`` times, each with inputs located on
76  a particular device. Particularly, the hooks are only guaranteed to be
77  executed in correct order with respect to operations on corresponding
78  devices. For example, it is not guaranteed that hooks set via
79  :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before
80  `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but
81  that each such hook be executed before the corresponding
82  :meth:`~torch.nn.Module.forward` call of that device.
83 
84  .. warning::
85  When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
86  :func:`forward`, this wrapper will return a vector of length equal to
87  number of devices used in data parallelism, containing the result from
88  each device.
89 
90  .. note::
91  There is a subtlety in using the
92  ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
93  :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
94  See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for
95  details.
96 
97 
98  Args:
99  module (Module): module to be parallelized
100  device_ids (list of int or torch.device): CUDA devices (default: all devices)
101  output_device (int or torch.device): device location of output (default: device_ids[0])
102 
103  Attributes:
104  module (Module): the module to be parallelized
105 
106  Example::
107 
108  >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
109  >>> output = net(input_var) # input_var can be on any device, including CPU
110  """
111 
112  # TODO: update notes/cuda.rst when this class handles 8+ GPUs well
113 
114  def __init__(self, module, device_ids=None, output_device=None, dim=0):
115  super(DataParallel, self).__init__()
116 
117  if not torch.cuda.is_available():
118  self.module = module
119  self.device_ids = []
120  return
121 
122  if device_ids is None:
123  device_ids = list(range(torch.cuda.device_count()))
124  if output_device is None:
125  output_device = device_ids[0]
126 
127  self.dim = dim
128  self.module = module
129  self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
130  self.output_device = _get_device_index(output_device, True)
131  self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0]))
132 
133  _check_balance(self.device_ids)
134 
135  if len(self.device_ids) == 1:
136  self.module.cuda(device_ids[0])
137 
138  def forward(self, *inputs, **kwargs):
139  if not self.device_ids:
140  return self.module(*inputs, **kwargs)
141 
142  for t in chain(self.module.parameters(), self.module.buffers()):
143  if t.device != self.src_device_obj:
144  raise RuntimeError("module must have its parameters and buffers "
145  "on device {} (device_ids[0]) but found one of "
146  "them on device: {}".format(self.src_device_obj, t.device))
147 
148  inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
149  if len(self.device_ids) == 1:
150  return self.module(*inputs[0], **kwargs[0])
151  replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
152  outputs = self.parallel_apply(replicas, inputs, kwargs)
153  return self.gather(outputs, self.output_device)
154 
155  def replicate(self, module, device_ids):
156  return replicate(module, device_ids)
157 
158  def scatter(self, inputs, kwargs, device_ids):
159  return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
160 
161  def parallel_apply(self, replicas, inputs, kwargs):
162  return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
163 
164  def gather(self, outputs, output_device):
165  return gather(outputs, output_device, dim=self.dim)
166 
167 
168 def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
169  r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
170 
171  This is the functional version of the DataParallel module.
172 
173  Args:
174  module (Module): the module to evaluate in parallel
175  inputs (Tensor): inputs to the module
176  device_ids (list of int or torch.device): GPU ids on which to replicate module
177  output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
178  (default: device_ids[0])
179  Returns:
180  a Tensor containing the result of module(input) located on
181  output_device
182  """
183  if not isinstance(inputs, tuple):
184  inputs = (inputs,)
185 
186  if device_ids is None:
187  device_ids = list(range(torch.cuda.device_count()))
188 
189  if output_device is None:
190  output_device = device_ids[0]
191 
192  device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
193  output_device = _get_device_index(output_device, True)
194  src_device_obj = torch.device("cuda:{}".format(device_ids[0]))
195 
196  for t in chain(module.parameters(), module.buffers()):
197  if t.device != src_device_obj:
198  raise RuntimeError("module must have its parameters and buffers "
199  "on device {} (device_ids[0]) but found one of "
200  "them on device: {}".format(src_device_obj, t.device))
201 
202  inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
203  if len(device_ids) == 1:
204  return module(*inputs[0], **module_kwargs[0])
205  used_device_ids = device_ids[:len(inputs)]
206  replicas = replicate(module, used_device_ids)
207  outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
208  return gather(outputs, output_device, dim)
def parallel_apply(self, replicas, inputs, kwargs)
def get_device_properties(device)
Definition: __init__.py:297
def replicate(self, module, device_ids)
def is_available()
Definition: __init__.py:45
def device_count()
Definition: __init__.py:341
def gather(self, outputs, output_device)
def scatter(self, inputs, kwargs, device_ids)