12 def forward(ctx, target_gpus, *inputs):
13 if not all(input.is_cuda
for input
in inputs):
14 raise TypeError(
'Broadcast function not implemented for CPU tensors')
15 target_gpus = list(map(
lambda x: _get_device_index(x,
True), target_gpus))
16 ctx.target_gpus = target_gpus
19 ctx.num_inputs = len(inputs)
20 ctx.input_device = inputs[0].get_device()
21 outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
22 non_differentiables = []
23 for idx, input_requires_grad
in enumerate(ctx.needs_input_grad[1:]):
24 if not input_requires_grad:
25 for output
in outputs:
26 non_differentiables.append(output[idx])
27 ctx.mark_non_differentiable(*non_differentiables)
28 return tuple([t
for tensors
in outputs
for t
in tensors])
31 def backward(ctx, *grad_outputs):
32 return (
None,) + ReduceAddCoalesced.apply(ctx.input_device, ctx.num_inputs, *grad_outputs)
38 def forward(ctx, destination, num_inputs, *grads):
39 ctx.target_gpus = [grads[i].get_device()
for i
in range(0, len(grads), num_inputs)]
41 grads = [grads[i:i + num_inputs]
42 for i
in range(0, len(grads), num_inputs)]
43 return comm.reduce_add_coalesced(grads, destination)
46 def backward(ctx, *grad_outputs):
47 return (
None,
None,) + Broadcast.apply(ctx.target_gpus, *grad_outputs)
53 def forward(ctx, target_device, dim, *inputs):
54 assert all(map(
lambda i: i.is_cuda, inputs))
55 target_device = _get_device_index(target_device,
True)
56 ctx.target_device = target_device
58 ctx.input_gpus = tuple(map(
lambda i: i.get_device(), inputs))
59 if all(t.dim() == 0
for t
in inputs)
and dim == 0:
60 inputs = tuple(t.view(1)
for t
in inputs)
61 warnings.warn(
'Was asked to gather along dimension 0, but all ' 62 'input tensors were scalars; will instead unsqueeze ' 63 'and return a vector.')
64 ctx.unsqueezed_scalar =
True 66 ctx.unsqueezed_scalar =
False 67 ctx.input_sizes = tuple(map(
lambda i: i.size(ctx.dim), inputs))
68 return comm.gather(inputs, ctx.dim, ctx.target_device)
71 def backward(ctx, grad_output):
72 scattered_grads = Scatter.apply(ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output)
73 if ctx.unsqueezed_scalar:
74 scattered_grads = tuple(g[0]
for g
in scattered_grads)
75 return (
None,
None) + scattered_grads
81 def forward(ctx, target_gpus, chunk_sizes, dim, input):
82 target_gpus = list(map(
lambda x: _get_device_index(x,
True), target_gpus))
84 ctx.input_device = input.get_device()
if input.is_cuda
else -1
86 if ctx.input_device == -1:
88 streams = [_get_stream(device)
for device
in target_gpus]
89 outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
91 if streams
is not None:
92 for i, output
in enumerate(outputs):
95 main_stream.wait_stream(streams[i])
96 output.record_stream(main_stream)
100 def backward(ctx, *grad_output):
101 return None,
None,
None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
108 def _get_stream(device):
109 """Gets a background stream for copying between CPU and GPU""" 115 if _streams[device]
is None:
116 _streams[device] = torch.cuda.Stream(device)
117 return _streams[device]
def current_stream(device=None)