1 #ifndef CAFFE2_CORE_DB_H_ 2 #define CAFFE2_CORE_DB_H_ 6 #include "c10/util/Registry.h" 7 #include "caffe2/core/blob_serialization.h" 8 #include "caffe2/proto/caffe2_pb.h" 17 enum Mode { READ, WRITE, NEW };
31 virtual void Seek(
const string& key) = 0;
32 virtual bool SupportsSeek() {
return false; }
36 virtual void SeekToFirst() = 0;
40 virtual void Next() = 0;
44 virtual string key() = 0;
48 virtual string value() = 0;
53 virtual bool Valid() = 0;
55 C10_DISABLE_COPY_AND_ASSIGN(
Cursor);
68 virtual void Put(
const string& key,
const string& value) = 0;
72 virtual void Commit() = 0;
82 DB(
const string& , Mode mode) : mode_(mode) {}
87 virtual void Close() = 0;
92 virtual std::unique_ptr<Cursor> NewCursor() = 0;
97 virtual std::unique_ptr<Transaction> NewTransaction() = 0;
102 C10_DISABLE_COPY_AND_ASSIGN(
DB);
107 C10_DECLARE_REGISTRY(Caffe2DBRegistry,
DB,
const string&, Mode);
108 #define REGISTER_CAFFE2_DB(name, ...) \ 109 C10_REGISTER_CLASS(Caffe2DBRegistry, name, __VA_ARGS__) 117 inline unique_ptr<DB> CreateDB(
118 const string& db_type,
const string& source, Mode mode) {
119 auto result = Caffe2DBRegistry()->Create(db_type, source, mode);
120 VLOG(1) << ((!result) ?
"not found db " :
"found db ") << db_type;
127 inline bool DBExists(
const string& db_type,
const string& full_db_name) {
133 std::unique_ptr<DB> db(
134 caffe2::db::CreateDB(db_type, full_db_name, caffe2::db::READ));
151 const string& db_type,
152 const string& source,
153 const int32_t num_shards = 1,
154 const int32_t shard_id = 0) {
155 Open(db_type, source, num_shards, shard_id);
158 explicit DBReader(
const DBReaderProto& proto) {
159 Open(proto.db_type(), proto.source());
160 if (proto.has_key()) {
161 CAFFE_ENFORCE(cursor_->SupportsSeek(),
162 "Encountering a proto that needs seeking but the db type " 163 "does not support it.");
164 cursor_->Seek(proto.key());
170 explicit DBReader(std::unique_ptr<DB> db)
171 : db_type_(
"<memory-type>"),
172 source_(
"<memory-source>"),
174 CAFFE_ENFORCE(db_.get(),
"Passed null db");
175 cursor_ = db_->NewCursor();
179 const string& db_type,
180 const string& source,
181 const int32_t num_shards = 1,
182 const int32_t shard_id = 0) {
189 db_ = CreateDB(db_type_, source_, READ);
192 "Cannot find db implementation of type ",
194 " (while trying to open ",
197 InitializeCursor(num_shards, shard_id);
202 const int32_t num_shards = 1,
203 const int32_t shard_id = 0) {
207 CAFFE_ENFORCE(db_.get(),
"Passed null db");
208 InitializeCursor(num_shards, shard_id);
228 void Read(
string* key,
string* value)
const {
229 CAFFE_ENFORCE(cursor_ !=
nullptr,
"Reader not initialized.");
230 std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
231 *key = cursor_->key();
232 *value = cursor_->value();
235 for (uint32_t s = 0; s < num_shards_; s++) {
237 if (!cursor_->Valid()) {
248 CAFFE_ENFORCE(cursor_ !=
nullptr,
"Reader not initialized.");
249 std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
261 VLOG(1) <<
"Usually for a DBReader you should use Read() to be " 262 "thread safe. Consider refactoring your code.";
263 return cursor_.get();
267 void InitializeCursor(
const int32_t num_shards,
const int32_t shard_id) {
268 CAFFE_ENFORCE(num_shards >= 1);
269 CAFFE_ENFORCE(shard_id >= 0);
270 CAFFE_ENFORCE(shard_id < num_shards);
271 num_shards_ = num_shards;
272 shard_id_ = shard_id;
273 cursor_ = db_->NewCursor();
277 void MoveToBeginning()
const {
278 cursor_->SeekToFirst();
279 for (uint32_t s = 0; s < shard_id_; s++) {
282 cursor_->Valid(),
"Db has fewer rows than shard id: ", s, shard_id_);
289 unique_ptr<Cursor> cursor_;
290 mutable std::mutex reader_mutex_;
291 uint32_t num_shards_{};
292 uint32_t shard_id_{};
294 C10_DISABLE_COPY_AND_ASSIGN(
DBReader);
307 BlobSerializerBase::SerializationAcceptor acceptor)
override;
312 void Deserialize(
const BlobProto& proto,
Blob* blob)
override;
318 #endif // CAFFE2_CORE_DB_H_ Blob is a general container that hosts a typed pointer.
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
An abstract class for the current database transaction while writing.
An abstract class for the cursor of the database while reading.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
A reader wrapper for DB that also allows us to serialize it.
Cursor * cursor() const
Returns the underlying cursor of the db reader.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
An abstract class for accessing a database of key-value pairs.
void SeekToFirst() const
Seeks to the first key.
BlobSerializerBase is an abstract class that serializes a blob to a string.