24 #include "caffe2/core/common.h" 25 #include "caffe2/core/db.h" 26 #include "caffe2/core/init.h" 27 #include "caffe2/proto/caffe2_pb.h" 28 #include "caffe2/core/logging.h" 30 C10_DEFINE_string(image_file,
"",
"The input image file name.");
31 C10_DEFINE_string(label_file,
"",
"The label file name.");
32 C10_DEFINE_string(output_file,
"",
"The output db name.");
33 C10_DEFINE_string(db,
"leveldb",
"The db type.");
37 "If set, only output this number of data points.");
41 "If set, write the data as channel-first (CHW order) as the old " 45 uint32_t swap_endian(uint32_t val) {
46 val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
47 return (val << 16) | (val >> 16);
50 void convert_dataset(
const char* image_filename,
const char* label_filename,
51 const char* db_path,
const int data_limit) {
53 std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
54 std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
55 CAFFE_ENFORCE(image_file,
"Unable to open file ", image_filename);
56 CAFFE_ENFORCE(label_file,
"Unable to open file ", label_filename);
64 image_file.read(reinterpret_cast<char*>(&magic), 4);
65 magic = swap_endian(magic);
66 if (magic == 529205256) {
68 "It seems that you forgot to unzip the mnist dataset. You should " 69 "first unzip them using e.g. gunzip on Linux.";
71 CAFFE_ENFORCE_EQ(magic, 2051,
"Incorrect image file magic.");
72 label_file.read(reinterpret_cast<char*>(&magic), 4);
73 magic = swap_endian(magic);
74 CAFFE_ENFORCE_EQ(magic, 2049,
"Incorrect label file magic.");
75 image_file.read(reinterpret_cast<char*>(&num_items), 4);
76 num_items = swap_endian(num_items);
77 label_file.read(reinterpret_cast<char*>(&num_labels), 4);
78 num_labels = swap_endian(num_labels);
79 CAFFE_ENFORCE_EQ(num_items, num_labels);
80 image_file.read(reinterpret_cast<char*>(&rows), 4);
81 rows = swap_endian(rows);
82 image_file.read(reinterpret_cast<char*>(&cols), 4);
83 cols = swap_endian(cols);
86 std::unique_ptr<db::DB> mnist_db(db::CreateDB(FLAGS_db, db_path, db::NEW));
87 std::unique_ptr<db::Transaction> transaction(mnist_db->NewTransaction());
90 std::vector<char> pixels(rows * cols);
92 const int kMaxKeyLength = 10;
93 char key_cstr[kMaxKeyLength];
97 TensorProto* data = protos.add_protos();
98 TensorProto* label = protos.add_protos();
99 data->set_data_type(TensorProto::BYTE);
100 if (FLAGS_channel_first) {
102 data->add_dims(rows);
103 data->add_dims(cols);
105 data->add_dims(rows);
106 data->add_dims(cols);
109 label->set_data_type(TensorProto::INT32);
110 label->add_int32_data(0);
112 LOG(INFO) <<
"A total of " << num_items <<
" items.";
113 LOG(INFO) <<
"Rows: " << rows <<
" Cols: " << cols;
114 for (
int item_id = 0; item_id < num_items; ++item_id) {
115 image_file.read(pixels.data(), rows * cols);
116 label_file.read(&label_value, 1);
117 for (
int i = 0; i < rows * cols; ++i) {
118 data->set_byte_data(pixels.data(), rows * cols);
120 label->set_int32_data(0, static_cast<int>(label_value));
121 snprintf(key_cstr, kMaxKeyLength,
"%08d", item_id);
122 protos.SerializeToString(&value);
123 string keystr(key_cstr);
126 transaction->Put(keystr, value);
127 if (++count % 1000 == 0) {
128 transaction->Commit();
130 if (data_limit > 0 && count == data_limit) {
131 LOG(INFO) <<
"Reached data limit of " << data_limit <<
", stop.";
138 int main(
int argc,
char** argv) {
140 caffe2::convert_dataset(
141 FLAGS_image_file.c_str(),
142 FLAGS_label_file.c_str(),
143 FLAGS_output_file.c_str(),
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.