Caffe2 - Python API
A deep learning, cross platform ML framework
record_queue.py
1 ## @package record_queue
2 # Module caffe2.python.record_queue
3 """
4 Implementation of a queue wrapper.
5 """
6 from __future__ import absolute_import
7 from __future__ import division
8 from __future__ import print_function
9 from __future__ import unicode_literals
10 
11 from caffe2.python import core
12 from caffe2.python.dataio import Reader, Writer
13 from caffe2.python.schema import (
14  Struct, Field, from_column_list)
15 
16 
18  def __init__(self, blobs_queue, schema, name=None):
19  """Don't call this directly. Instead, use dataset.reader()"""
20  super(_QueueReader, self).__init__(schema)
21  self.blobs_queue = blobs_queue
22  self.name = name
23 
24  def read(self, read_net):
25  with core.NameScope(read_net.NextName(self.name)):
26  status = read_net.NextName()
27  fields = read_net.SafeDequeueBlobs(
28  self.blobs_queue, self._schema.field_names() + [status])
29  return (fields[-1], fields[:-1])
30 
31 
33  def __init__(self, blobs_queue, schema):
34  self.blobs_queue = blobs_queue
35  self.schema = schema
36 
37  def write(self, writer_net, fields):
38  if isinstance(fields, Field):
39  fields = fields.field_blobs()
40  writer_net.CheckDatasetConsistency(
41  fields, [], fields=self.schema.field_names())
42  status = writer_net.NextName()
43  writer_net.SafeEnqueueBlobs(
44  [self.blobs_queue] + fields, fields + [status])
45  return status
46 
47 
48 class RecordQueue(object):
49  """ The class is used to feed data with some process from a reader into a
50  queue and provider a reader interface for data fetching from the queue.
51  """
52  def __init__(self, fields, name=None, capacity=1,
53  enforce_unique_name=False, num_threads=1):
54  assert isinstance(fields, list) or isinstance(fields, Struct), (
55  'fields must be either a Struct or a list of raw field names.')
56  if isinstance(fields, list):
57  fields = from_column_list(fields)
58  self.schema = fields
59  self.name = name or 'queue'
60  self.num_threads = num_threads
61  num_blobs = len(self.schema.field_names())
62  init_net = core.Net(self.name + '/init_net')
63  self.blobs_queue = init_net.CreateBlobsQueue(
64  [], 1,
65  capacity=capacity,
66  num_blobs=num_blobs,
67  enforce_unique_name=enforce_unique_name)
68  core.workspace.RunNetOnce(init_net)
69 
70  self.writer = _QueueWriter(self.blobs_queue, self.schema)
71  reader_name = self.name + '_reader'
72  self.reader = _QueueReader(self.blobs_queue, self.schema, reader_name)
73 
74  exit_net = core.Net(self.name + '/exit_net')
75  exit_net.CloseBlobsQueue(self.blobs_queue, 0)
76  self.exit_step = core.execution_step(
77  '{}_close_step'.format(str(exit_net)),
78  exit_net)
79 
80  def build(self, reader, process=None):
81  """
82  Build the producer_step to feed data from reader into the queue, and
83  return the reader interface.
84  Inputs:
85  reader: read data which will be stored in the queue.
86  process: preprocess data before enqueue.
87  Outputs:
88  reader: reader to fetch the data from the queue.
89  producer_step: the step insert the data into the queue. Should be
90  run with comsume_step together.
91  exit_step: the step to close queue
92  schema: the schema for the reader.
93  """
94  producer_steps = []
95  for i in range(self.num_threads):
96  name = 'reader_' + str(i)
97  net_reader = core.Net(name)
98  should_stop, fields = reader.read_record(net_reader)
99  step_read = core.execution_step(name, net_reader)
100 
101  name = 'queue_writer' + str(i)
102  net_prod = core.Net(name)
103  field_blobs = fields.field_blobs()
104  if process:
105  field_blobs = process(net_prod, fields).field_blobs()
106 
107  self.writer.write(net_prod, field_blobs)
108  step_prod = core.execution_step(name, net_prod)
109  step = core.execution_step(
110  'producer_' + str(i),
111  [step_read, step_prod],
112  should_stop_blob=should_stop)
113  producer_steps.append(step)
114  producer_step = core.execution_step(
115  'producers',
116  producer_steps,
117  concurrent_substeps=True)
118  return self.reader, producer_step, self.exit_step, self.schema
def build(self, reader, process=None)
Definition: record_queue.py:80
def __init__(self, blobs_queue, schema, name=None)
Definition: record_queue.py:18