Caffe2 - Python API
A deep learning, cross platform ML framework
test_multiprocessing_spawn.py
1 from __future__ import absolute_import, division, print_function, unicode_literals
2 
3 import os
4 import random
5 import signal
6 import sys
7 import time
8 import unittest
9 
10 from common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN)
11 import torch.multiprocessing as mp
12 
13 
14 def test_success_func(i):
15  pass
16 
17 
18 def test_success_single_arg_func(i, arg):
19  if arg:
20  arg.put(i)
21 
22 
23 def test_exception_single_func(i, arg):
24  if i == arg:
25  raise ValueError("legitimate exception from process %d" % i)
26  time.sleep(1.0)
27 
28 
29 def test_exception_all_func(i):
30  time.sleep(random.random() / 10)
31  raise ValueError("legitimate exception from process %d" % i)
32 
33 
34 def test_terminate_signal_func(i):
35  if i == 0:
36  os.kill(os.getpid(), signal.SIGABRT)
37  time.sleep(1.0)
38 
39 
40 def test_terminate_exit_func(i, arg):
41  if i == 0:
42  sys.exit(arg)
43  time.sleep(1.0)
44 
45 
46 def test_success_first_then_exception_func(i, arg):
47  if i == 0:
48  return
49  time.sleep(0.1)
50  raise ValueError("legitimate exception")
51 
52 
53 def test_nested_child_body(i, ready_queue, nested_child_sleep):
54  ready_queue.put(None)
55  time.sleep(nested_child_sleep)
56 
57 
58 def test_nested_spawn(i, pids_queue, nested_child_sleep):
59  context = mp.get_context("spawn")
60  nested_child_ready_queue = context.Queue()
61  nprocs = 2
62  spawn_context = mp.spawn(
63  fn=test_nested_child_body,
64  args=(nested_child_ready_queue, nested_child_sleep),
65  nprocs=nprocs,
66  join=False,
67  daemon=False,
68  )
69  pids_queue.put(spawn_context.pids())
70 
71  # Wait for both children to have spawned, to ensure that they
72  # have called prctl(2) to register a parent death signal.
73  for _ in range(nprocs):
74  nested_child_ready_queue.get()
75 
76  # Kill self. This should take down the child processes as well.
77  os.kill(os.getpid(), signal.SIGTERM)
78 
79 
80 @unittest.skipIf(
81  NO_MULTIPROCESSING_SPAWN,
82  "Disabled for environments that don't support the spawn start method")
83 class SpawnTest(TestCase):
84  def test_success(self):
85  mp.spawn(test_success_func, nprocs=2)
86 
87  def test_success_non_blocking(self):
88  spawn_context = mp.spawn(test_success_func, nprocs=2, join=False)
89 
90  # After all processes (nproc=2) have joined it must return True
91  spawn_context.join(timeout=None)
92  spawn_context.join(timeout=None)
93  self.assertTrue(spawn_context.join(timeout=None))
94 
95  def test_first_argument_index(self):
96  context = mp.get_context("spawn")
97  queue = context.SimpleQueue()
98  mp.spawn(test_success_single_arg_func, args=(queue,), nprocs=2)
99  self.assertEqual([0, 1], sorted([queue.get(), queue.get()]))
100 
101  def test_exception_single(self):
102  nprocs = 2
103  for i in range(nprocs):
104  with self.assertRaisesRegex(
105  Exception,
106  "\nValueError: legitimate exception from process %d$" % i,
107  ):
108  mp.spawn(test_exception_single_func, args=(i,), nprocs=nprocs)
109 
110  def test_exception_all(self):
111  with self.assertRaisesRegex(
112  Exception,
113  "\nValueError: legitimate exception from process (0|1)$",
114  ):
115  mp.spawn(test_exception_all_func, nprocs=2)
116 
117  def test_terminate_signal(self):
118  # SIGABRT is aliased with SIGIOT
119  message = "process 0 terminated with signal (SIGABRT|SIGIOT)"
120 
121  # Termination through with signal is expressed as a negative exit code
122  # in multiprocessing, so we know it was a signal that caused the exit.
123  # This doesn't appear to exist on Windows, where the exit code is always
124  # positive, and therefore results in a different exception message.
125  # Exit code 22 means "ERROR_BAD_COMMAND".
126  if IS_WINDOWS:
127  message = "process 0 terminated with exit code 22"
128 
129  with self.assertRaisesRegex(Exception, message):
130  mp.spawn(test_terminate_signal_func, nprocs=2)
131 
132  def test_terminate_exit(self):
133  exitcode = 123
134  with self.assertRaisesRegex(
135  Exception,
136  "process 0 terminated with exit code %d" % exitcode,
137  ):
138  mp.spawn(test_terminate_exit_func, args=(exitcode,), nprocs=2)
139 
140  def test_success_first_then_exception(self):
141  exitcode = 123
142  with self.assertRaisesRegex(
143  Exception,
144  "ValueError: legitimate exception",
145  ):
146  mp.spawn(test_success_first_then_exception_func, args=(exitcode,), nprocs=2)
147 
148  @unittest.skipIf(
149  sys.platform != "linux",
150  "Only runs on Linux; requires prctl(2)",
151  )
152  def test_nested_spawn(self):
153  context = mp.get_context("spawn")
154  pids_queue = context.Queue()
155  nested_child_sleep = 20.0
156  spawn_context = mp.spawn(
157  fn=test_nested_spawn,
158  args=(pids_queue, nested_child_sleep),
159  nprocs=1,
160  join=False,
161  daemon=False,
162  )
163 
164  # Wait for nested children to terminate in time
165  pids = pids_queue.get()
166  start = time.time()
167  while len(pids) > 0:
168  for pid in pids:
169  try:
170  os.kill(pid, 0)
171  except ProcessLookupError:
172  pids.remove(pid)
173  break
174 
175  # This assert fails if any nested child process is still
176  # alive after (nested_child_sleep / 2) seconds. By
177  # extension, this test times out with an assertion error
178  # after (nested_child_sleep / 2) seconds.
179  self.assertLess(time.time() - start, nested_child_sleep / 2)
180  time.sleep(0.1)
181 
182 
183 if __name__ == '__main__':
184  run_tests()