Caffe2 - Python API
A deep learning, cross platform ML framework
parallel_apply.py
1 import threading
2 import torch
3 from torch.cuda._utils import _get_device_index
4 
5 
6 def get_a_var(obj):
7  if isinstance(obj, torch.Tensor):
8  return obj
9 
10  if isinstance(obj, list) or isinstance(obj, tuple):
11  for result in map(get_a_var, obj):
12  if isinstance(result, torch.Tensor):
13  return result
14  if isinstance(obj, dict):
15  for result in map(get_a_var, obj.items()):
16  if isinstance(result, torch.Tensor):
17  return result
18  return None
19 
20 
21 def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
22  r"""Applies each `module` in :attr:`modules` in parallel on arguments
23  contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
24  on each of :attr:`devices`.
25 
26  Args:
27  modules (Module): modules to be parallelized
28  inputs (tensor): inputs to the modules
29  devices (list of int or torch.device): CUDA devices
30 
31  :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
32  :attr:`devices` (if given) should all have same length. Moreover, each
33  element of :attr:`inputs` can either be a single object as the only argument
34  to a module, or a collection of positional arguments.
35  """
36  assert len(modules) == len(inputs)
37  if kwargs_tup is not None:
38  assert len(modules) == len(kwargs_tup)
39  else:
40  kwargs_tup = ({},) * len(modules)
41  if devices is not None:
42  assert len(modules) == len(devices)
43  else:
44  devices = [None] * len(modules)
45  devices = list(map(lambda x: _get_device_index(x, True), devices))
46  lock = threading.Lock()
47  results = {}
48  grad_enabled = torch.is_grad_enabled()
49 
50  def _worker(i, module, input, kwargs, device=None):
51  torch.set_grad_enabled(grad_enabled)
52  if device is None:
53  device = get_a_var(input).get_device()
54  try:
55  with torch.cuda.device(device):
56  # this also avoids accidental slicing of `input` if it is a Tensor
57  if not isinstance(input, (list, tuple)):
58  input = (input,)
59  output = module(*input, **kwargs)
60  with lock:
61  results[i] = output
62  except Exception as e:
63  with lock:
64  results[i] = e
65 
66  if len(modules) > 1:
67  threads = [threading.Thread(target=_worker,
68  args=(i, module, input, kwargs, device))
69  for i, (module, input, kwargs, device) in
70  enumerate(zip(modules, inputs, kwargs_tup, devices))]
71 
72  for thread in threads:
73  thread.start()
74  for thread in threads:
75  thread.join()
76  else:
77  _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
78 
79  outputs = []
80  for i in range(len(inputs)):
81  output = results[i]
82  if isinstance(output, Exception):
83  raise output
84  outputs.append(output)
85  return outputs