Caffe2 - Python API
A deep learning, cross platform ML framework
record_queue.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package record_queue
17 # Module caffe2.python.record_queue
18 """
19 Implementation of a queue wrapper.
20 """
21 from __future__ import absolute_import
22 from __future__ import division
23 from __future__ import print_function
24 from __future__ import unicode_literals
25 
26 from caffe2.python import core
27 from caffe2.python.dataio import Reader, Writer
28 from caffe2.python.schema import (
29  Struct, Field, from_column_list)
30 
31 
33  def __init__(self, blobs_queue, schema, name=None):
34  """Don't call this directly. Instead, use dataset.reader()"""
35  super(_QueueReader, self).__init__(schema)
36  self.blobs_queue = blobs_queue
37  self.name = name
38 
39  def read(self, read_net):
40  with core.NameScope(read_net.NextName(self.name)):
41  status = read_net.NextName()
42  fields = read_net.SafeDequeueBlobs(
43  self.blobs_queue, self._schema.field_names() + [status])
44  return (fields[-1], fields[:-1])
45 
46 
48  def __init__(self, blobs_queue, schema):
49  self.blobs_queue = blobs_queue
50  self.schema = schema
51 
52  def write(self, writer_net, fields):
53  if isinstance(fields, Field):
54  fields = fields.field_blobs()
55  writer_net.CheckDatasetConsistency(
56  fields, [], fields=self.schema.field_names())
57  status = writer_net.NextName()
58  writer_net.SafeEnqueueBlobs(
59  [self.blobs_queue] + fields, fields + [status])
60  return status
61 
62 
63 class RecordQueue(object):
64  """ The class is used to feed data with some process from a reader into a
65  queue and provider a reader interface for data fetching from the queue.
66  """
67  def __init__(self, fields, name=None, capacity=1,
68  enforce_unique_name=False, num_threads=1):
69  assert isinstance(fields, list) or isinstance(fields, Struct), (
70  'fields must be either a Struct or a list of raw field names.')
71  if isinstance(fields, list):
72  fields = from_column_list(fields)
73  self.schema = fields
74  self.name = name or 'queue'
75  self.num_threads = num_threads
76  num_blobs = len(self.schema.field_names())
77  init_net = core.Net(self.name + '/init_net')
78  self.blobs_queue = init_net.CreateBlobsQueue(
79  [], 1,
80  capacity=capacity,
81  num_blobs=num_blobs,
82  enforce_unique_name=enforce_unique_name)
83  core.workspace.RunNetOnce(init_net)
84 
85  self.writer = _QueueWriter(self.blobs_queue, self.schema)
86  reader_name = self.name + '_reader'
87  self.reader = _QueueReader(self.blobs_queue, self.schema, reader_name)
88 
89  exit_net = core.Net(self.name + '/exit_net')
90  exit_net.CloseBlobsQueue(self.blobs_queue, 0)
91  self.exit_step = core.execution_step(
92  '{}_close_step'.format(str(exit_net)),
93  exit_net)
94 
95  def build(self, reader, process=None):
96  """
97  Build the producer_step to feed data from reader into the queue, and
98  return the reader interface.
99  Inputs:
100  reader: read data which will be stored in the queue.
101  process: preprocess data before enqueue.
102  Outputs:
103  reader: reader to fetch the data from the queue.
104  producer_step: the step insert the data into the queue. Should be
105  run with comsume_step together.
106  exit_step: the step to close queue
107  schema: the schema for the reader.
108  """
109  producer_steps = []
110  for i in range(self.num_threads):
111  name = 'reader_' + str(i)
112  net_reader = core.Net(name)
113  should_stop, fields = reader.read_record(net_reader)
114  step_read = core.execution_step(name, net_reader)
115 
116  name = 'queue_writer' + str(i)
117  net_prod = core.Net(name)
118  field_blobs = fields.field_blobs()
119  if process:
120  field_blobs = process(net_prod, fields).field_blobs()
121 
122  self.writer.write(net_prod, field_blobs)
123  step_prod = core.execution_step(name, net_prod)
124  step = core.execution_step(
125  'producer_' + str(i),
126  [step_read, step_prod],
127  should_stop_blob=should_stop)
128  producer_steps.append(step)
129  producer_step = core.execution_step(
130  'producers',
131  producer_steps,
132  concurrent_substeps=True)
133  return self.reader, producer_step, self.exit_step, self.schema
def build(self, reader, process=None)
Definition: record_queue.py:95
def __init__(self, blobs_queue, schema, name=None)
Definition: record_queue.py:33