Caffe2 - Python API
A deep learning, cross platform ML framework
data_workers.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package data_workers
17 # Module caffe2.python.data_workers
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 
24 '''
25 This module provides a python-land multithreaded data input mechanism
26 for Caffe2 nets.
27 
28 Basic usage is as follows:
29  coordinator = data_workers.init_data_input_workers(
30  net,
31  ["data", "label"],
32  my_fetch_fun,
33  batch_size=32,
34  input_source_name="train",
35  dont_rebatch=False
36  )
37  ...
38  coordinator.start()
39 
40 First argument is the Caffe2 net (or model helper), and second argument
41 is list of input blobs that are to be fed.
42 
43 Argument 'input_source_name' is used to distinguish different sources of data,
44 such as train or test data. This is to ensure the data does not get mixed up,
45 although two nets would share blobs.
46 
47 To do the actual data loading, one defines a "fetcher function"
48 that has call signature
49  my_fetch_fun(worker_id, batch_size)
50 
51 Optionally, one can define a "init function" that is called once before
52 threads start, and has call signature:
53  my_init_fun(data_coordinator, global_coordinator)
54 
55 If dont_rebatch is set to True, the data input is not batched into equal sized
56 chunks but data directly provided by fetchers is used.
57 
58 'batch_columns' can be used to specify which dimension is the batch dimension,
59 for each of the inputs. Default is 0 for all iputs.
60 
61 'timeout' is the timeout in seconds after which if no data is available, the
62 net will fail (default 600s = 10 mins).
63 
64 This function returns a list of numpy arrays corresponding to the different
65 input blobs. In the example above, it would return two arrays, one for the
66 data blob and another for the labels. These arrays can have arbitrary number
67 of elements (i.e they do not need to match the batch size). The batch size
68 is provided for the function as a hint only.
69 
70 For example, fetcher function could download images from a remote service or
71 load random images from a directory on a file system.
72 
73 For a dummy example, see the data_workers_test unit test.
74 
75 Note that for data_parallel_models, init_data_input_workers will be called
76 for each GPU. Note that the 'coordinator' returned by the function is same
77 each time.
78 '''
79 
80 try:
81  import Queue
82 except ImportError:
83  # Py3
84  import queue as Queue
85 from itertools import chain
86 import logging
87 import threading
88 import numpy as np
89 import time
90 
91 from caffe2.python import workspace, core, scope, utils
92 from caffe2.proto import caffe2_pb2
93 from caffe2.python.parallel_workers import Metrics, State, \
94  WorkerCoordinator, GlobalWorkerCoordinator, Worker, run_worker
95 
96 log = logging.getLogger("data_workers")
97 log.setLevel(logging.INFO)
98 LOG_INT_SECS = 60
99 
100 
101 def get_worker_ids(num_workers):
102  return list(range(0, num_workers))
103 
104 
105 def init_data_input_workers(
106  net,
107  input_blob_names,
108  fetch_fun,
109  batch_size,
110  num_worker_threads=2,
111  input_source_name="train",
112  max_buffered_batches=800,
113  init_fun=None,
114  external_loggers=None,
115  dont_rebatch=False,
116  batch_columns=None,
117  timeout=600
118 ):
119  global global_coordinator
120  device_option = scope.CurrentDeviceScope()
121  if (device_option is None):
122  device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
123 
124  metrics = Metrics(external_loggers)
125  batch_feeder = BatchFeeder(
126  net,
127  input_blob_names,
128  batch_size,
129  device_option,
130  scope.CurrentNameScope(),
131  input_source_name,
132  global_coordinator.get_queue(input_source_name, max_buffered_batches),
133  metrics,
134  dont_rebatch,
135  batch_columns,
136  timeout=timeout
137  )
138 
139  # Create coordinator object
140  coordinator = WorkerCoordinator(
141  input_source_name, init_fun, batch_feeder)
142 
143  # Launch fetch worker threads
144  worker_ids = [
145  global_coordinator.get_new_worker_id()
146  for i in range(num_worker_threads)
147  ]
148  workers = [
149  threading.Thread(
150  target=run_worker,
151  name="data_workers fetcher id {}".format(worker_id),
152  args=[coordinator,
153  DataWorker(coordinator, worker_id, fetch_fun, metrics,
154  batch_size, batch_feeder)],
155  ) for worker_id in worker_ids
156  ]
157 
158  workers.append(threading.Thread(
159  target=enqueuer,
160  name="Enqueuer {} {}".format(input_source_name, scope.CurrentNameScope()),
161  args=[coordinator, batch_feeder]))
162  coordinator._workers = workers
163  global_coordinator.add(coordinator)
164 
165  return global_coordinator
166 
167 
169  def __init__(self, net, input_blob_names, batch_size,
170  device_option, namescope, input_source_name, queue,
171  metrics, dont_rebatch, batch_columns, timeout=600):
172  self._counter = 0
173  self._input_blob_names = input_blob_names
174  self._batch_size = batch_size
175  self._internal_queue = queue
176  self._queues = []
177  self._device_option = device_option
178  self._namescope = namescope
179  self._timeout = timeout
180  self._input_source_name = input_source_name
181  self._c2_queue_capacity = 4
182  self._create_caffe2_queues(net)
183  self._create_caffe2_ops(net)
184  self._inputs = 0
185  self._prev_seconds = 0
186  self._last_warning = time.time()
187  self._dont_rebatch = dont_rebatch
188  self._init_scratch()
189  self._metrics = metrics
190 
191  if batch_columns is None:
192  batch_columns = [0 for _ in input_blob_names]
193  self._batch_columns = batch_columns
194 
195  def start(self):
196  self._inputs = 0
197  self._prev_seconds = time.time()
198 
199  def stop(self):
200  try:
201  for q in self._queues:
202  workspace.RunOperatorOnce(
203  core.CreateOperator("CloseBlobsQueue", [q], [])
204  )
205  finally:
206  self._log_inputs_per_interval(0, force=True)
207 
208  def cleanup(self):
209  utils.ResetBlobs(self._scratch_blob.values())
210  utils.ResetBlobs(self._scratch_status.values())
211 
212  def _get(self, data_input_coordinator):
213  start_time = time.time()
214  last_warning = time.time()
215  while data_input_coordinator.is_active():
216  try:
217  return self._internal_queue.get(block=True, timeout=0.5)
218  except Queue.Empty:
219  if time.time() - last_warning > 10.0:
220  log.warning("** Data input is slow: (still) no data in {} secs.".format(
221  time.time() - start_time))
222  last_warning = time.time()
223  continue
224  return None
225 
226  def _validate_chunk(self, chunk):
227  if chunk is None:
228  log.warning("Fetcher function returned None")
229  return False
230 
231  assert len(chunk) == len(self._input_blob_names), \
232  "Expecting data blob for each input"
233  for d in chunk:
234  assert isinstance(d, np.ndarray), \
235  "Fetcher function must return a numpy array"
236  if not self._dont_rebatch:
237  j = 1
238  for d in chunk[1:]:
239  assert d.shape[self._batch_columns[j]] == \
240  chunk[0].shape[self._batch_columns[0]], \
241  "Each returned input must have equal number of samples"
242  j += 1
243 
244  if len(chunk) == 0:
245  log.warning("Worker provided zero length input")
246  return False
247 
248  return True
249 
250  def put(self, chunk, data_input_coordinator):
251  if not self._validate_chunk(chunk):
252  return
253 
254  while data_input_coordinator.is_active():
255  try:
256  qsize = self._internal_queue.qsize()
257  if qsize < 2 and (time.time() - self._last_warning) > LOG_INT_SECS:
258  log.warning("Warning, data loading lagging behind: " +
259  "name={}".format(qsize, self._input_source_name))
260  self._last_warning = time.time()
261  self._counter += 1
262  self._internal_queue.put(chunk, block=True, timeout=0.5)
263  self._log_inputs_per_interval(chunk[0].shape[0])
264  return
265  except Queue.Full:
266  log.debug("Queue full: stalling fetchers...")
267  continue
268 
269  def _enqueue_batch_direct(self, data_input_coordinator):
270  data = self._get(data_input_coordinator)
271  if data is None:
272  return
273  if data_input_coordinator.is_active():
274  for b, q, c in zip(self._input_blob_names, self._queues, data):
275  self._enqueue(b, q, c)
276 
277  def _enqueue_batch(self, data_input_coordinator):
278  '''
279  This pulls data from the python-side queue and collects them
280  into batch-sized pieces, unless dont_rebatch is set to true.
281  '''
282  if self._dont_rebatch:
283  self._enqueue_batch_direct(data_input_coordinator)
284  return
285 
286  cur_batch = [np.array([]) for d in self._input_blob_names]
287  first_batch_col = self._batch_columns[0]
288 
289  # Collect data until we have a full batch size
290  while (
291  cur_batch[0].shape[0] == 0 or
292  cur_batch[0].shape[first_batch_col] < self._batch_size
293  ) and data_input_coordinator.is_active():
294  chunk = self._get(data_input_coordinator)
295  if chunk is None:
296  continue
297 
298  for j, chunk_elem in enumerate(chunk):
299  if cur_batch[j].shape[0] == 0:
300  cur_batch[j] = chunk_elem.copy()
301  else:
302  cur_batch[j] = np.append(
303  cur_batch[j], chunk_elem, axis=self._batch_columns[j]
304  )
305 
306  start_time = time.time()
307  try:
308  # Return data over the batch size back to queue
309  if cur_batch[0].shape[0] > 0 and cur_batch[0].shape[
310  first_batch_col
311  ] > self._batch_size:
312  leftover = []
313  trimmed_batch = []
314  for j, b in enumerate(cur_batch):
315  [c, l] = np.split(
316  b, [self._batch_size], axis=self._batch_columns[j]
317  )
318  leftover.append(l)
319  trimmed_batch.append(c)
320  cur_batch = trimmed_batch
321  try:
322  self._internal_queue.put(leftover, block=False)
323  except Queue.Full:
324  pass
325 
326  assert cur_batch[0].shape[first_batch_col] == self._batch_size
327 
328  if data_input_coordinator.is_active():
329  for b, q, c in zip(
330  self._input_blob_names, self._queues, cur_batch
331  ):
332  self._enqueue(b, q, c)
333  finally:
334  self._metrics.put_metric('enqueue_time', time.time() - start_time)
335 
336  def _init_scratch(self):
337  self._scratch_blob = {}
338  self._scratch_status = {}
339  for blob_name in self._input_blob_names:
340  scratch_name = self._namescope + blob_name + \
341  "_scratch_" + self._input_source_name
342  self._scratch_blob[blob_name] = core.BlobReference(scratch_name)
343  self._scratch_status[blob_name] = core.BlobReference(
344  scratch_name + "_status"
345  )
346 
347  # Feed empty arrays to the scratch blobs here, so that there won't be
348  # race conditions when calling FeedBlob (which calls wworkspace
349  # CreateBlob()) from enqueue threads
350  for b in chain(
351  self._scratch_blob.values(), self._scratch_status.values()
352  ):
353  workspace.FeedBlob(
354  b,
355  np.array([]).astype(np.float32),
356  device_option=self._device_option,
357  )
358 
359  def _enqueue(self, blob_name, queue, data_arr):
360  '''
361  Enqueue the correctly sized batch arrays to Caffe2's queue.
362  '''
363  workspace.FeedBlob(
364  self._scratch_blob[blob_name],
365  data_arr,
366  device_option=self._device_option
367  )
368 
369  op = core.CreateOperator(
370  "SafeEnqueueBlobs",
371  [queue, self._scratch_blob[blob_name]],
372  [self._scratch_blob[blob_name], self._scratch_status[blob_name]],
373  device_option=self._device_option
374  )
375  workspace.RunOperatorOnce(op)
376 
377  def _create_caffe2_queues(self, net):
378  '''
379  Creates queues on caffe2 side
380  '''
381  def create_queue(queue_name, num_blobs, capacity):
382  workspace.RunOperatorOnce(
383  core.CreateOperator(
384  "CreateBlobsQueue",
385  [], [queue_name],
386  num_blobs=1,
387  capacity=capacity))
388  return core.ScopedBlobReference(queue_name)
389 
390  for blob_name in self._input_blob_names:
391  qname = blob_name + "_c2queue" + "_" + self._input_source_name
392  q = create_queue(
393  qname, num_blobs=1, capacity=self._c2_queue_capacity
394  )
395  self._queues.append(q)
396 
397  def _create_caffe2_ops(self, net):
398  '''
399  Creates dequeue-ops on caffe2 side
400  '''
401  for q, blob_name in zip(self._queues, self._input_blob_names):
402  # Add operator to the Caffe2 network to dequeue
403  net.DequeueBlobs(q, blob_name, timeout_secs=float(self._timeout))
404 
405  def _log_inputs_per_interval(self, inputs, force=False):
406  self._inputs += inputs
407  current_seconds = time.time()
408  delta_seconds = current_seconds - self._prev_seconds
409  if delta_seconds >= LOG_INT_SECS or force:
410  inputs_per_sec = int(self._inputs / delta_seconds)
411  qsize = self._internal_queue.qsize()
412  log.info("{}/{}: {} inputs/sec".format(
413  self._input_source_name,
414  self._namescope,
415  inputs_per_sec,
416  ))
417  log.info("-- queue: {} batches".format(qsize))
418  # log and reset perf metrics
419  self._metrics.put_metric(
420  'inputs_per_sec', inputs_per_sec, False)
421  self._metrics.put_metric('queue_size', qsize, False)
422  self._metrics.put_metric(
423  'time_elapsed', delta_seconds, False)
424  self._metrics.log_metrics()
425  self._metrics.reset_metrics()
426  self._inputs = 0
427  self._prev_seconds = current_seconds
428 
429 
430 class GlobalCoordinator(GlobalWorkerCoordinator):
431  def __init__(self):
432  GlobalWorkerCoordinator.__init__(self)
433  self._queues = {}
434 
435  def get_queue(self, queue_name, max_buffered_batches):
436  assert isinstance(max_buffered_batches, int)
437  if queue_name not in self._queues:
438  self._queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
439  return self._queues[queue_name]
440 
441  def reset_data_input(self, namescope, name, net, batch_size):
442  log.info("Reset data input {}, batch size {}: ".format(name, batch_size))
443  for c in self._coordinators:
444  if c._worker_name == name and c._state._namescope == namescope:
445  c._state._batch_size = batch_size
446  c._state._create_caffe2_ops(net)
447 
448 
449 class DataWorker(Worker):
450  def __init__(
451  self,
452  coordinator,
453  worker_id,
454  worker_fun,
455  metrics,
456  batch_size,
457  batch_feeder
458  ):
459  Worker.__init__(self, coordinator, worker_id, worker_fun=worker_fun,
460  metrics=metrics)
461  self._batch_size = batch_size
462  self._batch_feeder = batch_feeder
463 
464  def run(self):
465  input_data = self._worker_fun(self._worker_id, self._batch_size)
466 
467  self._batch_feeder.put(input_data, self._coordinator)
468 
469  def finish(self):
470  self._metrics.put_metric(
471  'fetcher_time', time.time() - self._start_time)
472 
473 
474 global_coordinator = GlobalCoordinator()
475 
476 
477 def enqueuer(coordinator, batch_feeder):
478  while coordinator.is_active():
479  batch_feeder._enqueue_batch(coordinator)
def _enqueue_batch_direct(self, data_input_coordinator)
def _enqueue(self, blob_name, queue, data_arr)
def _log_inputs_per_interval(self, inputs, force=False)
def _get(self, data_input_coordinator)