Caffe2 - C++ API
A deep learning, cross platform ML framework
python_nccl.cpp
1 #include <torch/csrc/cuda/python_nccl.h>
2 
3 #include <torch/csrc/cuda/nccl.h>
4 #include <torch/csrc/DynamicTypes.h>
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/THP.h>
7 #include <torch/csrc/Types.h>
8 #include <torch/csrc/cuda/THCP.h>
9 #include <torch/csrc/cuda/nccl.h>
10 #include <ATen/core/functional.h>
11 
12 #include <c10/cuda/CUDAGuard.h>
13 
14 #include <nccl.h>
15 
16 #include <sstream>
17 #include <unordered_map>
18 
19 using namespace at;
20 using namespace torch;
21 using namespace torch::cuda::nccl;
22 using namespace torch::cuda::nccl::detail;
23 
24 static const char* COMM_CAPSULE_NAME = "torch.cuda.nccl.Communicator";
25 
26 PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args) {
27  return PyInt_FromLong(version());
28 }
29 
30 PyObject* THCPModule_nccl_unique_id(PyObject* self, PyObject* args) {
31  HANDLE_TH_ERRORS
32  ncclUniqueId id;
33  NCCL_CHECK(ncclGetUniqueId(&id));
34  return PyBytes_FromStringAndSize((char*)&id, NCCL_UNIQUE_ID_BYTES);
35  END_HANDLE_TH_ERRORS
36 }
37 
38 static ncclComm_t unpack_nccl_comm(PyObject* capsule) {
39  ncclComm_t comm =
40  (ncclComm_t)PyCapsule_GetPointer(capsule, COMM_CAPSULE_NAME);
41  if (!comm)
42  throw python_error();
43  return comm;
44 }
45 
46 static void destroy_nccl_comm(PyObject* capsule) {
47  /*
48  * TODO(T30279827) Temporarily disable calling ncclCommDestroy
49  * Calling ncclCommDestroy while program exiting is undefined
50  * according to Nvidia, and lead to segfault in NCCL 2
51  * (whether it is called before or after the CUDA runtime destructor).
52  * Temporarily disable it in destructor to avoid segfault.
53  * Following up with Nvidia for long term solution.
54  */
55  return;
56 
57  HANDLE_TH_ERRORS
58  ncclComm_t comm = unpack_nccl_comm(capsule);
59  with_no_gil([&] { ncclCommDestroy(comm); });
60  END_HANDLE_TH_ERRORS_RET()
61 }
62 
63 static std::vector<c10::optional<at::cuda::CUDAStream>> unpack_streams(PyObject* obj, size_t size) {
64  if (obj == Py_None) {
65  return std::vector<c10::optional<at::cuda::CUDAStream>>(size, c10::nullopt);
66  }
67  auto streams = THPUtils_PySequence_to_CUDAStreamList(obj);
68  if (streams.size() != size) {
69  throw std::runtime_error(
70  "number of streams is not equal to number of inputs");
71  }
72  return streams;
73 }
74 
75 static std::vector<at::Tensor> extract_tensors(PyObject* obj);
76 
77 static std::vector<ncclComm_t> unpack_comms(PyObject* obj, size_t size) {
78  if (obj == Py_None) {
79  return std::vector<ncclComm_t>();
80  }
81  std::vector<ncclComm_t> comms;
82  if (PyCapsule_CheckExact(obj)) {
83  comms = {unpack_nccl_comm(obj)};
84  } else {
85  auto seq = THPObjectPtr(PySequence_Fast(obj, "comm is not a sequence"));
86  if (!seq)
87  throw python_error();
88  auto size = PySequence_Fast_GET_SIZE(seq.get());
89  comms = std::vector<ncclComm_t>(size);
90  for (int64_t i = 0; i < size; i++) {
91  comms[i] = unpack_nccl_comm(PySequence_Fast_GET_ITEM(seq.get(), i));
92  }
93  }
94  if (comms.size() != size) {
95  throw std::runtime_error(
96  "number of communicators is not equal to number of inputs");
97  }
98  return comms;
99 }
100 
101 PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
102  HANDLE_TH_ERRORS
103  int nranks;
104  const char* id;
105  Py_ssize_t id_len;
106  int rank;
107 
108  if (!PyArg_ParseTuple(
109  args, "is#i:nccl_init_rank", &nranks, &id, &id_len, &rank)) {
110  return nullptr;
111  }
112  THPUtils_assert(
113  id_len == NCCL_UNIQUE_ID_BYTES,
114  "invalid unqiue_id (expected %d bytes, got %zd)",
115  NCCL_UNIQUE_ID_BYTES,
116  id_len);
117 
118  ncclUniqueId commId;
119  memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES);
120  ncclComm_t comm;
121  with_no_gil(
122  [&] { NCCL_CHECK(ncclCommInitRank(&comm, nranks, commId, rank)); });
123  return PyCapsule_New(comm, COMM_CAPSULE_NAME, &destroy_nccl_comm);
124  END_HANDLE_TH_ERRORS
125 }
126 
127 PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) {
128  HANDLE_TH_ERRORS
129  PyObject *_inputs, *_outputs, *_streams, *_comms;
130  int root, op;
131 
132  if (!PyArg_ParseTuple(
133  args,
134  "OOiiOO",
135  &_inputs,
136  &_outputs,
137  &root,
138  &op,
139  &_streams,
140  &_comms)) {
141  THPUtils_invalidArguments(
142  args,
143  nullptr,
144  "nccl_reduce",
145  1,
146  "(sequence[Tensor] inputs, sequence[Tensor] outputs, int root,"
147  " int op, sequence[torch.cuda.Stream or None]");
148  return nullptr;
149  }
150 
151  std::vector<at::Tensor> inputs = extract_tensors(_inputs);
152  std::vector<at::Tensor> outputs = extract_tensors(_outputs);
153  std::vector<c10::optional<at::cuda::CUDAStream>> streams = unpack_streams(_streams, inputs.size());
154  auto user_comms = unpack_comms(_comms, inputs.size());
155 
156  with_no_gil([&] {
157  torch::cuda::nccl::reduce(inputs, outputs, root, op, streams, user_comms);
158  });
159 
160  Py_RETURN_NONE;
161  END_HANDLE_TH_ERRORS
162 }
163 
164 PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
165  HANDLE_TH_ERRORS
166  PyObject *_inputs, *_outputs, *_streams, *_comms;
167  int op;
168 
169  if (!PyArg_ParseTuple(
170  args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
171  THPUtils_invalidArguments(
172  args,
173  nullptr,
174  "nccl_all_reduce",
175  1,
176  "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op,"
177  " sequence[torch.cuda.Stream] streams,"
178  " sequence[torch.cuda.nccl.Communicator] comms)");
179  return nullptr;
180  }
181 
182  std::vector<at::Tensor> inputs = extract_tensors(_inputs);
183  std::vector<at::Tensor> outputs = extract_tensors(_outputs);
184  auto streams = unpack_streams(_streams, inputs.size());
185  auto user_comms = unpack_comms(_comms, inputs.size());
186 
187  with_no_gil([&] {
188  _check_inputs(inputs, outputs, 1, 1);
189  size_t len = inputs.size();
190 
191  ncclDataType_t data_type = _get_data_type(inputs[0]);
192 
193  int64_t count = inputs[0].numel();
194  std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
195  auto comms = user_comms.empty() ? _get_communicators(inputs)
196  : ArrayRef<ncclComm_t>(user_comms);
197  at::cuda::OptionalCUDAGuard device_guard;
198  AutoNcclGroup nccl_group_guard;
199  for (size_t i = 0; i < len; i++) {
200  int device = inputs[i].get_device();
201  device_guard.set_index(device);
202  auto stream = !streams[i]
203  ? at::cuda::getCurrentCUDAStream(device).stream()
204  : streams[i]->stream();
205  NCCL_CHECK(ncclAllReduce(
206  inputs[i].data_ptr(),
207  outputs[i].data_ptr(),
208  count,
209  data_type,
210  (ncclRedOp_t)op,
211  comms[i],
212  stream));
213  }
214  });
215 
216  Py_RETURN_NONE;
217  END_HANDLE_TH_ERRORS
218 }
219 
220 PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
221  HANDLE_TH_ERRORS
222  PyObject *_inputs, *_streams, *_comms;
223  int root;
224 
225  if (!PyArg_ParseTuple(args, "OiOO", &_inputs, &root, &_streams, &_comms)) {
226  THPUtils_invalidArguments(
227  args,
228  nullptr,
229  "nccl_broadcast",
230  1,
231  "(sequence[Tensor] inputs, int root)");
232  return nullptr;
233  }
234 
235  std::vector<at::Tensor> inputs = extract_tensors(_inputs);
236  THPUtils_assert(root >= 0 && (size_t)root < inputs.size(), "invalid root");
237  auto streams = unpack_streams(_streams, inputs.size());
238  auto user_comms = unpack_comms(_comms, inputs.size());
239 
240  with_no_gil(
241  [&] { torch::cuda::nccl::broadcast(inputs, streams, user_comms); });
242 
243  Py_RETURN_NONE;
244  END_HANDLE_TH_ERRORS
245 }
246 
247 PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
248  HANDLE_TH_ERRORS
249  PyObject *_inputs, *_outputs, *_streams, *_comms;
250 
251  if (!PyArg_ParseTuple(
252  args, "OOOO", &_inputs, &_outputs, &_streams, &_comms)) {
253  THPUtils_invalidArguments(
254  args,
255  nullptr,
256  "nccl_all_gather",
257  1,
258  "(sequence[Tensor] inputs, sequence[Tensor] outputs");
259  return nullptr;
260  }
261 
262  std::vector<at::Tensor> inputs = extract_tensors(_inputs);
263  std::vector<at::Tensor> outputs = extract_tensors(_outputs);
264  auto streams = unpack_streams(_streams, inputs.size());
265  auto user_comms = unpack_comms(_comms, inputs.size());
266 
267  with_no_gil([&] {
268  size_t len = inputs.size();
269  _check_inputs(inputs, outputs, len, 1);
270 
271  ncclDataType_t data_type = _get_data_type(inputs[0]);
272 
273  int64_t count = inputs[0].numel();
274  std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
275  auto comms = user_comms.empty() ? _get_communicators(inputs)
276  : ArrayRef<ncclComm_t>(user_comms);
277  at::cuda::OptionalCUDAGuard device_guard;
278  AutoNcclGroup nccl_group_guard;
279  for (size_t i = 0; i < len; i++) {
280  int device = inputs[i].get_device();
281  device_guard.set_index(device);
282  auto stream = !streams[i]
283  ? at::cuda::getCurrentCUDAStream(device).stream()
284  : streams[i]->stream();
285 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
286  NCCL_CHECK(ncclAllGather(
287  inputs[i].data_ptr(),
288  outputs[i].data_ptr(),
289  count,
290  data_type,
291  comms[i],
292  stream));
293 #else
294  NCCL_CHECK(ncclAllGather(
295  inputs[i].data_ptr(),
296  count,
297  data_type,
298  outputs[i].data_ptr(),
299  comms[i],
300  stream));
301 #endif
302  }
303  });
304 
305  Py_RETURN_NONE;
306  END_HANDLE_TH_ERRORS
307 }
308 
309 PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
310  HANDLE_TH_ERRORS
311  PyObject *_inputs, *_outputs, *_streams, *_comms;
312  int op;
313 
314  if (!PyArg_ParseTuple(
315  args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
316  THPUtils_invalidArguments(
317  args,
318  nullptr,
319  "nccl_reduce_scatter",
320  1,
321  "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op");
322  return nullptr;
323  }
324 
325  std::vector<at::Tensor> inputs = extract_tensors(_inputs);
326  std::vector<at::Tensor> outputs = extract_tensors(_outputs);
327  auto streams = unpack_streams(_streams, inputs.size());
328  auto user_comms = unpack_comms(_comms, inputs.size());
329 
330  with_no_gil([&] {
331  size_t len = inputs.size();
332  _check_inputs(inputs, outputs, 1, len);
333 
334  ncclDataType_t data_type = _get_data_type(inputs[0]);
335 
336  int64_t count = inputs[0].numel() / len;
337  std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
338  auto comms = user_comms.empty() ? _get_communicators(inputs)
339  : ArrayRef<ncclComm_t>(user_comms);
340  at::cuda::OptionalCUDAGuard device_guard;
341  AutoNcclGroup nccl_group_guard;
342  for (size_t i = 0; i < len; i++) {
343  int device = inputs[i].get_device();
344  device_guard.set_index(device);
345  auto stream = !streams[i]
346  ? at::cuda::getCurrentCUDAStream(device).stream()
347  : streams[i]->stream();
348  NCCL_CHECK(ncclReduceScatter(
349  inputs[i].data_ptr(),
350  outputs[i].data_ptr(),
351  count,
352  data_type,
353  (ncclRedOp_t)op,
354  comms[i],
355  stream));
356  }
357  });
358 
359  Py_RETURN_NONE;
360  END_HANDLE_TH_ERRORS
361 }
362 
363 static std::vector<at::Tensor> extract_tensors(PyObject* obj) {
364  auto seq = THPObjectPtr(PySequence_Fast(obj, "expected a sequence"));
365  if (!seq)
366  throw python_error();
367 
368  std::vector<at::Tensor> list;
369  Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
370  for (Py_ssize_t i = 0; i < length; i++) {
371  PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
372  if (!THPVariable_Check(item)) {
373  throw TypeError(
374  "expected Tensor at %d (got %s)", (int)i, Py_TYPE(item)->tp_name);
375  }
376  auto var = (THPVariable*)item;
377  list.emplace_back(var->cdata.data());
378  }
379  return list;
380 }
A variant of OptionalDeviceGuard that is specialized for CUDA.
Definition: CUDAGuard.h:65
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
Flush-To-Zero and Denormals-Are-Zero mode.
void set_index(DeviceIndex device_index)
Sets the CUDA device to the given device index, initializing the guard if it is not already initializ...
Definition: CUDAGuard.h:97