3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
10 This module provides a python-land multithreaded mechanism for executing work. 12 Basic usage is as follows: 13 coordinator = parallel_workers.init_workers( 20 First argument is the function to run in a loop on potentially multiple threads. 21 It has the call signature 24 Argument 'worker_name' is used to distinguish different workers, 25 such as workers processing train data or workers processing test data. 27 Optionally, one can define an "init function" that is called once before 28 threads start, and has call signature: 29 my_init_fun(worker_coordinator, global_coordinator) 31 Note that for data_parallel_models, init_workers will be called 32 for each GPU. Note that the 'coordinator' returned by the function is same 44 from abc
import ABCMeta, abstractmethod
46 log = logging.getLogger(
"parallel_workers")
47 log.setLevel(logging.INFO)
56 external_loggers=
None,
59 global global_coordinator
61 metrics = Metrics(external_loggers)
64 global_coordinator.get_new_worker_id()
65 for i
in range(num_worker_threads)
69 coordinator = WorkerCoordinator(
70 worker_name, worker_ids, init_fun, shutdown_fun=shutdown_fun)
76 name=
"parallel_workers worker id {}".format(worker_id),
78 Worker(coordinator, worker_id, worker_fun, metrics)],
79 )
for worker_id
in worker_ids
82 coordinator._workers = workers
83 global_coordinator.add(coordinator)
85 return global_coordinator
89 def __init__(self, external_loggers):
90 self.
_metrics = collections.defaultdict(
lambda: 0)
93 def reset_metrics(self):
94 self.
_metrics = collections.defaultdict(
lambda: 0)
96 def log_metrics(self):
102 except Exception
as e:
103 print(
"Failed to call ExternalLogger: {}".format(e))
105 def put_metric(self, key, value, count=True):
108 count_key =
'{}_count'.format(key)
113 six.add_metaclass(ABCMeta)
128 class WorkerCoordinator(object):
130 self, worker_name, worker_ids, init_fun,
131 state=
None, shutdown_fun=
None 134 self._started =
False 136 self._worker_name = worker_name
137 self._worker_ids = worker_ids
138 self._init_fun = init_fun
140 self._shutdown_fun = shutdown_fun
145 def init(self, global_coordinator):
146 if self._init_fun
and not self._started:
147 data_coordinator = self
148 self._init_fun(data_coordinator, global_coordinator)
158 for w
in self._workers:
162 def _stop(self, reason=None):
164 if reason
is not None:
165 log.error(
"Data input failed due to an error: {}".format(reason))
166 if self._shutdown_fun
and self._started:
171 self._started =
False 173 def _wait_finish(self, cleanup=None):
174 print(
"Wait for workers to die: {}".format(self._worker_name))
175 for w
in self._workers:
176 if w != threading.current_thread():
179 for w
in self._workers:
181 print(
"Worker {} failed to close while waiting".format(w))
185 if success
and self._state:
186 self._state.cleanup()
188 print(
"All workers terminated: {}".format(success))
191 def get_worker_ids(self):
192 return self._worker_ids
202 def add(self, coordinator):
203 self._coordinators.append(coordinator)
205 def get_new_worker_id(self):
207 self._worker_ids.append(worker_id)
211 def get_worker_ids(self):
227 success = c._wait_finish()
228 all_success = all_success
and success
234 Stop a specific coordinator 237 if c._worker_name == worker_name:
242 if c._worker_name != worker_name
245 def register_shutdown_handler(self):
249 atexit.register(cleanup)
271 def handle_exception(self, e):
272 traceback.print_exc()
273 logging.exception(
"Exception in worker", e)
274 self._coordinator._stop(
"Exception in worker {}: {}".format(
279 self._metrics.put_metric(
281 self._metrics.log_metrics()
287 def run_worker(coordinator, worker):
288 while coordinator.is_active():
292 except Exception
as e:
293 worker.handle_exception(e)
def register_shutdown_handler(self)
def stop_coordinator(self, worker_name)