1 from __future__
import absolute_import, division, print_function, unicode_literals
4 import multiprocessing.connection
8 from .
import _prctl_pr_set_pdeathsig
11 def _wrap(fn, i, args, error_queue):
16 _prctl_pr_set_pdeathsig(signal.SIGINT)
20 except KeyboardInterrupt:
25 error_queue.put(traceback.format_exc())
29 def _python_version_check():
30 if sys.version_info < (3, 4):
31 raise RuntimeError(
"Requires python 3.4 or higher to use " 32 "torch.multiprocessing.spawn and " 33 "torch.multiprocessing.SpawnContext helper " 34 "to launch multiple processes. If you are using " 35 "this for distributed training and have a lower " 36 "version of python, please use " 37 "torch.distributed.launch instead.")
41 def __init__(self, processes, error_queues):
42 _python_version_check()
46 process.sentinel: index
47 for index, process
in enumerate(processes)
51 return [int(process.pid)
for process
in self.
processes]
53 def join(self, timeout=None):
55 Tries to join one or more processes in this spawn context. 56 If one of them exited with a non-zero exit status, this function 57 kills the remaining processes and raises an exception with the cause 58 of the first process exiting. 60 Returns ``True`` if all processes have been joined successfully, 61 ``False`` if there are more processes that need to be joined. 64 timeout (float): Wait this long before giving up on waiting. 71 ready = multiprocessing.connection.wait(
72 self.sentinels.keys(),
77 for sentinel
in ready:
78 index = self.sentinels.pop(sentinel)
81 if process.exitcode != 0:
86 if error_index
is None:
92 if process.is_alive():
98 exitcode = self.
processes[error_index].exitcode
100 name = signal.Signals(-exitcode).name
102 "process %d terminated with signal %s" %
107 "process %d terminated with exit code %d" %
108 (error_index, exitcode)
112 msg =
"\n\n-- Process %d terminated with the following error:\n" % error_index
113 msg += original_trace
117 def spawn(fn, args=(), nprocs=1, join=
True, daemon=
False):
118 r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``. 120 If one of the processes exits with a non-zero exit status, the 121 remaining processes are killed and an exception is raised with the 122 cause of termination. In the case an exception was caught in the 123 child process, it is forwarded and its traceback is included in 124 the exception raised in the parent process. 127 fn (function): Function is called as the entrypoint of the 128 spawned process. This function must be defined at the top 129 level of a module so it can be pickled and spawned. This 130 is a requirement imposed by multiprocessing. 132 The function is called as ``fn(i, *args)``, where ``i`` is 133 the process index and ``args`` is the passed through tuple 136 args (tuple): Arguments passed to ``fn``. 137 nprocs (int): Number of processes to spawn. 138 join (bool): Perform a blocking join on all processes. 139 daemon (bool): The spawned processes' daemon flag. If set to True, 140 daemonic processes will be created. 143 None if ``join`` is ``True``, 144 :class:`~SpawnContext` if ``join`` is ``False`` 147 _python_version_check()
148 mp = multiprocessing.get_context(
'spawn')
151 for i
in range(nprocs):
152 error_queue = mp.SimpleQueue()
153 process = mp.Process(
155 args=(fn, i, args, error_queue),
159 error_queues.append(error_queue)
160 processes.append(process)
167 while not spawn_context.join():