1 #include <torch/csrc/python_headers.h> 4 #include <unordered_map> 7 #include <torch/csrc/utils/python_strings.h> 8 #include <torch/csrc/distributed/THDP.h> 9 #include <torch/csrc/PythonTypes.h> 10 #include <torch/csrc/autograd/python_variable.h> 13 #include <torch/csrc/cuda/Stream.h> 17 static std::unordered_map<std::string, THDChannelType> name2channel_type = {
18 {
"mpi", THDChannelMPI},
19 {
"tcp", THDChannelTCP},
20 {
"gloo", THDChannelGloo},
21 {
"nccl", THDChannelNccl},
24 static std::unordered_map<PyObject*, THDReduceOp> obj2reduceop;
25 static std::unordered_map<PyObject*, THDGroup> obj2group;
28 extern THCState* state;
31 PyObject* THDPModule_initProcessGroup(PyObject *_unused, PyObject *args)
34 if (PyTuple_GET_SIZE(args) != 5 || !THPUtils_checkString(PyTuple_GET_ITEM(args, 0)) ||
35 !THPUtils_checkString(PyTuple_GET_ITEM(args, 1)) ||
36 !THPUtils_checkLong(PyTuple_GET_ITEM(args, 2)) ||
37 !THPUtils_checkString(PyTuple_GET_ITEM(args, 3)) ||
38 !THPUtils_checkLong(PyTuple_GET_ITEM(args, 4))) {
39 THPUtils_invalidArguments(args,
nullptr,
"init_process_group", 1,
"(string backend, string init_method, int world_size, string group_name, int rank)");
43 std::string backend_name = THPUtils_unpackString(PyTuple_GET_ITEM(args, 0));
44 std::string init_method = THPUtils_unpackString(PyTuple_GET_ITEM(args, 1));
45 int world_size = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 2));
46 std::string group_name = THPUtils_unpackString(PyTuple_GET_ITEM(args, 3));
47 int rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 4));
49 THDChannelType channel_type = name2channel_type.at(backend_name);
52 THDProcessGroupInit(channel_type, init_method, world_size, group_name, rank);
55 THDSetCudaStatePtr(&state);
61 PyObject* THDPModule_destroyProcessGroup(PyObject *_unused) {
65 THDProcessGroupDestroy();
72 PyObject* THDPModule_registerStream(PyObject *_unused, PyObject *_stream)
75 THPUtils_assert(THCPStream_Check(_stream),
"_register_stream expects a " 76 "torch.cuda.Stream object");
78 THDRegisterCudaStream(stream->cuda_stream);
84 PyObject* THDPModule_getRank(PyObject *_unused)
87 return PyInt_FromLong(THDGetRank());
91 PyObject* THDPModule_getNumProcesses(PyObject *_unused)
94 return PyInt_FromLong(THDGetNumProcesses());
99 extern PyObject* THCPDoubleTensorClass;
100 extern PyObject* THCPFloatTensorClass;
101 extern PyObject* THCPHalfTensorClass;
102 extern PyObject* THCPLongTensorClass;
103 extern PyObject* THCPIntTensorClass;
104 extern PyObject* THCPShortTensorClass;
105 extern PyObject* THCPCharTensorClass;
106 extern PyObject* THCPByteTensorClass;
111 return var->cdata.data();
114 static THDRequest* _unpackRequest(PyObject *obj)
116 return static_cast<THDRequest*
>(THPWrapper_get(obj));
119 static THDReduceOp _getReduceOp(PyObject *obj)
121 auto it = obj2reduceop.find(obj);
122 if (it == obj2reduceop.end()) {
123 throw std::runtime_error(
"op should be a constant from " 124 "torch.distributed.deprecated.reduce_op");
129 static THDGroup _getGroup(PyObject *obj)
131 auto it = obj2group.find(obj);
132 if (it == obj2group.end()) {
133 if (!THPUtils_checkLong(obj))
134 throw std::runtime_error(
"group should be an int or one of the values " 135 "from torch.distributed.deprecated.group");
136 return THPUtils_unpackLong(obj);
141 PyObject* THDPModule_clearGroupCache(PyObject *_unused, PyObject *args) {
143 if (PyTuple_GET_SIZE(args) != 1) {
144 THPUtils_invalidArguments(args,
nullptr,
"clear_group_cache", 1,
"(group gr)");
148 THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 0));
152 THDClearGroupCache(group);
158 PyObject* THDPModule_isend(PyObject *_unused, PyObject *args)
161 if (PyTuple_GET_SIZE(args) != 2 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
162 !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
163 THPUtils_invalidArguments(args,
nullptr,
"isend", 1,
"(tensor input, int dst_rank)");
167 auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
168 int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
172 req = THDIsend(desc, dst_rank);
174 return THPWrapper_New(req, (
void(*)(
void*))THDRequest_free);
178 PyObject* THDPModule_irecv(PyObject *_unused, PyObject *args)
181 if (PyTuple_GET_SIZE(args) != 2 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
182 !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
183 THPUtils_invalidArguments(args,
nullptr,
"irecv", 1,
"(tensor output, int src_rank)");
187 auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
188 int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
192 req = THDIrecv(desc, src_rank);
194 return THPWrapper_New(req, (
void(*)(
void*))THDRequest_free);
198 PyObject* THDPModule_send(PyObject *_unused, PyObject *args)
201 if (PyTuple_GET_SIZE(args) != 2 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
202 !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
203 THPUtils_invalidArguments(args,
nullptr,
"send", 1,
"(tensor input, int dst_rank)");
207 auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
208 int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
211 THDSend(desc, dst_rank);
217 PyObject* THDPModule_recvAnySource(PyObject *_unused, PyObject *_tensor)
220 if (!THPVariable_Check(_tensor)) {
221 THPUtils_invalidArguments(_tensor,
nullptr,
"recv", 1,
"(tensor output)");
225 auto desc = THDPModule_makeDescriptor(_tensor);
229 sender = THDRecvAnySource(desc);
231 return PyInt_FromLong(sender);
235 PyObject* THDPModule_recv(PyObject *_unused, PyObject *args)
238 if (PyTuple_GET_SIZE(args) != 2 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
239 !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
240 THPUtils_invalidArguments(args,
nullptr,
"recv", 1,
"(tensor output, int src_rank)");
244 auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
245 int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
248 THDRecv(desc, src_rank);
251 Py_INCREF(PyTuple_GET_ITEM(args, 1));
252 return PyTuple_GET_ITEM(args, 1);
257 PyObject* THDPModule_allReduceMultiGPU(PyObject *_unused, PyObject *args)
260 std::vector<at::Tensor> descriptors;
266 if (PyTuple_GET_SIZE(args) != 3) {
267 goto invalid_arguments;
270 if (!PySequence_Check(PyTuple_GET_ITEM(args, 0))) {
271 goto invalid_arguments;
274 sequence =
THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
275 "expected a sequence"));
276 if (!sequence.get()) {
277 goto invalid_arguments;
280 length =
static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
282 descriptors.reserve(length);
284 for (
size_t i = 0; i < length; ++i) {
285 if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence.get(), i))) {
286 goto invalid_arguments;
289 descriptors.push_back(
290 THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence.get(), i))
294 group = _getGroup(PyTuple_GET_ITEM(args, 2));
295 op = _getReduceOp(PyTuple_GET_ITEM(args, 1));
299 THDAllReduceMultiGPU(descriptors.data(), length, op, group);
304 THPUtils_invalidArguments(args,
nullptr,
"all_reduce_multigpu", 1,
305 "(list[tensor] in_out, reduce_op op, group gr)");
311 PyObject* THDPModule_reduceMultiGPU(PyObject *_unused, PyObject *args)
316 std::vector<at::Tensor> descriptors;
321 if (PyTuple_GET_SIZE(args) != 4) {
322 goto invalid_arguments;
325 if (!PySequence_Check(PyTuple_GET_ITEM(args, 0)) ||
326 !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
327 goto invalid_arguments;
330 sequence =
THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
331 "expected a sequence"));
332 if (!sequence.get()) {
333 goto invalid_arguments;
336 length =
static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
338 descriptors.reserve(length);
340 for (
size_t i = 0; i < length; ++i) {
341 if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence.get(), i))) {
342 goto invalid_arguments;
345 descriptors.push_back(
346 THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence.get(), i))
350 group = _getGroup(PyTuple_GET_ITEM(args, 3));
351 op = _getReduceOp(PyTuple_GET_ITEM(args, 2));
352 dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
356 THDReduceMultiGPU(descriptors.data(), length, op, dst_rank, group);
361 THPUtils_invalidArguments(args,
nullptr,
"reduce_multigpu", 1,
362 "(list[tensor] in_out, int dst_rank, " 363 "reduce_op op, group gr)");
369 PyObject* THDPModule_broadcastMultiGPU(PyObject *_unused, PyObject *args)
374 std::vector<at::Tensor> descriptors;
378 if (PyTuple_GET_SIZE(args) != 3) {
379 goto invalid_arguments;
382 if (!PySequence_Check(PyTuple_GET_ITEM(args, 0)) ||
383 !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
384 goto invalid_arguments;
387 sequence =
THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
388 "expected a sequence"));
389 if (!sequence.get()) {
390 goto invalid_arguments;
393 length =
static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
395 descriptors.reserve(length);
397 for (
size_t i = 0; i < length; ++i) {
398 if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence.get(), i))) {
399 goto invalid_arguments;
402 descriptors.push_back(
403 THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence.get(), i))
407 group = _getGroup(PyTuple_GET_ITEM(args, 2));
408 src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
412 THDBroadcastMultiGPU(descriptors.data(), length, src_rank, group);
417 THPUtils_invalidArguments(args,
nullptr,
"broadcast_multigpu", 1,
418 "(list[tensor] in_out, int src_rank, group gr)");
424 PyObject* THDPModule_allGatherMultiGPU(PyObject *_unused, PyObject *args)
433 std::vector<at::Tensor> output_descriptors;
434 std::vector<at::Tensor> input_descriptors;
438 if (PyTuple_GET_SIZE(args) != 3) {
439 goto invalid_arguments;
442 if (!PySequence_Check(PyTuple_GET_ITEM(args, 0)) ||
443 !PySequence_Check(PyTuple_GET_ITEM(args, 1))) {
444 goto invalid_arguments;
447 sequence_one =
THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
448 "expected a sequence"));
449 sequence_two =
THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 1),
450 "expected a sequence"));
452 if (!sequence_one.get() || !sequence_two.get()) {
453 goto invalid_arguments;
456 length_one =
static_cast<size_t>(
457 PySequence_Fast_GET_SIZE(sequence_one.get()));
459 length_two =
static_cast<size_t>(
460 PySequence_Fast_GET_SIZE(sequence_two.get()));
462 if (length_one != length_two) {
463 goto invalid_arguments;
466 output_descriptors.reserve(length_one);
467 input_descriptors.reserve(length_two);
470 for (
size_t i = 0; i < length_two; ++i) {
471 if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence_two.get(), i)) ||
472 !THPVariable_Check(PySequence_Fast_GET_ITEM(sequence_one.get(), i))) {
473 goto invalid_arguments;
476 input_descriptors.push_back(
477 THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence_two.get(), i))
480 output_descriptors.push_back(
481 THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence_one.get(), i))
485 group = _getGroup(PyTuple_GET_ITEM(args, 2));
489 THDAllGatherMultiGPU(output_descriptors.data(),
491 input_descriptors.data(),
499 THPUtils_invalidArguments(args,
nullptr,
"all_gather_multigpu", 1,
500 "(list[list[tensor]] output, list[tensor] input, group gr)");
506 PyObject* THDPModule_allReduce(PyObject *_unused, PyObject *args)
509 if (PyTuple_GET_SIZE(args) != 3 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0))) {
510 THPUtils_invalidArguments(args,
nullptr,
"all_reduce", 1,
"(tensor in_out, reduce_op op, group gr)");
514 THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
515 THDReduceOp op = _getReduceOp(PyTuple_GET_ITEM(args, 1));
516 auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
519 THDAllReduce(desc, op, group);
525 PyObject* THDPModule_reduce(PyObject *_unused, PyObject *args)
528 if (PyTuple_GET_SIZE(args) != 4 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
529 !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
530 THPUtils_invalidArguments(args,
nullptr,
"reduce", 1,
531 "(tensor reduced, int dst_rank, reduce_op op, group gr)");
535 THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 3));
536 THDReduceOp op = _getReduceOp(PyTuple_GET_ITEM(args, 2));
537 auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
538 int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
541 THDReduce(desc, op, dst_rank, group);
547 PyObject* THDPModule_broadcast(PyObject *_unused, PyObject *args)
550 if (PyTuple_GET_SIZE(args) != 3 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
551 !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
552 THPUtils_invalidArguments(args,
nullptr,
"broadcast", 1,
553 "(tensor src_dst, int src_rank, group gr)");
557 THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
558 auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
559 int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
562 THDBroadcast(desc, src_rank, group);
568 PyObject* THDPModule_allGather(PyObject *_unused, PyObject *args)
573 std::vector<at::Tensor> descriptors;
577 if (PyTuple_GET_SIZE(args) != 3 ||
578 !PySequence_Check(PyTuple_GET_ITEM(args, 0)) ||
579 !THPVariable_Check(PyTuple_GET_ITEM(args, 1))) {
581 goto invalid_arguments;
584 sequence =
THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
585 "expected a sequence"));
586 if (!sequence.get()) {
587 goto invalid_arguments;
590 length =
static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
592 descriptors.reserve(length);
594 for (
size_t i = 0; i < length; ++i) {
595 if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence.get(), i)))
596 goto invalid_arguments;
598 descriptors.push_back(
599 THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence.get(), i))
603 group = _getGroup(PyTuple_GET_ITEM(args, 2));
604 desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 1));
607 THDAllGather(descriptors.data(), length, desc, group);
612 THPUtils_invalidArguments(args,
nullptr,
"allGather", 1,
613 "(list[tensor] output, tensor input, group gr)");
618 PyObject* THDPModule_gatherSend(PyObject *_unused, PyObject *args)
621 if (PyTuple_GET_SIZE(args) != 3 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0))) {
622 THPUtils_invalidArguments(args,
nullptr,
"gatherSend", 1,
623 "(tensor input, int dst_rank, group gr)");
627 THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
628 auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
629 int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
632 THDGatherSend(desc, dst_rank, group);
638 PyObject* THDPModule_gatherRecv(PyObject *_unused, PyObject *args)
643 std::vector<at::Tensor> descriptors;
647 if (PyTuple_GET_SIZE(args) != 3 ||
648 !PySequence_Check(PyTuple_GET_ITEM(args, 0)) ||
649 !THPVariable_Check(PyTuple_GET_ITEM(args, 1))) {
650 goto invalid_arguments;
653 sequence =
THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
654 "expected a sequence"));
655 if (!sequence.get()) {
656 goto invalid_arguments;
659 length =
static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
661 descriptors.reserve(length);
663 for (
size_t i = 0; i < length; ++i) {
664 if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence.get(), i)))
665 goto invalid_arguments;
667 descriptors.push_back(
668 THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence.get(), i))
672 desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 1));
673 group = _getGroup(PyTuple_GET_ITEM(args, 2));
676 THDGatherRecv(descriptors.data(), length, desc, group);
681 THPUtils_invalidArguments(args,
nullptr,
"gatherRecv", 1,
682 "(list[tensor] output, tensor input, group gr)");
687 PyObject* THDPModule_scatterSend(PyObject *_unused, PyObject *args)
692 std::vector<at::Tensor> descriptors;
696 if (PyTuple_GET_SIZE(args) != 3 ||
697 !PySequence_Check(PyTuple_GET_ITEM(args, 0)) ||
698 !THPVariable_Check(PyTuple_GET_ITEM(args, 1))) {
699 goto invalid_arguments;
702 sequence =
THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
703 "expected a sequence"));
704 if (!sequence.get()) {
705 goto invalid_arguments;
708 length =
static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
710 descriptors.reserve(length);
712 for (
size_t i = 0; i < length; ++i) {
713 if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence.get(), i)))
714 goto invalid_arguments;
716 descriptors.push_back(
717 THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence.get(), i))
721 desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 1));
722 group = _getGroup(PyTuple_GET_ITEM(args, 2));
725 THDScatterSend(descriptors.data(), length, desc, group);
730 THPUtils_invalidArguments(args,
nullptr,
"scatterSend", 1,
731 "(list[tensor] input, tensor output, group gr)");
736 PyObject* THDPModule_scatterRecv(PyObject *_unused, PyObject *args)
739 if (PyTuple_GET_SIZE(args) != 3 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
740 !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
741 THPUtils_invalidArguments(args,
nullptr,
"scatterRecv", 1,
742 "(tensor output, int src_rank, group gr)");
746 THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
747 auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
748 int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
751 THDScatterRecv(desc, src_rank, group);
757 PyObject* THDPModule_barrier(PyObject *_unused, PyObject *_group)
762 THDBarrier(_getGroup(_group));
768 PyObject* THDPModule_newGroup(PyObject *_unused, PyObject *args)
773 std::vector<int> ranks;
775 if (PyTuple_GET_SIZE(args) != 1 ||
776 !PySequence_Check(PyTuple_GET_ITEM(args, 0))) {
777 goto invalid_arguments;
780 sequence =
THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
781 "expected a sequence"));
782 if (!sequence.get()) {
783 goto invalid_arguments;
786 length =
static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
788 ranks.reserve(length);
790 for (
size_t i = 0; i < length; ++i) {
791 if (!THPUtils_checkLong(PySequence_Fast_GET_ITEM(sequence.get(), i)))
792 goto invalid_arguments;
794 ranks.push_back(THPUtils_unpackLong(
795 PySequence_Fast_GET_ITEM(sequence.get(), i)));
797 for (
size_t j = 0; j < i; ++j)
798 THPUtils_assert(ranks[i] != ranks[j],
"ranks should be unique");
804 group = THDNewGroup(ranks.data(), length);
806 return PyInt_FromLong(group);
809 THPUtils_invalidArguments(args,
nullptr,
"newGroup", 1,
"(list[int] ranks)");
814 PyObject* THDPModule_requestIsCompleted(PyObject *_unused, PyObject *_req)
817 if (!THPWrapper_check(_req)) {
818 THPUtils_invalidArguments(_req,
nullptr,
"requestIsCompleted", 1,
"(request req)");
822 return PyBool_FromLong(THDRequest_isCompleted(_unpackRequest(_req)));
826 PyObject* THDPModule_requestWait(PyObject *_unused, PyObject *_req)
829 if (!THPWrapper_check(_req)) {
830 THPUtils_invalidArguments(_req,
nullptr,
"requestWait", 1,
"(request req)");
836 THDRequest_wait(_unpackRequest(_req));
842 PyObject* THDPModule_initExtension(PyObject *_unused, PyObject *args) {
843 if (PyTuple_GET_SIZE(args) != 3) {
844 THPUtils_invalidArguments(args,
nullptr,
"initExtension", 1,
"(bool is_master_worker, reduce_op obj, group obj)");
848 PyObject* is_master_worker_obj = PyTuple_GET_ITEM(args, 0);
849 PyObject* reduce_op_obj = PyTuple_GET_ITEM(args, 1);
850 PyObject* group_obj = PyTuple_GET_ITEM(args, 2);
852 THPUtils_assert(PyBool_Check(is_master_worker_obj),
"first argument should be a bool");
853 bool is_master_worker = is_master_worker_obj == Py_True;
856 #define REGISTER_REDUCE_OP(NAME) \ 857 reduce_op = PyObject_GetAttrString(reduce_op_obj, #NAME); \ 858 THPUtils_assert(reduce_op, "Missing object for reduce op " #NAME); \ 859 obj2reduceop.emplace(reduce_op.get(), THDReduce##NAME); 860 REGISTER_REDUCE_OP(SUM);
861 REGISTER_REDUCE_OP(PRODUCT);
862 REGISTER_REDUCE_OP(MIN);
863 REGISTER_REDUCE_OP(MAX);
864 #undef REGISTER_REDUCE_OP 867 #define REGISTER_GROUP(NAME) \ 868 group = PyObject_GetAttrString(group_obj, #NAME); \ 869 THPUtils_assert(group, "Missing object for group " #NAME); \ 870 obj2group.emplace(group.get(), THDGroup##NAME); 871 REGISTER_GROUP(WORLD);
872 #undef REGISTER_GROUP 874 if (is_master_worker) {
875 throw std::runtime_error(
"THD master_worker no longer supported");
880 static struct PyMethodDef _THDPModule_methods[] = {
881 {
"_dist_init_extension", (PyCFunction)THDPModule_initExtension, METH_VARARGS,
nullptr},
882 {
"_dist_init_process_group", (PyCFunction)THDPModule_initProcessGroup, METH_VARARGS,
nullptr},
883 {
"_dist_destroy_process_group", (PyCFunction)THDPModule_destroyProcessGroup, METH_NOARGS,
nullptr},
884 {
"_dist_clear_group_cache", (PyCFunction)THDPModule_clearGroupCache, METH_VARARGS,
nullptr},
886 {
"_dist_register_stream", (PyCFunction)THDPModule_registerStream, METH_O,
nullptr},
888 {
"_dist_get_rank", (PyCFunction)THDPModule_getRank, METH_NOARGS,
nullptr},
889 {
"_dist_get_num_processes", (PyCFunction)THDPModule_getNumProcesses, METH_NOARGS,
nullptr},
890 {
"_dist_isend", (PyCFunction)THDPModule_isend, METH_VARARGS,
nullptr},
891 {
"_dist_irecv", (PyCFunction)THDPModule_irecv, METH_VARARGS,
nullptr},
892 {
"_dist_send", (PyCFunction)THDPModule_send, METH_VARARGS,
nullptr},
893 {
"_dist_recv_any_source", (PyCFunction)THDPModule_recvAnySource, METH_O,
nullptr},
894 {
"_dist_recv", (PyCFunction)THDPModule_recv, METH_VARARGS,
nullptr},
895 {
"_dist_all_reduce", (PyCFunction)THDPModule_allReduce, METH_VARARGS,
nullptr},
896 {
"_dist_all_reduce_multigpu", (PyCFunction)THDPModule_allReduceMultiGPU, METH_VARARGS,
nullptr},
897 {
"_dist_reduce", (PyCFunction)THDPModule_reduce, METH_VARARGS,
nullptr},
898 {
"_dist_reduce_multigpu", (PyCFunction)THDPModule_reduceMultiGPU, METH_VARARGS,
nullptr},
899 {
"_dist_broadcast", (PyCFunction)THDPModule_broadcast, METH_VARARGS,
nullptr},
900 {
"_dist_broadcast_multigpu", (PyCFunction)THDPModule_broadcastMultiGPU, METH_VARARGS,
nullptr},
901 {
"_dist_all_gather", (PyCFunction)THDPModule_allGather, METH_VARARGS,
nullptr},
902 {
"_dist_all_gather_multigpu", (PyCFunction)THDPModule_allGatherMultiGPU, METH_VARARGS,
nullptr},
903 {
"_dist_gather_send", (PyCFunction)THDPModule_gatherSend, METH_VARARGS,
nullptr},
904 {
"_dist_gather_recv", (PyCFunction)THDPModule_gatherRecv, METH_VARARGS,
nullptr},
905 {
"_dist_scatter_send", (PyCFunction)THDPModule_scatterSend, METH_VARARGS,
nullptr},
906 {
"_dist_scatter_recv", (PyCFunction)THDPModule_scatterRecv, METH_VARARGS,
nullptr},
907 {
"_dist_barrier", (PyCFunction)THDPModule_barrier, METH_O,
nullptr},
908 {
"_dist_new_group", (PyCFunction)THDPModule_newGroup, METH_VARARGS,
nullptr},
909 {
"_dist_request_is_completed", (PyCFunction)THDPModule_requestIsCompleted, METH_O,
nullptr},
910 {
"_dist_request_wait", (PyCFunction)THDPModule_requestWait, METH_O,
nullptr},
914 PyMethodDef* THDPModule_methods() {
915 return _THDPModule_methods;