7 if isinstance(obj, torch.Tensor):
10 if isinstance(obj, list)
or isinstance(obj, tuple):
11 for result
in map(get_a_var, obj):
12 if isinstance(result, torch.Tensor):
14 if isinstance(obj, dict):
15 for result
in map(get_a_var, obj.items()):
16 if isinstance(result, torch.Tensor):
21 def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
22 r"""Applies each `module` in :attr:`modules` in parallel on arguments 23 contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) 24 on each of :attr:`devices`. 27 modules (Module): modules to be parallelized 28 inputs (tensor): inputs to the modules 29 devices (list of int or torch.device): CUDA devices 31 :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and 32 :attr:`devices` (if given) should all have same length. Moreover, each 33 element of :attr:`inputs` can either be a single object as the only argument 34 to a module, or a collection of positional arguments. 36 assert len(modules) == len(inputs)
37 if kwargs_tup
is not None:
38 assert len(modules) == len(kwargs_tup)
40 kwargs_tup = ({},) * len(modules)
41 if devices
is not None:
42 assert len(modules) == len(devices)
44 devices = [
None] * len(modules)
45 devices = list(map(
lambda x: _get_device_index(x,
True), devices))
46 lock = threading.Lock()
48 grad_enabled = torch.is_grad_enabled()
50 def _worker(i, module, input, kwargs, device=None):
51 torch.set_grad_enabled(grad_enabled)
53 device = get_a_var(input).get_device()
57 if not isinstance(input, (list, tuple)):
59 output = module(*input, **kwargs)
62 except Exception
as e:
67 threads = [threading.Thread(target=_worker,
68 args=(i, module, input, kwargs, device))
69 for i, (module, input, kwargs, device)
in 70 enumerate(zip(modules, inputs, kwargs_tup, devices))]
72 for thread
in threads:
74 for thread
in threads:
77 _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
80 for i
in range(len(inputs)):
82 if isinstance(output, Exception):
84 outputs.append(output)