Caffe2 - C++ API
A deep learning, cross platform ML framework
db.h
1 #ifndef CAFFE2_CORE_DB_H_
2 #define CAFFE2_CORE_DB_H_
3 
4 #include <mutex>
5 
6 #include "c10/util/Registry.h"
7 #include "caffe2/core/blob_serialization.h"
8 #include "caffe2/proto/caffe2_pb.h"
9 
10 namespace caffe2 {
11 namespace db {
12 
17 enum Mode { READ, WRITE, NEW };
18 
22 class CAFFE2_API Cursor {
23  public:
24  Cursor() { }
25  virtual ~Cursor() { }
31  virtual void Seek(const string& key) = 0;
32  virtual bool SupportsSeek() { return false; }
36  virtual void SeekToFirst() = 0;
40  virtual void Next() = 0;
44  virtual string key() = 0;
48  virtual string value() = 0;
53  virtual bool Valid() = 0;
54 
55  C10_DISABLE_COPY_AND_ASSIGN(Cursor);
56 };
57 
61 class CAFFE2_API Transaction {
62  public:
63  Transaction() { }
64  virtual ~Transaction() { }
68  virtual void Put(const string& key, const string& value) = 0;
72  virtual void Commit() = 0;
73 
74  C10_DISABLE_COPY_AND_ASSIGN(Transaction);
75 };
76 
80 class CAFFE2_API DB {
81  public:
82  DB(const string& /*source*/, Mode mode) : mode_(mode) {}
83  virtual ~DB() { }
87  virtual void Close() = 0;
92  virtual std::unique_ptr<Cursor> NewCursor() = 0;
97  virtual std::unique_ptr<Transaction> NewTransaction() = 0;
98 
99  protected:
100  Mode mode_;
101 
102  C10_DISABLE_COPY_AND_ASSIGN(DB);
103 };
104 
105 // Database classes are registered by their names so we can do optional
106 // dependencies.
107 C10_DECLARE_REGISTRY(Caffe2DBRegistry, DB, const string&, Mode);
108 #define REGISTER_CAFFE2_DB(name, ...) \
109  C10_REGISTER_CLASS(Caffe2DBRegistry, name, __VA_ARGS__)
110 
117 inline unique_ptr<DB> CreateDB(
118  const string& db_type, const string& source, Mode mode) {
119  auto result = Caffe2DBRegistry()->Create(db_type, source, mode);
120  VLOG(1) << ((!result) ? "not found db " : "found db ") << db_type;
121  return result;
122 }
123 
127 inline bool DBExists(const string& db_type, const string& full_db_name) {
128  // Warning! We assume that creating a DB throws an exception if the DB
129  // does not exist. If the DB constructor does not follow this design
130  // pattern,
131  // the returned output (the existence tensor) can be wrong.
132  try {
133  std::unique_ptr<DB> db(
134  caffe2::db::CreateDB(db_type, full_db_name, caffe2::db::READ));
135  return true;
136  } catch (...) {
137  return false;
138  }
139 }
140 
144 class CAFFE2_API DBReader {
145  public:
146 
147  friend class DBReaderSerializer;
148  DBReader() {}
149 
150  DBReader(
151  const string& db_type,
152  const string& source,
153  const int32_t num_shards = 1,
154  const int32_t shard_id = 0) {
155  Open(db_type, source, num_shards, shard_id);
156  }
157 
158  explicit DBReader(const DBReaderProto& proto) {
159  Open(proto.db_type(), proto.source());
160  if (proto.has_key()) {
161  CAFFE_ENFORCE(cursor_->SupportsSeek(),
162  "Encountering a proto that needs seeking but the db type "
163  "does not support it.");
164  cursor_->Seek(proto.key());
165  }
166  num_shards_ = 1;
167  shard_id_ = 0;
168  }
169 
170  explicit DBReader(std::unique_ptr<DB> db)
171  : db_type_("<memory-type>"),
172  source_("<memory-source>"),
173  db_(std::move(db)) {
174  CAFFE_ENFORCE(db_.get(), "Passed null db");
175  cursor_ = db_->NewCursor();
176  }
177 
178  void Open(
179  const string& db_type,
180  const string& source,
181  const int32_t num_shards = 1,
182  const int32_t shard_id = 0) {
183  // Note(jiayq): resetting is needed when we re-open e.g. leveldb where no
184  // concurrent access is allowed.
185  cursor_.reset();
186  db_.reset();
187  db_type_ = db_type;
188  source_ = source;
189  db_ = CreateDB(db_type_, source_, READ);
190  CAFFE_ENFORCE(
191  db_,
192  "Cannot find db implementation of type ",
193  db_type,
194  " (while trying to open ",
195  source_,
196  ")");
197  InitializeCursor(num_shards, shard_id);
198  }
199 
200  void Open(
201  unique_ptr<DB>&& db,
202  const int32_t num_shards = 1,
203  const int32_t shard_id = 0) {
204  cursor_.reset();
205  db_.reset();
206  db_ = std::move(db);
207  CAFFE_ENFORCE(db_.get(), "Passed null db");
208  InitializeCursor(num_shards, shard_id);
209  }
210 
211  public:
228  void Read(string* key, string* value) const {
229  CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
230  std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
231  *key = cursor_->key();
232  *value = cursor_->value();
233 
234  // In sharded mode, each read skips num_shards_ records
235  for (uint32_t s = 0; s < num_shards_; s++) {
236  cursor_->Next();
237  if (!cursor_->Valid()) {
238  MoveToBeginning();
239  break;
240  }
241  }
242  }
243 
247  void SeekToFirst() const {
248  CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
249  std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
250  MoveToBeginning();
251  }
252 
260  inline Cursor* cursor() const {
261  VLOG(1) << "Usually for a DBReader you should use Read() to be "
262  "thread safe. Consider refactoring your code.";
263  return cursor_.get();
264  }
265 
266  private:
267  void InitializeCursor(const int32_t num_shards, const int32_t shard_id) {
268  CAFFE_ENFORCE(num_shards >= 1);
269  CAFFE_ENFORCE(shard_id >= 0);
270  CAFFE_ENFORCE(shard_id < num_shards);
271  num_shards_ = num_shards;
272  shard_id_ = shard_id;
273  cursor_ = db_->NewCursor();
274  SeekToFirst();
275  }
276 
277  void MoveToBeginning() const {
278  cursor_->SeekToFirst();
279  for (uint32_t s = 0; s < shard_id_; s++) {
280  cursor_->Next();
281  CAFFE_ENFORCE(
282  cursor_->Valid(), "Db has fewer rows than shard id: ", s, shard_id_);
283  }
284  }
285 
286  string db_type_;
287  string source_;
288  unique_ptr<DB> db_;
289  unique_ptr<Cursor> cursor_;
290  mutable std::mutex reader_mutex_;
291  uint32_t num_shards_{};
292  uint32_t shard_id_{};
293 
294  C10_DISABLE_COPY_AND_ASSIGN(DBReader);
295 };
296 
297 class CAFFE2_API DBReaderSerializer : public BlobSerializerBase {
298  public:
303  void Serialize(
304  const void* pointer,
305  TypeMeta typeMeta,
306  const string& name,
307  BlobSerializerBase::SerializationAcceptor acceptor) override;
308 };
309 
310 class CAFFE2_API DBReaderDeserializer : public BlobDeserializerBase {
311  public:
312  void Deserialize(const BlobProto& proto, Blob* blob) override;
313 };
314 
315 } // namespace db
316 } // namespace caffe2
317 
318 #endif // CAFFE2_CORE_DB_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:228
An abstract class for the current database transaction while writing.
Definition: db.h:61
An abstract class for the cursor of the database while reading.
Definition: db.h:22
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:144
Cursor * cursor() const
Returns the underlying cursor of the db reader.
Definition: db.h:260
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
An abstract class for accessing a database of key-value pairs.
Definition: db.h:80
void SeekToFirst() const
Seeks to the first key.
Definition: db.h:247
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324
BlobSerializerBase is an abstract class that serializes a blob to a string.