Caffe2 - C++ API
A deep learning, cross platform ML framework
zmqdb.cc
1 #include <atomic>
2 #include <condition_variable>
3 #include <mutex>
4 #include <thread> // NOLINT
5 
6 #include "caffe2/core/db.h"
7 #include "caffe2/utils/zmq_helper.h"
8 #include "caffe2/core/logging.h"
9 
10 namespace caffe2 {
11 namespace db {
12 
13 class ZmqDBCursor : public Cursor {
14  public:
15  explicit ZmqDBCursor(const string& source)
16  : source_(source), socket_(ZMQ_PULL),
17  prefetched_(false), finalize_(false) {
18  socket_.Connect(source_);
19  // Start prefetching thread.
20  prefetch_thread_.reset(
21  new std::thread([this] { this->Prefetch(); }));
22  // obtain the first value.
23  Next();
24  }
25 
26  ~ZmqDBCursor() override {
27  finalize_ = true;
28  prefetched_ = false;
29  producer_.notify_one();
30  // Wait for the prefetch thread to finish elegantly.
31  prefetch_thread_->join();
32  socket_.Disconnect(source_);
33  }
34 
35  void Seek(const string& /*key*/) override { /* do nothing */
36  }
37 
38  void SeekToFirst() override { /* do nothing */ }
39 
40  void Next() override {
41  std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
42  while (!prefetched_) consumer_.wait(lock);
43  key_ = prefetch_key_;
44  value_ = prefetch_value_;
45  prefetched_ = false;
46  producer_.notify_one();
47  }
48 
49  string key() override { return key_; }
50  string value() override { return value_; }
51  bool Valid() override { return true; }
52 
53  private:
54 
55  void Prefetch() {
56  while (!finalize_) {
57  std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
58  while (prefetched_) producer_.wait(lock);
59  if (finalize_) {
60  return;
61  }
62  ZmqMessage msg;
63  socket_.RecvTillSuccess(&msg);
64  prefetch_key_.assign(static_cast<char*>(msg.data()), msg.size());
65  socket_.RecvTillSuccess(&msg);
66  prefetch_value_.assign(static_cast<char*>(msg.data()), msg.size());
67  prefetched_ = true;
68  consumer_.notify_one();
69  }
70  }
71 
72  string source_;
73  ZmqSocket socket_;
74  string key_;
75  string value_;
76  string prefetch_key_;
77  string prefetch_value_;
78 
79  unique_ptr<std::thread> prefetch_thread_;
80  std::mutex prefetch_access_mutex_;
81  std::condition_variable producer_, consumer_;
82  std::atomic<bool> prefetched_;
83  // finalize_ is used to tell the prefetcher to quit.
84  std::atomic<bool> finalize_;
85 };
86 
87 class ZmqDB : public DB {
88  public:
89  ZmqDB(const string& source, Mode mode)
90  : DB(source, mode), source_(source) {
91  CAFFE_ENFORCE(mode == READ, "ZeroMQ DB only supports read mode.");
92  }
93 
94  ~ZmqDB() override {}
95 
96  void Close() override {}
97 
98  unique_ptr<Cursor> NewCursor() override {
99  return make_unique<ZmqDBCursor>(source_);
100  }
101 
102  unique_ptr<Transaction> NewTransaction() override {
103  CAFFE_THROW("ZeroMQ DB does not support writing with a transaction.");
104  return nullptr; // dummy placeholder to suppress old compiler warnings.
105  }
106 
107  private:
108  string source_;
109 };
110 
111 REGISTER_CAFFE2_DB(ZmqDB, ZmqDB);
112 // For lazy-minded, one can also call with lower-case name.
113 REGISTER_CAFFE2_DB(zmqdb, ZmqDB);
114 
115 } // namespace db
116 } // namespace caffe2
string value() override
Returns the current value.
Definition: zmqdb.cc:50
An abstract class for the cursor of the database while reading.
Definition: db.h:22
void Seek(const string &) override
Seek to a specific key (or if the key does not exist, seek to the immediate next).
Definition: zmqdb.cc:35
string key() override
Returns the current key.
Definition: zmqdb.cc:49
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
unique_ptr< Cursor > NewCursor() override
Returns a cursor to read the database.
Definition: zmqdb.cc:98
void SeekToFirst() override
Seek to the first key in the database.
Definition: zmqdb.cc:38
void Close() override
Closes the database.
Definition: zmqdb.cc:96
An abstract class for accessing a database of key-value pairs.
Definition: db.h:80
bool Valid() override
Returns whether the current location is valid - for example, if we have reached the end of the databa...
Definition: zmqdb.cc:51
void Next() override
Go to the next location in the database.
Definition: zmqdb.cc:40
unique_ptr< Transaction > NewTransaction() override
Returns a transaction to write data to the database.
Definition: zmqdb.cc:102