Caffe2 - C++ API
A deep learning, cross platform ML framework
convert_caffe_image_db.cc
1 
17 #include "caffe2/core/db.h"
18 #include "caffe2/core/init.h"
19 #include "caffe2/proto/caffe2_pb.h"
20 #include "caffe2/proto/caffe2_legacy.pb.h"
21 #include "caffe2/core/logging.h"
22 
23 C10_DEFINE_string(input_db, "", "The input db.");
24 C10_DEFINE_string(input_db_type, "", "The input db type.");
25 C10_DEFINE_string(output_db, "", "The output db.");
26 C10_DEFINE_string(output_db_type, "", "The output db type.");
27 C10_DEFINE_int(batch_size, 1000, "The write batch size.");
28 
29 using caffe2::db::Cursor;
30 using caffe2::db::DB;
32 using caffe2::CaffeDatum;
33 using caffe2::TensorProto;
34 using caffe2::TensorProtos;
35 
36 int main(int argc, char** argv) {
37  caffe2::GlobalInit(&argc, &argv);
38 
39  std::unique_ptr<DB> in_db(caffe2::db::CreateDB(
40  FLAGS_input_db_type, FLAGS_input_db, caffe2::db::READ));
41  std::unique_ptr<DB> out_db(caffe2::db::CreateDB(
42  FLAGS_output_db_type, FLAGS_output_db, caffe2::db::NEW));
43  std::unique_ptr<Cursor> cursor(in_db->NewCursor());
44  std::unique_ptr<Transaction> transaction(out_db->NewTransaction());
45  int count = 0;
46  for (; cursor->Valid(); cursor->Next()) {
47  CaffeDatum datum;
48  CAFFE_ENFORCE(datum.ParseFromString(cursor->value()));
49  TensorProtos protos;
50  TensorProto* data = protos.add_protos();
51  TensorProto* label = protos.add_protos();
52  label->set_data_type(TensorProto::INT32);
53  label->add_dims(1);
54  label->add_int32_data(datum.label());
55  if (datum.encoded()) {
56  // This is an encoded image. we will copy over the data directly.
57  data->set_data_type(TensorProto::STRING);
58  data->add_dims(1);
59  data->add_string_data(datum.data());
60  } else {
61  // float data not supported right now.
62  CAFFE_ENFORCE_EQ(datum.float_data_size(), 0);
63  std::vector<char> buffer_vec(datum.data().size());
64  char* buffer = buffer_vec.data();
65  // swap order from CHW to HWC
66  int channels = datum.channels();
67  int size = datum.height() * datum.width();
68  CAFFE_ENFORCE_EQ(datum.data().size(), channels * size);
69  for (int c = 0; c < channels; ++c) {
70  char* dst = buffer + c;
71  const char* src = datum.data().c_str() + c * size;
72  for (int n = 0; n < size; ++n) {
73  dst[n*channels] = src[n];
74  }
75  }
76  data->set_data_type(TensorProto::BYTE);
77  data->add_dims(datum.height());
78  data->add_dims(datum.width());
79  data->add_dims(datum.channels());
80  data->set_byte_data(buffer, datum.data().size());
81  }
82  transaction->Put(cursor->key(), SerializeAsString_EnforceCheck(protos));
83  if (++count % FLAGS_batch_size == 0) {
84  transaction->Commit();
85  LOG(INFO) << "Converted " << count << " items so far.";
86  }
87  }
88  LOG(INFO) << "A total of " << count << " items processed.";
89  return 0;
90 }
An abstract class for the current database transaction while writing.
Definition: db.h:61
An abstract class for the cursor of the database while reading.
Definition: db.h:22
An abstract class for accessing a database of key-value pairs.
Definition: db.h:80
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.
Definition: init.cc:44