3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
14 logger = logging.getLogger(__name__)
18 def __init__(self, wrapper, num_dequeue_records=1):
19 assert wrapper.schema
is not None, (
20 'Queue needs a schema in order to be read from.')
21 dataio.Reader.__init__(self, wrapper.schema())
25 def setup_ex(self, init_net, exit_net):
26 exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
28 def read_ex(self, local_init_net, local_finish_net):
29 self._wrapper._new_reader(local_init_net)
31 fields, status_blob = dequeue(
33 self._wrapper.queue(),
34 len(self.
schema().field_names()),
35 field_names=self.
schema().field_names(),
37 return [dequeue_net], status_blob, fields
40 net, _, fields = self.
read_ex(net,
None)
45 def __init__(self, wrapper):
48 def setup_ex(self, init_net, exit_net):
49 exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
51 def write_ex(self, fields, local_init_net, local_finish_net, status):
52 self._wrapper._new_writer(self.
schema(), local_init_net)
54 enqueue(enqueue_net, self._wrapper.queue(), fields, status)
59 def __init__(self, handler, schema=None, num_dequeue_records=1):
60 dataio.Pipe.__init__(self, schema, TaskGroup.LOCAL_SETUP)
76 def __init__(self, capacity, schema=None, name='queue',
77 num_dequeue_records=1):
80 queue_blob = net.AddExternalInput(net.NextName(
'handler'))
81 QueueWrapper.__init__(
82 self, queue_blob, schema, num_dequeue_records=num_dequeue_records)
86 def setup(self, global_init_net):
87 assert self.
_schema,
'This queue does not have a schema.' 89 global_init_net.CreateBlobsQueue(
93 num_blobs=len(self._schema.field_names()),
94 field_names=self._schema.field_names())
97 def enqueue(net, queue, data_blobs, status=None):
99 status = net.NextName(
'status')
103 for blob
in data_blobs:
104 if blob
not in queue_blobs:
105 queue_blobs.append(blob)
107 logger.warning(
"Need to copy blob {} to enqueue".format(blob))
108 queue_blobs.append(net.Copy(blob))
109 results = net.SafeEnqueueBlobs([queue] + queue_blobs, queue_blobs + [status])
113 def dequeue(net, queue, num_blobs, status=None, field_names=None,
115 if field_names
is not None:
116 assert len(field_names) == num_blobs
117 data_names = [net.NextName(name)
for name
in field_names]
119 data_names = [net.NextName(
'data', i)
for i
in range(num_blobs)]
121 status = net.NextName(
'status')
122 results = net.SafeDequeueBlobs(
123 queue, data_names + [status], num_records=num_records)
124 results = list(results)
125 status_blob = results.pop(-1)
126 return results, status_blob
129 def close_queue(step, *queues):
130 close_net =
core.Net(
"close_queue_net")
132 close_net.CloseBlobsQueue([queue], 0)
133 close_step = core.execution_step(
"%s_step" % str(close_net), close_net)
134 return core.execution_step(
135 "%s_wraper_step" % str(close_net),
def read_ex(self, local_init_net, local_finish_net)