Caffe2 - Python API
A deep learning, cross platform ML framework
parallel_workers.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 # @package parallel_workers
17 # Module caffe2.python.parallel_workers
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 
24 '''
25 This module provides a python-land multithreaded mechanism for executing work.
26 
27 Basic usage is as follows:
28  coordinator = parallel_workers.init_workers(
29  my_worker_fun,
30  worker_name="train"
31  )
32  ...
33  coordinator.start()
34 
35 First argument is the function to run in a loop on potentially multiple threads.
36 It has the call signature
37  worker_fun(worker_id)
38 
39 Argument 'worker_name' is used to distinguish different workers,
40 such as workers processing train data or workers processing test data.
41 
42 Optionally, one can define an "init function" that is called once before
43 threads start, and has call signature:
44  my_init_fun(worker_coordinator, global_coordinator)
45 
46 Note that for data_parallel_models, init_workers will be called
47 for each GPU. Note that the 'coordinator' returned by the function is same
48 each time.
49 '''
50 
51 import logging
52 import threading
53 import atexit
54 import time
55 import collections
56 import six
57 import traceback
58 
59 from abc import ABCMeta, abstractmethod
60 
61 log = logging.getLogger("parallel_workers")
62 log.setLevel(logging.INFO)
63 LOG_INT_SECS = 60
64 
65 
66 def init_workers(
67  worker_fun,
68  num_worker_threads=2,
69  worker_name="train",
70  init_fun=None,
71  external_loggers=None,
72  shutdown_fun=None,
73 ):
74  global global_coordinator
75 
76  metrics = Metrics(external_loggers)
77 
78  # Create coordinator object
79  coordinator = WorkerCoordinator(
80  worker_name, init_fun, shutdown_fun=shutdown_fun)
81 
82  # Launch fetch worker threads
83  worker_ids = [
84  global_coordinator.get_new_worker_id()
85  for i in range(num_worker_threads)
86  ]
87  workers = [
88  threading.Thread(
89  target=run_worker,
90  name="parallel_workers worker id {}".format(worker_id),
91  args=[coordinator,
92  Worker(coordinator, worker_id, worker_fun, metrics)],
93  ) for worker_id in worker_ids
94  ]
95 
96  coordinator._workers = workers
97  global_coordinator.add(coordinator)
98 
99  return global_coordinator
100 
101 
102 class Metrics(object):
103  def __init__(self, external_loggers):
104  self._metrics = collections.defaultdict(lambda: 0)
105  self._external_loggers = external_loggers
106 
107  def reset_metrics(self):
108  self._metrics = collections.defaultdict(lambda: 0)
109 
110  def log_metrics(self):
111  if not self._external_loggers:
112  return
113  for logger in self._external_loggers:
114  try:
115  logger.log(self._metrics)
116  except Exception as e:
117  print("Failed to call ExternalLogger: {}".format(e))
118 
119  def put_metric(self, key, value, count=True):
120  self._metrics[key] += value
121  if count:
122  count_key = '{}_count'.format(key)
123  self._metrics[count_key] += 1
124 
125 
126 class State():
127  six.add_metaclass(ABCMeta)
128 
129  @abstractmethod
130  def start(self):
131  pass
132 
133  @abstractmethod
134  def stop(self):
135  pass
136 
137  @abstractmethod
138  def cleanup(self):
139  pass
140 
141 
142 class WorkerCoordinator(object):
143  def __init__(self, worker_name, init_fun, state=None, shutdown_fun=None):
144  self._active = True
145  self._started = False
146  self._workers = []
147  self._worker_name = worker_name
148  self._init_fun = init_fun
149  self._state = state
150  self._shutdown_fun = shutdown_fun
151 
152  def is_active(self):
153  return self._active
154 
155  def init(self, global_coordinator):
156  if self._init_fun and not self._started:
157  data_coordinator = self
158  self._init_fun(data_coordinator, global_coordinator)
159 
160  def _start(self):
161  if self._started:
162  return
163  self._active = True
164  self._started = True
165  if self._state:
166  self._state.start()
167 
168  for w in self._workers:
169  w.daemon = True
170  w.start()
171 
172  def _stop(self, reason=None):
173  self._active = False
174  if reason is not None:
175  log.error("Data input failed due to an error: {}".format(reason))
176  if self._shutdown_fun and self._started:
177  self._shutdown_fun()
178  if self._state:
179  self._state.stop()
180 
181  self._started = False
182 
183  def _wait_finish(self, cleanup=None):
184  print("Wait for workers to die: {}".format(self._worker_name))
185  for w in self._workers:
186  if w != threading.current_thread():
187  w.join(5.0) # don't wait forever, thread may be blocked in i/o
188  success = True
189  for w in self._workers:
190  if w.isAlive():
191  print("Worker {} failed to close while waiting".format(w))
192  success = False
193 
194  # Release memory for the scratch blobs
195  if success and self._state:
196  self._state.cleanup()
197 
198  print("All workers terminated: {}".format(success))
199  return success
200 
201 
203  def __init__(self):
204  self._coordinators = []
205  self._fetcher_id_seq = 0
206  self._worker_ids = []
208 
209  def add(self, coordinator):
210  self._coordinators.append(coordinator)
211 
212  def get_new_worker_id(self):
213  worker_id = self._fetcher_id_seq
214  self._worker_ids.append(worker_id)
215  self._fetcher_id_seq += 1
216  return worker_id
217 
218  def get_worker_ids(self):
219  return self._worker_ids
220 
221  def start(self):
222  # run init and start in separate for loop to
223  # ensure init happens serially before threads are spawn.
224  for c in self._coordinators:
225  c.init(self)
226  for c in self._coordinators:
227  c._start()
228 
229  def stop(self):
230  all_success = True
231  for c in self._coordinators:
232  c._stop()
233  for c in self._coordinators:
234  success = c._wait_finish()
235  all_success = all_success and success
236  self._coordinators = []
237  return all_success
238 
239  def stop_coordinator(self, worker_name):
240  '''
241  Stop a specific coordinator
242  '''
243  for c in self._coordinators:
244  if c._worker_name == worker_name:
245  c._stop()
246  c._wait_finish()
247  self._coordinators = [
248  c for c in self._coordinators
249  if c._worker_name != worker_name
250  ]
251 
252  def register_shutdown_handler(self):
253  def cleanup():
254  self.stop()
255 
256  atexit.register(cleanup)
257 
258 
259 class Worker(object):
260  def __init__(
261  self,
262  coordinator,
263  worker_id,
264  worker_fun=None,
265  metrics=None
266  ):
267  self._coordinator = coordinator
268  self._worker_id = worker_id
269  self._worker_fun = worker_fun
270  self._metrics = metrics
271 
272  def start(self):
273  self._start_time = time.time()
274 
275  def run(self):
276  self._worker_fun(self._worker_id)
277 
278  def handle_exception(self, e):
279  traceback.print_exc()
280  logging.exception("Exception in worker", e)
281  self._coordinator._stop("Exception in worker {}: {}".format(
282  self._worker_id, e
283  ))
284 
285  def finish(self):
286  self._metrics.put_metric(
287  'worker_time', time.time() - self._start_time)
288  self._metrics.log_metrics()
289 
290 
291 global_coordinator = GlobalWorkerCoordinator()
292 
293 
294 def run_worker(coordinator, worker):
295  while coordinator.is_active():
296  worker.start()
297  try:
298  worker.run()
299  except Exception as e:
300  worker.handle_exception(e)
301  finally:
302  worker.finish()