Caffe2 - C++ API
A deep learning, cross platform ML framework
make_mnist_db.cc
1 
17 // This script converts the MNIST dataset to leveldb.
18 // The MNIST dataset could be downloaded at
19 // http://yann.lecun.com/exdb/mnist/
20 
21 #include <fstream> // NOLINT(readability/streams)
22 #include <string>
23 
24 #include "caffe2/core/common.h"
25 #include "caffe2/core/db.h"
26 #include "caffe2/core/init.h"
27 #include "caffe2/proto/caffe2.pb.h"
28 #include "caffe2/core/logging.h"
29 
30 CAFFE2_DEFINE_string(image_file, "", "The input image file name.");
31 CAFFE2_DEFINE_string(label_file, "", "The label file name.");
32 CAFFE2_DEFINE_string(output_file, "", "The output db name.");
33 CAFFE2_DEFINE_string(db, "leveldb", "The db type.");
34 CAFFE2_DEFINE_int(data_limit, -1,
35  "If set, only output this number of data points.");
36 CAFFE2_DEFINE_bool(channel_first, false,
37  "If set, write the data as channel-first (CHW order) as the old "
38  "Caffe does.");
39 
40 namespace caffe2 {
41 uint32_t swap_endian(uint32_t val) {
42  val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
43  return (val << 16) | (val >> 16);
44 }
45 
46 void convert_dataset(const char* image_filename, const char* label_filename,
47  const char* db_path, const int data_limit) {
48  // Open files
49  std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
50  std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
51  CAFFE_ENFORCE(image_file, "Unable to open file ", image_filename);
52  CAFFE_ENFORCE(label_file, "Unable to open file ", label_filename);
53  // Read the magic and the meta data
54  uint32_t magic;
55  uint32_t num_items;
56  uint32_t num_labels;
57  uint32_t rows;
58  uint32_t cols;
59 
60  image_file.read(reinterpret_cast<char*>(&magic), 4);
61  magic = swap_endian(magic);
62  if (magic == 529205256) {
63  LOG(FATAL) <<
64  "It seems that you forgot to unzip the mnist dataset. You should "
65  "first unzip them using e.g. gunzip on Linux.";
66  }
67  CAFFE_ENFORCE_EQ(magic, 2051, "Incorrect image file magic.");
68  label_file.read(reinterpret_cast<char*>(&magic), 4);
69  magic = swap_endian(magic);
70  CAFFE_ENFORCE_EQ(magic, 2049, "Incorrect label file magic.");
71  image_file.read(reinterpret_cast<char*>(&num_items), 4);
72  num_items = swap_endian(num_items);
73  label_file.read(reinterpret_cast<char*>(&num_labels), 4);
74  num_labels = swap_endian(num_labels);
75  CAFFE_ENFORCE_EQ(num_items, num_labels);
76  image_file.read(reinterpret_cast<char*>(&rows), 4);
77  rows = swap_endian(rows);
78  image_file.read(reinterpret_cast<char*>(&cols), 4);
79  cols = swap_endian(cols);
80 
81  // leveldb
82  std::unique_ptr<db::DB> mnist_db(db::CreateDB(caffe2::FLAGS_db, db_path, db::NEW));
83  std::unique_ptr<db::Transaction> transaction(mnist_db->NewTransaction());
84  // Storing to db
85  char label_value;
86  std::vector<char> pixels(rows * cols);
87  int count = 0;
88  const int kMaxKeyLength = 10;
89  char key_cstr[kMaxKeyLength];
90  string value;
91 
92  TensorProtos protos;
93  TensorProto* data = protos.add_protos();
94  TensorProto* label = protos.add_protos();
95  data->set_data_type(TensorProto::BYTE);
96  if (caffe2::FLAGS_channel_first) {
97  data->add_dims(1);
98  data->add_dims(rows);
99  data->add_dims(cols);
100  } else {
101  data->add_dims(rows);
102  data->add_dims(cols);
103  data->add_dims(1);
104  }
105  label->set_data_type(TensorProto::INT32);
106  label->add_int32_data(0);
107 
108  LOG(INFO) << "A total of " << num_items << " items.";
109  LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
110  for (int item_id = 0; item_id < num_items; ++item_id) {
111  image_file.read(pixels.data(), rows * cols);
112  label_file.read(&label_value, 1);
113  for (int i = 0; i < rows * cols; ++i) {
114  data->set_byte_data(pixels.data(), rows * cols);
115  }
116  label->set_int32_data(0, static_cast<int>(label_value));
117  snprintf(key_cstr, kMaxKeyLength, "%08d", item_id);
118  protos.SerializeToString(&value);
119  string keystr(key_cstr);
120 
121  // Put in db
122  transaction->Put(keystr, value);
123  if (++count % 1000 == 0) {
124  transaction->Commit();
125  }
126  if (data_limit > 0 && count == data_limit) {
127  LOG(INFO) << "Reached data limit of " << data_limit << ", stop.";
128  break;
129  }
130  }
131 }
132 } // namespace caffe2
133 
134 int main(int argc, char** argv) {
135  caffe2::GlobalInit(&argc, &argv);
136  caffe2::convert_dataset(caffe2::FLAGS_image_file.c_str(), caffe2::FLAGS_label_file.c_str(),
137  caffe2::FLAGS_output_file.c_str(), caffe2::FLAGS_data_limit);
138  return 0;
139 }
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.
Definition: init.cc:34
Copyright (c) 2016-present, Facebook, Inc.