Caffe2 - C++ API
A deep learning, cross platform ML framework
split_db.cc
1 
17 #include <string>
18 #include <sstream>
19 
20 #include "caffe2/core/db.h"
21 #include "caffe2/core/init.h"
22 #include "caffe2/proto/caffe2_pb.h"
23 #include "caffe2/core/logging.h"
24 
25 C10_DEFINE_string(input_db, "", "The input db.");
26 C10_DEFINE_int(splits, 0, "The number of splits.");
27 C10_DEFINE_string(db_type, "", "The db type.");
28 C10_DEFINE_int(batch_size, 1000, "The write batch size.");
29 
30 namespace caffe2 {
31 
32 static int Split(int argc, char** argv) {
33  GlobalInit(&argc, &argv);
34 
35  CAFFE_ENFORCE(FLAGS_input_db.size(), "Must specify --input_db=/path/to/db.");
36  CAFFE_ENFORCE(FLAGS_splits > 0, "Must specify a nonnegative split number.");
37  CAFFE_ENFORCE(FLAGS_db_type.size(), "Must specify --db_type=[a db type].");
38 
39  unique_ptr<db::DB> in_db(
40  db::CreateDB(FLAGS_db_type, FLAGS_input_db, db::READ));
41  CAFFE_ENFORCE(in_db != nullptr, "Cannot open input db: ", FLAGS_input_db);
42  unique_ptr<db::Cursor> cursor(in_db->NewCursor());
43  // This usually won't happen, but FWIW.
44  CAFFE_ENFORCE(
45  cursor != nullptr, "Cannot obtain cursor for input db: ", FLAGS_input_db);
46 
47  vector<unique_ptr<db::DB>> out_dbs;
48  vector<unique_ptr<db::Transaction>> transactions;
49  for (int i = 0; i < FLAGS_splits; ++i) {
50  out_dbs.push_back(unique_ptr<db::DB>(db::CreateDB(
51  FLAGS_db_type, FLAGS_input_db + "_split_" + to_string(i), db::NEW)));
52  CAFFE_ENFORCE(out_dbs.back().get(), "Cannot create output db #", i);
53  transactions.push_back(
54  unique_ptr<db::Transaction>(out_dbs[i]->NewTransaction()));
55  CAFFE_ENFORCE(
56  transactions.back().get(), "Cannot get transaction for output db #", i);
57  }
58 
59  int count = 0;
60  for (; cursor->Valid(); cursor->Next()) {
61  transactions[count % FLAGS_splits]->Put(cursor->key(), cursor->value());
62  if (++count % FLAGS_batch_size == 0) {
63  for (int i = 0; i < FLAGS_splits; ++i) {
64  transactions[i]->Commit();
65  }
66  LOG(INFO) << "Split " << count << " items so far.";
67  }
68  }
69  LOG(INFO) << "A total of " << count << " items processed.";
70  return 0;
71 }
72 
73 } // namespace caffe2
74 
75 int main(int argc, char** argv) {
76  return caffe2::Split(argc, argv);
77 }
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.
Definition: init.cc:44