Caffe2 - C++ API
A deep learning, cross platform ML framework
make_cifar_db.cc
1 
17 //
18 // This script converts the CIFAR dataset to the leveldb format used
19 // by caffe to perform classification.
20 // Usage:
21 // convert_cifar_data input_folder output_db_file
22 // The CIFAR dataset could be downloaded at
23 // http://www.cs.toronto.edu/~kriz/cifar.html
24 
25 #include <array>
26 #include <fstream> // NOLINT(readability/streams)
27 #include <sstream>
28 #include <string>
29 
30 #include "caffe2/core/common.h"
31 #include "caffe2/core/db.h"
32 #include "caffe2/core/init.h"
33 #include "caffe2/proto/caffe2_pb.h"
34 #include "caffe2/core/logging.h"
35 
36 C10_DEFINE_string(input_folder, "", "The input folder name.");
37 C10_DEFINE_string(output_train_db_name, "", "The output training db name.");
38 C10_DEFINE_string(output_test_db_name, "", "The output testing db name.");
39 C10_DEFINE_string(db, "leveldb", "The db type.");
40 C10_DEFINE_bool(
41  is_cifar100,
42  false,
43  "If set, convert cifar100. Otherwise do cifar10.");
44 
45 namespace caffe2 {
46 
47 using std::stringstream;
48 
49 const int kCIFARSize = 32;
50 const int kCIFARImageNBytes = kCIFARSize * kCIFARSize * 3;
51 const int kCIFAR10BatchSize = 10000;
52 const int kCIFAR10TestDataSize = 10000;
53 const int kCIFAR10TrainBatches = 5;
54 
55 const int kCIFAR100TrainDataSize = 50000;
56 const int kCIFAR100TestDataSize = 10000;
57 
58 void ReadImage(std::ifstream* file, int* label, char* buffer) {
59  char label_char;
60  if (FLAGS_is_cifar100) {
61  // Skip the coarse label.
62  file->read(&label_char, 1);
63  }
64  file->read(&label_char, 1);
65  *label = label_char;
66  // Yes, there are better ways to do it, like in-place swap... but I am too
67  // lazy so let's just write it in a memory-wasteful way.
68  std::array<char, kCIFARImageNBytes> channel_first_storage;
69  file->read(channel_first_storage.data(), kCIFARImageNBytes);
70  for (int c = 0; c < 3; ++c) {
71  for (int i = 0; i < kCIFARSize * kCIFARSize; ++i) {
72  buffer[i * 3 + c] =
73  channel_first_storage[c * kCIFARSize * kCIFARSize + i];
74  }
75  }
76  return;
77 }
78 
79 void WriteToDB(const string& filename, const int num_items,
80  const int& offset, db::DB* db) {
81  TensorProtos protos;
82  TensorProto* data = protos.add_protos();
83  TensorProto* label = protos.add_protos();
84  data->set_data_type(TensorProto::BYTE);
85  data->add_dims(kCIFARSize);
86  data->add_dims(kCIFARSize);
87  data->add_dims(3);
88  label->set_data_type(TensorProto::INT32);
89  label->add_dims(1);
90  label->add_int32_data(0);
91 
92  LOG(INFO) << "Converting file " << filename;
93  std::ifstream data_file(filename.c_str(),
94  std::ios::in | std::ios::binary);
95  CAFFE_ENFORCE(data_file, "Unable to open file ", filename);
96  char str_buffer[kCIFARImageNBytes];
97  int label_value;
98  string serialized_protos;
99  std::unique_ptr<db::Transaction> transaction(db->NewTransaction());
100  for (int itemid = 0; itemid < num_items; ++itemid) {
101  ReadImage(&data_file, &label_value, str_buffer);
102  data->set_byte_data(str_buffer, kCIFARImageNBytes);
103  label->set_int32_data(0, label_value);
104  protos.SerializeToString(&serialized_protos);
105  snprintf(str_buffer, kCIFARImageNBytes, "%05d",
106  offset + itemid);
107  transaction->Put(string(str_buffer), serialized_protos);
108  }
109 }
110 
111 void ConvertCIFAR() {
112  std::unique_ptr<db::DB> train_db(
113  db::CreateDB(FLAGS_db, FLAGS_output_train_db_name, db::NEW));
114  std::unique_ptr<db::DB> test_db(
115  db::CreateDB(FLAGS_db, FLAGS_output_test_db_name, db::NEW));
116 
117  if (!FLAGS_is_cifar100) {
118  // This is cifar 10.
119  for (int fileid = 0; fileid < kCIFAR10TrainBatches; ++fileid) {
120  stringstream train_file;
121  train_file << FLAGS_input_folder << "/data_batch_" << fileid + 1
122  << ".bin";
123  WriteToDB(train_file.str(), kCIFAR10BatchSize,
124  fileid * kCIFAR10BatchSize, train_db.get());
125  }
126  stringstream test_file;
127  test_file << FLAGS_input_folder << "/test_batch.bin";
128  WriteToDB(test_file.str(), kCIFAR10TestDataSize, 0, test_db.get());
129  } else {
130  // This is cifar 100.
131  stringstream train_file;
132  train_file << FLAGS_input_folder << "/train.bin";
133  WriteToDB(train_file.str(), kCIFAR100TrainDataSize, 0, train_db.get());
134  stringstream test_file;
135  test_file << FLAGS_input_folder << "/test.bin";
136  WriteToDB(test_file.str(), kCIFAR100TestDataSize, 0, test_db.get());
137  }
138 }
139 
140 } // namespace caffe2
141 
142 int main(int argc, char** argv) {
143  caffe2::GlobalInit(&argc, &argv);
144  caffe2::ConvertCIFAR();
145  return 0;
146 }
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.
Definition: init.cc:44