3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python import core, scope, workspace, _import_c_extension
as C
18 default_name_suffix =
'db_file_reader' 20 """Reader reads from a DB file. 23 db_file_reader = DBFileReader(db_path='/tmp/cache.db', db_type='LevelDB') 27 db_type: str. DB type of file. A db_type is registed by 28 `REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`. 29 name: str or None. Name of DBFileReader. 30 Optional name to prepend to blobs that will store the data. 31 Default to '<db_name>_<default_name_suffix>'. 33 How many examples are read for each time the read_net is run. 35 If True given, will go through examples in random order endlessly. 36 field_names: List[str]. If the schema.field_names() should not in 37 alphabetic order, it must be specified. 38 Otherwise, schema will be automatically restored with 39 schema.field_names() sorted in alphabetic order. 50 assert db_path
is not None,
"db_path can't be None." 51 assert db_type
in C.registered_dbs(), \
52 "db_type [{db_type}] is not available. \n" \
53 "Choose one of these: {registered_dbs}.".format(
55 registered_dbs=C.registered_dbs(),
58 self.
db_path = os.path.expanduser(db_path)
60 self.
name = name
or '{db_name}_{default_name_suffix}'.format(
73 def _init_name(self, name):
77 def _init_reader_schema(self, field_names=None):
78 """Restore a reader schema from the DB file. 80 If `field_names` given, restore scheme according to it. 82 Overwise, loade blobs from the DB file into the workspace, 83 and restore schema from these blob names. 84 It is also assumed that: 85 1). Each field of the schema have corresponding blobs 86 stored in the DB file. 87 2). Each blob loaded from the DB file corresponds to 88 a field of the schema. 89 3). field_names in the original schema are in alphabetic order, 90 since blob names loaded to the workspace from the DB file 91 will be in alphabetic order. 93 Load a set of blobs from a DB file. From names of these blobs, 94 restore the DB file schema using `from_column_list(...)`. 97 schema: schema.Struct. Used in Reader.__init__(...). 100 return from_column_list(field_names)
102 assert os.path.exists(self.
db_path), \
103 'db_path [{db_path}] does not exist'.format(db_path=self.
db_path)
104 with core.NameScope(self.
name):
106 blob_prefix = scope.CurrentNameScope()
107 workspace.RunOperatorOnce(
116 add_prefix=blob_prefix,
120 blob_name[len(blob_prefix):]
for blob_name
in workspace.Blobs()
121 if blob_name.startswith(blob_prefix)
123 schema = from_column_list(col_names)
127 """From the Dataset, create a _DatasetReader and setup a init_net. 129 Make sure the _init_field_blobs_as_empty(...) is only called once. 131 Because the underlying NewRecord(...) creats blobs by calling 132 NextScopedBlob(...), so that references to previously-initiated 133 empty blobs will be lost, causing accessibility issue. 136 self.ds_reader.setup_ex(init_net, finish_net)
145 self.ds_reader.sort_and_shuffle(init_net)
146 self.ds_reader.computeoffset(init_net)
148 def read(self, read_net):
149 assert self.
ds_reader,
'setup_ex must be called first' 150 return self.ds_reader.read(read_net)
152 def _init_field_blobs_as_empty(self, init_net):
153 """Initialize dataset field blobs by creating an empty record""" 154 with core.NameScope(self.
name):
155 self.ds.init_empty(init_net)
157 def _feed_field_blobs_from_db_file(self, net):
158 """Load from the DB file at db_path and feed dataset field blobs""" 159 assert os.path.exists(self.
db_path), \
160 'db_path [{db_path}] does not exist'.format(db_path=self.
db_path)
167 source_blob_names=self.ds.field_names(),
170 def _extract_db_name_from_db_path(self):
171 """Extract DB name from DB path 173 E.g. given self.db_path=`/tmp/sample.db`, 179 return os.path.basename(self.
db_path).rsplit(
'.', 1)[0]
def setup_ex(self, init_net, finish_net)
def _feed_field_blobs_from_db_file(self, net)
def _init_field_blobs_as_empty(self, init_net)
def _init_reader_schema(self, field_names=None)
def _extract_db_name_from_db_path(self)
string default_name_suffix