Caffe2 - Python API
A deep learning, cross platform ML framework
data_workers.py
1 ## @package data_workers
2 # Module caffe2.python.data_workers
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 
9 '''
10 This module provides a python-land multithreaded data input mechanism
11 for Caffe2 nets.
12 
13 Basic usage is as follows:
14  coordinator = data_workers.init_data_input_workers(
15  net,
16  ["data", "label"],
17  my_fetch_fun,
18  batch_size=32,
19  input_source_name="train",
20  dont_rebatch=False
21  )
22  ...
23  coordinator.start()
24 
25 First argument is the Caffe2 net (or model helper), and second argument
26 is list of input blobs that are to be fed.
27 
28 Argument 'input_source_name' is used to distinguish different sources of data,
29 such as train or test data. This is to ensure the data does not get mixed up,
30 although two nets would share blobs.
31 
32 To do the actual data loading, one defines a "fetcher function"
33 that has call signature
34  my_fetch_fun(worker_id, batch_size)
35 
36 Optionally, one can define a "init function" that is called once before
37 threads start, and has call signature:
38  my_init_fun(data_coordinator, global_coordinator)
39 
40 If dont_rebatch is set to True, the data input is not batched into equal sized
41 chunks but data directly provided by fetchers is used.
42 
43 'batch_columns' can be used to specify which dimension is the batch dimension,
44 for each of the inputs. Default is 0 for all iputs.
45 
46 'timeout' is the timeout in seconds after which if no data is available, the
47 net will fail (default 600s = 10 mins).
48 
49 This function returns a list of numpy arrays corresponding to the different
50 input blobs. In the example above, it would return two arrays, one for the
51 data blob and another for the labels. These arrays can have arbitrary number
52 of elements (i.e they do not need to match the batch size). The batch size
53 is provided for the function as a hint only.
54 
55 For example, fetcher function could download images from a remote service or
56 load random images from a directory on a file system.
57 
58 For a dummy example, see the data_workers_test unit test.
59 
60 Note that for data_parallel_models, init_data_input_workers will be called
61 for each GPU. Note that the 'coordinator' returned by the function is same
62 each time.
63 '''
64 
65 try:
66  import Queue
67 except ImportError:
68  # Py3
69  import queue as Queue
70 from itertools import chain
71 import logging
72 import threading
73 import numpy as np
74 import time
75 
76 from caffe2.python import workspace, core, scope, utils
77 from caffe2.proto import caffe2_pb2
78 from caffe2.python.parallel_workers import Metrics, State, \
79  WorkerCoordinator, GlobalWorkerCoordinator, Worker, run_worker
80 
81 log = logging.getLogger("data_workers")
82 log.setLevel(logging.INFO)
83 LOG_INT_SECS = 60
84 
85 
86 def get_worker_ids(num_workers):
87  return list(range(0, num_workers))
88 
89 
90 def init_data_input_workers(
91  net,
92  input_blob_names,
93  fetch_fun,
94  batch_size,
95  num_worker_threads=2,
96  input_source_name="train",
97  max_buffered_batches=800,
98  init_fun=None,
99  external_loggers=None,
100  dont_rebatch=False,
101  batch_columns=None,
102  timeout=600
103 ):
104  global global_coordinator
105  device_option = scope.CurrentDeviceScope()
106  if (device_option is None):
107  device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
108 
109  metrics = Metrics(external_loggers)
110  batch_feeder = BatchFeeder(
111  net,
112  input_blob_names,
113  batch_size,
114  device_option,
115  scope.CurrentNameScope(),
116  input_source_name,
117  global_coordinator.get_queue(input_source_name, max_buffered_batches),
118  metrics,
119  dont_rebatch,
120  batch_columns,
121  timeout=timeout
122  )
123 
124  # Launch fetch worker threads
125  worker_ids = [
126  global_coordinator.get_new_worker_id()
127  for i in range(num_worker_threads)
128  ]
129 
130  # Create coordinator object
131  coordinator = WorkerCoordinator(
132  input_source_name, worker_ids, init_fun, batch_feeder)
133 
134  workers = [
135  threading.Thread(
136  target=run_worker,
137  name="data_workers fetcher id {}".format(worker_id),
138  args=[coordinator,
139  DataWorker(coordinator, worker_id, fetch_fun, metrics,
140  batch_size, batch_feeder)],
141  ) for worker_id in worker_ids
142  ]
143 
144  workers.append(threading.Thread(
145  target=enqueuer,
146  name="Enqueuer {} {}".format(input_source_name, scope.CurrentNameScope()),
147  args=[coordinator, batch_feeder]))
148  coordinator._workers = workers
149  global_coordinator.add(coordinator)
150 
151  return global_coordinator
152 
153 
155  def __init__(self, net, input_blob_names, batch_size,
156  device_option, namescope, input_source_name, queue,
157  metrics, dont_rebatch, batch_columns, timeout=600):
158  self._counter = 0
159  self._input_blob_names = input_blob_names
160  self._batch_size = batch_size
161  self._internal_queue = queue
162  self._queues = []
163  self._device_option = device_option
164  self._namescope = namescope
165  self._timeout = timeout
166  self._input_source_name = input_source_name
167  self._c2_queue_capacity = 4
168  self._create_caffe2_queues(net)
169  self._create_caffe2_ops(net)
170  self._inputs = 0
171  self._prev_seconds = 0
172  self._last_warning = time.time()
173  self._dont_rebatch = dont_rebatch
174  self._init_scratch()
175  self._metrics = metrics
176 
177  if batch_columns is None:
178  batch_columns = [0 for _ in input_blob_names]
179  self._batch_columns = batch_columns
180 
181  def start(self):
182  self._inputs = 0
183  self._prev_seconds = time.time()
184 
185  def stop(self):
186  try:
187  for q in self._queues:
188  workspace.RunOperatorOnce(
189  core.CreateOperator("CloseBlobsQueue", [q], [])
190  )
191  finally:
192  self._log_inputs_per_interval(0, force=True)
193 
194  def cleanup(self):
195  utils.ResetBlobs(self._scratch_blob.values())
196  utils.ResetBlobs(self._scratch_status.values())
197 
198  def _get(self, data_input_coordinator):
199  start_time = time.time()
200  last_warning = time.time()
201  while data_input_coordinator.is_active():
202  try:
203  return self._internal_queue.get(block=True, timeout=0.5)
204  except Queue.Empty:
205  if time.time() - last_warning > 10.0:
206  log.warning("** Data input is slow: (still) no data in {} secs.".format(
207  time.time() - start_time))
208  last_warning = time.time()
209  continue
210  return None
211 
212  def _validate_chunk(self, chunk):
213  if chunk is None:
214  log.warning("Fetcher function returned None")
215  return False
216 
217  assert len(chunk) == len(self._input_blob_names), \
218  "Expecting data blob for each input"
219  for d in chunk:
220  assert isinstance(d, np.ndarray), \
221  "Fetcher function must return a numpy array"
222  if not self._dont_rebatch:
223  j = 1
224  for d in chunk[1:]:
225  assert d.shape[self._batch_columns[j]] == \
226  chunk[0].shape[self._batch_columns[0]], \
227  "Each returned input must have equal number of samples"
228  j += 1
229 
230  if len(chunk) == 0:
231  log.warning("Worker provided zero length input")
232  return False
233 
234  return True
235 
236  def put(self, chunk, data_input_coordinator):
237  if not self._validate_chunk(chunk):
238  return
239 
240  while data_input_coordinator.is_active():
241  try:
242  qsize = self._internal_queue.qsize()
243  if qsize < 2 and (time.time() - self._last_warning) > LOG_INT_SECS:
244  log.warning("Warning, data loading lagging behind: " +
245  "queue size={}, name={}".format(qsize, self._input_source_name))
246  self._last_warning = time.time()
247  self._counter += 1
248  self._internal_queue.put(chunk, block=True, timeout=0.5)
249  self._log_inputs_per_interval(chunk[0].shape[0])
250  return
251  except Queue.Full:
252  log.debug("Queue full: stalling fetchers...")
253  continue
254 
255  def _enqueue_batch_direct(self, data_input_coordinator):
256  data = self._get(data_input_coordinator)
257  if data is None:
258  return
259  if data_input_coordinator.is_active():
260  for b, q, c in zip(self._input_blob_names, self._queues, data):
261  self._enqueue(b, q, c)
262 
263  def _enqueue_batch(self, data_input_coordinator):
264  '''
265  This pulls data from the python-side queue and collects them
266  into batch-sized pieces, unless dont_rebatch is set to true.
267  '''
268  if self._dont_rebatch:
269  self._enqueue_batch_direct(data_input_coordinator)
270  return
271 
272  cur_batch = [np.array([]) for d in self._input_blob_names]
273  first_batch_col = self._batch_columns[0]
274 
275  # Collect data until we have a full batch size
276  while (
277  cur_batch[0].shape[0] == 0 or
278  cur_batch[0].shape[first_batch_col] < self._batch_size
279  ) and data_input_coordinator.is_active():
280  chunk = self._get(data_input_coordinator)
281  if chunk is None:
282  continue
283 
284  for j, chunk_elem in enumerate(chunk):
285  if cur_batch[j].shape[0] == 0:
286  cur_batch[j] = chunk_elem.copy()
287  else:
288  cur_batch[j] = np.append(
289  cur_batch[j], chunk_elem, axis=self._batch_columns[j]
290  )
291 
292  start_time = time.time()
293  try:
294  # Return data over the batch size back to queue
295  if cur_batch[0].shape[0] > 0 and cur_batch[0].shape[
296  first_batch_col
297  ] > self._batch_size:
298  leftover = []
299  trimmed_batch = []
300  for j, b in enumerate(cur_batch):
301  [c, l] = np.split(
302  b, [self._batch_size], axis=self._batch_columns[j]
303  )
304  leftover.append(l)
305  trimmed_batch.append(c)
306  cur_batch = trimmed_batch
307  try:
308  self._internal_queue.put(leftover, block=False)
309  except Queue.Full:
310  pass
311 
312  assert cur_batch[0].shape[first_batch_col] == self._batch_size
313 
314  if data_input_coordinator.is_active():
315  for b, q, c in zip(
316  self._input_blob_names, self._queues, cur_batch
317  ):
318  self._enqueue(b, q, c)
319  finally:
320  self._metrics.put_metric('enqueue_time', time.time() - start_time)
321 
322  def _init_scratch(self):
323  self._scratch_blob = {}
324  self._scratch_status = {}
325  for blob_name in self._input_blob_names:
326  scratch_name = self._namescope + blob_name + \
327  "_scratch_" + self._input_source_name
328  self._scratch_blob[blob_name] = core.BlobReference(scratch_name)
329  self._scratch_status[blob_name] = core.BlobReference(
330  scratch_name + "_status"
331  )
332 
333  # Feed empty arrays to the scratch blobs here, so that there won't be
334  # race conditions when calling FeedBlob (which calls wworkspace
335  # CreateBlob()) from enqueue threads
336  for b in chain(
337  self._scratch_blob.values(), self._scratch_status.values()
338  ):
339  workspace.FeedBlob(
340  b,
341  np.array([]).astype(np.float32),
342  device_option=self._device_option,
343  )
344 
345  def _enqueue(self, blob_name, queue, data_arr):
346  '''
347  Enqueue the correctly sized batch arrays to Caffe2's queue.
348  '''
349  workspace.FeedBlob(
350  self._scratch_blob[blob_name],
351  data_arr,
352  device_option=self._device_option
353  )
354 
355  op = core.CreateOperator(
356  "SafeEnqueueBlobs",
357  [queue, self._scratch_blob[blob_name]],
358  [self._scratch_blob[blob_name], self._scratch_status[blob_name]],
359  device_option=self._device_option
360  )
361  workspace.RunOperatorOnce(op)
362 
363  def _create_caffe2_queues(self, net):
364  '''
365  Creates queues on caffe2 side
366  '''
367  def create_queue(queue_name, num_blobs, capacity):
368  workspace.RunOperatorOnce(
369  core.CreateOperator(
370  "CreateBlobsQueue",
371  [], [queue_name],
372  num_blobs=1,
373  capacity=capacity))
374  return core.ScopedBlobReference(queue_name)
375 
376  for blob_name in self._input_blob_names:
377  qname = blob_name + "_c2queue" + "_" + self._input_source_name
378  q = create_queue(
379  qname, num_blobs=1, capacity=self._c2_queue_capacity
380  )
381  self._queues.append(q)
382 
383  def _create_caffe2_ops(self, net):
384  '''
385  Creates dequeue-ops on caffe2 side
386  '''
387  for q, blob_name in zip(self._queues, self._input_blob_names):
388  # Add operator to the Caffe2 network to dequeue
389  net.DequeueBlobs(q, blob_name, timeout_secs=float(self._timeout))
390 
391  def _log_inputs_per_interval(self, inputs, force=False):
392  self._inputs += inputs
393  current_seconds = time.time()
394  delta_seconds = current_seconds - self._prev_seconds
395  if delta_seconds >= LOG_INT_SECS or force:
396  inputs_per_sec = int(self._inputs / delta_seconds)
397  qsize = self._internal_queue.qsize()
398  log.info("{}/{}: {} inputs/sec".format(
399  self._input_source_name,
400  self._namescope,
401  inputs_per_sec,
402  ))
403  log.info("-- queue: {} batches".format(qsize))
404  # log and reset perf metrics
405  self._metrics.put_metric(
406  'inputs_per_sec', inputs_per_sec, False)
407  self._metrics.put_metric('queue_size', qsize, False)
408  self._metrics.put_metric(
409  'time_elapsed', delta_seconds, False)
410  self._metrics.log_metrics()
411  self._metrics.reset_metrics()
412  self._inputs = 0
413  self._prev_seconds = current_seconds
414 
415 
416 class GlobalCoordinator(GlobalWorkerCoordinator):
417  def __init__(self):
418  GlobalWorkerCoordinator.__init__(self)
419  self._queues = {}
420 
421  def get_queue(self, queue_name, max_buffered_batches):
422  assert isinstance(max_buffered_batches, int)
423  if queue_name not in self._queues:
424  self._queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
425  return self._queues[queue_name]
426 
427  def reset_data_input(self, namescope, name, net, batch_size):
428  log.info("Reset data input {}, batch size {}: ".format(name, batch_size))
429  for c in self._coordinators:
430  if c._worker_name == name and c._state._namescope == namescope:
431  c._state._batch_size = batch_size
432  c._state._create_caffe2_ops(net)
433 
434 
435 class DataWorker(Worker):
436  def __init__(
437  self,
438  coordinator,
439  worker_id,
440  worker_fun,
441  metrics,
442  batch_size,
443  batch_feeder
444  ):
445  Worker.__init__(self, coordinator, worker_id, worker_fun=worker_fun,
446  metrics=metrics)
447  self._batch_size = batch_size
448  self._batch_feeder = batch_feeder
449 
450  def run(self):
451  input_data = self._worker_fun(self._worker_id, self._batch_size)
452 
453  self._batch_feeder.put(input_data, self._coordinator)
454 
455  def finish(self):
456  self._metrics.put_metric(
457  'fetcher_time', time.time() - self._start_time)
458 
459 
460 global_coordinator = GlobalCoordinator()
461 
462 
463 def enqueuer(coordinator, batch_feeder):
464  while coordinator.is_active():
465  batch_feeder._enqueue_batch(coordinator)
def _enqueue_batch_direct(self, data_input_coordinator)
def _enqueue(self, blob_name, queue, data_arr)
def _log_inputs_per_interval(self, inputs, force=False)
def _get(self, data_input_coordinator)