4 from datetime
import timedelta
6 from .rendezvous
import rendezvous, register_rendezvous_handler
7 from .
import BroadcastOptions, AllreduceOptions, ReduceOptions, \
8 ScatterOptions, GatherOptions
10 from .
import PrefixStore
11 from .
import ProcessGroupGloo
15 _NCCL_AVAILABLE =
True 19 from.
import ProcessGroupMPI
21 _MPI_AVAILABLE =
False 24 from.
import ProcessGroupNCCL
26 _NCCL_AVAILABLE =
False 31 An enum-like class of available backends: GLOO, NCCL, and MPI. 33 The values of this class are lowercase strings, e.g., ``"gloo"``. They can 34 be accessed as attributes, e.g., ``Backend.NCCL``. 36 This class can be directly called to parse the string, e.g., 37 ``Backend(backend_str)`` will check if ``backend_str`` is valid, and 38 return the parsed lowercase string if so. It also accepts uppercase strings, 39 e.g., ``Backend("GLOO")`` returns ``"gloo"``. 41 .. note:: The entry ``Backend.UNDEFINED`` is present but only used as 42 initial value of some fields. Users should neither use it directly 43 nor assume its existence. 45 UNDEFINED =
"undefined" 51 def __new__(cls, name):
52 if not isinstance(name, string_classes):
53 raise ValueError(
"Backend name must be a string, but got: {}".format(name))
54 value = getattr(Backend, name.upper(), Backend.UNDEFINED)
56 if value == Backend.TCP:
57 raise ValueError(
"TCP backend has been deprecated. Please use " 58 "Gloo or MPI backend for collective operations " 60 elif value == Backend.UNDEFINED:
61 raise ValueError(
"Invalid backend: '{}'".format(name))
67 _backend = Backend.UNDEFINED
68 dist_backend = Backend
73 Deprecated enum-like class for reduction operations: ``SUM``, ``PRODUCT``, 76 :class:`~torch.distributed.ReduceOp` is recommended to use instead. 81 for k, v
in ReduceOp.__members__.items():
85 def __getattribute__(self, key):
86 warnings.warn(
"torch.distributed.reduce_op is deprecated, please use " 87 "torch.distributed.ReduceOp instead")
88 return object.__getattribute__(self, key)
100 NON_GROUP_MEMBER = object()
115 _default_pg_init_method =
None 121 _default_pg_timeout = timedelta(minutes=30)
127 def _rank_not_in_group(group):
129 Helper that checks if the current process's rank is not in a given group 132 default_backend, _ = _pg_map[_get_default_group()]
133 if default_backend != Backend.MPI:
134 return group == GroupMember.NON_GROUP_MEMBER
136 if group == GroupMember.WORLD:
139 _, in_group = _pg_map[group]
143 def _get_group_rank(group, rank):
145 Helper that gets a given group's local rank in the group from a given global 149 if group
is GroupMember.WORLD:
150 raise RuntimeError(
"group.WORLD does not have local rank to global " 152 if group
not in _pg_group_ranks:
153 raise RuntimeError(
"The given group does not exist")
155 group_rank = _pg_group_ranks[group][rank]
157 raise RuntimeError(
"The global rank is not part of the group")
161 def _get_global_rank(group, group_rank):
163 Helper that gets a given group's global rank from a given local rank in the 167 if group
is GroupMember.WORLD:
168 raise RuntimeError(
"group.WORLD does not have local rank to global " 170 group_rank_map = _pg_group_ranks[group]
171 for rank, grp_rank
in group_rank_map.items():
172 if grp_rank == group_rank:
174 raise RuntimeError(
"The group rank is not part of the group")
177 def _check_default_pg():
179 Helper that checks if the default ProcessGroup has been initializd, with 183 assert _default_pg
is not None, \
184 "Default process group is not initialized" 187 def _get_group_size(group):
189 Helper that gets a given group's world size 192 if group
is GroupMember.WORLD:
194 return _default_pg.size()
195 if group
not in _pg_group_ranks:
196 raise RuntimeError(
"The given group does not exist")
197 return len(_pg_group_ranks[group])
200 def _check_single_tensor(param, param_name):
202 Helper that check the parameter: param_name is a single Tensor 205 if not isinstance(param, torch.Tensor):
206 raise RuntimeError(
"Invalid function argument. Expecting parameter: {} " 207 "to be a torch.Tensor type".format(param_name))
210 def _check_tensor_list(param, param_name):
212 Helper that check the parameter: param_name is a Tensor list 216 if isinstance(param, list):
218 if not isinstance(p, torch.Tensor):
224 raise RuntimeError(
"Invalid function argument. Expecting parameter: {} " 225 "to be a List[torch.Tensor] type".format(param_name))
228 def is_mpi_available():
230 Checks if MPI is available 233 return _MPI_AVAILABLE
236 def is_nccl_available():
238 Checks if NCCL is available 241 return _NCCL_AVAILABLE
244 def is_initialized():
246 Checking if the default process group has been initialized 249 return _default_pg
is not None 252 def _get_default_group():
254 Getting the default process group created by init_process_group 257 if not is_initialized():
258 raise RuntimeError(
"Default process group has not been initialized, " 259 "please make sure to call init_process_group.")
263 def get_backend(group=group.WORLD):
265 Returns the backend of the given process group. 268 group (ProcessGroup, optional): The process group to work on. The 269 default is the general main process group. If another specific group 270 is specified, the calling process must be part of :attr:`group`. 273 The backend of the given process group as a lower case string. 278 if group == GroupMember.WORLD:
282 if _rank_not_in_group(pg):
283 raise RuntimeError(
"Invalid process group specified")
284 return _pg_map.get(pg,
None)[0]
287 def init_process_group(backend,
288 init_method=
"env://",
289 timeout=_default_pg_timeout,
292 Initializes the default distributed process group, and this will also 293 initialize the distributed package 296 backend (str or Backend): The backend to use. Depending on 297 build-time configurations, valid values include ``mpi``, ``gloo``, 298 and ``nccl``. This field should be given as a lowercase string 299 (e.g., ``"gloo"``), which can also be accessed via 300 :class:`Backend` attributes (e.g., ``Backend.GLOO``). 301 init_method (str, optional): URL specifying how to initialize the 303 world_size (int, optional): Number of processes participating in 305 rank (int, optional): Rank of the current process. 306 store(Store, optional): Rendevous key/value store as an alternative 307 to other init methods. 308 timeout (timedelta, optional): Timeout for operations executed against 309 the process group. Default value equals 30 minutes. 310 This is only applicable for the ``gloo`` backend. 311 group_name (str, optional, deprecated): Group name. 313 To enable ``backend == Backend.MPI``, PyTorch needs to built from source 314 on a system that supports MPI. The same applies to NCCL as well. 321 global _default_pg_init_method
323 if not isinstance(timeout, timedelta):
324 raise RuntimeError(
"Expected timeout argument to be of type" 325 "datetime.timedelta")
327 if _default_pg
is not None:
328 raise RuntimeError(
"trying to initialize the default process group " 331 world_size = kwargs.pop(
'world_size', -1)
332 group_name = kwargs.pop(
'group_name',
'')
333 rank = kwargs.pop(
'rank', -1)
334 store = kwargs.pop(
'store',
None)
335 if store
is not None:
336 assert world_size > 0,
'world_size needs to be positive' 337 assert rank >= 0,
'rank needs to be non-negative' 338 assert len(kwargs) == 0, \
339 "got unexpected keyword arguments: %s" %
",".join(kwargs.keys())
343 if backend == Backend.MPI:
344 if not is_mpi_available():
345 raise RuntimeError(
"Distributed package doesn't have MPI built in")
347 _default_pg = ProcessGroupMPI([])
348 _pg_map[_default_pg] = (Backend.MPI,
True)
349 _pg_names[_default_pg] = group_name
353 if world_size != -1
and rank != -1:
354 url +=
"?rank={}&world_size={}".format(rank, world_size)
356 url +=
"?rank={}".format(rank)
357 elif world_size != -1:
358 url +=
"?world_size={}".format(world_size)
361 store, rank, world_size = next(rendezvous(url))
362 if backend == Backend.GLOO:
363 _default_pg = ProcessGroupGloo(
368 _pg_map[_default_pg] = (Backend.GLOO, store)
369 _pg_names[_default_pg] = group_name
370 elif backend == Backend.NCCL:
371 if not is_nccl_available():
372 raise RuntimeError(
"Distributed package doesn't have NCCL " 374 _default_pg = ProcessGroupNCCL(store, rank, world_size)
375 _pg_map[_default_pg] = (Backend.NCCL, store)
376 _pg_names[_default_pg] = group_name
378 _backend = _pg_map[_default_pg][0]
379 _default_pg_init_method = init_method
382 def _new_process_group_helper(world_size,
387 timeout=_default_pg_timeout):
389 Create a new distributed process group. And the new process group can be 390 used to perform collective operations. 398 group_name = str(_group_count)
401 if group_name
in _pg_names.values():
402 raise RuntimeError(
"The specified group name has already been " 403 "created, please use a different group name")
405 if not isinstance(timeout, timedelta):
406 raise RuntimeError(
"Expected timeout argument to be of type" 407 "datetime.timedelta")
409 default_backend, default_store = _pg_map[_default_pg]
411 if default_backend == Backend.MPI:
412 if not is_mpi_available():
413 raise RuntimeError(
"Distributed package doesn't have MPI built in")
414 pg = ProcessGroupMPI(group_ranks)
415 _pg_map[pg] = (Backend.MPI, in_group)
416 _pg_names[pg] = group_name
419 store = PrefixStore(group_name, default_store)
421 if default_backend == Backend.GLOO:
422 pg = ProcessGroupGloo(
427 _pg_map[pg] = (Backend.GLOO, store)
428 _pg_names[pg] = group_name
429 elif default_backend == Backend.NCCL:
430 if not is_nccl_available():
431 raise RuntimeError(
"Distributed package doesn't have NCCL " 433 pg = ProcessGroupNCCL(store, rank, world_size, group_name)
434 _pg_map[pg] = (Backend.NCCL, store)
435 _pg_names[pg] = group_name
437 raise RuntimeError(
"Unsupported distributed backend by group")
441 def destroy_process_group(group=group.WORLD):
443 Destroy a given process group, and deinitialize the distributed package 446 group (ProcessGroup, optional): The process group to be destroyed, if 447 group.WORLD is given, all process 448 groups including the default one will 453 global _pg_group_ranks
455 global _default_pg_init_method
457 default_backend, _ = _pg_map[_get_default_group()]
458 if (default_backend != Backend.MPI
and 459 group == GroupMember.NON_GROUP_MEMBER):
462 if group == GroupMember.WORLD:
466 if _pg_map.get(pg,
None)
is None:
467 raise RuntimeError(
"Invalid process group specified")
469 if group == GroupMember.WORLD:
471 _default_pg_init_method =
None 474 _pg_group_ranks.clear()
478 del _pg_group_ranks[pg]
481 def get_rank(group=group.WORLD):
483 Returns the rank of current process group 485 Rank is a unique identifier assigned to each process within a distributed 486 process group. They are always consecutive integers ranging from 0 to 490 group (ProcessGroup, optional): The process group to work on 493 The rank of the process group 494 -1, if not part of the group 497 if _rank_not_in_group(group):
501 if group == GroupMember.WORLD:
502 return _default_pg.rank()
504 return _get_group_rank(group, _default_pg.rank())
507 def get_world_size(group=group.WORLD):
509 Returns the number of processes in the current process group 512 group (ProcessGroup, optional): The process group to work on 515 The world size of the process group 516 -1, if not part of the group 519 if _rank_not_in_group(group):
522 return _get_group_size(group)
530 Sends a tensor asynchronously. 533 tensor (Tensor): Tensor to send. 534 dst (int): Destination rank. 535 group (ProcessGroup, optional): The process group to work on 536 tag (int, optional): Tag to match send with remote recv 539 A distributed request object. 540 None, if not part of the group 543 _check_single_tensor(tensor,
"tensor")
544 if _rank_not_in_group(group):
547 if group == GroupMember.WORLD:
549 return _default_pg.send([tensor], dst, tag)
551 group_dst_rank = _get_group_rank(group, dst)
552 return group.send([tensor], group_dst_rank, tag)
560 Receives a tensor asynchronously. 563 tensor (Tensor): Tensor to fill with received data. 564 src (int): Source rank. 565 group (ProcessGroup, optional): The process group to work on 566 tag (int, optional): Tag to match recv with remote send 569 A distributed request object. 570 None, if not part of the group 573 _check_single_tensor(tensor,
"tensor")
574 if _rank_not_in_group(group):
577 if group == GroupMember.WORLD:
579 return _default_pg.recv([tensor], src, tag)
581 group_src_rank = _get_group_rank(group, src)
582 return group.recv([tensor], group_src_rank, tag)
590 Sends a tensor synchronously. 593 tensor (Tensor): Tensor to send. 594 dst (int): Destination rank. 595 group (ProcessGroup, optional): The process group to work on 596 tag (int, optional): Tag to match send with remote recv 599 _check_single_tensor(tensor,
"tensor")
600 if _rank_not_in_group(group):
603 if group == GroupMember.WORLD:
605 _default_pg.send([tensor], dst, tag).wait()
607 group_dst_rank = _get_group_rank(group, dst)
608 group.send([tensor], group_dst_rank, tag).wait()
616 Receives a tensor synchronously. 619 tensor (Tensor): Tensor to fill with received data. 620 src (int, optional): Source rank. Will receive from any 621 process if unspecified. 622 group (ProcessGroup, optional): The process group to work on 623 tag (int, optional): Tag to match recv with remote send 627 -1, if not part of the group 630 _check_single_tensor(tensor,
"tensor")
631 if _rank_not_in_group(group):
634 if group == GroupMember.WORLD:
641 work = pg.recv_anysource([tensor], tag)
643 src_rank = work.source_rank()
644 if group == GroupMember.WORLD:
647 return _get_global_rank(pg, src_rank)
649 if group == GroupMember.WORLD:
650 pg.recv([tensor], src, tag).wait()
652 group_src_rank = _get_group_rank(pg, src)
653 pg.recv([tensor], group_src_rank, tag).wait()
657 def broadcast_multigpu(tensor_list,
663 Broadcasts the tensor to the whole group with multiple GPU tensors 666 ``tensor`` must have the same number of elements in all the GPUs from 667 all processes participating in the collective. each tensor in the list must 668 be on a different GPU 670 Only nccl and gloo backend are currently supported 671 tensors should only be GPU tensors 674 tensor_list (List[Tensor]): Tensors that participate in the collective 675 operation. If ``src`` is the rank, then the specified ``src_tensor`` 676 element of ``tensor_list`` (``tensor_list[src_tensor]``) will be 677 broadcast to all other tensors (on different GPUs) in the src process 678 and all tensors in ``tensor_list`` of other non-src processes. 679 You also need to make sure that ``len(tensor_list)`` is the same 680 for all the distributed processes calling this function. 682 src (int): Source rank. 683 group (ProcessGroup, optional): The process group to work on 684 async_op (bool, optional): Whether this op should be an async op 685 src_tensor (int, optional): Source tensor rank within ``tensor_list`` 688 Async work handle, if async_op is set to True. 689 None, if not async_op or if not part of the group 692 if _rank_not_in_group(group):
695 opts = BroadcastOptions()
697 opts.rootTensor = src_tensor
699 if group == GroupMember.WORLD:
701 work = _default_pg.broadcast(tensor_list, opts)
703 group_src_rank = _get_group_rank(group, src)
704 opts.rootRank = group_src_rank
705 work = group.broadcast(tensor_list, opts)
712 def broadcast(tensor,
717 Broadcasts the tensor to the whole group. 719 ``tensor`` must have the same number of elements in all processes 720 participating in the collective. 723 tensor (Tensor): Data to be sent if ``src`` is the rank of current 724 process, and tensor to be used to save received data otherwise. 725 src (int): Source rank. 726 group (ProcessGroup, optional): The process group to work on 727 async_op (bool, optional): Whether this op should be an async op 730 Async work handle, if async_op is set to True. 731 None, if not async_op or if not part of the group 734 _check_single_tensor(tensor,
"tensor")
735 if _rank_not_in_group(group):
738 opts = BroadcastOptions()
742 if group == GroupMember.WORLD:
744 work = _default_pg.broadcast([tensor], opts)
746 group_src_rank = _get_group_rank(group, src)
747 opts.rootRank = group_src_rank
748 work = group.broadcast([tensor], opts)
755 def all_reduce_multigpu(tensor_list,
760 Reduces the tensor data across all machines in such a way that all get 761 the final result. This function reduces a number of tensors on every node, 762 while each tensor resides on different GPUs. 763 Therefore, the input tensor in the tensor list needs to be GPU tensors. 764 Also, each tensor in the tensor list needs to reside on a different GPU. 766 After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise 767 identical in all processes. 769 Only nccl and gloo backend is currently supported 770 tensors should only be GPU tensors 773 tensor list (List[Tensor]): List of input and output tensors of 774 the collective. The function operates in-place and requires that 775 each tensor to be a GPU tensor on different GPUs. 776 You also need to make sure that ``len(tensor_list)`` is the same for 777 all the distributed processes calling this function. 778 op (optional): One of the values from 779 ``torch.distributed.ReduceOp`` 780 enum. Specifies an operation used for element-wise reductions. 781 group (ProcessGroup, optional): The process group to work on 782 async_op (bool, optional): Whether this op should be an async op 785 Async work handle, if async_op is set to True. 786 None, if not async_op or if not part of the group 789 if _rank_not_in_group(group):
792 opts = AllreduceOptions()
794 if group == GroupMember.WORLD:
796 work = _default_pg.allreduce(tensor_list, opts)
798 work = group.allreduce(tensor_list, opts)
806 def all_reduce(tensor,
811 Reduces the tensor data across all machines in such a way that all get 814 After the call ``tensor`` is going to be bitwise identical in all processes. 817 tensor (Tensor): Input and output of the collective. The function 819 op (optional): One of the values from 820 ``torch.distributed.ReduceOp`` 821 enum. Specifies an operation used for element-wise reductions. 822 group (ProcessGroup, optional): The process group to work on 823 async_op (bool, optional): Whether this op should be an async op 826 Async work handle, if async_op is set to True. 827 None, if not async_op or if not part of the group 830 _check_single_tensor(tensor,
"tensor")
831 if _rank_not_in_group(group):
834 opts = AllreduceOptions()
836 if group == GroupMember.WORLD:
838 work = _default_pg.allreduce([tensor], opts)
840 work = group.allreduce([tensor], opts)
848 def reduce_multigpu(tensor_list,
855 Reduces the tensor data on multiple GPUs across all machines. Each tensor 856 in ``tensor_list`` should reside on a separate GPU 858 Only the GPU of ``tensor_list[dst_tensor]`` on the process with rank ``dst`` 859 is going to receive the final result. 861 Only nccl backend is currently supported 862 tensors should only be GPU tensors 865 tensor_list (List[Tensor]): Input and output GPU tensors of the 866 collective. The function operates in-place. 867 You also need to make sure that ``len(tensor_list)`` is the same for 868 all the distributed processes calling this function. 869 dst (int): Destination rank 870 op (optional): One of the values from 871 ``torch.distributed.ReduceOp`` 872 enum. Specifies an operation used for element-wise reductions. 873 group (ProcessGroup, optional): The process group to work on 874 async_op (bool, optional): Whether this op should be an async op 875 dst_tensor (int, optional): Destination tensor rank within 879 Async work handle, if async_op is set to True. 883 if _rank_not_in_group(group):
886 opts = ReduceOptions()
889 opts.rootTensor = dst_tensor
891 if group == GroupMember.WORLD:
893 work = _default_pg.reduce(tensor_list, opts)
895 group_dst_rank = _get_group_rank(group, dst)
896 opts.rootRank = group_dst_rank
897 work = group.reduce(tensor_list, opts)
911 Reduces the tensor data across all machines. 913 Only the process with rank ``dst`` is going to receive the final result. 916 tensor (Tensor): Input and output of the collective. The function 918 dst (int): Destination rank 919 op (optional): One of the values from 920 ``torch.distributed.ReduceOp`` 921 enum. Specifies an operation used for element-wise reductions. 922 group (ProcessGroup, optional): The process group to work on 923 async_op (bool, optional): Whether this op should be an async op 926 Async work handle, if async_op is set to True. 927 None, if not async_op or if not part of the group 930 _check_single_tensor(tensor,
"tensor")
931 if _rank_not_in_group(group):
934 opts = ReduceOptions()
938 if group == GroupMember.WORLD:
940 work = _default_pg.reduce([tensor], opts)
942 group_dst_rank = _get_group_rank(group, dst)
943 opts.rootRank = group_dst_rank
944 work = group.reduce([tensor], opts)
952 def all_gather_multigpu(output_tensor_lists,
957 Gathers tensors from the whole group in a list. 958 Each tensor in ``tensor_list`` should reside on a separate GPU 960 Only nccl backend is currently supported 961 tensors should only be GPU tensors 964 output_tensor_lists (List[List[Tensor]]): Output lists. It should 965 contain correctly-sized tensors on each GPU to be used for output of 967 e.g. ``output_tensor_lists[i]`` contains the all_gather 968 result that resides on the GPU of ``input_tensor_list[i]``. 969 Note that each element of ``output_tensor_lists[i]`` has the size of 970 ``world_size * len(input_tensor_list)``, since the function all 971 gathers the result from every single GPU in the group. To interpret 972 each element of ``output_tensor_list[i]``, note that 973 ``input_tensor_list[j]`` of rank k will be appear in 974 ``output_tensor_list[i][rank * world_size + j]`` 975 Also note that ``len(output_tensor_lists)``, and the size of each 976 element in ``output_tensor_lists`` (each element is a list, 977 therefore ``len(output_tensor_lists[i])``) need to be the same 978 for all the distributed processes calling this function. 980 input_tensor_list (List[Tensor]): List of tensors(on different GPUs) to 981 be broadcast from current process. 982 Note that ``len(input_tensor_list)`` needs to be the same for 983 all the distributed processes calling this function. 985 group (ProcessGroup, optional): The process group to work on 986 async_op (bool, optional): Whether this op should be an async op 989 Async work handle, if async_op is set to True. 990 None, if not async_op or if not part of the group 993 if _rank_not_in_group(group):
996 if group == GroupMember.WORLD:
998 work = _default_pg.allgather(output_tensor_lists, input_tensor_list)
1000 work = group.allgather(output_tensor_lists, input_tensor_list)
1008 def all_gather(tensor_list,
1013 Gathers tensors from the whole group in a list. 1016 tensor_list (list[Tensor]): Output list. It should contain 1017 correctly-sized tensors to be used for output of the collective. 1018 tensor (Tensor): Tensor to be broadcast from current process. 1019 group (ProcessGroup, optional): The process group to work on 1020 async_op (bool, optional): Whether this op should be an async op 1023 Async work handle, if async_op is set to True. 1024 None, if not async_op or if not part of the group 1027 _check_tensor_list(tensor_list,
"tensor_list")
1028 _check_single_tensor(tensor,
"tensor")
1029 if _rank_not_in_group(group):
1032 if group == GroupMember.WORLD:
1034 work = _default_pg.allgather([tensor_list], [tensor])
1036 work = group.allgather([tensor_list], [tensor])
1050 Gathers a list of tensors in a single process. 1053 tensor (Tensor): Input tensor. 1054 gather_list (list[Tensor]): List of appropriately-sized tensors to 1055 use for received data. Required only in the receiving process. 1056 dst (int): Destination rank. Required in all processes except the one 1057 that is receiveing the data. 1058 group (ProcessGroup, optional): The process group to work on 1059 async_op (bool, optional): Whether this op should be an async op 1062 Async work handle, if async_op is set to True. 1063 None, if not async_op or if not part of the group 1066 _check_single_tensor(tensor,
"tensor")
1067 _check_tensor_list(gather_list,
"gather_list")
1068 if _rank_not_in_group(group):
1071 my_rank = get_rank()
1073 if gather_list
is None:
1074 raise RuntimeError(
"gather_list is a required argument in gather " 1076 input_tensors = [tensor]
1077 output_tensors = [gather_list]
1080 raise RuntimeError(
"non-empty gather_list can be given only " 1081 "to gather destination")
1082 input_tensors = [tensor]
1085 opts = GatherOptions()
1088 if group == GroupMember.WORLD:
1090 work = _default_pg.gather(output_tensors, input_tensors, opts)
1092 group_dst_rank = _get_group_rank(group, dst)
1093 opts.rootRank = group_dst_rank
1094 work = group.gather(output_tensors, input_tensors, opts)
1108 Scatters a list of tensors to all processes in a group. 1110 Each process will receive exactly one tensor and store its data in the 1111 ``tensor`` argument. 1114 tensor (Tensor): Output tensor. 1115 scatter_list (list[Tensor]): List of tensors to scatter. Required only 1116 in the process that is sending the data. 1117 src (int): Source rank. Required in all processes except the one that 1118 is sending the data. 1119 group (ProcessGroup, optional): The process group to work on 1120 async_op (bool, optional): Whether this op should be an async op 1123 Async work handle, if async_op is set to True. 1124 None, if not async_op or if not part of the group 1127 _check_single_tensor(tensor,
"tensor")
1128 _check_tensor_list(scatter_list,
"scatter_list")
1129 if _rank_not_in_group(group):
1132 my_rank = get_rank()
1134 if scatter_list
is None:
1135 raise RuntimeError(
"scatter_list is a required argument in " 1137 input_tensors = [scatter_list]
1138 output_tensors = [tensor]
1141 raise RuntimeError(
"non-empty can be given only to scatter " 1144 output_tensors = [tensor]
1146 opts = ScatterOptions()
1149 if group == GroupMember.WORLD:
1151 work = _default_pg.scatter(output_tensors, input_tensors, opts)
1153 group_src_rank = _get_group_rank(group, src)
1154 opts.rootRank = group_src_rank
1155 work = group.scatter(output_tensors, input_tensors, opts)
1163 def barrier(group=group.WORLD,
1166 Synchronizes all processes. 1168 This collective blocks processes until the whole group enters this function, 1169 if async_op is False, or if async work handle is called on wait(). 1172 group (ProcessGroup, optional): The process group to work on 1173 async_op (bool, optional): Whether this op should be an async op 1176 Async work handle, if async_op is set to True. 1177 None, if not async_op or if not part of the group 1179 if _rank_not_in_group(group):
1182 if group == GroupMember.WORLD:
1184 work = _default_pg.barrier()
1186 work = group.barrier()
1194 def new_group(ranks=None, timeout=_default_pg_timeout):
1196 Creates a new distributed group. 1198 This function requires that all processes in the main group (i.e. all 1199 processes that are part of the distributed job) enter this function, even 1200 if they are not going to be members of the group. Additionally, groups 1201 should be created in the same order in all processes. 1204 ranks (list[int]): List of ranks of group members. 1205 timeout (timedelta, optional): Timeout for operations executed against 1206 the process group. Default value equals 30 minutes. 1207 This is only applicable for the ``gloo`` backend. 1210 A handle of distributed group that can be given to collective calls. 1215 global _pg_group_ranks
1219 group_name = str(_group_count)
1222 if group_name
in _pg_names.values():
1223 raise RuntimeError(
"The specified group name has already been " 1224 "created, please use a different group name")
1226 default_backend, _ = _pg_map[_default_pg]
1227 global_rank = _default_pg.rank()
1228 global_world_size = _default_pg.size()
1231 if ranks
is not None:
1232 input_ranks = list(ranks)
1233 group_world_size = len(ranks)
1234 if group_world_size > global_world_size:
1235 raise RuntimeError(
"the new group's world size should be less or " 1236 "equal to the world size set by " 1237 "init_process_group")
1240 if rank < 0
or rank >= global_world_size:
1241 raise RuntimeError(
"The new group's rank should be within the " 1242 "the world_size set by init_process_group")
1243 if global_rank
in ranks:
1244 group_rank = ranks.index(global_rank)
1249 ranks = list(range(global_world_size))
1250 group_world_size = global_world_size
1251 group_rank = global_rank
1253 if default_backend == Backend.MPI:
1254 in_group = global_rank
in ranks
1255 pg = _new_process_group_helper(group_world_size,
1263 if global_rank
not in ranks:
1264 return GroupMember.NON_GROUP_MEMBER
1266 if default_backend != Backend.MPI:
1267 pg = _new_process_group_helper(group_world_size,
1275 _pg_group_ranks[pg] = {}
1276 if default_backend == Backend.MPI:
1277 _pg_group_ranks[pg] = pg.group_ranks()
1279 for rank
in range(global_world_size):
1281 _pg_group_ranks[pg][rank] = ranks.index(rank)