1 #include <torch/csrc/python_headers.h> 3 #include <c10d/FileStore.hpp> 4 #include <c10d/ProcessGroup.hpp> 5 #include <c10d/ProcessGroupGloo.hpp> 8 #include <c10d/ProcessGroupNCCL.hpp> 12 #include <c10d/ProcessGroupMPI.hpp> 15 #include <c10d/PrefixStore.hpp> 16 #include <c10d/TCPStore.hpp> 17 #include <gloo/transport/tcp/device.h> 18 #include <pybind11/chrono.h> 20 #include <torch/csrc/Exceptions.h> 21 #include <torch/csrc/distributed/c10d/ddp.h> 22 #include <torch/csrc/utils/object_ptr.h> 23 #include <torch/csrc/utils/pybind.h> 26 namespace distributed {
31 constexpr
char* GLOO_SOCKET_IFNAME_ENV =
"GLOO_SOCKET_IFNAME";
34 using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
36 PyObject* c10d_init(PyObject* _unused) {
37 auto c10d_module =
THPObjectPtr(PyImport_ImportModule(
"torch.distributed"));
42 auto module = py::handle(c10d_module).cast<py::module>();
44 py::enum_<::c10d::ReduceOp>(module,
"ReduceOp", R
"( 45 An enum-like class of available reduce operations: ``SUM``, ``PRODUCT``, 48 The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``. 49 They are used in specifying strategies for reduction collectives, e.g., 50 :func:`reduce`, :func:`all_reduce_multigpu`, etc.)") 51 .value("SUM", ::c10d::ReduceOp::SUM)
52 .value(
"PRODUCT", ::c10d::ReduceOp::PRODUCT)
53 .value(
"MIN", ::c10d::ReduceOp::MIN)
54 .value(
"MAX", ::c10d::ReduceOp::MAX);
56 py::class_<::c10d::BroadcastOptions>(module,
"BroadcastOptions")
58 .def_readwrite(
"rootRank", &::c10d::BroadcastOptions::rootRank)
59 .def_readwrite(
"rootTensor", &::c10d::BroadcastOptions::rootTensor)
60 .def_readwrite(
"timeout", &::c10d::BroadcastOptions::timeout);
62 py::class_<::c10d::AllreduceOptions>(module,
"AllreduceOptions")
64 .def_readwrite(
"reduceOp", &::c10d::AllreduceOptions::reduceOp)
65 .def_readwrite(
"timeout", &::c10d::AllreduceOptions::timeout);
67 py::class_<::c10d::ReduceOptions>(module,
"ReduceOptions")
69 .def_readwrite(
"reduceOp", &::c10d::ReduceOptions::reduceOp)
70 .def_readwrite(
"rootRank", &::c10d::ReduceOptions::rootRank)
71 .def_readwrite(
"rootTensor", &::c10d::ReduceOptions::rootTensor)
72 .def_readwrite(
"timeout", &::c10d::ReduceOptions::timeout);
74 py::class_<::c10d::AllgatherOptions>(module,
"AllgatherOptions")
76 .def_readwrite(
"timeout", &::c10d::AllgatherOptions::timeout);
78 py::class_<::c10d::GatherOptions>(module,
"GatherOptions")
80 .def_readwrite(
"rootRank", &::c10d::GatherOptions::rootRank)
81 .def_readwrite(
"timeout", &::c10d::GatherOptions::timeout);
83 py::class_<::c10d::ScatterOptions>(module,
"ScatterOptions")
85 .def_readwrite(
"rootRank", &::c10d::ScatterOptions::rootRank)
86 .def_readwrite(
"timeout", &::c10d::ScatterOptions::timeout);
88 py::class_<::c10d::BarrierOptions>(module,
"BarrierOptions")
90 .def_readwrite(
"timeout", &::c10d::BarrierOptions::timeout);
93 shared_ptr_class_<::c10d::Store>(module,
"Store")
97 [](::c10d::Store& store,
98 const std::string& key,
99 const std::string& value) {
100 std::vector<uint8_t> value_(value.begin(), value.end());
101 store.set(key, value_);
103 py::call_guard<py::gil_scoped_release>())
108 [](::c10d::Store& store,
const std::string& key) -> py::bytes {
109 auto value = store.get(key);
111 reinterpret_cast<char*>(value.data()), value.size());
113 py::call_guard<py::gil_scoped_release>())
117 py::call_guard<py::gil_scoped_release>())
120 &::c10d::Store::setTimeout,
121 py::call_guard<py::gil_scoped_release>())
124 [](::c10d::Store& store,
const std::vector<std::string>& keys) {
127 py::call_guard<py::gil_scoped_release>())
130 [](::c10d::Store& store,
131 const std::vector<std::string>& keys,
132 const std::chrono::milliseconds& timeout) {
133 store.wait(keys, timeout);
135 py::call_guard<py::gil_scoped_release>());
137 shared_ptr_class_<::c10d::FileStore>(module,
"FileStore", store)
138 .def(py::init<const std::string&, int>());
140 shared_ptr_class_<::c10d::TCPStore>(module,
"TCPStore", store)
141 .def(py::init<const std::string&, int, int, bool>());
143 shared_ptr_class_<::c10d::PrefixStore>(module,
"PrefixStore", store)
144 .def(py::init<const std::string&, ::c10d::Store&>());
147 shared_ptr_class_<::c10d::ProcessGroup>(module,
"ProcessGroup")
148 .def(
"rank", &::c10d::ProcessGroup::getRank)
149 .def(
"size", &::c10d::ProcessGroup::getSize)
153 &::c10d::ProcessGroup::broadcast,
155 py::arg(
"opts") = ::c10d::BroadcastOptions(),
156 py::call_guard<py::gil_scoped_release>())
160 [](::c10d::ProcessGroup& pg,
at::Tensor& x,
int rootRank) {
161 ::c10d::BroadcastOptions opts;
162 opts.rootRank = rootRank;
163 std::vector<at::Tensor> xs = {x};
164 return pg.broadcast(xs, opts);
168 py::call_guard<py::gil_scoped_release>())
172 &::c10d::ProcessGroup::allreduce,
174 py::arg(
"opts") = ::c10d::AllreduceOptions(),
175 py::call_guard<py::gil_scoped_release>())
179 [](::c10d::ProcessGroup& pg,
180 std::vector<at::Tensor>& xs,
181 ::c10d::ReduceOp op) {
182 ::c10d::AllreduceOptions opts;
184 return pg.allreduce(xs, opts);
187 py::arg(
"op") = ::c10d::ReduceOp::SUM,
188 py::call_guard<py::gil_scoped_release>())
192 [](::c10d::ProcessGroup& pg,
at::Tensor& x, ::c10d::ReduceOp op) {
193 ::c10d::AllreduceOptions opts;
195 std::vector<at::Tensor> xs = {x};
196 return pg.allreduce(xs, opts);
199 py::arg(
"op") = ::c10d::ReduceOp::SUM,
200 py::call_guard<py::gil_scoped_release>())
204 &::c10d::ProcessGroup::reduce,
206 py::arg(
"opts") = ::c10d::ReduceOptions(),
207 py::call_guard<py::gil_scoped_release>())
211 [](::c10d::ProcessGroup& pg,
214 ::c10d::ReduceOp op) {
215 ::c10d::ReduceOptions opts;
217 opts.rootRank = rootRank;
218 std::vector<at::Tensor> xs = {x};
219 return pg.reduce(xs, opts);
223 py::arg(
"op") = ::c10d::ReduceOp::SUM,
224 py::call_guard<py::gil_scoped_release>())
228 &::c10d::ProcessGroup::allgather,
229 py::arg(
"output_tensors"),
230 py::arg(
"input_tensors"),
231 py::arg(
"opts") = ::c10d::AllgatherOptions(),
232 py::call_guard<py::gil_scoped_release>())
236 [](::c10d::ProcessGroup& pg,
237 std::vector<at::Tensor>& output,
239 std::vector<std::vector<at::Tensor>> outputs = {output};
240 std::vector<at::Tensor> inputs = {input};
242 outputs, inputs, ::c10d::AllgatherOptions());
244 py::arg(
"output_tensors"),
245 py::arg(
"input_tensor"),
246 py::call_guard<py::gil_scoped_release>())
250 &::c10d::ProcessGroup::gather,
251 py::arg(
"output_tensors"),
252 py::arg(
"input_tensors"),
253 py::arg(
"opts") = ::c10d::GatherOptions(),
254 py::call_guard<py::gil_scoped_release>())
258 [](::c10d::ProcessGroup& pg,
259 std::vector<at::Tensor>& output,
262 ::c10d::GatherOptions opts;
263 opts.rootRank = rootRank;
264 std::vector<std::vector<at::Tensor>> outputs = {output};
265 std::vector<at::Tensor> inputs = {input};
266 return pg.gather(outputs, inputs, opts);
268 py::arg(
"output_tensors"),
269 py::arg(
"input_tensor"),
271 py::call_guard<py::gil_scoped_release>())
275 &::c10d::ProcessGroup::scatter,
276 py::arg(
"output_tensors"),
277 py::arg(
"input_tensors"),
278 py::arg(
"opts") = ::c10d::ScatterOptions(),
279 py::call_guard<py::gil_scoped_release>())
283 [](::c10d::ProcessGroup& pg,
285 std::vector<at::Tensor>& input,
287 ::c10d::ScatterOptions opts;
288 opts.rootRank = rootRank;
289 std::vector<std::vector<at::Tensor>> inputs = {input};
290 std::vector<at::Tensor> outputs = {output};
291 return pg.scatter(outputs, inputs, opts);
293 py::arg(
"output_tensor"),
294 py::arg(
"input_tensors"),
296 py::call_guard<py::gil_scoped_release>())
300 &::c10d::ProcessGroup::send,
301 py::call_guard<py::gil_scoped_release>())
305 &::c10d::ProcessGroup::recv,
306 py::call_guard<py::gil_scoped_release>())
310 &::c10d::ProcessGroup::recvAnysource,
311 py::call_guard<py::gil_scoped_release>())
315 &::c10d::ProcessGroup::barrier,
316 py::arg(
"opts") = ::c10d::BarrierOptions(),
317 py::call_guard<py::gil_scoped_release>())
321 &::c10d::ProcessGroup::barrier,
322 py::arg(
"opts") = ::c10d::BarrierOptions(),
323 py::call_guard<py::gil_scoped_release>())
327 &::c10d::ProcessGroup::getGroupRank,
328 py::call_guard<py::gil_scoped_release>());
330 auto processGroupGloo = shared_ptr_class_<::c10d::ProcessGroupGloo>(
331 module,
"ProcessGroupGloo", processGroup);
333 shared_ptr_class_<::gloo::transport::Device>(processGroupGloo,
"Device");
335 shared_ptr_class_<::c10d::ProcessGroupGloo::Options>(
336 processGroupGloo,
"Options")
338 .def_readwrite(
"devices", &::c10d::ProcessGroupGloo::Options::devices)
339 .def_readwrite(
"timeout", &::c10d::ProcessGroupGloo::Options::timeout)
340 .def_readwrite(
"threads", &::c10d::ProcessGroupGloo::Options::threads);
342 processGroupGloo.def_static(
344 [](
const std::string& hostname,
const std::string& interface)
345 -> std::shared_ptr<::gloo::transport::Device> {
346 ::gloo::transport::tcp::attr attr;
347 if (!hostname.empty()) {
348 attr.hostname = hostname;
349 }
else if (!interface.empty()) {
350 attr.iface = interface;
356 return ::gloo::transport::tcp::CreateDevice(attr);
358 py::arg(
"hostname") =
"",
359 py::arg(
"interface") =
"");
363 const std::shared_ptr<::c10d::Store>&,
366 ::c10d::ProcessGroupGloo::Options>())
368 py::init([](
const std::shared_ptr<::c10d::Store>& store,
371 std::chrono::milliseconds timeout) {
372 ::c10d::ProcessGroupGloo::Options options;
373 ::gloo::transport::tcp::attr attr;
376 char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV);
378 attr.iface = std::string(ifnameEnv);
383 std::array<char, HOST_NAME_MAX> hostname{};
384 auto rv = gethostname(hostname.data(), hostname.size());
386 throw std::system_error(errno, std::system_category());
388 attr.hostname = hostname.data();
390 options.devices.push_back(
391 ::gloo::transport::tcp::CreateDevice(attr));
392 options.timeout = timeout;
393 return std::make_shared<::c10d::ProcessGroupGloo>(
394 store, rank, size, options);
399 py::arg(
"timeout") = std::chrono::milliseconds(10 * 1000));
402 shared_ptr_class_<::c10d::ProcessGroupNCCL>(
403 module,
"ProcessGroupNCCL", processGroup)
406 const std::shared_ptr<::c10d::Store>&,
409 const std::string&>(),
413 py::arg(
"groupName") =
"");
417 shared_ptr_class_<::c10d::ProcessGroupMPI>(
418 module,
"ProcessGroupMPI", processGroup)
419 .def(py::init([](std::vector<int> ranks) {
420 return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks);
424 shared_ptr_class_<::c10d::ProcessGroup::Work>(module,
"Work")
425 .def(
"is_completed", &::c10d::ProcessGroup::Work::isCompleted)
426 .def(
"is_success", &::c10d::ProcessGroup::Work::isSuccess)
427 .def(
"exception", &::c10d::ProcessGroup::Work::exception)
428 .def(
"source_rank", &::c10d::ProcessGroup::Work::sourceRank)
429 .def(
"synchronize", &::c10d::ProcessGroup::Work::synchronize)
432 &::c10d::ProcessGroup::Work::wait,
433 py::call_guard<py::gil_scoped_release>());
437 "_dist_bucket_tensors",
438 &::c10d::bucketTensors,
440 py::arg(
"bucket_size"),
441 py::arg(
"fine_grained"),
442 py::call_guard<py::gil_scoped_release>());
445 "_dist_broadcast_coalesced",
446 &::c10d::distBroadcastCoalesced,
447 py::arg(
"process_group"),
449 py::arg(
"buffer_size"),
450 py::arg(
"fine_grained"),
451 py::call_guard<py::gil_scoped_release>());
456 py::arg(
"process_group"),
457 py::arg(
"parameter_data"),
458 py::arg(
"buffer_data"),
460 py::arg(
"broadcast_bucket_size"),
461 py::arg(
"broadcast_buffers"),
462 py::call_guard<py::gil_scoped_release>());
466 &::c10d::queueReduction,
467 py::arg(
"process_group"),
468 py::arg(
"grads_batch"),
470 py::call_guard<py::gil_scoped_release>());
474 &::c10d::syncReduction,
475 py::arg(
"reduction_work"),
476 py::arg(
"grads_batch"),
477 py::arg(
"grads_batch_coalesced"),
478 py::call_guard<py::gil_scoped_release>());
487 static PyMethodDef methods[] = {
488 {
"_c10d_init", (PyCFunction)c10d_init, METH_NOARGS,
nullptr},
489 {
nullptr,
nullptr, 0,
nullptr}};
491 PyMethodDef* python_functions() {