1 r""""Signal handling for multiprocessing data loading. 3 NOTE [ Signal handling in multiprocessing data loading ] 5 In cases like DataLoader, if a worker process dies due to bus error/segfault 6 or just hang, the main process will hang waiting for data. This is difficult 7 to avoid on PyTorch side as it can be caused by limited shm, or other 8 libraries users call in the workers. In this file and `DataLoader.cpp`, we make 9 our best effort to provide some error message to users when such unfortunate 12 When a _DataLoaderIter starts worker processes, their pids are registered in a 13 defined in `DataLoader.cpp`: id(_DataLoaderIter) => Collection[ Worker pids ] 14 via `_set_worker_pids`. 16 When an error happens in a worker process, the main process received a SIGCHLD, 17 and Python will eventually call the handler registered below 18 (in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails` 19 call checks all registered worker pids and raise proper error message to 20 prevent main process from hanging waiting for data from worker. 22 Additionally, at the beginning of each worker's `_utils.worker._worker_loop`, 23 `_set_worker_signal_handlers` is called to register critical signal handlers 24 (e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error 25 message to stderr before triggering the default handler. So a message will also 26 be printed from the worker process when it is killed by such signals. 28 See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of 29 this signal handling design and other mechanism we implement to make our 30 multiprocessing data loading robust to errors. 36 from torch._C import _set_worker_pids, _remove_worker_pids, \
37 _error_if_any_worker_fails, _set_worker_signal_handlers
38 from .
import IS_WINDOWS
41 _SIGCHLD_handler_set =
False 42 r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one 43 handler needs to be set for all DataLoaders in a process.""" 46 def _set_SIGCHLD_handler():
51 if not isinstance(threading.current_thread(), threading._MainThread):
53 global _SIGCHLD_handler_set
54 if _SIGCHLD_handler_set:
56 previous_handler = signal.getsignal(signal.SIGCHLD)
57 if not callable(previous_handler):
60 previous_handler =
None 62 def handler(signum, frame):
65 _error_if_any_worker_fails()
66 if previous_handler
is not None:
67 previous_handler(signum, frame)
69 signal.signal(signal.SIGCHLD, handler)
70 _SIGCHLD_handler_set =
True