Caffe2 - C++ API
A deep learning, cross platform ML framework
init.cpp
1 #include <torch/csrc/python_headers.h>
2 
3 #include <c10d/FileStore.hpp>
4 #include <c10d/ProcessGroup.hpp>
5 #include <c10d/ProcessGroupGloo.hpp>
6 
7 #ifdef USE_C10D_NCCL
8 #include <c10d/ProcessGroupNCCL.hpp>
9 #endif
10 
11 #ifdef USE_C10D_MPI
12 #include <c10d/ProcessGroupMPI.hpp>
13 #endif
14 
15 #include <c10d/PrefixStore.hpp>
16 #include <c10d/TCPStore.hpp>
17 #include <gloo/transport/tcp/device.h>
18 #include <pybind11/chrono.h>
19 
20 #include <torch/csrc/Exceptions.h>
21 #include <torch/csrc/distributed/c10d/ddp.h>
22 #include <torch/csrc/utils/object_ptr.h>
23 #include <torch/csrc/utils/pybind.h>
24 
25 namespace torch {
26 namespace distributed {
27 namespace c10d {
28 
29 namespace {
30 
31 constexpr char* GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
32 
33 template <typename T>
34 using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
35 
36 PyObject* c10d_init(PyObject* _unused) {
37  auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed"));
38  if (!c10d_module) {
39  throw python_error();
40  }
41 
42  auto module = py::handle(c10d_module).cast<py::module>();
43 
44  py::enum_<::c10d::ReduceOp>(module, "ReduceOp", R"(
45 An enum-like class of available reduce operations: ``SUM``, ``PRODUCT``,
46 ``MIN``, and ``MAX``.
47 
48 The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``.
49 They are used in specifying strategies for reduction collectives, e.g.,
50 :func:`reduce`, :func:`all_reduce_multigpu`, etc.)")
51  .value("SUM", ::c10d::ReduceOp::SUM)
52  .value("PRODUCT", ::c10d::ReduceOp::PRODUCT)
53  .value("MIN", ::c10d::ReduceOp::MIN)
54  .value("MAX", ::c10d::ReduceOp::MAX);
55 
56  py::class_<::c10d::BroadcastOptions>(module, "BroadcastOptions")
57  .def(py::init<>())
58  .def_readwrite("rootRank", &::c10d::BroadcastOptions::rootRank)
59  .def_readwrite("rootTensor", &::c10d::BroadcastOptions::rootTensor)
60  .def_readwrite("timeout", &::c10d::BroadcastOptions::timeout);
61 
62  py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions")
63  .def(py::init<>())
64  .def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp)
65  .def_readwrite("timeout", &::c10d::AllreduceOptions::timeout);
66 
67  py::class_<::c10d::ReduceOptions>(module, "ReduceOptions")
68  .def(py::init<>())
69  .def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp)
70  .def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank)
71  .def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor)
72  .def_readwrite("timeout", &::c10d::ReduceOptions::timeout);
73 
74  py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions")
75  .def(py::init<>())
76  .def_readwrite("timeout", &::c10d::AllgatherOptions::timeout);
77 
78  py::class_<::c10d::GatherOptions>(module, "GatherOptions")
79  .def(py::init<>())
80  .def_readwrite("rootRank", &::c10d::GatherOptions::rootRank)
81  .def_readwrite("timeout", &::c10d::GatherOptions::timeout);
82 
83  py::class_<::c10d::ScatterOptions>(module, "ScatterOptions")
84  .def(py::init<>())
85  .def_readwrite("rootRank", &::c10d::ScatterOptions::rootRank)
86  .def_readwrite("timeout", &::c10d::ScatterOptions::timeout);
87 
88  py::class_<::c10d::BarrierOptions>(module, "BarrierOptions")
89  .def(py::init<>())
90  .def_readwrite("timeout", &::c10d::BarrierOptions::timeout);
91 
92  auto store =
93  shared_ptr_class_<::c10d::Store>(module, "Store")
94  // Convert from std::string to std::vector<uint8>.
95  .def(
96  "set",
97  [](::c10d::Store& store,
98  const std::string& key,
99  const std::string& value) {
100  std::vector<uint8_t> value_(value.begin(), value.end());
101  store.set(key, value_);
102  },
103  py::call_guard<py::gil_scoped_release>())
104  // Convert from std::vector<uint8_t> to py::bytes.
105  // The returned value is not guaranteed to be valid UTF-8.
106  .def(
107  "get",
108  [](::c10d::Store& store, const std::string& key) -> py::bytes {
109  auto value = store.get(key);
110  return py::bytes(
111  reinterpret_cast<char*>(value.data()), value.size());
112  },
113  py::call_guard<py::gil_scoped_release>())
114  .def(
115  "add",
116  &::c10d::Store::add,
117  py::call_guard<py::gil_scoped_release>())
118  .def(
119  "set_timeout",
120  &::c10d::Store::setTimeout,
121  py::call_guard<py::gil_scoped_release>())
122  .def(
123  "wait",
124  [](::c10d::Store& store, const std::vector<std::string>& keys) {
125  store.wait(keys);
126  },
127  py::call_guard<py::gil_scoped_release>())
128  .def(
129  "wait",
130  [](::c10d::Store& store,
131  const std::vector<std::string>& keys,
132  const std::chrono::milliseconds& timeout) {
133  store.wait(keys, timeout);
134  },
135  py::call_guard<py::gil_scoped_release>());
136 
137  shared_ptr_class_<::c10d::FileStore>(module, "FileStore", store)
138  .def(py::init<const std::string&, int>());
139 
140  shared_ptr_class_<::c10d::TCPStore>(module, "TCPStore", store)
141  .def(py::init<const std::string&, int, int, bool>());
142 
143  shared_ptr_class_<::c10d::PrefixStore>(module, "PrefixStore", store)
144  .def(py::init<const std::string&, ::c10d::Store&>());
145 
146  auto processGroup =
147  shared_ptr_class_<::c10d::ProcessGroup>(module, "ProcessGroup")
148  .def("rank", &::c10d::ProcessGroup::getRank)
149  .def("size", &::c10d::ProcessGroup::getSize)
150 
151  .def(
152  "broadcast",
153  &::c10d::ProcessGroup::broadcast,
154  py::arg("tensors"),
155  py::arg("opts") = ::c10d::BroadcastOptions(),
156  py::call_guard<py::gil_scoped_release>())
157 
158  .def(
159  "broadcast",
160  [](::c10d::ProcessGroup& pg, at::Tensor& x, int rootRank) {
161  ::c10d::BroadcastOptions opts;
162  opts.rootRank = rootRank;
163  std::vector<at::Tensor> xs = {x};
164  return pg.broadcast(xs, opts);
165  },
166  py::arg("tensor"),
167  py::arg("root"),
168  py::call_guard<py::gil_scoped_release>())
169 
170  .def(
171  "allreduce",
172  &::c10d::ProcessGroup::allreduce,
173  py::arg("tensors"),
174  py::arg("opts") = ::c10d::AllreduceOptions(),
175  py::call_guard<py::gil_scoped_release>())
176 
177  .def(
178  "allreduce",
179  [](::c10d::ProcessGroup& pg,
180  std::vector<at::Tensor>& xs,
181  ::c10d::ReduceOp op) {
182  ::c10d::AllreduceOptions opts;
183  opts.reduceOp = op;
184  return pg.allreduce(xs, opts);
185  },
186  py::arg("tensors"),
187  py::arg("op") = ::c10d::ReduceOp::SUM,
188  py::call_guard<py::gil_scoped_release>())
189 
190  .def(
191  "allreduce",
192  [](::c10d::ProcessGroup& pg, at::Tensor& x, ::c10d::ReduceOp op) {
193  ::c10d::AllreduceOptions opts;
194  opts.reduceOp = op;
195  std::vector<at::Tensor> xs = {x};
196  return pg.allreduce(xs, opts);
197  },
198  py::arg("tensor"),
199  py::arg("op") = ::c10d::ReduceOp::SUM,
200  py::call_guard<py::gil_scoped_release>())
201 
202  .def(
203  "reduce",
204  &::c10d::ProcessGroup::reduce,
205  py::arg("tensors"),
206  py::arg("opts") = ::c10d::ReduceOptions(),
207  py::call_guard<py::gil_scoped_release>())
208 
209  .def(
210  "reduce",
211  [](::c10d::ProcessGroup& pg,
212  at::Tensor& x,
213  int rootRank,
214  ::c10d::ReduceOp op) {
215  ::c10d::ReduceOptions opts;
216  opts.reduceOp = op;
217  opts.rootRank = rootRank;
218  std::vector<at::Tensor> xs = {x};
219  return pg.reduce(xs, opts);
220  },
221  py::arg("tensor"),
222  py::arg("root"),
223  py::arg("op") = ::c10d::ReduceOp::SUM,
224  py::call_guard<py::gil_scoped_release>())
225 
226  .def(
227  "allgather",
228  &::c10d::ProcessGroup::allgather,
229  py::arg("output_tensors"),
230  py::arg("input_tensors"),
231  py::arg("opts") = ::c10d::AllgatherOptions(),
232  py::call_guard<py::gil_scoped_release>())
233 
234  .def(
235  "allgather",
236  [](::c10d::ProcessGroup& pg,
237  std::vector<at::Tensor>& output,
238  at::Tensor& input) {
239  std::vector<std::vector<at::Tensor>> outputs = {output};
240  std::vector<at::Tensor> inputs = {input};
241  return pg.allgather(
242  outputs, inputs, ::c10d::AllgatherOptions());
243  },
244  py::arg("output_tensors"),
245  py::arg("input_tensor"),
246  py::call_guard<py::gil_scoped_release>())
247 
248  .def(
249  "gather",
250  &::c10d::ProcessGroup::gather,
251  py::arg("output_tensors"),
252  py::arg("input_tensors"),
253  py::arg("opts") = ::c10d::GatherOptions(),
254  py::call_guard<py::gil_scoped_release>())
255 
256  .def(
257  "gather",
258  [](::c10d::ProcessGroup& pg,
259  std::vector<at::Tensor>& output,
260  at::Tensor& input,
261  int rootRank) {
262  ::c10d::GatherOptions opts;
263  opts.rootRank = rootRank;
264  std::vector<std::vector<at::Tensor>> outputs = {output};
265  std::vector<at::Tensor> inputs = {input};
266  return pg.gather(outputs, inputs, opts);
267  },
268  py::arg("output_tensors"),
269  py::arg("input_tensor"),
270  py::arg("root"),
271  py::call_guard<py::gil_scoped_release>())
272 
273  .def(
274  "scatter",
275  &::c10d::ProcessGroup::scatter,
276  py::arg("output_tensors"),
277  py::arg("input_tensors"),
278  py::arg("opts") = ::c10d::ScatterOptions(),
279  py::call_guard<py::gil_scoped_release>())
280 
281  .def(
282  "scatter",
283  [](::c10d::ProcessGroup& pg,
284  at::Tensor& output,
285  std::vector<at::Tensor>& input,
286  int rootRank) {
287  ::c10d::ScatterOptions opts;
288  opts.rootRank = rootRank;
289  std::vector<std::vector<at::Tensor>> inputs = {input};
290  std::vector<at::Tensor> outputs = {output};
291  return pg.scatter(outputs, inputs, opts);
292  },
293  py::arg("output_tensor"),
294  py::arg("input_tensors"),
295  py::arg("root"),
296  py::call_guard<py::gil_scoped_release>())
297 
298  .def(
299  "send",
300  &::c10d::ProcessGroup::send,
301  py::call_guard<py::gil_scoped_release>())
302 
303  .def(
304  "recv",
305  &::c10d::ProcessGroup::recv,
306  py::call_guard<py::gil_scoped_release>())
307 
308  .def(
309  "recv_anysource",
310  &::c10d::ProcessGroup::recvAnysource,
311  py::call_guard<py::gil_scoped_release>())
312 
313  .def(
314  "abort",
315  &::c10d::ProcessGroup::barrier,
316  py::arg("opts") = ::c10d::BarrierOptions(),
317  py::call_guard<py::gil_scoped_release>())
318 
319  .def(
320  "barrier",
321  &::c10d::ProcessGroup::barrier,
322  py::arg("opts") = ::c10d::BarrierOptions(),
323  py::call_guard<py::gil_scoped_release>())
324 
325  .def(
326  "group_ranks",
327  &::c10d::ProcessGroup::getGroupRank,
328  py::call_guard<py::gil_scoped_release>());
329 
330  auto processGroupGloo = shared_ptr_class_<::c10d::ProcessGroupGloo>(
331  module, "ProcessGroupGloo", processGroup);
332 
333  shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device");
334 
335  shared_ptr_class_<::c10d::ProcessGroupGloo::Options>(
336  processGroupGloo, "Options")
337  .def(py::init<>())
338  .def_readwrite("devices", &::c10d::ProcessGroupGloo::Options::devices)
339  .def_readwrite("timeout", &::c10d::ProcessGroupGloo::Options::timeout)
340  .def_readwrite("threads", &::c10d::ProcessGroupGloo::Options::threads);
341 
342  processGroupGloo.def_static(
343  "create_tcp_device",
344  [](const std::string& hostname, const std::string& interface)
345  -> std::shared_ptr<::gloo::transport::Device> {
346  ::gloo::transport::tcp::attr attr;
347  if (!hostname.empty()) {
348  attr.hostname = hostname;
349  } else if (!interface.empty()) {
350  attr.iface = interface;
351  } else {
352  // Neither argument is specified; Gloo itself will use the
353  // hostname
354  // Nothing specified, default to something useful
355  }
356  return ::gloo::transport::tcp::CreateDevice(attr);
357  },
358  py::arg("hostname") = "",
359  py::arg("interface") = "");
360 
361  processGroupGloo
362  .def(py::init<
363  const std::shared_ptr<::c10d::Store>&,
364  int,
365  int,
366  ::c10d::ProcessGroupGloo::Options>())
367  .def(
368  py::init([](const std::shared_ptr<::c10d::Store>& store,
369  int rank,
370  int size,
371  std::chrono::milliseconds timeout) {
372  ::c10d::ProcessGroupGloo::Options options;
373  ::gloo::transport::tcp::attr attr;
374  // First step, check "GLOO_SOCKET_IFNAME" environmental variable
375  // that can be set by the user
376  char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV);
377  if (ifnameEnv) {
378  attr.iface = std::string(ifnameEnv);
379  } else {
380  // Use the hostname to resolve the network address to
381  // use. Note: if the hostname does not resolve to an address (e.g.
382  // because of misconfigured /etc/hosts file), this will not work.
383  std::array<char, HOST_NAME_MAX> hostname{};
384  auto rv = gethostname(hostname.data(), hostname.size());
385  if (rv != 0) {
386  throw std::system_error(errno, std::system_category());
387  }
388  attr.hostname = hostname.data();
389  }
390  options.devices.push_back(
391  ::gloo::transport::tcp::CreateDevice(attr));
392  options.timeout = timeout;
393  return std::make_shared<::c10d::ProcessGroupGloo>(
394  store, rank, size, options);
395  }),
396  py::arg("store"),
397  py::arg("rank"),
398  py::arg("size"),
399  py::arg("timeout") = std::chrono::milliseconds(10 * 1000));
400 
401 #ifdef USE_C10D_NCCL
402  shared_ptr_class_<::c10d::ProcessGroupNCCL>(
403  module, "ProcessGroupNCCL", processGroup)
404  .def(
405  py::init<
406  const std::shared_ptr<::c10d::Store>&,
407  int,
408  int,
409  const std::string&>(),
410  py::arg("store"),
411  py::arg("rank"),
412  py::arg("size"),
413  py::arg("groupName") = "");
414 #endif
415 
416 #ifdef USE_C10D_MPI
417  shared_ptr_class_<::c10d::ProcessGroupMPI>(
418  module, "ProcessGroupMPI", processGroup)
419  .def(py::init([](std::vector<int> ranks) {
420  return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks);
421  }));
422 #endif
423 
424  shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work")
425  .def("is_completed", &::c10d::ProcessGroup::Work::isCompleted)
426  .def("is_success", &::c10d::ProcessGroup::Work::isSuccess)
427  .def("exception", &::c10d::ProcessGroup::Work::exception)
428  .def("source_rank", &::c10d::ProcessGroup::Work::sourceRank)
429  .def("synchronize", &::c10d::ProcessGroup::Work::synchronize)
430  .def(
431  "wait",
432  &::c10d::ProcessGroup::Work::wait,
433  py::call_guard<py::gil_scoped_release>());
434 
435 #ifdef USE_CUDA
436  module.def(
437  "_dist_bucket_tensors",
438  &::c10d::bucketTensors,
439  py::arg("tensors"),
440  py::arg("bucket_size"),
441  py::arg("fine_grained"),
442  py::call_guard<py::gil_scoped_release>());
443 
444  module.def(
445  "_dist_broadcast_coalesced",
446  &::c10d::distBroadcastCoalesced,
447  py::arg("process_group"),
448  py::arg("tensors"),
449  py::arg("buffer_size"),
450  py::arg("fine_grained"),
451  py::call_guard<py::gil_scoped_release>());
452 
453  module.def(
454  "_sync_params",
455  &::c10d::syncParams,
456  py::arg("process_group"),
457  py::arg("parameter_data"),
458  py::arg("buffer_data"),
459  py::arg("devices"),
460  py::arg("broadcast_bucket_size"),
461  py::arg("broadcast_buffers"),
462  py::call_guard<py::gil_scoped_release>());
463 
464  module.def(
465  "_queue_reduction",
466  &::c10d::queueReduction,
467  py::arg("process_group"),
468  py::arg("grads_batch"),
469  py::arg("devices"),
470  py::call_guard<py::gil_scoped_release>());
471 
472  module.def(
473  "_sync_reduction",
474  &::c10d::syncReduction,
475  py::arg("reduction_work"),
476  py::arg("grads_batch"),
477  py::arg("grads_batch_coalesced"),
478  py::call_guard<py::gil_scoped_release>());
479 #endif
480 
481  Py_RETURN_TRUE;
482 }
483 
484 } // namespace
485 
486 // c10d methods on torch._C
487 static PyMethodDef methods[] = {
488  {"_c10d_init", (PyCFunction)c10d_init, METH_NOARGS, nullptr},
489  {nullptr, nullptr, 0, nullptr}};
490 
491 PyMethodDef* python_functions() {
492  return methods;
493 }
494 
495 } // namespace c10d
496 } // namespace distributed
497 } // namespace torch
Definition: ddp.cpp:21
Definition: jit_type.h:17