Caffe2 - Python API
A deep learning, cross platform ML framework
queue_util.py
1 ## @package queue_util
2 # Module caffe2.python.queue_util
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, dataio
9 from caffe2.python.task import TaskGroup
10 
11 import logging
12 
13 
14 logger = logging.getLogger(__name__)
15 
16 
18  def __init__(self, wrapper, num_dequeue_records=1):
19  assert wrapper.schema is not None, (
20  'Queue needs a schema in order to be read from.')
21  dataio.Reader.__init__(self, wrapper.schema())
22  self._wrapper = wrapper
23  self._num_dequeue_records = num_dequeue_records
24 
25  def setup_ex(self, init_net, exit_net):
26  exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
27 
28  def read_ex(self, local_init_net, local_finish_net):
29  self._wrapper._new_reader(local_init_net)
30  dequeue_net = core.Net('dequeue')
31  fields, status_blob = dequeue(
32  dequeue_net,
33  self._wrapper.queue(),
34  len(self.schema().field_names()),
35  field_names=self.schema().field_names(),
36  num_records=self._num_dequeue_records)
37  return [dequeue_net], status_blob, fields
38 
39  def read(self, net):
40  net, _, fields = self.read_ex(net, None)
41  return net, fields
42 
43 
45  def __init__(self, wrapper):
46  self._wrapper = wrapper
47 
48  def setup_ex(self, init_net, exit_net):
49  exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
50 
51  def write_ex(self, fields, local_init_net, local_finish_net, status):
52  self._wrapper._new_writer(self.schema(), local_init_net)
53  enqueue_net = core.Net('enqueue')
54  enqueue(enqueue_net, self._wrapper.queue(), fields, status)
55  return [enqueue_net]
56 
57 
59  def __init__(self, handler, schema=None, num_dequeue_records=1):
60  dataio.Pipe.__init__(self, schema, TaskGroup.LOCAL_SETUP)
61  self._queue = handler
62  self._num_dequeue_records = num_dequeue_records
63 
64  def reader(self):
65  return _QueueReader(
66  self, num_dequeue_records=self._num_dequeue_records)
67 
68  def writer(self):
69  return _QueueWriter(self)
70 
71  def queue(self):
72  return self._queue
73 
74 
76  def __init__(self, capacity, schema=None, name='queue',
77  num_dequeue_records=1):
78  # find a unique blob name for the queue
79  net = core.Net(name)
80  queue_blob = net.AddExternalInput(net.NextName('handler'))
81  QueueWrapper.__init__(
82  self, queue_blob, schema, num_dequeue_records=num_dequeue_records)
83  self.capacity = capacity
84  self._setup_done = False
85 
86  def setup(self, global_init_net):
87  assert self._schema, 'This queue does not have a schema.'
88  self._setup_done = True
89  global_init_net.CreateBlobsQueue(
90  [],
91  [self._queue],
92  capacity=self.capacity,
93  num_blobs=len(self._schema.field_names()),
94  field_names=self._schema.field_names())
95 
96 
97 def enqueue(net, queue, data_blobs, status=None):
98  if status is None:
99  status = net.NextName('status')
100  # Enqueueing moved the data into the queue;
101  # duplication will result in data corruption
102  queue_blobs = []
103  for blob in data_blobs:
104  if blob not in queue_blobs:
105  queue_blobs.append(blob)
106  else:
107  logger.warning("Need to copy blob {} to enqueue".format(blob))
108  queue_blobs.append(net.Copy(blob))
109  results = net.SafeEnqueueBlobs([queue] + queue_blobs, queue_blobs + [status])
110  return results[-1]
111 
112 
113 def dequeue(net, queue, num_blobs, status=None, field_names=None,
114  num_records=1):
115  if field_names is not None:
116  assert len(field_names) == num_blobs
117  data_names = [net.NextName(name) for name in field_names]
118  else:
119  data_names = [net.NextName('data', i) for i in range(num_blobs)]
120  if status is None:
121  status = net.NextName('status')
122  results = net.SafeDequeueBlobs(
123  queue, data_names + [status], num_records=num_records)
124  results = list(results)
125  status_blob = results.pop(-1)
126  return results, status_blob
127 
128 
129 def close_queue(step, *queues):
130  close_net = core.Net("close_queue_net")
131  for queue in queues:
132  close_net.CloseBlobsQueue([queue], 0)
133  close_step = core.execution_step("%s_step" % str(close_net), close_net)
134  return core.execution_step(
135  "%s_wraper_step" % str(close_net),
136  [step, close_step])
Definition: setup.py:1
def read_ex(self, local_init_net, local_finish_net)
Definition: dataio.py:60