Caffe2 - C++ API
A deep learning, cross platform ML framework
make_mnist_db.cc
1 
17 // This script converts the MNIST dataset to leveldb.
18 // The MNIST dataset could be downloaded at
19 // http://yann.lecun.com/exdb/mnist/
20 
21 #include <fstream> // NOLINT(readability/streams)
22 #include <string>
23 
24 #include "caffe2/core/common.h"
25 #include "caffe2/core/db.h"
26 #include "caffe2/core/init.h"
27 #include "caffe2/proto/caffe2_pb.h"
28 #include "caffe2/core/logging.h"
29 
30 C10_DEFINE_string(image_file, "", "The input image file name.");
31 C10_DEFINE_string(label_file, "", "The label file name.");
32 C10_DEFINE_string(output_file, "", "The output db name.");
33 C10_DEFINE_string(db, "leveldb", "The db type.");
34 C10_DEFINE_int(
35  data_limit,
36  -1,
37  "If set, only output this number of data points.");
38 C10_DEFINE_bool(
39  channel_first,
40  false,
41  "If set, write the data as channel-first (CHW order) as the old "
42  "Caffe does.");
43 
44 namespace caffe2 {
45 uint32_t swap_endian(uint32_t val) {
46  val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
47  return (val << 16) | (val >> 16);
48 }
49 
50 void convert_dataset(const char* image_filename, const char* label_filename,
51  const char* db_path, const int data_limit) {
52  // Open files
53  std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
54  std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
55  CAFFE_ENFORCE(image_file, "Unable to open file ", image_filename);
56  CAFFE_ENFORCE(label_file, "Unable to open file ", label_filename);
57  // Read the magic and the meta data
58  uint32_t magic;
59  uint32_t num_items;
60  uint32_t num_labels;
61  uint32_t rows;
62  uint32_t cols;
63 
64  image_file.read(reinterpret_cast<char*>(&magic), 4);
65  magic = swap_endian(magic);
66  if (magic == 529205256) {
67  LOG(FATAL) <<
68  "It seems that you forgot to unzip the mnist dataset. You should "
69  "first unzip them using e.g. gunzip on Linux.";
70  }
71  CAFFE_ENFORCE_EQ(magic, 2051, "Incorrect image file magic.");
72  label_file.read(reinterpret_cast<char*>(&magic), 4);
73  magic = swap_endian(magic);
74  CAFFE_ENFORCE_EQ(magic, 2049, "Incorrect label file magic.");
75  image_file.read(reinterpret_cast<char*>(&num_items), 4);
76  num_items = swap_endian(num_items);
77  label_file.read(reinterpret_cast<char*>(&num_labels), 4);
78  num_labels = swap_endian(num_labels);
79  CAFFE_ENFORCE_EQ(num_items, num_labels);
80  image_file.read(reinterpret_cast<char*>(&rows), 4);
81  rows = swap_endian(rows);
82  image_file.read(reinterpret_cast<char*>(&cols), 4);
83  cols = swap_endian(cols);
84 
85  // leveldb
86  std::unique_ptr<db::DB> mnist_db(db::CreateDB(FLAGS_db, db_path, db::NEW));
87  std::unique_ptr<db::Transaction> transaction(mnist_db->NewTransaction());
88  // Storing to db
89  char label_value;
90  std::vector<char> pixels(rows * cols);
91  int count = 0;
92  const int kMaxKeyLength = 10;
93  char key_cstr[kMaxKeyLength];
94  string value;
95 
96  TensorProtos protos;
97  TensorProto* data = protos.add_protos();
98  TensorProto* label = protos.add_protos();
99  data->set_data_type(TensorProto::BYTE);
100  if (FLAGS_channel_first) {
101  data->add_dims(1);
102  data->add_dims(rows);
103  data->add_dims(cols);
104  } else {
105  data->add_dims(rows);
106  data->add_dims(cols);
107  data->add_dims(1);
108  }
109  label->set_data_type(TensorProto::INT32);
110  label->add_int32_data(0);
111 
112  LOG(INFO) << "A total of " << num_items << " items.";
113  LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
114  for (int item_id = 0; item_id < num_items; ++item_id) {
115  image_file.read(pixels.data(), rows * cols);
116  label_file.read(&label_value, 1);
117  for (int i = 0; i < rows * cols; ++i) {
118  data->set_byte_data(pixels.data(), rows * cols);
119  }
120  label->set_int32_data(0, static_cast<int>(label_value));
121  snprintf(key_cstr, kMaxKeyLength, "%08d", item_id);
122  protos.SerializeToString(&value);
123  string keystr(key_cstr);
124 
125  // Put in db
126  transaction->Put(keystr, value);
127  if (++count % 1000 == 0) {
128  transaction->Commit();
129  }
130  if (data_limit > 0 && count == data_limit) {
131  LOG(INFO) << "Reached data limit of " << data_limit << ", stop.";
132  break;
133  }
134  }
135 }
136 } // namespace caffe2
137 
138 int main(int argc, char** argv) {
139  caffe2::GlobalInit(&argc, &argv);
140  caffe2::convert_dataset(
141  FLAGS_image_file.c_str(),
142  FLAGS_label_file.c_str(),
143  FLAGS_output_file.c_str(),
144  FLAGS_data_limit);
145  return 0;
146 }
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.
Definition: init.cc:44