4 from itertools
import chain
5 from ..modules
import Module
6 from .scatter_gather
import scatter_kwargs, gather
7 from .replicate
import replicate
8 from .parallel_apply
import parallel_apply
12 def _check_balance(device_ids):
14 There is an imbalance between your GPUs. You may want to exclude GPU {} which 15 has less than 75% of the memory or cores of GPU {}. You can do so by setting 16 the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES 17 environment variable.""" 18 device_ids = list(map(
lambda x: _get_device_index(x,
True), device_ids))
21 def warn_imbalance(get_prop):
22 values = [get_prop(props)
for props
in dev_props]
23 min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
24 max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
25 if min_val / max_val < 0.75:
26 warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
30 if warn_imbalance(
lambda props: props.total_memory):
32 if warn_imbalance(
lambda props: props.multi_processor_count):
37 r"""Implements data parallelism at the module level. 39 This container parallelizes the application of the given :attr:`module` by 40 splitting the input across the specified devices by chunking in the batch 41 dimension (other objects will be copied once per device). In the forward 42 pass, the module is replicated on each device, and each replica handles a 43 portion of the input. During the backwards pass, gradients from each replica 44 are summed into the original module. 46 The batch size should be larger than the number of GPUs used. 48 See also: :ref:`cuda-nn-dataparallel-instead` 50 Arbitrary positional and keyword inputs are allowed to be passed into 51 DataParallel but some types are specially handled. tensors will be 52 **scattered** on dim specified (default 0). tuple, list and dict types will 53 be shallow copied. The other types will be shared among different threads 54 and can be corrupted if written to in the model's forward pass. 56 The parallelized :attr:`module` must have its parameters and buffers on 57 ``device_ids[0]`` before running this :class:`~torch.nn.DataParallel` 61 In each forward, :attr:`module` is **replicated** on each device, so any 62 updates to the running module in ``forward`` will be lost. For example, 63 if :attr:`module` has a counter attribute that is incremented in each 64 ``forward``, it will always stay at the initial value because the update 65 is done on the replicas which are destroyed after ``forward``. However, 66 :class:`~torch.nn.DataParallel` guarantees that the replica on 67 ``device[0]`` will have its parameters and buffers sharing storage with 68 the base parallelized :attr:`module`. So **in-place** updates to the 69 parameters or buffers on ``device[0]`` will be recorded. E.g., 70 :class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm` 71 rely on this behavior to update the buffers. 74 Forward and backward hooks defined on :attr:`module` and its submodules 75 will be invoked ``len(device_ids)`` times, each with inputs located on 76 a particular device. Particularly, the hooks are only guaranteed to be 77 executed in correct order with respect to operations on corresponding 78 devices. For example, it is not guaranteed that hooks set via 79 :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before 80 `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but 81 that each such hook be executed before the corresponding 82 :meth:`~torch.nn.Module.forward` call of that device. 85 When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in 86 :func:`forward`, this wrapper will return a vector of length equal to 87 number of devices used in data parallelism, containing the result from 91 There is a subtlety in using the 92 ``pack sequence -> recurrent network -> unpack sequence`` pattern in a 93 :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. 94 See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for 99 module (Module): module to be parallelized 100 device_ids (list of int or torch.device): CUDA devices (default: all devices) 101 output_device (int or torch.device): device location of output (default: device_ids[0]) 104 module (Module): the module to be parallelized 108 >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) 109 >>> output = net(input_var) # input_var can be on any device, including CPU 114 def __init__(self, module, device_ids=None, output_device=None, dim=0):
115 super(DataParallel, self).__init__()
122 if device_ids
is None:
124 if output_device
is None:
125 output_device = device_ids[0]
129 self.
device_ids = list(map(
lambda x: _get_device_index(x,
True), device_ids))
136 self.module.cuda(device_ids[0])
138 def forward(self, *inputs, **kwargs):
140 return self.
module(*inputs, **kwargs)
142 for t
in chain(self.module.parameters(), self.module.buffers()):
144 raise RuntimeError(
"module must have its parameters and buffers " 145 "on device {} (device_ids[0]) but found one of " 150 return self.
module(*inputs[0], **kwargs[0])
155 def replicate(self, module, device_ids):
156 return replicate(module, device_ids)
158 def scatter(self, inputs, kwargs, device_ids):
159 return scatter_kwargs(inputs, kwargs, device_ids, dim=self.
dim)
161 def parallel_apply(self, replicas, inputs, kwargs):
162 return parallel_apply(replicas, inputs, kwargs, self.
device_ids[:len(replicas)])
164 def gather(self, outputs, output_device):
165 return gather(outputs, output_device, dim=self.
dim)
168 def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
169 r"""Evaluates module(input) in parallel across the GPUs given in device_ids. 171 This is the functional version of the DataParallel module. 174 module (Module): the module to evaluate in parallel 175 inputs (Tensor): inputs to the module 176 device_ids (list of int or torch.device): GPU ids on which to replicate module 177 output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU. 178 (default: device_ids[0]) 180 a Tensor containing the result of module(input) located on 183 if not isinstance(inputs, tuple):
186 if device_ids
is None:
189 if output_device
is None:
190 output_device = device_ids[0]
192 device_ids = list(map(
lambda x: _get_device_index(x,
True), device_ids))
193 output_device = _get_device_index(output_device,
True)
194 src_device_obj = torch.device(
"cuda:{}".format(device_ids[0]))
196 for t
in chain(module.parameters(), module.buffers()):
197 if t.device != src_device_obj:
198 raise RuntimeError(
"module must have its parameters and buffers " 199 "on device {} (device_ids[0]) but found one of " 200 "them on device: {}".format(src_device_obj, t.device))
202 inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
203 if len(device_ids) == 1:
204 return module(*inputs[0], **module_kwargs[0])
205 used_device_ids = device_ids[:len(inputs)]
206 replicas = replicate(module, used_device_ids)
207 outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
208 return gather(outputs, output_device, dim)
def parallel_apply(self, replicas, inputs, kwargs)
def get_device_properties(device)
def replicate(self, module, device_ids)
def gather(self, outputs, output_device)
def scatter(self, inputs, kwargs, device_ids)