Caffe2 - Python API
A deep learning, cross platform ML framework
scatter_gather.py
1 import torch
2 from ._functions import Scatter, Gather
3 
4 
5 def scatter(inputs, target_gpus, dim=0):
6  r"""
7  Slices tensors into approximately equal chunks and
8  distributes them across given GPUs. Duplicates
9  references to objects that are not tensors.
10  """
11  def scatter_map(obj):
12  if isinstance(obj, torch.Tensor):
13  return Scatter.apply(target_gpus, None, dim, obj)
14  if isinstance(obj, tuple) and len(obj) > 0:
15  return list(zip(*map(scatter_map, obj)))
16  if isinstance(obj, list) and len(obj) > 0:
17  return list(map(list, zip(*map(scatter_map, obj))))
18  if isinstance(obj, dict) and len(obj) > 0:
19  return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
20  return [obj for targets in target_gpus]
21 
22  # After scatter_map is called, a scatter_map cell will exist. This cell
23  # has a reference to the actual function scatter_map, which has references
24  # to a closure that has a reference to the scatter_map cell (because the
25  # fn is recursive). To avoid this reference cycle, we set the function to
26  # None, clearing the cell
27  try:
28  return scatter_map(inputs)
29  finally:
30  scatter_map = None
31 
32 
33 def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
34  r"""Scatter with support for kwargs dictionary"""
35  inputs = scatter(inputs, target_gpus, dim) if inputs else []
36  kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
37  if len(inputs) < len(kwargs):
38  inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
39  elif len(kwargs) < len(inputs):
40  kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
41  inputs = tuple(inputs)
42  kwargs = tuple(kwargs)
43  return inputs, kwargs
44 
45 
46 def gather(outputs, target_device, dim=0):
47  r"""
48  Gathers tensors from different GPUs on a specified device
49  (-1 means the CPU).
50  """
51  def gather_map(outputs):
52  out = outputs[0]
53  if isinstance(out, torch.Tensor):
54  return Gather.apply(target_device, dim, *outputs)
55  if out is None:
56  return None
57  if isinstance(out, dict):
58  if not all((len(out) == len(d) for d in outputs)):
59  raise ValueError('All dicts must have the same number of keys')
60  return type(out)(((k, gather_map([d[k] for d in outputs]))
61  for k in out))
62  return type(out)(map(gather_map, zip(*outputs)))
63 
64  # Recursive function calls like this create reference cycles.
65  # Setting the function to None clears the refcycle.
66  try:
67  return gather_map(outputs)
68  finally:
69  gather_map = None