Caffe2 - C++ API
A deep learning, cross platform ML framework
db_throughput.cc
1 
17 #include <cstdio>
18 #include <thread>
19 #include <vector>
20 
21 #include "caffe2/core/db.h"
22 #include "caffe2/core/init.h"
23 #include "caffe2/core/timer.h"
24 #include "caffe2/core/logging.h"
25 
26 C10_DEFINE_string(input_db, "", "The input db.");
27 C10_DEFINE_string(input_db_type, "", "The input db type.");
28 C10_DEFINE_int(report_interval, 1000, "The report interval.");
29 C10_DEFINE_int(repeat, 10, "The number to repeat the throughput test.");
30 C10_DEFINE_bool(use_reader, false, "If true, use the reader interface.");
31 C10_DEFINE_int(
32  num_read_threads,
33  1,
34  "The number of concurrent reading threads.");
35 
36 using caffe2::db::Cursor;
37 using caffe2::db::DB;
39 using caffe2::string;
40 
41 void TestThroughputWithDB() {
42  std::unique_ptr<DB> in_db(caffe2::db::CreateDB(
43  FLAGS_input_db_type, FLAGS_input_db, caffe2::db::READ));
44  std::unique_ptr<Cursor> cursor(in_db->NewCursor());
45  for (int iter_id = 0; iter_id < FLAGS_repeat; ++iter_id) {
46  caffe2::Timer timer;
47  for (int i = 0; i < FLAGS_report_interval; ++i) {
48  string key = cursor->key();
49  string value = cursor->value();
50  //VLOG(1) << "Key " << key;
51  cursor->Next();
52  if (!cursor->Valid()) {
53  cursor->SeekToFirst();
54  }
55  }
56  double elapsed_seconds = timer.Seconds();
57  printf(
58  "Iteration %03d, took %4.5f seconds, throughput %f items/sec.\n",
59  iter_id,
60  elapsed_seconds,
61  FLAGS_report_interval / elapsed_seconds);
62  }
63 }
64 
65 void TestThroughputWithReaderWorker(const DBReader* reader, int thread_id) {
66  string key, value;
67  for (int iter_id = 0; iter_id < FLAGS_repeat; ++iter_id) {
68  caffe2::Timer timer;
69  for (int i = 0; i < FLAGS_report_interval; ++i) {
70  reader->Read(&key, &value);
71  }
72  double elapsed_seconds = timer.Seconds();
73  printf(
74  "Thread %03d iteration %03d, took %4.5f seconds, "
75  "throughput %f items/sec.\n",
76  thread_id,
77  iter_id,
78  elapsed_seconds,
79  FLAGS_report_interval / elapsed_seconds);
80  }
81 }
82 
83 void TestThroughputWithReader() {
84  caffe2::db::DBReader reader(FLAGS_input_db_type, FLAGS_input_db);
85  std::vector<std::unique_ptr<std::thread>> reading_threads(
86  FLAGS_num_read_threads);
87  for (int i = 0; i < reading_threads.size(); ++i) {
88  reading_threads[i].reset(new std::thread(
89  TestThroughputWithReaderWorker, &reader, i));
90  }
91  for (int i = 0; i < reading_threads.size(); ++i) {
92  reading_threads[i]->join();
93  }
94 }
95 
96 int main(int argc, char** argv) {
97  caffe2::GlobalInit(&argc, &argv);
98  if (FLAGS_use_reader) {
99  TestThroughputWithReader();
100  } else {
101  TestThroughputWithDB();
102  }
103  return 0;
104 }
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:228
An abstract class for the cursor of the database while reading.
Definition: db.h:22
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:144
float Seconds()
Returns the elapsed time in seconds.
Definition: timer.h:40
An abstract class for accessing a database of key-value pairs.
Definition: db.h:80
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.
Definition: init.cc:44
A simple timer object for measuring time.
Definition: timer.h:16