1 #include <torch/csrc/utils/pybind.h> 2 #include <torch/csrc/cuda/comm.h> 3 #include <torch/csrc/cuda/Stream.h> 4 #include <torch/csrc/cuda/THCP.h> 5 #include <torch/csrc/utils/auto_gil.h> 6 #include <ATen/core/functional.h> 15 namespace torch {
namespace cuda {
namespace python {
16 void initCommMethods(PyObject *module) {
17 auto m = py::cast<py::module>(module);
19 "_broadcast_coalesced",
20 [](std::vector<at::Tensor>& tensors,
21 std::vector<int64_t> devices,
23 return broadcast_coalesced(tensors, devices, buffer_size);
27 py::arg(
"buffer_size"),
28 py::call_guard<py::gil_scoped_release>())
31 [](
at::Tensor& tensor, std::vector<int64_t> devices) {
32 return broadcast(tensor, devices);
34 py::call_guard<py::gil_scoped_release>())
38 std::vector<int64_t>& devices,
44 py::handle handle = *py_streams;
45 streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr());
49 return scatter(tensor, devices, chunk_sizes, dim, streams);
53 py::arg(
"chunk_sizes"),
58 [](std::vector<at::Tensor>& tensors,
61 return gather(tensors, dim, destination_index);
65 py::arg(
"destination_index"),
66 py::call_guard<py::gil_scoped_release>());