Caffe2 - Python API
A deep learning, cross platform ML framework
cached_reader.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package cached_reader
17 # Module caffe2.python.cached_reader
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 import os
24 
25 from caffe2.python import core
26 from caffe2.python.dataio import Reader
27 from caffe2.python.dataset import Dataset
28 from caffe2.python.pipeline import pipe
29 from caffe2.python.task import Cluster, TaskGroup
30 
31 
33  """
34  Reader with persistent in-file cache.
35 
36  Example usage:
37  cached_reader = CachedReader(reader)
38  build_cache_step = cached_reader.build_cache('/tmp/cache.db')
39  with LocalSession() as session:
40  session.run(build_cache_step)
41 
42  Every time new reader is created, it's expected that build_cache will be
43  called before setup_ex and usage of the reader. build_cache will check
44  existence of provided file path and in case it's missing will initialize it
45  by reading data from original reader. All consequent attempts to read will
46  ignore original reader (i.e. no additional data will be read from it).
47  """
48 
49  def __init__(self, reader, db_type='leveldb', name='cached_reader'):
50  super(CachedReader, self).__init__(reader.schema())
51  self.original_reader = reader
52  self.cache_path = None
53  self.ds_reader = None
54  self.ds = Dataset(self._schema, name)
55  self.db_type = db_type
56  self.name = name
57  self.field_names = self._schema.field_names()
58 
59  def setup_ex(self, init_net, finish_net):
60  assert self.cache_path, 'build_cache must be called first'
61  self._init_dataset(init_net)
62  self._load_from_file(init_net)
63  self.ds_reader = self.ds.reader(init_net, batch_size=100)
64 
65  def read(self, read_net):
66  assert self.ds_reader, 'setup must be called first'
67  return self.ds_reader.read(read_net)
68 
69  def has_cache(self):
70  return self.cache_path and os.path.exists(self.cache_path)
71 
72  def build_cache(self, cache_path, overwrite=False):
73  if not self.has_cache() or overwrite:
74  self.cache_path = cache_path
75  if self.has_cache() and not overwrite:
76  # cache already exists, no need to rebuild it
77  return core.execution_step('build_step', [])
78 
79  init_net = core.Net('init')
80  self._init_dataset(init_net)
81  with Cluster(), core.NameScope(self.name), TaskGroup() as copy_tg:
82  pipe(self.original_reader, self.ds.writer(), num_threads=16)
83  copy_step = copy_tg.to_task().get_step()
84  save_net = core.Net('save')
85  self._save_to_file(save_net)
86 
87  return core.execution_step('build_cache', [init_net, copy_step, save_net])
88 
89  def _init_dataset(self, init_net):
90  with core.NameScope(self.name):
91  self.ds.init_empty(init_net)
92 
93  def _save_to_file(self, net):
94  net.Save(
95  self.ds.content().field_blobs(),
96  [],
97  db=self.cache_path,
98  db_type=self.db_type,
99  blob_name_overrides=self.field_names,
100  absolute_path=True,
101  )
102 
103  def _load_from_file(self, net):
104  net.Load(
105  [],
106  self.ds.content().field_blobs(),
107  db=self.cache_path,
108  db_type=self.db_type,
109  absolute_path=True,
110  source_blob_names=self.field_names,
111  )