1 #include <torch/csrc/cuda/python_nccl.h> 3 #include <torch/csrc/cuda/nccl.h> 4 #include <torch/csrc/DynamicTypes.h> 5 #include <torch/csrc/Exceptions.h> 6 #include <torch/csrc/THP.h> 7 #include <torch/csrc/Types.h> 8 #include <torch/csrc/cuda/THCP.h> 9 #include <torch/csrc/cuda/nccl.h> 10 #include <ATen/core/functional.h> 12 #include <c10/cuda/CUDAGuard.h> 17 #include <unordered_map> 20 using namespace torch;
24 static const char* COMM_CAPSULE_NAME =
"torch.cuda.nccl.Communicator";
26 PyObject* THCPModule_nccl_version(PyObject*
self, PyObject* args) {
27 return PyInt_FromLong(version());
30 PyObject* THCPModule_nccl_unique_id(PyObject*
self, PyObject* args) {
33 NCCL_CHECK(ncclGetUniqueId(&
id));
34 return PyBytes_FromStringAndSize((
char*)&
id, NCCL_UNIQUE_ID_BYTES);
38 static ncclComm_t unpack_nccl_comm(PyObject* capsule) {
40 (ncclComm_t)PyCapsule_GetPointer(capsule, COMM_CAPSULE_NAME);
46 static void destroy_nccl_comm(PyObject* capsule) {
58 ncclComm_t comm = unpack_nccl_comm(capsule);
59 with_no_gil([&] { ncclCommDestroy(comm); });
60 END_HANDLE_TH_ERRORS_RET()
63 static std::vector<c10::optional<at::cuda::CUDAStream>> unpack_streams(PyObject* obj,
size_t size) {
65 return std::vector<c10::optional<at::cuda::CUDAStream>>(size, c10::nullopt);
67 auto streams = THPUtils_PySequence_to_CUDAStreamList(obj);
68 if (streams.size() != size) {
69 throw std::runtime_error(
70 "number of streams is not equal to number of inputs");
75 static std::vector<at::Tensor> extract_tensors(PyObject* obj);
77 static std::vector<ncclComm_t> unpack_comms(PyObject* obj,
size_t size) {
79 return std::vector<ncclComm_t>();
81 std::vector<ncclComm_t> comms;
82 if (PyCapsule_CheckExact(obj)) {
83 comms = {unpack_nccl_comm(obj)};
85 auto seq =
THPObjectPtr(PySequence_Fast(obj,
"comm is not a sequence"));
88 auto size = PySequence_Fast_GET_SIZE(seq.get());
89 comms = std::vector<ncclComm_t>(size);
90 for (int64_t i = 0; i < size; i++) {
91 comms[i] = unpack_nccl_comm(PySequence_Fast_GET_ITEM(seq.get(), i));
94 if (comms.size() != size) {
95 throw std::runtime_error(
96 "number of communicators is not equal to number of inputs");
101 PyObject* THCPModule_nccl_init_rank(PyObject*
self, PyObject* args) {
108 if (!PyArg_ParseTuple(
109 args,
"is#i:nccl_init_rank", &nranks, &
id, &id_len, &rank)) {
113 id_len == NCCL_UNIQUE_ID_BYTES,
114 "invalid unqiue_id (expected %d bytes, got %zd)",
115 NCCL_UNIQUE_ID_BYTES,
119 memcpy(&commId,
id, NCCL_UNIQUE_ID_BYTES);
122 [&] { NCCL_CHECK(ncclCommInitRank(&comm, nranks, commId, rank)); });
123 return PyCapsule_New(comm, COMM_CAPSULE_NAME, &destroy_nccl_comm);
127 PyObject* THCPModule_nccl_reduce(PyObject*
self, PyObject* args) {
129 PyObject *_inputs, *_outputs, *_streams, *_comms;
132 if (!PyArg_ParseTuple(
141 THPUtils_invalidArguments(
146 "(sequence[Tensor] inputs, sequence[Tensor] outputs, int root," 147 " int op, sequence[torch.cuda.Stream or None]");
151 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
152 std::vector<at::Tensor> outputs = extract_tensors(_outputs);
153 std::vector<c10::optional<at::cuda::CUDAStream>> streams = unpack_streams(_streams, inputs.size());
154 auto user_comms = unpack_comms(_comms, inputs.size());
157 torch::cuda::nccl::reduce(inputs, outputs, root, op, streams, user_comms);
164 PyObject* THCPModule_nccl_all_reduce(PyObject*
self, PyObject* args) {
166 PyObject *_inputs, *_outputs, *_streams, *_comms;
169 if (!PyArg_ParseTuple(
170 args,
"OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
171 THPUtils_invalidArguments(
176 "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op," 177 " sequence[torch.cuda.Stream] streams," 178 " sequence[torch.cuda.nccl.Communicator] comms)");
182 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
183 std::vector<at::Tensor> outputs = extract_tensors(_outputs);
184 auto streams = unpack_streams(_streams, inputs.size());
185 auto user_comms = unpack_comms(_comms, inputs.size());
188 _check_inputs(inputs, outputs, 1, 1);
189 size_t len = inputs.size();
191 ncclDataType_t data_type = _get_data_type(inputs[0]);
193 int64_t count = inputs[0].numel();
194 std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
195 auto comms = user_comms.empty() ? _get_communicators(inputs)
199 for (
size_t i = 0; i < len; i++) {
200 int device = inputs[i].get_device();
202 auto stream = !streams[i]
203 ? at::cuda::getCurrentCUDAStream(device).stream()
204 : streams[i]->stream();
205 NCCL_CHECK(ncclAllReduce(
206 inputs[i].data_ptr(),
207 outputs[i].data_ptr(),
220 PyObject* THCPModule_nccl_broadcast(PyObject*
self, PyObject* args) {
222 PyObject *_inputs, *_streams, *_comms;
225 if (!PyArg_ParseTuple(args,
"OiOO", &_inputs, &root, &_streams, &_comms)) {
226 THPUtils_invalidArguments(
231 "(sequence[Tensor] inputs, int root)");
235 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
236 THPUtils_assert(root >= 0 && (
size_t)root < inputs.size(),
"invalid root");
237 auto streams = unpack_streams(_streams, inputs.size());
238 auto user_comms = unpack_comms(_comms, inputs.size());
241 [&] { torch::cuda::nccl::broadcast(inputs, streams, user_comms); });
247 PyObject* THCPModule_nccl_all_gather(PyObject*
self, PyObject* args) {
249 PyObject *_inputs, *_outputs, *_streams, *_comms;
251 if (!PyArg_ParseTuple(
252 args,
"OOOO", &_inputs, &_outputs, &_streams, &_comms)) {
253 THPUtils_invalidArguments(
258 "(sequence[Tensor] inputs, sequence[Tensor] outputs");
262 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
263 std::vector<at::Tensor> outputs = extract_tensors(_outputs);
264 auto streams = unpack_streams(_streams, inputs.size());
265 auto user_comms = unpack_comms(_comms, inputs.size());
268 size_t len = inputs.size();
269 _check_inputs(inputs, outputs, len, 1);
271 ncclDataType_t data_type = _get_data_type(inputs[0]);
273 int64_t count = inputs[0].numel();
274 std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
275 auto comms = user_comms.empty() ? _get_communicators(inputs)
279 for (
size_t i = 0; i < len; i++) {
280 int device = inputs[i].get_device();
282 auto stream = !streams[i]
283 ? at::cuda::getCurrentCUDAStream(device).stream()
284 : streams[i]->stream();
285 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) 286 NCCL_CHECK(ncclAllGather(
287 inputs[i].data_ptr(),
288 outputs[i].data_ptr(),
294 NCCL_CHECK(ncclAllGather(
295 inputs[i].data_ptr(),
298 outputs[i].data_ptr(),
309 PyObject* THCPModule_nccl_reduce_scatter(PyObject*
self, PyObject* args) {
311 PyObject *_inputs, *_outputs, *_streams, *_comms;
314 if (!PyArg_ParseTuple(
315 args,
"OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
316 THPUtils_invalidArguments(
319 "nccl_reduce_scatter",
321 "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op");
325 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
326 std::vector<at::Tensor> outputs = extract_tensors(_outputs);
327 auto streams = unpack_streams(_streams, inputs.size());
328 auto user_comms = unpack_comms(_comms, inputs.size());
331 size_t len = inputs.size();
332 _check_inputs(inputs, outputs, 1, len);
334 ncclDataType_t data_type = _get_data_type(inputs[0]);
336 int64_t count = inputs[0].numel() / len;
337 std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
338 auto comms = user_comms.empty() ? _get_communicators(inputs)
342 for (
size_t i = 0; i < len; i++) {
343 int device = inputs[i].get_device();
345 auto stream = !streams[i]
346 ? at::cuda::getCurrentCUDAStream(device).stream()
347 : streams[i]->stream();
348 NCCL_CHECK(ncclReduceScatter(
349 inputs[i].data_ptr(),
350 outputs[i].data_ptr(),
363 static std::vector<at::Tensor> extract_tensors(PyObject* obj) {
364 auto seq =
THPObjectPtr(PySequence_Fast(obj,
"expected a sequence"));
368 std::vector<at::Tensor> list;
369 Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
370 for (Py_ssize_t i = 0; i < length; i++) {
371 PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
372 if (!THPVariable_Check(item)) {
374 "expected Tensor at %d (got %s)", (
int)i, Py_TYPE(item)->tp_name);
377 list.emplace_back(var->cdata.data());
A variant of OptionalDeviceGuard that is specialized for CUDA.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Flush-To-Zero and Denormals-Are-Zero mode.
void set_index(DeviceIndex device_index)
Sets the CUDA device to the given device index, initializing the guard if it is not already initializ...