3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
10 This module provides a python-land multithreaded data input mechanism 13 Basic usage is as follows: 14 coordinator = data_workers.init_data_input_workers( 19 input_source_name="train", 25 First argument is the Caffe2 net (or model helper), and second argument 26 is list of input blobs that are to be fed. 28 Argument 'input_source_name' is used to distinguish different sources of data, 29 such as train or test data. This is to ensure the data does not get mixed up, 30 although two nets would share blobs. 32 To do the actual data loading, one defines a "fetcher function" 33 that has call signature 34 my_fetch_fun(worker_id, batch_size) 36 Optionally, one can define a "init function" that is called once before 37 threads start, and has call signature: 38 my_init_fun(data_coordinator, global_coordinator) 40 If dont_rebatch is set to True, the data input is not batched into equal sized 41 chunks but data directly provided by fetchers is used. 43 'batch_columns' can be used to specify which dimension is the batch dimension, 44 for each of the inputs. Default is 0 for all iputs. 46 'timeout' is the timeout in seconds after which if no data is available, the 47 net will fail (default 600s = 10 mins). 49 This function returns a list of numpy arrays corresponding to the different 50 input blobs. In the example above, it would return two arrays, one for the 51 data blob and another for the labels. These arrays can have arbitrary number 52 of elements (i.e they do not need to match the batch size). The batch size 53 is provided for the function as a hint only. 55 For example, fetcher function could download images from a remote service or 56 load random images from a directory on a file system. 58 For a dummy example, see the data_workers_test unit test. 60 Note that for data_parallel_models, init_data_input_workers will be called 61 for each GPU. Note that the 'coordinator' returned by the function is same 70 from itertools
import chain
77 from caffe2.proto
import caffe2_pb2
79 WorkerCoordinator, GlobalWorkerCoordinator, Worker, run_worker
81 log = logging.getLogger(
"data_workers")
82 log.setLevel(logging.INFO)
86 def get_worker_ids(num_workers):
87 return list(range(0, num_workers))
90 def init_data_input_workers(
96 input_source_name=
"train",
97 max_buffered_batches=800,
99 external_loggers=
None,
104 global global_coordinator
105 device_option = scope.CurrentDeviceScope()
106 if (device_option
is None):
107 device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
109 metrics = Metrics(external_loggers)
110 batch_feeder = BatchFeeder(
115 scope.CurrentNameScope(),
117 global_coordinator.get_queue(input_source_name, max_buffered_batches),
126 global_coordinator.get_new_worker_id()
127 for i
in range(num_worker_threads)
131 coordinator = WorkerCoordinator(
132 input_source_name, worker_ids, init_fun, batch_feeder)
137 name=
"data_workers fetcher id {}".format(worker_id),
139 DataWorker(coordinator, worker_id, fetch_fun, metrics,
140 batch_size, batch_feeder)],
141 )
for worker_id
in worker_ids
144 workers.append(threading.Thread(
146 name=
"Enqueuer {} {}".format(input_source_name, scope.CurrentNameScope()),
147 args=[coordinator, batch_feeder]))
148 coordinator._workers = workers
149 global_coordinator.add(coordinator)
151 return global_coordinator
155 def __init__(self, net, input_blob_names, batch_size,
156 device_option, namescope, input_source_name, queue,
157 metrics, dont_rebatch, batch_columns, timeout=600):
177 if batch_columns
is None:
178 batch_columns = [0
for _
in input_blob_names]
188 workspace.RunOperatorOnce(
189 core.CreateOperator(
"CloseBlobsQueue", [q], [])
195 utils.ResetBlobs(self._scratch_blob.values())
196 utils.ResetBlobs(self._scratch_status.values())
198 def _get(self, data_input_coordinator):
199 start_time = time.time()
200 last_warning = time.time()
201 while data_input_coordinator.is_active():
203 return self._internal_queue.get(block=
True, timeout=0.5)
205 if time.time() - last_warning > 10.0:
206 log.warning(
"** Data input is slow: (still) no data in {} secs.".format(
207 time.time() - start_time))
208 last_warning = time.time()
212 def _validate_chunk(self, chunk):
214 log.warning(
"Fetcher function returned None")
218 "Expecting data blob for each input" 220 assert isinstance(d, np.ndarray), \
221 "Fetcher function must return a numpy array" 227 "Each returned input must have equal number of samples" 231 log.warning(
"Worker provided zero length input")
236 def put(self, chunk, data_input_coordinator):
240 while data_input_coordinator.is_active():
242 qsize = self._internal_queue.qsize()
243 if qsize < 2
and (time.time() - self.
_last_warning) > LOG_INT_SECS:
244 log.warning(
"Warning, data loading lagging behind: " +
248 self._internal_queue.put(chunk, block=
True, timeout=0.5)
252 log.debug(
"Queue full: stalling fetchers...")
255 def _enqueue_batch_direct(self, data_input_coordinator):
256 data = self.
_get(data_input_coordinator)
259 if data_input_coordinator.is_active():
263 def _enqueue_batch(self, data_input_coordinator):
265 This pulls data from the python-side queue and collects them 266 into batch-sized pieces, unless dont_rebatch is set to true. 277 cur_batch[0].shape[0] == 0
or 278 cur_batch[0].shape[first_batch_col] < self.
_batch_size 279 )
and data_input_coordinator.is_active():
280 chunk = self.
_get(data_input_coordinator)
284 for j, chunk_elem
in enumerate(chunk):
285 if cur_batch[j].shape[0] == 0:
286 cur_batch[j] = chunk_elem.copy()
288 cur_batch[j] = np.append(
292 start_time = time.time()
295 if cur_batch[0].shape[0] > 0
and cur_batch[0].shape[
300 for j, b
in enumerate(cur_batch):
305 trimmed_batch.append(c)
306 cur_batch = trimmed_batch
308 self._internal_queue.put(leftover, block=
False)
312 assert cur_batch[0].shape[first_batch_col] == self.
_batch_size 314 if data_input_coordinator.is_active():
320 self._metrics.put_metric(
'enqueue_time', time.time() - start_time)
322 def _init_scratch(self):
326 scratch_name = self.
_namescope + blob_name + \
330 scratch_name +
"_status" 337 self._scratch_blob.values(), self._scratch_status.values()
341 np.array([]).astype(np.float32),
345 def _enqueue(self, blob_name, queue, data_arr):
347 Enqueue the correctly sized batch arrays to Caffe2's queue. 355 op = core.CreateOperator(
361 workspace.RunOperatorOnce(op)
363 def _create_caffe2_queues(self, net):
365 Creates queues on caffe2 side 367 def create_queue(queue_name, num_blobs, capacity):
368 workspace.RunOperatorOnce(
374 return core.ScopedBlobReference(queue_name)
381 self._queues.append(q)
383 def _create_caffe2_ops(self, net):
385 Creates dequeue-ops on caffe2 side 389 net.DequeueBlobs(q, blob_name, timeout_secs=float(self.
_timeout))
391 def _log_inputs_per_interval(self, inputs, force=False):
393 current_seconds = time.time()
395 if delta_seconds >= LOG_INT_SECS
or force:
396 inputs_per_sec = int(self.
_inputs / delta_seconds)
397 qsize = self._internal_queue.qsize()
398 log.info(
"{}/{}: {} inputs/sec".format(
403 log.info(
"-- queue: {} batches".format(qsize))
405 self._metrics.put_metric(
406 'inputs_per_sec', inputs_per_sec,
False)
407 self._metrics.put_metric(
'queue_size', qsize,
False)
408 self._metrics.put_metric(
409 'time_elapsed', delta_seconds,
False)
410 self._metrics.log_metrics()
411 self._metrics.reset_metrics()
418 GlobalWorkerCoordinator.__init__(self)
421 def get_queue(self, queue_name, max_buffered_batches):
422 assert isinstance(max_buffered_batches, int)
423 if queue_name
not in self.
_queues:
424 self.
_queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
425 return self.
_queues[queue_name]
427 def reset_data_input(self, namescope, name, net, batch_size):
428 log.info(
"Reset data input {}, batch size {}: ".format(name, batch_size))
429 for c
in self._coordinators:
430 if c._worker_name == name
and c._state._namescope == namescope:
431 c._state._batch_size = batch_size
432 c._state._create_caffe2_ops(net)
445 Worker.__init__(self, coordinator, worker_id, worker_fun=worker_fun,
451 input_data = self._worker_fun(self._worker_id, self.
_batch_size)
453 self._batch_feeder.put(input_data, self._coordinator)
456 self._metrics.put_metric(
457 'fetcher_time', time.time() - self._start_time)
463 def enqueuer(coordinator, batch_feeder):
464 while coordinator.is_active():
465 batch_feeder._enqueue_batch(coordinator)
def _create_caffe2_ops(self, net)
def _enqueue_batch_direct(self, data_input_coordinator)
def _enqueue(self, blob_name, queue, data_arr)
def _log_inputs_per_interval(self, inputs, force=False)
def _create_caffe2_queues(self, net)
def _get(self, data_input_coordinator)
def _validate_chunk(self, chunk)