Caffe2 - C++ API
A deep learning, cross platform ML framework
db.h
1 
17 #ifndef CAFFE2_CORE_DB_H_
18 #define CAFFE2_CORE_DB_H_
19 
20 #include <mutex>
21 
22 #include "caffe2/core/blob_serialization.h"
23 #include "caffe2/core/registry.h"
24 #include "caffe2/proto/caffe2.pb.h"
25 
26 namespace caffe2 {
27 namespace db {
28 
33 enum Mode { READ, WRITE, NEW };
34 
38 class Cursor {
39  public:
40  Cursor() { }
41  virtual ~Cursor() { }
47  virtual void Seek(const string& key) = 0;
48  virtual bool SupportsSeek() { return false; }
52  virtual void SeekToFirst() = 0;
56  virtual void Next() = 0;
60  virtual string key() = 0;
64  virtual string value() = 0;
69  virtual bool Valid() = 0;
70 
71  DISABLE_COPY_AND_ASSIGN(Cursor);
72 };
73 
77 class Transaction {
78  public:
79  Transaction() { }
80  virtual ~Transaction() { }
84  virtual void Put(const string& key, const string& value) = 0;
88  virtual void Commit() = 0;
89 
90  DISABLE_COPY_AND_ASSIGN(Transaction);
91 };
92 
96 class DB {
97  public:
98  DB(const string& /*source*/, Mode mode) : mode_(mode) {}
99  virtual ~DB() { }
103  virtual void Close() = 0;
108  virtual std::unique_ptr<Cursor> NewCursor() = 0;
113  virtual std::unique_ptr<Transaction> NewTransaction() = 0;
114 
115  protected:
116  Mode mode_;
117 
118  DISABLE_COPY_AND_ASSIGN(DB);
119 };
120 
121 // Database classes are registered by their names so we can do optional
122 // dependencies.
123 CAFFE_DECLARE_REGISTRY(Caffe2DBRegistry, DB, const string&, Mode);
124 #define REGISTER_CAFFE2_DB(name, ...) \
125  CAFFE_REGISTER_CLASS(Caffe2DBRegistry, name, __VA_ARGS__)
126 
133 inline unique_ptr<DB> CreateDB(
134  const string& db_type, const string& source, Mode mode) {
135  auto result = Caffe2DBRegistry()->Create(db_type, source, mode);
136  VLOG(1) << ((!result) ? "not found db " : "found db ") << db_type;
137  return result;
138 }
139 
143 inline bool DBExists(const string& db_type, const string& full_db_name) {
144  // Warning! We assume that creating a DB throws an exception if the DB
145  // does not exist. If the DB constructor does not follow this design
146  // pattern,
147  // the returned output (the existence tensor) can be wrong.
148  try {
149  std::unique_ptr<DB> db(
150  caffe2::db::CreateDB(db_type, full_db_name, caffe2::db::READ));
151  return true;
152  } catch (...) {
153  return false;
154  }
155 }
156 
160 class DBReader {
161  public:
162 
163  friend class DBReaderSerializer;
164  DBReader() {}
165 
166  DBReader(
167  const string& db_type,
168  const string& source,
169  const int32_t num_shards = 1,
170  const int32_t shard_id = 0) {
171  Open(db_type, source, num_shards, shard_id);
172  }
173 
174  explicit DBReader(const DBReaderProto& proto) {
175  Open(proto.db_type(), proto.source());
176  if (proto.has_key()) {
177  CAFFE_ENFORCE(cursor_->SupportsSeek(),
178  "Encountering a proto that needs seeking but the db type "
179  "does not support it.");
180  cursor_->Seek(proto.key());
181  }
182  num_shards_ = 1;
183  shard_id_ = 0;
184  }
185 
186  explicit DBReader(std::unique_ptr<DB> db)
187  : db_type_("<memory-type>"),
188  source_("<memory-source>"),
189  db_(std::move(db)) {
190  CAFFE_ENFORCE(db_.get(), "Passed null db");
191  cursor_ = db_->NewCursor();
192  }
193 
194  void Open(
195  const string& db_type,
196  const string& source,
197  const int32_t num_shards = 1,
198  const int32_t shard_id = 0) {
199  // Note(jiayq): resetting is needed when we re-open e.g. leveldb where no
200  // concurrent access is allowed.
201  cursor_.reset();
202  db_.reset();
203  db_type_ = db_type;
204  source_ = source;
205  db_ = CreateDB(db_type_, source_, READ);
206  CAFFE_ENFORCE(db_, "Cannot open db: ", source_, " of type ", db_type_);
207  InitializeCursor(num_shards, shard_id);
208  }
209 
210  void Open(
211  unique_ptr<DB>&& db,
212  const int32_t num_shards = 1,
213  const int32_t shard_id = 0) {
214  cursor_.reset();
215  db_.reset();
216  db_ = std::move(db);
217  CAFFE_ENFORCE(db_.get(), "Passed null db");
218  InitializeCursor(num_shards, shard_id);
219  }
220 
221  public:
238  void Read(string* key, string* value) const {
239  CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
240  std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
241  *key = cursor_->key();
242  *value = cursor_->value();
243 
244  // In sharded mode, each read skips num_shards_ records
245  for (int s = 0; s < num_shards_; s++) {
246  cursor_->Next();
247  if (!cursor_->Valid()) {
248  MoveToBeginning();
249  break;
250  }
251  }
252  }
253 
257  void SeekToFirst() const {
258  CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
259  std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
260  MoveToBeginning();
261  }
262 
270  inline Cursor* cursor() const {
271  LOG(ERROR) << "Usually for a DBReader you should use Read() to be "
272  "thread safe. Consider refactoring your code.";
273  return cursor_.get();
274  }
275 
276  private:
277  void InitializeCursor(const int32_t num_shards, const int32_t shard_id) {
278  CAFFE_ENFORCE(num_shards >= 1);
279  CAFFE_ENFORCE(shard_id >= 0);
280  CAFFE_ENFORCE(shard_id < num_shards);
281  num_shards_ = num_shards;
282  shard_id_ = shard_id;
283  cursor_ = db_->NewCursor();
284  SeekToFirst();
285  }
286 
287  void MoveToBeginning() const {
288  cursor_->SeekToFirst();
289  for (auto s = 0; s < shard_id_; s++) {
290  cursor_->Next();
291  CAFFE_ENFORCE(
292  cursor_->Valid(), "Db has less rows than shard id: ", s, shard_id_);
293  }
294  }
295 
296  string db_type_;
297  string source_;
298  unique_ptr<DB> db_;
299  unique_ptr<Cursor> cursor_;
300  mutable std::mutex reader_mutex_;
301  uint32_t num_shards_;
302  uint32_t shard_id_;
303 
304  DISABLE_COPY_AND_ASSIGN(DBReader);
305 };
306 
308  public:
313  void Serialize(
314  const Blob& blob,
315  const string& name,
316  BlobSerializerBase::SerializationAcceptor acceptor) override;
317 };
318 
320  public:
321  void Deserialize(const BlobProto& proto, Blob* blob) override;
322 };
323 
324 } // namespace db
325 } // namespace caffe2
326 
327 #endif // CAFFE2_CORE_DB_H_
virtual bool Valid()=0
Returns whether the current location is valid - for example, if we have reached the end of the databa...
virtual string key()=0
Returns the current key.
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
virtual void Seek(const string &key)=0
Seek to a specific key (or if the key does not exist, seek to the immediate next).
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:238
An abstract class for the current database transaction while writing.
Definition: db.h:77
An abstract class for the cursor of the database while reading.
Definition: db.h:38
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:160
Cursor * cursor() const
Returns the underlying cursor of the db reader.
Definition: db.h:270
virtual void SeekToFirst()=0
Seek to the first key in the database.
Copyright (c) 2016-present, Facebook, Inc.
An abstract class for accessing a database of key-value pairs.
Definition: db.h:96
void SeekToFirst() const
Seeks to the first key.
Definition: db.h:257
virtual void Next()=0
Go to the next location in the database.
virtual string value()=0
Returns the current value.
BlobSerializerBase is an abstract class that serializes a blob to a string.