1 from __future__ 
import absolute_import, division, print_function, unicode_literals
     4 import multiprocessing.connection
     8 from . 
import _prctl_pr_set_pdeathsig
    11 def _wrap(fn, i, args, error_queue):
    16     _prctl_pr_set_pdeathsig(signal.SIGINT)
    20     except KeyboardInterrupt:
    25         error_queue.put(traceback.format_exc())
    29 def _python_version_check():
    30     if sys.version_info < (3, 4):
    31         raise RuntimeError(
"Requires python 3.4 or higher to use "    32                            "torch.multiprocessing.spawn and "    33                            "torch.multiprocessing.SpawnContext helper "    34                            "to launch multiple processes. If you are using "    35                            "this for distributed training and have a lower "    36                            "version of python, please use "    37                            "torch.distributed.launch instead.")
    41     def __init__(self, processes, error_queues):
    42         _python_version_check()
    46             process.sentinel: index
    47             for index, process 
in enumerate(processes)
    51         return [int(process.pid) 
for process 
in self.
processes]
    53     def join(self, timeout=None):
    55         Tries to join one or more processes in this spawn context.    56         If one of them exited with a non-zero exit status, this function    57         kills the remaining processes and raises an exception with the cause    58         of the first process exiting.    60         Returns ``True`` if all processes have been joined successfully,    61         ``False`` if there are more processes that need to be joined.    64             timeout (float): Wait this long before giving up on waiting.    71         ready = multiprocessing.connection.wait(
    72             self.sentinels.keys(),
    77         for sentinel 
in ready:
    78             index = self.sentinels.pop(sentinel)
    81             if process.exitcode != 0:
    86         if error_index 
is None:
    92             if process.is_alive():
    98             exitcode = self.
processes[error_index].exitcode
   100                 name = signal.Signals(-exitcode).name
   102                     "process %d terminated with signal %s" %
   107                     "process %d terminated with exit code %d" %
   108                     (error_index, exitcode)
   112         msg = 
"\n\n-- Process %d terminated with the following error:\n" % error_index
   113         msg += original_trace
   117 def spawn(fn, args=(), nprocs=1, join=
True, daemon=
False):
   118     r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.   120     If one of the processes exits with a non-zero exit status, the   121     remaining processes are killed and an exception is raised with the   122     cause of termination. In the case an exception was caught in the   123     child process, it is forwarded and its traceback is included in   124     the exception raised in the parent process.   127         fn (function): Function is called as the entrypoint of the   128             spawned process. This function must be defined at the top   129             level of a module so it can be pickled and spawned. This   130             is a requirement imposed by multiprocessing.   132             The function is called as ``fn(i, *args)``, where ``i`` is   133             the process index and ``args`` is the passed through tuple   136         args (tuple): Arguments passed to ``fn``.   137         nprocs (int): Number of processes to spawn.   138         join (bool): Perform a blocking join on all processes.   139         daemon (bool): The spawned processes' daemon flag. If set to True,   140                        daemonic processes will be created.   143         None if ``join`` is ``True``,   144         :class:`~SpawnContext` if ``join`` is ``False``   147     _python_version_check()
   148     mp = multiprocessing.get_context(
'spawn')
   151     for i 
in range(nprocs):
   152         error_queue = mp.SimpleQueue()
   153         process = mp.Process(
   155             args=(fn, i, args, error_queue),
   159         error_queues.append(error_queue)
   160         processes.append(process)
   167     while not spawn_context.join():