4 #include <ATen/cuda/CUDAContext.h> 6 #include <c10/util/Optional.h> 21 void throw_nccl_error(ncclResult_t status);
23 static inline void NCCL_CHECK(ncclResult_t status) {
24 if (status != ncclSuccess) {
25 throw_nccl_error(status);
31 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) 32 NCCL_CHECK(ncclGroupStart());
36 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) 37 NCCL_CHECK(ncclGroupEnd());
47 int output_multiplier);
48 ncclDataType_t _get_data_type(
const at::Tensor& t);
52 using comm_list = std::vector<ncclComm_t>;
53 using stream_list = std::vector<c10::optional<at::cuda::CUDAStream>>;
55 std::uint64_t version();
61 const stream_list& streams = {},
62 const comm_list& user_comms = {});
64 size_t get_max_count();
67 const std::vector<at::Tensor>& inputs,
68 std::vector<at::Tensor>& outputs,
71 const stream_list& streams = {},
72 const comm_list& user_comms = {});
75 std::vector<at::Tensor>& inputs,
78 const stream_list& streams = {},
79 const comm_list& user_comms = {});
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...