1 #include <torch/csrc/cuda/nccl.h> 2 #include <torch/csrc/cuda/device_set.h> 3 #include <ATen/core/functional.h> 4 #include <torch/csrc/utils/hash.h> 7 #include <c10/cuda/CUDAGuard.h> 8 #include <c10/util/Exception.h> 14 #include <type_traits> 15 #include <unordered_map> 25 void throw_nccl_error(ncclResult_t status) {
26 std::ostringstream err;
27 err <<
"NCCL Error " << status <<
": " << ncclGetErrorString(status);
28 throw std::runtime_error(err.str());
32 std::unique_ptr<ncclComm_t[]> comms;
35 : comms(
new ncclComm_t[devices.size()]), ndevices(devices.size()) {
36 NCCL_CHECK(ncclCommInitAll(comms.get(), devices.size(), devices.data()));
51 for (
int i = 0; i < ndevices; i++) {
53 if (cudaGetDevice(&dummy_var) != cudaSuccess) {
59 ncclCommDestroy(comms[i]);
68 using device_list = std::vector<int>;
70 static std::unordered_map<device_list, NcclCommList, torch::hash<device_list>>
74 static auto get_device = [](
const at::Tensor& t) ->
int {
77 device_list devices = fmap(inputs, get_device);
78 auto it = _communicators.find(devices);
79 if (it == _communicators.end())
80 std::tie(it, std::ignore) = _communicators.emplace(devices, devices);
81 return it->second.ref();
84 ncclDataType_t _get_data_type(
const Tensor& t) {
85 if (t.type().backend() != Backend::CUDA) {
86 throw std::runtime_error(
"Unconvertible NCCL type");
88 switch (t.scalar_type()) {
104 throw std::runtime_error(
"Unconvertible NCCL type");
111 int input_multiplier,
112 int output_multiplier) {
114 size_t len = inputs.
size();
117 throw std::runtime_error(
"input sequence can't be empty");
120 if (len != outputs.
size()) {
121 std::stringstream err;
122 err <<
"inputs and outputs sequences have to be of the same length, but got input of length " 123 << len <<
" and output of length " << outputs.
size();
124 throw std::runtime_error(err.str());
128 int64_t numel = inputs[0].numel();
129 auto& type = inputs[0].type();
131 for (
size_t i = 0; i < len; i++) {
132 auto input = inputs[i];
133 auto output = outputs[i];
135 if (!(input.is_cuda() && !input.is_sparse() &&
136 output.is_cuda() && !output.is_sparse())) {
137 throw std::runtime_error(
138 "input and output elements have to be cuda dense Tensors");
141 if (!(type == input.type() && type == output.type())) {
142 throw std::runtime_error(
143 "all inputs and outputs must be of the same Tensor type");
146 if (!input.is_contiguous() || !output.is_contiguous()) {
147 throw std::runtime_error(
"all inputs and outputs have to be contiguous");
150 auto input_device = input.get_device();
152 if (devices.test(input_device)) {
153 throw std::runtime_error(
"inputs must be on unique devices");
155 devices.set(input_device);
158 if (input_device != output.get_device()) {
159 throw std::runtime_error(
"input and output must be on the same device");
163 if (input.numel() != numel) {
164 throw std::runtime_error(
165 "all inputs must have the same number of elements");
168 if (output.numel() * output_multiplier != numel * input_multiplier) {
169 throw std::runtime_error(
170 "output must be of size input_size * size_multiplier");
180 for (
auto& tensor : tensors) {
181 auto& type = tensor.type();
182 if (!type.is_cuda() || type.is_sparse())
184 if (!tensor.is_contiguous())
186 auto device = tensor.get_device();
189 devices[device] =
true;
197 std::uint64_t version() {
198 #if defined(NCCL_MAJOR) 199 return NCCL_MAJOR * 1000 + NCCL_MINOR * 100 + NCCL_PATCH;
200 #elif defined(USE_NCCL) 213 template <
typename T>
214 struct GetSecondArgType;
216 template <
typename R,
typename Arg0,
typename Arg1,
typename... Args>
217 struct GetSecondArgType<R(Arg0, Arg1, Args...)> {
218 typedef typename std::decay<Arg1>::type type;
221 constexpr
auto count_max =
222 std::numeric_limits<GetSecondArgType<decltype(ncclBcast)>::type>::max();
225 size_t get_max_count() {
231 const stream_list& streams,
232 const comm_list& user_comms) {
235 _check_inputs(tensors, tensors, 1, 1);
236 ncclDataType_t data_type = _get_data_type(tensors[0]);
237 int64_t numel = tensors[0].numel();
239 std::lock_guard<std::mutex> free_mutex(
240 *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
241 const auto comms = user_comms.empty() ? _get_communicators(tensors)
246 for (
size_t i = 0, num_tensors = tensors.
size(); i < num_tensors; i++) {
247 int device = tensors[i].get_device();
248 device_guard.set_index(device);
250 const auto stream = (streams.empty() || !streams[i])
251 ? at::cuda::getCurrentCUDAStream(device).stream()
252 : streams[i]->stream();
254 static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max),
255 "Broadcast tensor has ",
257 " elements, which exceeds the " 258 "maximum NCCL supports (",
261 NCCL_CHECK(ncclBcast(
262 tensors[i].data_ptr(), numel, data_type, 0, comms[i], stream));
265 AT_ERROR(
"PyTorch built without NCCL support");
270 const std::vector<at::Tensor>& inputs,
271 std::vector<at::Tensor>& outputs,
274 const stream_list& streams,
275 const comm_list& user_comms) {
279 root >= 0 && static_cast<size_t>(root) < inputs.size(),
"invalid root");
281 _check_inputs(inputs, outputs, 1, 1);
282 const auto len = inputs.size();
284 ncclDataType_t data_type = _get_data_type(inputs[0]);
286 const auto count = inputs[0].numel();
287 std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
288 auto comms_ref = user_comms.empty() ? _get_communicators(inputs)
293 for (
size_t i = 0; i < len; i++) {
294 int device = inputs[i].device().index();
295 device_guard.set_index(device);
297 const auto stream = (streams.empty() || !streams[i])
298 ? at::cuda::getCurrentCUDAStream(device).stream()
299 : streams[i]->stream();
301 NCCL_CHECK(ncclReduce(
302 inputs[i].data_ptr(),
303 outputs[i].data_ptr(),
312 AT_ERROR(
"PyTorch built without NCCL support");
317 std::vector<at::Tensor>& inputs,
320 const stream_list& streams,
321 const comm_list& user_comms) {
322 reduce(inputs, inputs, root, op, streams, user_comms);
A variant of OptionalDeviceGuard that is specialized for CUDA.
int64_t get_device() const
Returns a Tensor's device index.
constexpr size_t size() const
size - Get the array size.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Flush-To-Zero and Denormals-Are-Zero mode.