Caffe2 - Python API
A deep learning, cross platform ML framework
queue_util.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package queue_util
17 # Module caffe2.python.queue_util
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core, dataio
24 from caffe2.python.task import TaskGroup
25 
26 
28  def __init__(self, wrapper, num_dequeue_records=1):
29  assert wrapper.schema is not None, (
30  'Queue needs a schema in order to be read from.')
31  dataio.Reader.__init__(self, wrapper.schema())
32  self._wrapper = wrapper
33  self._num_dequeue_records = num_dequeue_records
34 
35  def setup_ex(self, init_net, exit_net):
36  exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
37 
38  def read_ex(self, local_init_net, local_finish_net):
39  self._wrapper._new_reader(local_init_net)
40  dequeue_net = core.Net('dequeue')
41  fields, status_blob = dequeue(
42  dequeue_net,
43  self._wrapper.queue(),
44  len(self.schema().field_names()),
45  field_names=self.schema().field_names(),
46  num_records=self._num_dequeue_records)
47  return [dequeue_net], status_blob, fields
48 
49  def read(self, net):
50  net, _, fields = self.read_ex(net, None)
51  return net, fields
52 
53 
55  def __init__(self, wrapper):
56  self._wrapper = wrapper
57 
58  def setup_ex(self, init_net, exit_net):
59  exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
60 
61  def write_ex(self, fields, local_init_net, local_finish_net, status):
62  self._wrapper._new_writer(self.schema(), local_init_net)
63  enqueue_net = core.Net('enqueue')
64  enqueue(enqueue_net, self._wrapper.queue(), fields, status)
65  return [enqueue_net]
66 
67 
69  def __init__(self, handler, schema=None, num_dequeue_records=1):
70  dataio.Pipe.__init__(self, schema, TaskGroup.LOCAL_SETUP)
71  self._queue = handler
72  self._num_dequeue_records = num_dequeue_records
73 
74  def reader(self):
75  return _QueueReader(
76  self, num_dequeue_records=self._num_dequeue_records)
77 
78  def writer(self):
79  return _QueueWriter(self)
80 
81  def queue(self):
82  return self._queue
83 
84 
86  def __init__(self, capacity, schema=None, name='queue',
87  num_dequeue_records=1):
88  # find a unique blob name for the queue
89  net = core.Net(name)
90  queue_blob = net.AddExternalInput(net.NextName('handler'))
91  QueueWrapper.__init__(
92  self, queue_blob, schema, num_dequeue_records=num_dequeue_records)
93  self.capacity = capacity
94  self._setup_done = False
95 
96  def setup(self, global_init_net):
97  assert self._schema, 'This queue does not have a schema.'
98  self._setup_done = True
99  global_init_net.CreateBlobsQueue(
100  [],
101  [self._queue],
102  capacity=self.capacity,
103  num_blobs=len(self._schema.field_names()),
104  field_names=self._schema.field_names())
105 
106 
107 def enqueue(net, queue, data_blobs, status=None):
108  if status is None:
109  status = net.NextName('status')
110  results = net.SafeEnqueueBlobs([queue] + data_blobs, data_blobs + [status])
111  return results[-1]
112 
113 
114 def dequeue(net, queue, num_blobs, status=None, field_names=None,
115  num_records=1):
116  if field_names is not None:
117  assert len(field_names) == num_blobs
118  data_names = [net.NextName(name) for name in field_names]
119  else:
120  data_names = [net.NextName('data', i) for i in range(num_blobs)]
121  if status is None:
122  status = net.NextName('status')
123  results = net.SafeDequeueBlobs(
124  queue, data_names + [status], num_records=num_records)
125  results = list(results)
126  status_blob = results.pop(-1)
127  return results, status_blob
128 
129 
130 def close_queue(step, *queues):
131  close_net = core.Net("close_queue_net")
132  for queue in queues:
133  close_net.CloseBlobsQueue([queue], 0)
134  close_step = core.execution_step("%s_step" % str(close_net), close_net)
135  return core.execution_step(
136  "%s_wraper_step" % str(close_net),
137  [step, close_step])
Definition: setup.py:1
def read_ex(self, local_init_net, local_finish_net)
Definition: dataio.py:74