Caffe2 - Python API
A deep learning, cross platform ML framework
rendezvous.py
1 try:
2  from urllib.parse import urlparse
3 except ImportError:
4  from urlparse import urlparse
5 
6 import os
7 from . import FileStore, TCPStore
8 
9 
10 _rendezvous_handlers = {}
11 
12 
13 def register_rendezvous_handler(scheme, handler):
14  """Registers a new rendezvous handler.
15 
16  Before we can run collective algorithms, participating processes
17  need to find each other and exchange information to be able to
18  communicate. We call this process rendezvous.
19 
20  The outcome of the rendezvous process is a triplet containing a
21  shared key/value store, the rank of the process, and the total
22  number of participating processes.
23 
24  If none of the bundled rendezvous methods apply to your execution
25  environment you can opt to register your own rendezvous handler.
26  Pick a unique name and use the URL scheme to identify it when
27  calling the `rendezvous()` function.
28 
29  Arguments:
30  scheme (str): URL scheme to identify your rendezvous handler.
31  handler (function): Handler that is invoked when the
32  `rendezvous()` function is called with a URL that uses
33  the corresponding scheme. It must be a generator function
34  that yields the triplet.
35  """
36  global _rendezvous_handlers
37  if scheme in _rendezvous_handlers:
38  raise RuntimeError(
39  "Rendezvous handler for {}:// already registered".format(scheme)
40  )
41  _rendezvous_handlers[scheme] = handler
42 
43 
44 def rendezvous(url, **kwargs):
45  global _rendezvous_handlers
46  result = urlparse(url)
47  if result.scheme not in _rendezvous_handlers:
48  raise RuntimeError("No rendezvous handler for {}://".format(result.scheme))
49  return _rendezvous_handlers[result.scheme](url, **kwargs)
50 
51 
52 def _rendezvous_error(msg):
53  return ValueError("Error initializing torch.distributed using " + msg)
54 
55 
56 def _file_rendezvous_handler(url):
57  def _error(msg):
58  return _rendezvous_error("file:// rendezvous: " + msg)
59 
60  result = urlparse(url)
61  path = result.path
62  if not path:
63  raise _error("path missing")
64  query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
65  if "rank" not in query:
66  raise _error("rank parameter missing")
67  if "world_size" not in query:
68  raise _error("world size parameter missing")
69 
70  rank = int(query["rank"])
71  world_size = int(query["world_size"])
72  store = FileStore(path, world_size)
73  yield (store, rank, world_size)
74 
75  # If this configuration is invalidated, there is nothing we can do about it
76  raise RuntimeError("Unable to perform rerendezvous using file:// method")
77 
78 
79 def _tcp_rendezvous_handler(url):
80  def _error(msg):
81  return _rendezvous_error("tcp:// rendezvous: " + msg)
82 
83  result = urlparse(url)
84  if not result.port:
85  raise _error("port number missing")
86  query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
87  if "rank" not in query:
88  raise _error("rank parameter missing")
89  if "world_size" not in query:
90  raise _error("world size parameter missing")
91 
92  rank = int(query["rank"])
93  world_size = int(query["world_size"])
94  start_daemon = rank == 0
95  store = TCPStore(result.hostname, result.port, world_size, start_daemon)
96  yield (store, rank, world_size)
97 
98  # If this configuration is invalidated, there is nothing we can do about it
99  raise RuntimeError("Unable to perform rerendezvous using tcp:// method")
100 
101 
102 def _env_rendezvous_handler(url):
103  def _error(msg):
104  return _rendezvous_error("env:// rendezvous: " + msg)
105 
106  def _env_error(var):
107  return _error("environment variable %s expected, but not set" % var)
108 
109  if not url.startswith("env://"):
110  raise _error("url must be equal to `env://`")
111  result = urlparse(url)
112  query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
113 
114  if "rank" in query:
115  rank = int(query["rank"])
116  else:
117  rank = os.environ.get("RANK", None)
118  if rank is None:
119  raise _env_error("RANK")
120 
121  if "world_size" in query:
122  world_size = int(query["world_size"])
123  else:
124  world_size = os.environ.get("WORLD_SIZE", None)
125  if world_size is None:
126  raise _env_error("WORLD_SIZE")
127 
128  master_addr = os.environ.get("MASTER_ADDR", None)
129  if master_addr is None:
130  raise _env_error("MASTER_ADDR")
131 
132  master_port = os.environ.get("MASTER_PORT", None)
133  if master_port is None:
134  raise _env_error("MASTER_PORT")
135 
136  # Converting before creating the store
137  rank = int(rank)
138  world_size = int(world_size)
139  master_port = int(master_port)
140 
141  # Now start the TCP store daemon on the rank 0
142  start_daemon = rank == 0
143  store = TCPStore(master_addr, master_port, world_size, start_daemon)
144  yield (store, rank, world_size)
145 
146  # If this configuration is invalidated, there is nothing we can do about it
147  raise RuntimeError("Unable to perform rerendezvous using env:// method")
148 
149 
150 register_rendezvous_handler("file", _file_rendezvous_handler)
151 register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
152 register_rendezvous_handler("env", _env_rendezvous_handler)