Caffe2 - C++ API
A deep learning, cross platform ML framework
convert_caffe_image_db.cc
1 
17 #include "caffe2/core/db.h"
18 #include "caffe2/core/init.h"
19 #include "caffe2/proto/caffe2.pb.h"
20 #include "caffe/proto/caffe.pb.h"
21 #include "caffe2/core/logging.h"
22 
23 CAFFE2_DEFINE_string(input_db, "", "The input db.");
24 CAFFE2_DEFINE_string(input_db_type, "", "The input db type.");
25 CAFFE2_DEFINE_string(output_db, "", "The output db.");
26 CAFFE2_DEFINE_string(output_db_type, "", "The output db type.");
27 CAFFE2_DEFINE_int(batch_size, 1000, "The write batch size.");
28 
29 using caffe2::db::Cursor;
30 using caffe2::db::DB;
32 using caffe2::TensorProto;
33 using caffe2::TensorProtos;
34 
35 int main(int argc, char** argv) {
36  caffe2::GlobalInit(&argc, &argv);
37 
38  std::unique_ptr<DB> in_db(caffe2::db::CreateDB(
39  caffe2::FLAGS_input_db_type, caffe2::FLAGS_input_db, caffe2::db::READ));
40  std::unique_ptr<DB> out_db(caffe2::db::CreateDB(
41  caffe2::FLAGS_output_db_type, caffe2::FLAGS_output_db, caffe2::db::NEW));
42  std::unique_ptr<Cursor> cursor(in_db->NewCursor());
43  std::unique_ptr<Transaction> transaction(out_db->NewTransaction());
44  int count = 0;
45  for (; cursor->Valid(); cursor->Next()) {
46  caffe::Datum datum;
47  CAFFE_ENFORCE(datum.ParseFromString(cursor->value()));
48  TensorProtos protos;
49  TensorProto* data = protos.add_protos();
50  TensorProto* label = protos.add_protos();
51  label->set_data_type(TensorProto::INT32);
52  label->add_dims(1);
53  label->add_int32_data(datum.label());
54  if (datum.encoded()) {
55  // This is an encoded image. we will copy over the data directly.
56  data->set_data_type(TensorProto::STRING);
57  data->add_dims(1);
58  data->add_string_data(datum.data());
59  } else {
60  // float data not supported right now.
61  CAFFE_ENFORCE_EQ(datum.float_data_size(), 0);
62  std::vector<char> buffer_vec(datum.data().size());
63  char* buffer = buffer_vec.data();
64  // swap order from CHW to HWC
65  int channels = datum.channels();
66  int size = datum.height() * datum.width();
67  CAFFE_ENFORCE_EQ(datum.data().size(), channels * size);
68  for (int c = 0; c < channels; ++c) {
69  char* dst = buffer + c;
70  const char* src = datum.data().c_str() + c * size;
71  for (int n = 0; n < size; ++n) {
72  dst[n*channels] = src[n];
73  }
74  }
75  data->set_data_type(TensorProto::BYTE);
76  data->add_dims(datum.height());
77  data->add_dims(datum.width());
78  data->add_dims(datum.channels());
79  data->set_byte_data(buffer, datum.data().size());
80  }
81  transaction->Put(cursor->key(), protos.SerializeAsString());
82  if (++count % caffe2::FLAGS_batch_size == 0) {
83  transaction->Commit();
84  LOG(INFO) << "Converted " << count << " items so far.";
85  }
86  }
87  LOG(INFO) << "A total of " << count << " items processed.";
88  return 0;
89 }
90 
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.
Definition: init.cc:34
An abstract class for the current database transaction while writing.
Definition: db.h:77
An abstract class for the cursor of the database while reading.
Definition: db.h:38
An abstract class for accessing a database of key-value pairs.
Definition: db.h:96