Caffe2 - Python API
A deep learning, cross platform ML framework
spawn.py
1 from __future__ import absolute_import, division, print_function, unicode_literals
2 
3 import multiprocessing
4 import multiprocessing.connection
5 import signal
6 import sys
7 
8 from . import _prctl_pr_set_pdeathsig
9 
10 
11 def _wrap(fn, i, args, error_queue):
12  # prctl(2) is a Linux specific system call.
13  # On other systems the following function call has no effect.
14  # This is set to ensure that non-daemonic child processes can
15  # terminate if their parent terminates before they do.
16  _prctl_pr_set_pdeathsig(signal.SIGINT)
17 
18  try:
19  fn(i, *args)
20  except KeyboardInterrupt:
21  pass # SIGINT; Killed by parent, do nothing
22  except Exception:
23  # Propagate exception to parent process, keeping original traceback
24  import traceback
25  error_queue.put(traceback.format_exc())
26  sys.exit(1)
27 
28 
29 def _python_version_check():
30  if sys.version_info < (3, 4):
31  raise RuntimeError("Requires python 3.4 or higher to use "
32  "torch.multiprocessing.spawn and "
33  "torch.multiprocessing.SpawnContext helper "
34  "to launch multiple processes. If you are using "
35  "this for distributed training and have a lower "
36  "version of python, please use "
37  "torch.distributed.launch instead.")
38 
39 
41  def __init__(self, processes, error_queues):
42  _python_version_check()
43  self.error_queues = error_queues
44  self.processes = processes
45  self.sentinels = {
46  process.sentinel: index
47  for index, process in enumerate(processes)
48  }
49 
50  def pids(self):
51  return [int(process.pid) for process in self.processes]
52 
53  def join(self, timeout=None):
54  r"""
55  Tries to join one or more processes in this spawn context.
56  If one of them exited with a non-zero exit status, this function
57  kills the remaining processes and raises an exception with the cause
58  of the first process exiting.
59 
60  Returns ``True`` if all processes have been joined successfully,
61  ``False`` if there are more processes that need to be joined.
62 
63  Arguments:
64  timeout (float): Wait this long before giving up on waiting.
65  """
66  # Ensure this function can be called even when we're done.
67  if len(self.sentinels) == 0:
68  return True
69 
70  # Wait for any process to fail or all of them to succeed.
71  ready = multiprocessing.connection.wait(
72  self.sentinels.keys(),
73  timeout=timeout,
74  )
75 
76  error_index = None
77  for sentinel in ready:
78  index = self.sentinels.pop(sentinel)
79  process = self.processes[index]
80  process.join()
81  if process.exitcode != 0:
82  error_index = index
83  break
84 
85  # Return if there was no error.
86  if error_index is None:
87  # Return whether or not all processes have been joined.
88  return len(self.sentinels) == 0
89 
90  # Assume failure. Terminate processes that are still alive.
91  for process in self.processes:
92  if process.is_alive():
93  process.terminate()
94  process.join()
95 
96  # There won't be an error on the queue if the process crashed.
97  if self.error_queues[error_index].empty():
98  exitcode = self.processes[error_index].exitcode
99  if exitcode < 0:
100  name = signal.Signals(-exitcode).name
101  raise Exception(
102  "process %d terminated with signal %s" %
103  (error_index, name)
104  )
105  else:
106  raise Exception(
107  "process %d terminated with exit code %d" %
108  (error_index, exitcode)
109  )
110 
111  original_trace = self.error_queues[error_index].get()
112  msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
113  msg += original_trace
114  raise Exception(msg)
115 
116 
117 def spawn(fn, args=(), nprocs=1, join=True, daemon=False):
118  r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
119 
120  If one of the processes exits with a non-zero exit status, the
121  remaining processes are killed and an exception is raised with the
122  cause of termination. In the case an exception was caught in the
123  child process, it is forwarded and its traceback is included in
124  the exception raised in the parent process.
125 
126  Arguments:
127  fn (function): Function is called as the entrypoint of the
128  spawned process. This function must be defined at the top
129  level of a module so it can be pickled and spawned. This
130  is a requirement imposed by multiprocessing.
131 
132  The function is called as ``fn(i, *args)``, where ``i`` is
133  the process index and ``args`` is the passed through tuple
134  of arguments.
135 
136  args (tuple): Arguments passed to ``fn``.
137  nprocs (int): Number of processes to spawn.
138  join (bool): Perform a blocking join on all processes.
139  daemon (bool): The spawned processes' daemon flag. If set to True,
140  daemonic processes will be created.
141 
142  Returns:
143  None if ``join`` is ``True``,
144  :class:`~SpawnContext` if ``join`` is ``False``
145 
146  """
147  _python_version_check()
148  mp = multiprocessing.get_context('spawn')
149  error_queues = []
150  processes = []
151  for i in range(nprocs):
152  error_queue = mp.SimpleQueue()
153  process = mp.Process(
154  target=_wrap,
155  args=(fn, i, args, error_queue),
156  daemon=daemon,
157  )
158  process.start()
159  error_queues.append(error_queue)
160  processes.append(process)
161 
162  spawn_context = SpawnContext(processes, error_queues)
163  if not join:
164  return spawn_context
165 
166  # Loop on join until it returns True or raises an exception.
167  while not spawn_context.join():
168  pass