4 __all__ = [
'all_reduce',
'reduce',
'broadcast',
'all_gather',
'reduce_scatter']
9 def is_available(tensors):
11 for tensor
in tensors:
14 if not tensor.is_contiguous():
16 if not tensor.is_cuda:
18 device = tensor.get_device()
23 if not hasattr(
torch._C,
'_nccl_all_reduce'):
24 warnings.warn(
'PyTorch is not compiled with NCCL support')
31 return torch._C._nccl_version()
35 return torch._C._nccl_unique_id()
38 def init_rank(num_ranks, uid, rank):
39 return torch._C._nccl_init_rank(num_ranks, uid, rank)
42 def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
45 torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
48 def reduce(inputs, outputs=None, root=0, op=SUM, streams=None, comms=None):
51 torch._C._nccl_reduce(inputs, outputs, root, op, streams, comms)
54 def broadcast(inputs, root=0, streams=None, comms=None):
55 torch._C._nccl_broadcast(inputs, root, streams, comms)
58 def all_gather(inputs, outputs, streams=None, comms=None):
59 torch._C._nccl_all_gather(inputs, outputs, streams, comms)
62 def reduce_scatter(inputs, outputs, op=SUM, streams=None, comms=None):
63 torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)