Caffe2 - Python API
A deep learning, cross platform ML framework
db_file_reader.py
1 ## @package db_file_reader
2 # Module caffe2.python.db_file_reader
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, scope, workspace, _import_c_extension as C
9 from caffe2.python.dataio import Reader
10 from caffe2.python.dataset import Dataset
11 from caffe2.python.schema import from_column_list
12 
13 import os
14 
15 
17 
18  default_name_suffix = 'db_file_reader'
19 
20  """Reader reads from a DB file.
21 
22  Example usage:
23  db_file_reader = DBFileReader(db_path='/tmp/cache.db', db_type='LevelDB')
24 
25  Args:
26  db_path: str.
27  db_type: str. DB type of file. A db_type is registed by
28  `REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`.
29  name: str or None. Name of DBFileReader.
30  Optional name to prepend to blobs that will store the data.
31  Default to '<db_name>_<default_name_suffix>'.
32  batch_size: int.
33  How many examples are read for each time the read_net is run.
34  loop_over: bool.
35  If True given, will go through examples in random order endlessly.
36  field_names: List[str]. If the schema.field_names() should not in
37  alphabetic order, it must be specified.
38  Otherwise, schema will be automatically restored with
39  schema.field_names() sorted in alphabetic order.
40  """
41  def __init__(
42  self,
43  db_path,
44  db_type,
45  name=None,
46  batch_size=100,
47  loop_over=False,
48  field_names=None,
49  ):
50  assert db_path is not None, "db_path can't be None."
51  assert db_type in C.registered_dbs(), \
52  "db_type [{db_type}] is not available. \n" \
53  "Choose one of these: {registered_dbs}.".format(
54  db_type=db_type,
55  registered_dbs=C.registered_dbs(),
56  )
57 
58  self.db_path = os.path.expanduser(db_path)
59  self.db_type = db_type
60  self.name = name or '{db_name}_{default_name_suffix}'.format(
61  db_name=self._extract_db_name_from_db_path(),
62  default_name_suffix=self.default_name_suffix,
63  )
64  self.batch_size = batch_size
65  self.loop_over = loop_over
66 
67  # Before self._init_reader_schema(...),
68  # self.db_path and self.db_type are required to be set.
69  super(DBFileReader, self).__init__(self._init_reader_schema(field_names))
70  self.ds = Dataset(self._schema, self.name + '_dataset')
71  self.ds_reader = None
72 
73  def _init_name(self, name):
74  return name or self._extract_db_name_from_db_path(
75  ) + '_db_file_reader'
76 
77  def _init_reader_schema(self, field_names=None):
78  """Restore a reader schema from the DB file.
79 
80  If `field_names` given, restore scheme according to it.
81 
82  Overwise, loade blobs from the DB file into the workspace,
83  and restore schema from these blob names.
84  It is also assumed that:
85  1). Each field of the schema have corresponding blobs
86  stored in the DB file.
87  2). Each blob loaded from the DB file corresponds to
88  a field of the schema.
89  3). field_names in the original schema are in alphabetic order,
90  since blob names loaded to the workspace from the DB file
91  will be in alphabetic order.
92 
93  Load a set of blobs from a DB file. From names of these blobs,
94  restore the DB file schema using `from_column_list(...)`.
95 
96  Returns:
97  schema: schema.Struct. Used in Reader.__init__(...).
98  """
99  if field_names:
100  return from_column_list(field_names)
101 
102  assert os.path.exists(self.db_path), \
103  'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
104  with core.NameScope(self.name):
105  # blob_prefix is for avoiding name conflict in workspace
106  blob_prefix = scope.CurrentNameScope()
107  workspace.RunOperatorOnce(
108  core.CreateOperator(
109  'Load',
110  [],
111  [],
112  absolute_path=True,
113  db=self.db_path,
114  db_type=self.db_type,
115  load_all=True,
116  add_prefix=blob_prefix,
117  )
118  )
119  col_names = [
120  blob_name[len(blob_prefix):] for blob_name in workspace.Blobs()
121  if blob_name.startswith(blob_prefix)
122  ]
123  schema = from_column_list(col_names)
124  return schema
125 
126  def setup_ex(self, init_net, finish_net):
127  """From the Dataset, create a _DatasetReader and setup a init_net.
128 
129  Make sure the _init_field_blobs_as_empty(...) is only called once.
130 
131  Because the underlying NewRecord(...) creats blobs by calling
132  NextScopedBlob(...), so that references to previously-initiated
133  empty blobs will be lost, causing accessibility issue.
134  """
135  if self.ds_reader:
136  self.ds_reader.setup_ex(init_net, finish_net)
137  else:
138  self._init_field_blobs_as_empty(init_net)
139  self._feed_field_blobs_from_db_file(init_net)
140  self.ds_reader = self.ds.random_reader(
141  init_net,
142  batch_size=self.batch_size,
143  loop_over=self.loop_over,
144  )
145  self.ds_reader.sort_and_shuffle(init_net)
146  self.ds_reader.computeoffset(init_net)
147 
148  def read(self, read_net):
149  assert self.ds_reader, 'setup_ex must be called first'
150  return self.ds_reader.read(read_net)
151 
152  def _init_field_blobs_as_empty(self, init_net):
153  """Initialize dataset field blobs by creating an empty record"""
154  with core.NameScope(self.name):
155  self.ds.init_empty(init_net)
156 
157  def _feed_field_blobs_from_db_file(self, net):
158  """Load from the DB file at db_path and feed dataset field blobs"""
159  assert os.path.exists(self.db_path), \
160  'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
161  net.Load(
162  [],
163  self.ds.get_blobs(),
164  db=self.db_path,
165  db_type=self.db_type,
166  absolute_path=True,
167  source_blob_names=self.ds.field_names(),
168  )
169 
170  def _extract_db_name_from_db_path(self):
171  """Extract DB name from DB path
172 
173  E.g. given self.db_path=`/tmp/sample.db`,
174  it returns `sample`.
175 
176  Returns:
177  db_name: str.
178  """
179  return os.path.basename(self.db_path).rsplit('.', 1)[0]
def setup_ex(self, init_net, finish_net)
def _init_reader_schema(self, field_names=None)