Caffe2 - Python API
A deep learning, cross platform ML framework
parallel_workers.py
1 # @package parallel_workers
2 # Module caffe2.python.parallel_workers
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 
9 '''
10 This module provides a python-land multithreaded mechanism for executing work.
11 
12 Basic usage is as follows:
13  coordinator = parallel_workers.init_workers(
14  my_worker_fun,
15  worker_name="train"
16  )
17  ...
18  coordinator.start()
19 
20 First argument is the function to run in a loop on potentially multiple threads.
21 It has the call signature
22  worker_fun(worker_id)
23 
24 Argument 'worker_name' is used to distinguish different workers,
25 such as workers processing train data or workers processing test data.
26 
27 Optionally, one can define an "init function" that is called once before
28 threads start, and has call signature:
29  my_init_fun(worker_coordinator, global_coordinator)
30 
31 Note that for data_parallel_models, init_workers will be called
32 for each GPU. Note that the 'coordinator' returned by the function is same
33 each time.
34 '''
35 
36 import logging
37 import threading
38 import atexit
39 import time
40 import collections
41 import six
42 import traceback
43 
44 from abc import ABCMeta, abstractmethod
45 
46 log = logging.getLogger("parallel_workers")
47 log.setLevel(logging.INFO)
48 LOG_INT_SECS = 60
49 
50 
51 def init_workers(
52  worker_fun,
53  num_worker_threads=2,
54  worker_name="train",
55  init_fun=None,
56  external_loggers=None,
57  shutdown_fun=None,
58 ):
59  global global_coordinator
60 
61  metrics = Metrics(external_loggers)
62 
63  worker_ids = [
64  global_coordinator.get_new_worker_id()
65  for i in range(num_worker_threads)
66  ]
67 
68  # Create coordinator object
69  coordinator = WorkerCoordinator(
70  worker_name, worker_ids, init_fun, shutdown_fun=shutdown_fun)
71 
72  # Launch fetch worker threads
73  workers = [
74  threading.Thread(
75  target=run_worker,
76  name="parallel_workers worker id {}".format(worker_id),
77  args=[coordinator,
78  Worker(coordinator, worker_id, worker_fun, metrics)],
79  ) for worker_id in worker_ids
80  ]
81 
82  coordinator._workers = workers
83  global_coordinator.add(coordinator)
84 
85  return global_coordinator
86 
87 
88 class Metrics(object):
89  def __init__(self, external_loggers):
90  self._metrics = collections.defaultdict(lambda: 0)
91  self._external_loggers = external_loggers
92 
93  def reset_metrics(self):
94  self._metrics = collections.defaultdict(lambda: 0)
95 
96  def log_metrics(self):
97  if not self._external_loggers:
98  return
99  for logger in self._external_loggers:
100  try:
101  logger.log(self._metrics)
102  except Exception as e:
103  print("Failed to call ExternalLogger: {}".format(e))
104 
105  def put_metric(self, key, value, count=True):
106  self._metrics[key] += value
107  if count:
108  count_key = '{}_count'.format(key)
109  self._metrics[count_key] += 1
110 
111 
112 class State():
113  six.add_metaclass(ABCMeta)
114 
115  @abstractmethod
116  def start(self):
117  pass
118 
119  @abstractmethod
120  def stop(self):
121  pass
122 
123  @abstractmethod
124  def cleanup(self):
125  pass
126 
127 
128 class WorkerCoordinator(object):
129  def __init__(
130  self, worker_name, worker_ids, init_fun,
131  state=None, shutdown_fun=None
132  ):
133  self._active = True
134  self._started = False
135  self._workers = []
136  self._worker_name = worker_name
137  self._worker_ids = worker_ids
138  self._init_fun = init_fun
139  self._state = state
140  self._shutdown_fun = shutdown_fun
141 
142  def is_active(self):
143  return self._active
144 
145  def init(self, global_coordinator):
146  if self._init_fun and not self._started:
147  data_coordinator = self
148  self._init_fun(data_coordinator, global_coordinator)
149 
150  def _start(self):
151  if self._started:
152  return
153  self._active = True
154  self._started = True
155  if self._state:
156  self._state.start()
157 
158  for w in self._workers:
159  w.daemon = True
160  w.start()
161 
162  def _stop(self, reason=None):
163  self._active = False
164  if reason is not None:
165  log.error("Data input failed due to an error: {}".format(reason))
166  if self._shutdown_fun and self._started:
167  self._shutdown_fun()
168  if self._state:
169  self._state.stop()
170 
171  self._started = False
172 
173  def _wait_finish(self, cleanup=None):
174  print("Wait for workers to die: {}".format(self._worker_name))
175  for w in self._workers:
176  if w != threading.current_thread():
177  w.join(5.0) # don't wait forever, thread may be blocked in i/o
178  success = True
179  for w in self._workers:
180  if w.isAlive():
181  print("Worker {} failed to close while waiting".format(w))
182  success = False
183 
184  # Release memory for the scratch blobs
185  if success and self._state:
186  self._state.cleanup()
187 
188  print("All workers terminated: {}".format(success))
189  return success
190 
191  def get_worker_ids(self):
192  return self._worker_ids
193 
194 
196  def __init__(self):
197  self._coordinators = []
198  self._fetcher_id_seq = 0
199  self._worker_ids = []
201 
202  def add(self, coordinator):
203  self._coordinators.append(coordinator)
204 
205  def get_new_worker_id(self):
206  worker_id = self._fetcher_id_seq
207  self._worker_ids.append(worker_id)
208  self._fetcher_id_seq += 1
209  return worker_id
210 
211  def get_worker_ids(self):
212  return self._worker_ids
213 
214  def start(self):
215  # run init and start in separate for loop to
216  # ensure init happens serially before threads are spawn.
217  for c in self._coordinators:
218  c.init(self)
219  for c in self._coordinators:
220  c._start()
221 
222  def stop(self):
223  all_success = True
224  for c in self._coordinators:
225  c._stop()
226  for c in self._coordinators:
227  success = c._wait_finish()
228  all_success = all_success and success
229  self._coordinators = []
230  return all_success
231 
232  def stop_coordinator(self, worker_name):
233  '''
234  Stop a specific coordinator
235  '''
236  for c in self._coordinators:
237  if c._worker_name == worker_name:
238  c._stop()
239  c._wait_finish()
240  self._coordinators = [
241  c for c in self._coordinators
242  if c._worker_name != worker_name
243  ]
244 
245  def register_shutdown_handler(self):
246  def cleanup():
247  self.stop()
248 
249  atexit.register(cleanup)
250 
251 
252 class Worker(object):
253  def __init__(
254  self,
255  coordinator,
256  worker_id,
257  worker_fun=None,
258  metrics=None
259  ):
260  self._coordinator = coordinator
261  self._worker_id = worker_id
262  self._worker_fun = worker_fun
263  self._metrics = metrics
264 
265  def start(self):
266  self._start_time = time.time()
267 
268  def run(self):
269  self._worker_fun(self._worker_id)
270 
271  def handle_exception(self, e):
272  traceback.print_exc()
273  logging.exception("Exception in worker", e)
274  self._coordinator._stop("Exception in worker {}: {}".format(
275  self._worker_id, e
276  ))
277 
278  def finish(self):
279  self._metrics.put_metric(
280  'worker_time', time.time() - self._start_time)
281  self._metrics.log_metrics()
282 
283 
284 global_coordinator = GlobalWorkerCoordinator()
285 
286 
287 def run_worker(coordinator, worker):
288  while coordinator.is_active():
289  worker.start()
290  try:
291  worker.run()
292  except Exception as e:
293  worker.handle_exception(e)
294  finally:
295  worker.finish()