Caffe2 - C++ API
A deep learning, cross platform ML framework
db.cc
1 
17 #include "caffe2/core/db.h"
18 
19 #include <mutex>
20 
21 #include "caffe2/core/blob_serialization.h"
22 #include "caffe2/core/logging.h"
23 
24 namespace caffe2 {
25 
26 CAFFE_KNOWN_TYPE(db::DBReader);
27 CAFFE_KNOWN_TYPE(db::Cursor);
28 
29 namespace db {
30 
31 CAFFE_DEFINE_REGISTRY(Caffe2DBRegistry, DB, const string&, Mode);
32 
33 // Below, we provide a bare minimum database "minidb" as a reference
34 // implementation as well as a portable choice to store data.
35 // Note that the MiniDB classes are not exposed via a header file - they should
36 // be created directly via the db interface. See MiniDB for details.
37 
38 class MiniDBCursor : public Cursor {
39  public:
40  explicit MiniDBCursor(FILE* f, std::mutex* mutex)
41  : file_(f), lock_(*mutex), valid_(true) {
42  // We call Next() to read in the first entry.
43  Next();
44  }
45  ~MiniDBCursor() {}
46 
47  void Seek(const string& /*key*/) override {
48  LOG(FATAL) << "MiniDB does not support seeking to a specific key.";
49  }
50 
51  void SeekToFirst() override {
52  fseek(file_, 0, SEEK_SET);
53  CAFFE_ENFORCE(!feof(file_), "Hmm, empty file?");
54  // Read the first item.
55  valid_ = true;
56  Next();
57  }
58 
59  void Next() override {
60  // First, read in the key and value length.
61  if (fread(&key_len_, sizeof(int), 1, file_) == 0) {
62  // Reaching EOF.
63  VLOG(1) << "EOF reached, setting valid to false";
64  valid_ = false;
65  return;
66  }
67  CAFFE_ENFORCE_EQ(fread(&value_len_, sizeof(int), 1, file_), 1);
68  CAFFE_ENFORCE_GT(key_len_, 0);
69  CAFFE_ENFORCE_GT(value_len_, 0);
70  // Resize if the key and value len is larger than the current one.
71  if (key_len_ > key_.size()) {
72  key_.resize(key_len_);
73  }
74  if (value_len_ > value_.size()) {
75  value_.resize(value_len_);
76  }
77  // Actually read in the contents.
78  CAFFE_ENFORCE_EQ(
79  fread(key_.data(), sizeof(char), key_len_, file_), key_len_);
80  CAFFE_ENFORCE_EQ(
81  fread(value_.data(), sizeof(char), value_len_, file_), value_len_);
82  // Note(Yangqing): as we read the file, the cursor naturally moves to the
83  // beginning of the next entry.
84  }
85 
86  string key() override {
87  CAFFE_ENFORCE(valid_, "Cursor is at invalid location!");
88  return string(key_.data(), key_len_);
89  }
90 
91  string value() override {
92  CAFFE_ENFORCE(valid_, "Cursor is at invalid location!");
93  return string(value_.data(), value_len_);
94  }
95 
96  bool Valid() override { return valid_; }
97 
98  private:
99  FILE* file_;
100  std::lock_guard<std::mutex> lock_;
101  bool valid_;
102  int key_len_;
103  vector<char> key_;
104  int value_len_;
105  vector<char> value_;
106 };
107 
109  public:
110  explicit MiniDBTransaction(FILE* f, std::mutex* mutex)
111  : file_(f), lock_(*mutex) {}
112  ~MiniDBTransaction() {
113  Commit();
114  }
115 
116  void Put(const string& key, const string& value) override {
117  int key_len = key.size();
118  int value_len = value.size();
119  CAFFE_ENFORCE_EQ(fwrite(&key_len, sizeof(int), 1, file_), 1);
120  CAFFE_ENFORCE_EQ(fwrite(&value_len, sizeof(int), 1, file_), 1);
121  CAFFE_ENFORCE_EQ(
122  fwrite(key.c_str(), sizeof(char), key_len, file_), key_len);
123  CAFFE_ENFORCE_EQ(
124  fwrite(value.c_str(), sizeof(char), value_len, file_), value_len);
125  }
126 
127  void Commit() override {
128  if (file_ != nullptr) {
129  CAFFE_ENFORCE_EQ(fflush(file_), 0);
130  file_ = nullptr;
131  }
132  }
133 
134  private:
135  FILE* file_;
136  std::lock_guard<std::mutex> lock_;
137 
138  DISABLE_COPY_AND_ASSIGN(MiniDBTransaction);
139 };
140 
141 class MiniDB : public DB {
142  public:
143  MiniDB(const string& source, Mode mode) : DB(source, mode), file_(nullptr) {
144  switch (mode) {
145  case NEW:
146  file_ = fopen(source.c_str(), "wb");
147  break;
148  case WRITE:
149  file_ = fopen(source.c_str(), "ab");
150  fseek(file_, 0, SEEK_END);
151  break;
152  case READ:
153  file_ = fopen(source.c_str(), "rb");
154  break;
155  }
156  CAFFE_ENFORCE(file_, "Cannot open file: " + source);
157  VLOG(1) << "Opened MiniDB " << source;
158  }
159  ~MiniDB() { Close(); }
160 
161  void Close() override {
162  if (file_) {
163  fclose(file_);
164  }
165  file_ = nullptr;
166  }
167 
168  unique_ptr<Cursor> NewCursor() override {
169  CAFFE_ENFORCE_EQ(this->mode_, READ);
170  return make_unique<MiniDBCursor>(file_, &file_access_mutex_);
171  }
172 
173  unique_ptr<Transaction> NewTransaction() override {
174  CAFFE_ENFORCE(this->mode_ == NEW || this->mode_ == WRITE);
175  return make_unique<MiniDBTransaction>(file_, &file_access_mutex_);
176  }
177 
178  private:
179  FILE* file_;
180  // access mutex makes sure we don't have multiple cursors/transactions
181  // reading the same file.
182  std::mutex file_access_mutex_;
183 };
184 
185 REGISTER_CAFFE2_DB(MiniDB, MiniDB);
186 REGISTER_CAFFE2_DB(minidb, MiniDB);
187 
189  const Blob& blob,
190  const string& name,
191  BlobSerializerBase::SerializationAcceptor acceptor) {
192  CAFFE_ENFORCE(blob.IsType<DBReader>());
193  auto& reader = blob.Get<DBReader>();
194  DBReaderProto proto;
195  proto.set_name(name);
196  proto.set_source(reader.source_);
197  proto.set_db_type(reader.db_type_);
198  if (reader.cursor() && reader.cursor()->SupportsSeek()) {
199  proto.set_key(reader.cursor()->key());
200  }
201  BlobProto blob_proto;
202  blob_proto.set_name(name);
203  blob_proto.set_type("DBReader");
204  blob_proto.set_content(proto.SerializeAsString());
205  acceptor(name, blob_proto.SerializeAsString());
206 }
207 
208 void DBReaderDeserializer::Deserialize(const BlobProto& proto, Blob* blob) {
209  DBReaderProto reader_proto;
210  CAFFE_ENFORCE(
211  reader_proto.ParseFromString(proto.content()),
212  "Cannot parse content into a DBReaderProto.");
213  blob->Reset(new DBReader(reader_proto));
214 }
215 
216 namespace {
217 // Serialize TensorCPU.
218 REGISTER_BLOB_SERIALIZER((TypeMeta::Id<DBReader>()),
220 REGISTER_BLOB_DESERIALIZER(DBReader, DBReaderDeserializer);
221 } // namespace
222 
223 } // namespace db
224 } // namespace caffe2
unique_ptr< Transaction > NewTransaction() override
Returns a transaction to write data to the database.
Definition: db.cc:173
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
string key() override
Returns the current key.
Definition: db.cc:86
An abstract class for the current database transaction while writing.
Definition: db.h:77
An abstract class for the cursor of the database while reading.
Definition: db.h:38
bool Valid() override
Returns whether the current location is valid - for example, if we have reached the end of the databa...
Definition: db.cc:96
unique_ptr< Cursor > NewCursor() override
Returns a cursor to read the database.
Definition: db.cc:168
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:160
void Next() override
Go to the next location in the database.
Definition: db.cc:59
void SeekToFirst() override
Seek to the first key in the database.
Definition: db.cc:51
void Close() override
Closes the database.
Definition: db.cc:161
void Put(const string &key, const string &value) override
Puts the key value pair to the database.
Definition: db.cc:116
string value() override
Returns the current value.
Definition: db.cc:91
Copyright (c) 2016-present, Facebook, Inc.
An abstract class for accessing a database of key-value pairs.
Definition: db.h:96
void Serialize(const Blob &blob, const string &name, BlobSerializerBase::SerializationAcceptor acceptor) override
Serializes a DBReader.
Definition: db.cc:188
T * Reset(T *allocated)
Sets the underlying object to the allocated one.
Definition: blob.h:137
void Seek(const string &) override
Seek to a specific key (or if the key does not exist, seek to the immediate next).
Definition: db.cc:47
bool IsType() const
Checks if the content stored in the blob is of type T.
Definition: blob.h:74
const T & Get() const
Gets the const reference of the stored object.
Definition: blob.h:91
void Commit() override
Commits the current writes.
Definition: db.cc:127