2 from ._functions
import Scatter, Gather
5 def scatter(inputs, target_gpus, dim=0):
7 Slices tensors into approximately equal chunks and 8 distributes them across given GPUs. Duplicates 9 references to objects that are not tensors. 12 if isinstance(obj, torch.Tensor):
13 return Scatter.apply(target_gpus,
None, dim, obj)
14 if isinstance(obj, tuple)
and len(obj) > 0:
15 return list(zip(*map(scatter_map, obj)))
16 if isinstance(obj, list)
and len(obj) > 0:
17 return list(map(list, zip(*map(scatter_map, obj))))
18 if isinstance(obj, dict)
and len(obj) > 0:
19 return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
20 return [obj
for targets
in target_gpus]
28 return scatter_map(inputs)
33 def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
34 r"""Scatter with support for kwargs dictionary""" 35 inputs = scatter(inputs, target_gpus, dim)
if inputs
else []
36 kwargs = scatter(kwargs, target_gpus, dim)
if kwargs
else []
37 if len(inputs) < len(kwargs):
38 inputs.extend([()
for _
in range(len(kwargs) - len(inputs))])
39 elif len(kwargs) < len(inputs):
40 kwargs.extend([{}
for _
in range(len(inputs) - len(kwargs))])
41 inputs = tuple(inputs)
42 kwargs = tuple(kwargs)
46 def gather(outputs, target_device, dim=0):
48 Gathers tensors from different GPUs on a specified device 51 def gather_map(outputs):
53 if isinstance(out, torch.Tensor):
54 return Gather.apply(target_device, dim, *outputs)
57 if isinstance(out, dict):
58 if not all((len(out) == len(d)
for d
in outputs)):
59 raise ValueError(
'All dicts must have the same number of keys')
60 return type(out)(((k, gather_map([d[k]
for d
in outputs]))
62 return type(out)(map(gather_map, zip(*outputs)))
67 return gather_map(outputs)