Caffe2 - Python API
A deep learning, cross platform ML framework
distributed_c10d.py
1 import torch
2 import warnings
3 from torch._six import string_classes
4 from datetime import timedelta
5 
6 from .rendezvous import rendezvous, register_rendezvous_handler
7 from . import BroadcastOptions, AllreduceOptions, ReduceOptions, \
8  ScatterOptions, GatherOptions
9 from . import ReduceOp
10 from . import PrefixStore
11 from . import ProcessGroupGloo
12 
13 
14 _MPI_AVAILABLE = True
15 _NCCL_AVAILABLE = True
16 
17 
18 try:
19  from. import ProcessGroupMPI
20 except ImportError:
21  _MPI_AVAILABLE = False
22 
23 try:
24  from. import ProcessGroupNCCL
25 except ImportError:
26  _NCCL_AVAILABLE = False
27 
28 
29 class Backend(object):
30  """
31  An enum-like class of available backends: GLOO, NCCL, and MPI.
32 
33  The values of this class are lowercase strings, e.g., ``"gloo"``. They can
34  be accessed as attributes, e.g., ``Backend.NCCL``.
35 
36  This class can be directly called to parse the string, e.g.,
37  ``Backend(backend_str)`` will check if ``backend_str`` is valid, and
38  return the parsed lowercase string if so. It also accepts uppercase strings,
39  e.g., ``Backend("GLOO")`` returns ``"gloo"``.
40 
41  .. note:: The entry ``Backend.UNDEFINED`` is present but only used as
42  initial value of some fields. Users should neither use it directly
43  nor assume its existence.
44  """
45  UNDEFINED = "undefined"
46  GLOO = "gloo"
47  NCCL = "nccl"
48  MPI = "mpi"
49  TCP = "tcp"
50 
51  def __new__(cls, name):
52  if not isinstance(name, string_classes):
53  raise ValueError("Backend name must be a string, but got: {}".format(name))
54  value = getattr(Backend, name.upper(), Backend.UNDEFINED)
55 
56  if value == Backend.TCP:
57  raise ValueError("TCP backend has been deprecated. Please use "
58  "Gloo or MPI backend for collective operations "
59  "on CPU tensors.")
60  elif value == Backend.UNDEFINED:
61  raise ValueError("Invalid backend: '{}'".format(name))
62  return value
63 
64 # `_backend`, `dist_backend`, and `reduce_op` are here to maintain backward
65 # compatibility with pre-c10d distributed package.
66 # TODO: remove them when users are ready to take a hard dependency on PyTorch 1.
67 _backend = Backend.UNDEFINED
68 dist_backend = Backend
69 
70 
71 class reduce_op(object):
72  r"""
73  Deprecated enum-like class for reduction operations: ``SUM``, ``PRODUCT``,
74  ``MIN``, and ``MAX``.
75 
76  :class:`~torch.distributed.ReduceOp` is recommended to use instead.
77  """
78 
79  def __init__(self):
80  # __members__ is a dict storing key-value pairs for enum classes
81  for k, v in ReduceOp.__members__.items():
82  setattr(self, k, v)
83  self.__members__ = ReduceOp.__members__
84 
85  def __getattribute__(self, key):
86  warnings.warn("torch.distributed.reduce_op is deprecated, please use "
87  "torch.distributed.ReduceOp instead")
88  return object.__getattribute__(self, key)
89 
90 reduce_op = reduce_op()
91 
92 
93 class group(object):
94  WORLD = object()
95 
96 
97 class GroupMember(object):
98  # Alias to group.WORLD for backward compatibility
99  WORLD = group.WORLD
100  NON_GROUP_MEMBER = object()
101 
102 
103 # Cached process groups
104 # For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
105 # For MPI pg, it is a map from ProcessGroup to (Backend, Bool), where bool
106 # represents if the ProcessGroup objects is part of the group
107 _pg_map = {}
108 # Process group's names, map from ProcessGroup to str
109 _pg_names = {}
110 # Process group's global rank to local rank mapping
111 _pg_group_ranks = {}
112 
113 # Default process group state
114 _default_pg = None
115 _default_pg_init_method = None
116 
117 # Default process group wide timeout, if applicable.
118 # This currently only applies to the gloo backend. To make an attempt at
119 # backwards compatibility with THD, we use an extraordinarily high default
120 # timeout, given that THD did not have timeouts.
121 _default_pg_timeout = timedelta(minutes=30)
122 
123 # Process group count for default naming
124 _group_count = 0
125 
126 
127 def _rank_not_in_group(group):
128  """
129  Helper that checks if the current process's rank is not in a given group
130 
131  """
132  default_backend, _ = _pg_map[_get_default_group()]
133  if default_backend != Backend.MPI:
134  return group == GroupMember.NON_GROUP_MEMBER
135  else:
136  if group == GroupMember.WORLD:
137  return False
138  else:
139  _, in_group = _pg_map[group]
140  return not in_group
141 
142 
143 def _get_group_rank(group, rank):
144  """
145  Helper that gets a given group's local rank in the group from a given global
146  rank
147 
148  """
149  if group is GroupMember.WORLD:
150  raise RuntimeError("group.WORLD does not have local rank to global "
151  "rank mapping")
152  if group not in _pg_group_ranks:
153  raise RuntimeError("The given group does not exist")
154  try:
155  group_rank = _pg_group_ranks[group][rank]
156  except KeyError:
157  raise RuntimeError("The global rank is not part of the group")
158  return group_rank
159 
160 
161 def _get_global_rank(group, group_rank):
162  """
163  Helper that gets a given group's global rank from a given local rank in the
164  group
165 
166  """
167  if group is GroupMember.WORLD:
168  raise RuntimeError("group.WORLD does not have local rank to global "
169  "rank mapping")
170  group_rank_map = _pg_group_ranks[group]
171  for rank, grp_rank in group_rank_map.items():
172  if grp_rank == group_rank:
173  return rank
174  raise RuntimeError("The group rank is not part of the group")
175 
176 
177 def _check_default_pg():
178  """
179  Helper that checks if the default ProcessGroup has been initializd, with
180  assertion
181 
182  """
183  assert _default_pg is not None, \
184  "Default process group is not initialized"
185 
186 
187 def _get_group_size(group):
188  """
189  Helper that gets a given group's world size
190 
191  """
192  if group is GroupMember.WORLD:
193  _check_default_pg()
194  return _default_pg.size()
195  if group not in _pg_group_ranks:
196  raise RuntimeError("The given group does not exist")
197  return len(_pg_group_ranks[group])
198 
199 
200 def _check_single_tensor(param, param_name):
201  """
202  Helper that check the parameter: param_name is a single Tensor
203 
204  """
205  if not isinstance(param, torch.Tensor):
206  raise RuntimeError("Invalid function argument. Expecting parameter: {} "
207  "to be a torch.Tensor type".format(param_name))
208 
209 
210 def _check_tensor_list(param, param_name):
211  """
212  Helper that check the parameter: param_name is a Tensor list
213 
214  """
215  wrong_type = False
216  if isinstance(param, list):
217  for p in param:
218  if not isinstance(p, torch.Tensor):
219  wrong_type = True
220  break
221  else:
222  wrong_type = True
223  if wrong_type:
224  raise RuntimeError("Invalid function argument. Expecting parameter: {} "
225  "to be a List[torch.Tensor] type".format(param_name))
226 
227 
228 def is_mpi_available():
229  """
230  Checks if MPI is available
231 
232  """
233  return _MPI_AVAILABLE
234 
235 
236 def is_nccl_available():
237  """
238  Checks if NCCL is available
239 
240  """
241  return _NCCL_AVAILABLE
242 
243 
244 def is_initialized():
245  """
246  Checking if the default process group has been initialized
247 
248  """
249  return _default_pg is not None
250 
251 
252 def _get_default_group():
253  """
254  Getting the default process group created by init_process_group
255 
256  """
257  if not is_initialized():
258  raise RuntimeError("Default process group has not been initialized, "
259  "please make sure to call init_process_group.")
260  return _default_pg
261 
262 
263 def get_backend(group=group.WORLD):
264  """
265  Returns the backend of the given process group.
266 
267  Arguments:
268  group (ProcessGroup, optional): The process group to work on. The
269  default is the general main process group. If another specific group
270  is specified, the calling process must be part of :attr:`group`.
271 
272  Returns:
273  The backend of the given process group as a lower case string.
274 
275  """
276  _check_default_pg()
277 
278  if group == GroupMember.WORLD:
279  pg = _default_pg
280  else:
281  pg = group
282  if _rank_not_in_group(pg):
283  raise RuntimeError("Invalid process group specified")
284  return _pg_map.get(pg, None)[0]
285 
286 
287 def init_process_group(backend,
288  init_method="env://",
289  timeout=_default_pg_timeout,
290  **kwargs):
291  """
292  Initializes the default distributed process group, and this will also
293  initialize the distributed package
294 
295  Arguments:
296  backend (str or Backend): The backend to use. Depending on
297  build-time configurations, valid values include ``mpi``, ``gloo``,
298  and ``nccl``. This field should be given as a lowercase string
299  (e.g., ``"gloo"``), which can also be accessed via
300  :class:`Backend` attributes (e.g., ``Backend.GLOO``).
301  init_method (str, optional): URL specifying how to initialize the
302  process group.
303  world_size (int, optional): Number of processes participating in
304  the job.
305  rank (int, optional): Rank of the current process.
306  store(Store, optional): Rendevous key/value store as an alternative
307  to other init methods.
308  timeout (timedelta, optional): Timeout for operations executed against
309  the process group. Default value equals 30 minutes.
310  This is only applicable for the ``gloo`` backend.
311  group_name (str, optional, deprecated): Group name.
312 
313  To enable ``backend == Backend.MPI``, PyTorch needs to built from source
314  on a system that supports MPI. The same applies to NCCL as well.
315 
316  """
317  global _pg_map
318  global _pg_names
319  global _backend
320  global _default_pg
321  global _default_pg_init_method
322 
323  if not isinstance(timeout, timedelta):
324  raise RuntimeError("Expected timeout argument to be of type"
325  "datetime.timedelta")
326 
327  if _default_pg is not None:
328  raise RuntimeError("trying to initialize the default process group "
329  "twice!")
330 
331  world_size = kwargs.pop('world_size', -1)
332  group_name = kwargs.pop('group_name', '')
333  rank = kwargs.pop('rank', -1)
334  store = kwargs.pop('store', None)
335  if store is not None:
336  assert world_size > 0, 'world_size needs to be positive'
337  assert rank >= 0, 'rank needs to be non-negative'
338  assert len(kwargs) == 0, \
339  "got unexpected keyword arguments: %s" % ",".join(kwargs.keys())
340 
341  backend = Backend(backend)
342 
343  if backend == Backend.MPI:
344  if not is_mpi_available():
345  raise RuntimeError("Distributed package doesn't have MPI built in")
346 
347  _default_pg = ProcessGroupMPI([])
348  _pg_map[_default_pg] = (Backend.MPI, True)
349  _pg_names[_default_pg] = group_name
350  else:
351  # backward compatible API
352  url = init_method
353  if world_size != -1 and rank != -1:
354  url += "?rank={}&world_size={}".format(rank, world_size)
355  elif rank != -1:
356  url += "?rank={}".format(rank)
357  elif world_size != -1:
358  url += "?world_size={}".format(world_size)
359 
360  if store is None:
361  store, rank, world_size = next(rendezvous(url))
362  if backend == Backend.GLOO:
363  _default_pg = ProcessGroupGloo(
364  store,
365  rank,
366  world_size,
367  timeout=timeout)
368  _pg_map[_default_pg] = (Backend.GLOO, store)
369  _pg_names[_default_pg] = group_name
370  elif backend == Backend.NCCL:
371  if not is_nccl_available():
372  raise RuntimeError("Distributed package doesn't have NCCL "
373  "built in")
374  _default_pg = ProcessGroupNCCL(store, rank, world_size)
375  _pg_map[_default_pg] = (Backend.NCCL, store)
376  _pg_names[_default_pg] = group_name
377 
378  _backend = _pg_map[_default_pg][0]
379  _default_pg_init_method = init_method
380 
381 
382 def _new_process_group_helper(world_size,
383  rank,
384  group_ranks,
385  in_group,
386  group_name,
387  timeout=_default_pg_timeout):
388  """
389  Create a new distributed process group. And the new process group can be
390  used to perform collective operations.
391 
392  """
393  global _pg_map
394  global _group_count
395  global _pg_names
396 
397  if not group_name:
398  group_name = str(_group_count)
399  _group_count += 1
400 
401  if group_name in _pg_names.values():
402  raise RuntimeError("The specified group name has already been "
403  "created, please use a different group name")
404 
405  if not isinstance(timeout, timedelta):
406  raise RuntimeError("Expected timeout argument to be of type"
407  "datetime.timedelta")
408 
409  default_backend, default_store = _pg_map[_default_pg]
410 
411  if default_backend == Backend.MPI:
412  if not is_mpi_available():
413  raise RuntimeError("Distributed package doesn't have MPI built in")
414  pg = ProcessGroupMPI(group_ranks)
415  _pg_map[pg] = (Backend.MPI, in_group)
416  _pg_names[pg] = group_name
417  else:
418  # Create the prefix store
419  store = PrefixStore(group_name, default_store)
420 
421  if default_backend == Backend.GLOO:
422  pg = ProcessGroupGloo(
423  store,
424  rank,
425  world_size,
426  timeout=timeout)
427  _pg_map[pg] = (Backend.GLOO, store)
428  _pg_names[pg] = group_name
429  elif default_backend == Backend.NCCL:
430  if not is_nccl_available():
431  raise RuntimeError("Distributed package doesn't have NCCL "
432  "built in")
433  pg = ProcessGroupNCCL(store, rank, world_size, group_name)
434  _pg_map[pg] = (Backend.NCCL, store)
435  _pg_names[pg] = group_name
436  else:
437  raise RuntimeError("Unsupported distributed backend by group")
438  return pg
439 
440 
441 def destroy_process_group(group=group.WORLD):
442  """
443  Destroy a given process group, and deinitialize the distributed package
444 
445  Arguments:
446  group (ProcessGroup, optional): The process group to be destroyed, if
447  group.WORLD is given, all process
448  groups including the default one will
449  be destroyed.
450  """
451  global _pg_map
452  global _pg_names
453  global _pg_group_ranks
454  global _default_pg
455  global _default_pg_init_method
456 
457  default_backend, _ = _pg_map[_get_default_group()]
458  if (default_backend != Backend.MPI and
459  group == GroupMember.NON_GROUP_MEMBER):
460  return
461 
462  if group == GroupMember.WORLD:
463  pg = _default_pg
464  else:
465  pg = group
466  if _pg_map.get(pg, None) is None:
467  raise RuntimeError("Invalid process group specified")
468 
469  if group == GroupMember.WORLD:
470  _default_pg = None
471  _default_pg_init_method = None
472  _pg_map.clear()
473  _pg_names.clear()
474  _pg_group_ranks.clear()
475  else:
476  del _pg_map[pg]
477  del _pg_names[pg]
478  del _pg_group_ranks[pg]
479 
480 
481 def get_rank(group=group.WORLD):
482  """
483  Returns the rank of current process group
484 
485  Rank is a unique identifier assigned to each process within a distributed
486  process group. They are always consecutive integers ranging from 0 to
487  ``world_size``.
488 
489  Arguments:
490  group (ProcessGroup, optional): The process group to work on
491 
492  Returns:
493  The rank of the process group
494  -1, if not part of the group
495 
496  """
497  if _rank_not_in_group(group):
498  return -1
499 
500  _check_default_pg()
501  if group == GroupMember.WORLD:
502  return _default_pg.rank()
503 
504  return _get_group_rank(group, _default_pg.rank())
505 
506 
507 def get_world_size(group=group.WORLD):
508  """
509  Returns the number of processes in the current process group
510 
511  Arguments:
512  group (ProcessGroup, optional): The process group to work on
513 
514  Returns:
515  The world size of the process group
516  -1, if not part of the group
517 
518  """
519  if _rank_not_in_group(group):
520  return -1
521 
522  return _get_group_size(group)
523 
524 
525 def isend(tensor,
526  dst,
527  group=group.WORLD,
528  tag=0):
529  """
530  Sends a tensor asynchronously.
531 
532  Arguments:
533  tensor (Tensor): Tensor to send.
534  dst (int): Destination rank.
535  group (ProcessGroup, optional): The process group to work on
536  tag (int, optional): Tag to match send with remote recv
537 
538  Returns:
539  A distributed request object.
540  None, if not part of the group
541 
542  """
543  _check_single_tensor(tensor, "tensor")
544  if _rank_not_in_group(group):
545  return
546 
547  if group == GroupMember.WORLD:
548  _check_default_pg()
549  return _default_pg.send([tensor], dst, tag)
550  else:
551  group_dst_rank = _get_group_rank(group, dst)
552  return group.send([tensor], group_dst_rank, tag)
553 
554 
555 def irecv(tensor,
556  src,
557  group=group.WORLD,
558  tag=0):
559  """
560  Receives a tensor asynchronously.
561 
562  Arguments:
563  tensor (Tensor): Tensor to fill with received data.
564  src (int): Source rank.
565  group (ProcessGroup, optional): The process group to work on
566  tag (int, optional): Tag to match recv with remote send
567 
568  Returns:
569  A distributed request object.
570  None, if not part of the group
571 
572  """
573  _check_single_tensor(tensor, "tensor")
574  if _rank_not_in_group(group):
575  return
576 
577  if group == GroupMember.WORLD:
578  _check_default_pg()
579  return _default_pg.recv([tensor], src, tag)
580  else:
581  group_src_rank = _get_group_rank(group, src)
582  return group.recv([tensor], group_src_rank, tag)
583 
584 
585 def send(tensor,
586  dst,
587  group=group.WORLD,
588  tag=0):
589  """
590  Sends a tensor synchronously.
591 
592  Arguments:
593  tensor (Tensor): Tensor to send.
594  dst (int): Destination rank.
595  group (ProcessGroup, optional): The process group to work on
596  tag (int, optional): Tag to match send with remote recv
597 
598  """
599  _check_single_tensor(tensor, "tensor")
600  if _rank_not_in_group(group):
601  return
602 
603  if group == GroupMember.WORLD:
604  _check_default_pg()
605  _default_pg.send([tensor], dst, tag).wait()
606  else:
607  group_dst_rank = _get_group_rank(group, dst)
608  group.send([tensor], group_dst_rank, tag).wait()
609 
610 
611 def recv(tensor,
612  src=None,
613  group=group.WORLD,
614  tag=0):
615  """
616  Receives a tensor synchronously.
617 
618  Arguments:
619  tensor (Tensor): Tensor to fill with received data.
620  src (int, optional): Source rank. Will receive from any
621  process if unspecified.
622  group (ProcessGroup, optional): The process group to work on
623  tag (int, optional): Tag to match recv with remote send
624 
625  Returns:
626  Sender rank
627  -1, if not part of the group
628 
629  """
630  _check_single_tensor(tensor, "tensor")
631  if _rank_not_in_group(group):
632  return -1
633 
634  if group == GroupMember.WORLD:
635  _check_default_pg()
636  pg = _default_pg
637  else:
638  pg = group
639 
640  if src is None:
641  work = pg.recv_anysource([tensor], tag)
642  work.wait()
643  src_rank = work.source_rank()
644  if group == GroupMember.WORLD:
645  return src_rank
646  else:
647  return _get_global_rank(pg, src_rank)
648  else:
649  if group == GroupMember.WORLD:
650  pg.recv([tensor], src, tag).wait()
651  else:
652  group_src_rank = _get_group_rank(pg, src)
653  pg.recv([tensor], group_src_rank, tag).wait()
654  return src
655 
656 
657 def broadcast_multigpu(tensor_list,
658  src,
659  group=group.WORLD,
660  async_op=False,
661  src_tensor=0):
662  """
663  Broadcasts the tensor to the whole group with multiple GPU tensors
664  per node.
665 
666  ``tensor`` must have the same number of elements in all the GPUs from
667  all processes participating in the collective. each tensor in the list must
668  be on a different GPU
669 
670  Only nccl and gloo backend are currently supported
671  tensors should only be GPU tensors
672 
673  Arguments:
674  tensor_list (List[Tensor]): Tensors that participate in the collective
675  operation. If ``src`` is the rank, then the specified ``src_tensor``
676  element of ``tensor_list`` (``tensor_list[src_tensor]``) will be
677  broadcast to all other tensors (on different GPUs) in the src process
678  and all tensors in ``tensor_list`` of other non-src processes.
679  You also need to make sure that ``len(tensor_list)`` is the same
680  for all the distributed processes calling this function.
681 
682  src (int): Source rank.
683  group (ProcessGroup, optional): The process group to work on
684  async_op (bool, optional): Whether this op should be an async op
685  src_tensor (int, optional): Source tensor rank within ``tensor_list``
686 
687  Returns:
688  Async work handle, if async_op is set to True.
689  None, if not async_op or if not part of the group
690 
691  """
692  if _rank_not_in_group(group):
693  return
694 
695  opts = BroadcastOptions()
696  opts.rootRank = src
697  opts.rootTensor = src_tensor
698 
699  if group == GroupMember.WORLD:
700  _check_default_pg()
701  work = _default_pg.broadcast(tensor_list, opts)
702  else:
703  group_src_rank = _get_group_rank(group, src)
704  opts.rootRank = group_src_rank
705  work = group.broadcast(tensor_list, opts)
706  if async_op:
707  return work
708  else:
709  work.wait()
710 
711 
712 def broadcast(tensor,
713  src,
714  group=group.WORLD,
715  async_op=False):
716  """
717  Broadcasts the tensor to the whole group.
718 
719  ``tensor`` must have the same number of elements in all processes
720  participating in the collective.
721 
722  Arguments:
723  tensor (Tensor): Data to be sent if ``src`` is the rank of current
724  process, and tensor to be used to save received data otherwise.
725  src (int): Source rank.
726  group (ProcessGroup, optional): The process group to work on
727  async_op (bool, optional): Whether this op should be an async op
728 
729  Returns:
730  Async work handle, if async_op is set to True.
731  None, if not async_op or if not part of the group
732 
733  """
734  _check_single_tensor(tensor, "tensor")
735  if _rank_not_in_group(group):
736  return
737 
738  opts = BroadcastOptions()
739  opts.rootRank = src
740  opts.rootTensor = 0
741 
742  if group == GroupMember.WORLD:
743  _check_default_pg()
744  work = _default_pg.broadcast([tensor], opts)
745  else:
746  group_src_rank = _get_group_rank(group, src)
747  opts.rootRank = group_src_rank
748  work = group.broadcast([tensor], opts)
749  if async_op:
750  return work
751  else:
752  work.wait()
753 
754 
755 def all_reduce_multigpu(tensor_list,
756  op=ReduceOp.SUM,
757  group=group.WORLD,
758  async_op=False):
759  r"""
760  Reduces the tensor data across all machines in such a way that all get
761  the final result. This function reduces a number of tensors on every node,
762  while each tensor resides on different GPUs.
763  Therefore, the input tensor in the tensor list needs to be GPU tensors.
764  Also, each tensor in the tensor list needs to reside on a different GPU.
765 
766  After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise
767  identical in all processes.
768 
769  Only nccl and gloo backend is currently supported
770  tensors should only be GPU tensors
771 
772  Arguments:
773  tensor list (List[Tensor]): List of input and output tensors of
774  the collective. The function operates in-place and requires that
775  each tensor to be a GPU tensor on different GPUs.
776  You also need to make sure that ``len(tensor_list)`` is the same for
777  all the distributed processes calling this function.
778  op (optional): One of the values from
779  ``torch.distributed.ReduceOp``
780  enum. Specifies an operation used for element-wise reductions.
781  group (ProcessGroup, optional): The process group to work on
782  async_op (bool, optional): Whether this op should be an async op
783 
784  Returns:
785  Async work handle, if async_op is set to True.
786  None, if not async_op or if not part of the group
787 
788  """
789  if _rank_not_in_group(group):
790  return
791 
792  opts = AllreduceOptions()
793  opts.reduceOp = op
794  if group == GroupMember.WORLD:
795  _check_default_pg()
796  work = _default_pg.allreduce(tensor_list, opts)
797  else:
798  work = group.allreduce(tensor_list, opts)
799 
800  if async_op:
801  return work
802  else:
803  work.wait()
804 
805 
806 def all_reduce(tensor,
807  op=ReduceOp.SUM,
808  group=group.WORLD,
809  async_op=False):
810  """
811  Reduces the tensor data across all machines in such a way that all get
812  the final result.
813 
814  After the call ``tensor`` is going to be bitwise identical in all processes.
815 
816  Arguments:
817  tensor (Tensor): Input and output of the collective. The function
818  operates in-place.
819  op (optional): One of the values from
820  ``torch.distributed.ReduceOp``
821  enum. Specifies an operation used for element-wise reductions.
822  group (ProcessGroup, optional): The process group to work on
823  async_op (bool, optional): Whether this op should be an async op
824 
825  Returns:
826  Async work handle, if async_op is set to True.
827  None, if not async_op or if not part of the group
828 
829  """
830  _check_single_tensor(tensor, "tensor")
831  if _rank_not_in_group(group):
832  return
833 
834  opts = AllreduceOptions()
835  opts.reduceOp = op
836  if group == GroupMember.WORLD:
837  _check_default_pg()
838  work = _default_pg.allreduce([tensor], opts)
839  else:
840  work = group.allreduce([tensor], opts)
841 
842  if async_op:
843  return work
844  else:
845  work.wait()
846 
847 
848 def reduce_multigpu(tensor_list,
849  dst,
850  op=ReduceOp.SUM,
851  group=group.WORLD,
852  async_op=False,
853  dst_tensor=0):
854  """
855  Reduces the tensor data on multiple GPUs across all machines. Each tensor
856  in ``tensor_list`` should reside on a separate GPU
857 
858  Only the GPU of ``tensor_list[dst_tensor]`` on the process with rank ``dst``
859  is going to receive the final result.
860 
861  Only nccl backend is currently supported
862  tensors should only be GPU tensors
863 
864  Arguments:
865  tensor_list (List[Tensor]): Input and output GPU tensors of the
866  collective. The function operates in-place.
867  You also need to make sure that ``len(tensor_list)`` is the same for
868  all the distributed processes calling this function.
869  dst (int): Destination rank
870  op (optional): One of the values from
871  ``torch.distributed.ReduceOp``
872  enum. Specifies an operation used for element-wise reductions.
873  group (ProcessGroup, optional): The process group to work on
874  async_op (bool, optional): Whether this op should be an async op
875  dst_tensor (int, optional): Destination tensor rank within
876  ``tensor_list``
877 
878  Returns:
879  Async work handle, if async_op is set to True.
880  None, otherwise
881 
882  """
883  if _rank_not_in_group(group):
884  return
885 
886  opts = ReduceOptions()
887  opts.reduceOp = op
888  opts.rootRank = dst
889  opts.rootTensor = dst_tensor
890 
891  if group == GroupMember.WORLD:
892  _check_default_pg()
893  work = _default_pg.reduce(tensor_list, opts)
894  else:
895  group_dst_rank = _get_group_rank(group, dst)
896  opts.rootRank = group_dst_rank
897  work = group.reduce(tensor_list, opts)
898 
899  if async_op:
900  return work
901  else:
902  work.wait()
903 
904 
905 def reduce(tensor,
906  dst,
907  op=ReduceOp.SUM,
908  group=group.WORLD,
909  async_op=False):
910  """
911  Reduces the tensor data across all machines.
912 
913  Only the process with rank ``dst`` is going to receive the final result.
914 
915  Arguments:
916  tensor (Tensor): Input and output of the collective. The function
917  operates in-place.
918  dst (int): Destination rank
919  op (optional): One of the values from
920  ``torch.distributed.ReduceOp``
921  enum. Specifies an operation used for element-wise reductions.
922  group (ProcessGroup, optional): The process group to work on
923  async_op (bool, optional): Whether this op should be an async op
924 
925  Returns:
926  Async work handle, if async_op is set to True.
927  None, if not async_op or if not part of the group
928 
929  """
930  _check_single_tensor(tensor, "tensor")
931  if _rank_not_in_group(group):
932  return
933 
934  opts = ReduceOptions()
935  opts.reduceOp = op
936  opts.rootRank = dst
937 
938  if group == GroupMember.WORLD:
939  _check_default_pg()
940  work = _default_pg.reduce([tensor], opts)
941  else:
942  group_dst_rank = _get_group_rank(group, dst)
943  opts.rootRank = group_dst_rank
944  work = group.reduce([tensor], opts)
945 
946  if async_op:
947  return work
948  else:
949  work.wait()
950 
951 
952 def all_gather_multigpu(output_tensor_lists,
953  input_tensor_list,
954  group=group.WORLD,
955  async_op=False):
956  """
957  Gathers tensors from the whole group in a list.
958  Each tensor in ``tensor_list`` should reside on a separate GPU
959 
960  Only nccl backend is currently supported
961  tensors should only be GPU tensors
962 
963  Arguments:
964  output_tensor_lists (List[List[Tensor]]): Output lists. It should
965  contain correctly-sized tensors on each GPU to be used for output of
966  the collective.
967  e.g. ``output_tensor_lists[i]`` contains the all_gather
968  result that resides on the GPU of ``input_tensor_list[i]``.
969  Note that each element of ``output_tensor_lists[i]`` has the size of
970  ``world_size * len(input_tensor_list)``, since the function all
971  gathers the result from every single GPU in the group. To interpret
972  each element of ``output_tensor_list[i]``, note that
973  ``input_tensor_list[j]`` of rank k will be appear in
974  ``output_tensor_list[i][rank * world_size + j]``
975  Also note that ``len(output_tensor_lists)``, and the size of each
976  element in ``output_tensor_lists`` (each element is a list,
977  therefore ``len(output_tensor_lists[i])``) need to be the same
978  for all the distributed processes calling this function.
979 
980  input_tensor_list (List[Tensor]): List of tensors(on different GPUs) to
981  be broadcast from current process.
982  Note that ``len(input_tensor_list)`` needs to be the same for
983  all the distributed processes calling this function.
984 
985  group (ProcessGroup, optional): The process group to work on
986  async_op (bool, optional): Whether this op should be an async op
987 
988  Returns:
989  Async work handle, if async_op is set to True.
990  None, if not async_op or if not part of the group
991 
992  """
993  if _rank_not_in_group(group):
994  return
995 
996  if group == GroupMember.WORLD:
997  _check_default_pg()
998  work = _default_pg.allgather(output_tensor_lists, input_tensor_list)
999  else:
1000  work = group.allgather(output_tensor_lists, input_tensor_list)
1001 
1002  if async_op:
1003  return work
1004  else:
1005  work.wait()
1006 
1007 
1008 def all_gather(tensor_list,
1009  tensor,
1010  group=group.WORLD,
1011  async_op=False):
1012  """
1013  Gathers tensors from the whole group in a list.
1014 
1015  Arguments:
1016  tensor_list (list[Tensor]): Output list. It should contain
1017  correctly-sized tensors to be used for output of the collective.
1018  tensor (Tensor): Tensor to be broadcast from current process.
1019  group (ProcessGroup, optional): The process group to work on
1020  async_op (bool, optional): Whether this op should be an async op
1021 
1022  Returns:
1023  Async work handle, if async_op is set to True.
1024  None, if not async_op or if not part of the group
1025 
1026  """
1027  _check_tensor_list(tensor_list, "tensor_list")
1028  _check_single_tensor(tensor, "tensor")
1029  if _rank_not_in_group(group):
1030  return
1031 
1032  if group == GroupMember.WORLD:
1033  _check_default_pg()
1034  work = _default_pg.allgather([tensor_list], [tensor])
1035  else:
1036  work = group.allgather([tensor_list], [tensor])
1037 
1038  if async_op:
1039  return work
1040  else:
1041  work.wait()
1042 
1043 
1044 def gather(tensor,
1045  gather_list,
1046  dst,
1047  group=group.WORLD,
1048  async_op=False):
1049  """
1050  Gathers a list of tensors in a single process.
1051 
1052  Arguments:
1053  tensor (Tensor): Input tensor.
1054  gather_list (list[Tensor]): List of appropriately-sized tensors to
1055  use for received data. Required only in the receiving process.
1056  dst (int): Destination rank. Required in all processes except the one
1057  that is receiveing the data.
1058  group (ProcessGroup, optional): The process group to work on
1059  async_op (bool, optional): Whether this op should be an async op
1060 
1061  Returns:
1062  Async work handle, if async_op is set to True.
1063  None, if not async_op or if not part of the group
1064 
1065  """
1066  _check_single_tensor(tensor, "tensor")
1067  _check_tensor_list(gather_list, "gather_list")
1068  if _rank_not_in_group(group):
1069  return
1070 
1071  my_rank = get_rank()
1072  if dst == my_rank:
1073  if gather_list is None:
1074  raise RuntimeError("gather_list is a required argument in gather "
1075  "destination")
1076  input_tensors = [tensor]
1077  output_tensors = [gather_list]
1078  else:
1079  if gather_list:
1080  raise RuntimeError("non-empty gather_list can be given only "
1081  "to gather destination")
1082  input_tensors = [tensor]
1083  output_tensors = []
1084 
1085  opts = GatherOptions()
1086  opts.rootRank = dst
1087 
1088  if group == GroupMember.WORLD:
1089  _check_default_pg()
1090  work = _default_pg.gather(output_tensors, input_tensors, opts)
1091  else:
1092  group_dst_rank = _get_group_rank(group, dst)
1093  opts.rootRank = group_dst_rank
1094  work = group.gather(output_tensors, input_tensors, opts)
1095 
1096  if async_op:
1097  return work
1098  else:
1099  work.wait()
1100 
1101 
1102 def scatter(tensor,
1103  scatter_list,
1104  src,
1105  group=group.WORLD,
1106  async_op=False):
1107  """
1108  Scatters a list of tensors to all processes in a group.
1109 
1110  Each process will receive exactly one tensor and store its data in the
1111  ``tensor`` argument.
1112 
1113  Arguments:
1114  tensor (Tensor): Output tensor.
1115  scatter_list (list[Tensor]): List of tensors to scatter. Required only
1116  in the process that is sending the data.
1117  src (int): Source rank. Required in all processes except the one that
1118  is sending the data.
1119  group (ProcessGroup, optional): The process group to work on
1120  async_op (bool, optional): Whether this op should be an async op
1121 
1122  Returns:
1123  Async work handle, if async_op is set to True.
1124  None, if not async_op or if not part of the group
1125 
1126  """
1127  _check_single_tensor(tensor, "tensor")
1128  _check_tensor_list(scatter_list, "scatter_list")
1129  if _rank_not_in_group(group):
1130  return
1131 
1132  my_rank = get_rank()
1133  if src == my_rank:
1134  if scatter_list is None:
1135  raise RuntimeError("scatter_list is a required argument in "
1136  "scatter source")
1137  input_tensors = [scatter_list]
1138  output_tensors = [tensor]
1139  else:
1140  if scatter_list:
1141  raise RuntimeError("non-empty can be given only to scatter "
1142  "source")
1143  input_tensors = []
1144  output_tensors = [tensor]
1145 
1146  opts = ScatterOptions()
1147  opts.rootRank = src
1148 
1149  if group == GroupMember.WORLD:
1150  _check_default_pg()
1151  work = _default_pg.scatter(output_tensors, input_tensors, opts)
1152  else:
1153  group_src_rank = _get_group_rank(group, src)
1154  opts.rootRank = group_src_rank
1155  work = group.scatter(output_tensors, input_tensors, opts)
1156 
1157  if async_op:
1158  return work
1159  else:
1160  work.wait()
1161 
1162 
1163 def barrier(group=group.WORLD,
1164  async_op=False):
1165  """
1166  Synchronizes all processes.
1167 
1168  This collective blocks processes until the whole group enters this function,
1169  if async_op is False, or if async work handle is called on wait().
1170 
1171  Arguments:
1172  group (ProcessGroup, optional): The process group to work on
1173  async_op (bool, optional): Whether this op should be an async op
1174 
1175  Returns:
1176  Async work handle, if async_op is set to True.
1177  None, if not async_op or if not part of the group
1178  """
1179  if _rank_not_in_group(group):
1180  return
1181 
1182  if group == GroupMember.WORLD:
1183  _check_default_pg()
1184  work = _default_pg.barrier()
1185  else:
1186  work = group.barrier()
1187 
1188  if async_op:
1189  return work
1190  else:
1191  work.wait()
1192 
1193 
1194 def new_group(ranks=None, timeout=_default_pg_timeout):
1195  """
1196  Creates a new distributed group.
1197 
1198  This function requires that all processes in the main group (i.e. all
1199  processes that are part of the distributed job) enter this function, even
1200  if they are not going to be members of the group. Additionally, groups
1201  should be created in the same order in all processes.
1202 
1203  Arguments:
1204  ranks (list[int]): List of ranks of group members.
1205  timeout (timedelta, optional): Timeout for operations executed against
1206  the process group. Default value equals 30 minutes.
1207  This is only applicable for the ``gloo`` backend.
1208 
1209  Returns:
1210  A handle of distributed group that can be given to collective calls.
1211  """
1212 
1213  _check_default_pg()
1214 
1215  global _pg_group_ranks
1216  global _group_count
1217  global _pg_names
1218 
1219  group_name = str(_group_count)
1220  _group_count += 1
1221 
1222  if group_name in _pg_names.values():
1223  raise RuntimeError("The specified group name has already been "
1224  "created, please use a different group name")
1225 
1226  default_backend, _ = _pg_map[_default_pg]
1227  global_rank = _default_pg.rank()
1228  global_world_size = _default_pg.size()
1229 
1230  # checks the input ranks
1231  if ranks is not None:
1232  input_ranks = list(ranks)
1233  group_world_size = len(ranks)
1234  if group_world_size > global_world_size:
1235  raise RuntimeError("the new group's world size should be less or "
1236  "equal to the world size set by "
1237  "init_process_group")
1238  # check ranks' sanity
1239  for rank in ranks:
1240  if rank < 0 or rank >= global_world_size:
1241  raise RuntimeError("The new group's rank should be within the "
1242  "the world_size set by init_process_group")
1243  if global_rank in ranks:
1244  group_rank = ranks.index(global_rank)
1245  else:
1246  group_rank = None
1247  else:
1248  input_ranks = []
1249  ranks = list(range(global_world_size))
1250  group_world_size = global_world_size
1251  group_rank = global_rank
1252 
1253  if default_backend == Backend.MPI:
1254  in_group = global_rank in ranks
1255  pg = _new_process_group_helper(group_world_size,
1256  group_rank,
1257  input_ranks,
1258  in_group,
1259  group_name,
1260  timeout=timeout)
1261  else:
1262  # Release ranks not in the group
1263  if global_rank not in ranks:
1264  return GroupMember.NON_GROUP_MEMBER
1265 
1266  if default_backend != Backend.MPI:
1267  pg = _new_process_group_helper(group_world_size,
1268  group_rank,
1269  input_ranks,
1270  True,
1271  group_name,
1272  timeout=timeout)
1273 
1274  # Create the global rank to group rank mapping
1275  _pg_group_ranks[pg] = {}
1276  if default_backend == Backend.MPI:
1277  _pg_group_ranks[pg] = pg.group_ranks()
1278  else:
1279  for rank in range(global_world_size):
1280  if rank in ranks:
1281  _pg_group_ranks[pg][rank] = ranks.index(rank)
1282  return pg