30 #include "caffe2/core/common.h" 31 #include "caffe2/core/db.h" 32 #include "caffe2/core/init.h" 33 #include "caffe2/proto/caffe2_pb.h" 34 #include "caffe2/core/logging.h" 36 C10_DEFINE_string(input_folder,
"",
"The input folder name.");
37 C10_DEFINE_string(output_train_db_name,
"",
"The output training db name.");
38 C10_DEFINE_string(output_test_db_name,
"",
"The output testing db name.");
39 C10_DEFINE_string(db,
"leveldb",
"The db type.");
43 "If set, convert cifar100. Otherwise do cifar10.");
47 using std::stringstream;
49 const int kCIFARSize = 32;
50 const int kCIFARImageNBytes = kCIFARSize * kCIFARSize * 3;
51 const int kCIFAR10BatchSize = 10000;
52 const int kCIFAR10TestDataSize = 10000;
53 const int kCIFAR10TrainBatches = 5;
55 const int kCIFAR100TrainDataSize = 50000;
56 const int kCIFAR100TestDataSize = 10000;
58 void ReadImage(std::ifstream* file,
int* label,
char* buffer) {
60 if (FLAGS_is_cifar100) {
62 file->read(&label_char, 1);
64 file->read(&label_char, 1);
68 std::array<char, kCIFARImageNBytes> channel_first_storage;
69 file->read(channel_first_storage.data(), kCIFARImageNBytes);
70 for (
int c = 0; c < 3; ++c) {
71 for (
int i = 0; i < kCIFARSize * kCIFARSize; ++i) {
73 channel_first_storage[c * kCIFARSize * kCIFARSize + i];
79 void WriteToDB(
const string& filename,
const int num_items,
80 const int& offset, db::DB* db) {
82 TensorProto* data = protos.add_protos();
83 TensorProto* label = protos.add_protos();
84 data->set_data_type(TensorProto::BYTE);
85 data->add_dims(kCIFARSize);
86 data->add_dims(kCIFARSize);
88 label->set_data_type(TensorProto::INT32);
90 label->add_int32_data(0);
92 LOG(INFO) <<
"Converting file " << filename;
93 std::ifstream data_file(filename.c_str(),
94 std::ios::in | std::ios::binary);
95 CAFFE_ENFORCE(data_file,
"Unable to open file ", filename);
96 char str_buffer[kCIFARImageNBytes];
98 string serialized_protos;
99 std::unique_ptr<db::Transaction> transaction(db->NewTransaction());
100 for (
int itemid = 0; itemid < num_items; ++itemid) {
101 ReadImage(&data_file, &label_value, str_buffer);
102 data->set_byte_data(str_buffer, kCIFARImageNBytes);
103 label->set_int32_data(0, label_value);
104 protos.SerializeToString(&serialized_protos);
105 snprintf(str_buffer, kCIFARImageNBytes,
"%05d",
107 transaction->Put(
string(str_buffer), serialized_protos);
111 void ConvertCIFAR() {
112 std::unique_ptr<db::DB> train_db(
113 db::CreateDB(FLAGS_db, FLAGS_output_train_db_name, db::NEW));
114 std::unique_ptr<db::DB> test_db(
115 db::CreateDB(FLAGS_db, FLAGS_output_test_db_name, db::NEW));
117 if (!FLAGS_is_cifar100) {
119 for (
int fileid = 0; fileid < kCIFAR10TrainBatches; ++fileid) {
120 stringstream train_file;
121 train_file << FLAGS_input_folder <<
"/data_batch_" << fileid + 1
123 WriteToDB(train_file.str(), kCIFAR10BatchSize,
124 fileid * kCIFAR10BatchSize, train_db.get());
126 stringstream test_file;
127 test_file << FLAGS_input_folder <<
"/test_batch.bin";
128 WriteToDB(test_file.str(), kCIFAR10TestDataSize, 0, test_db.get());
131 stringstream train_file;
132 train_file << FLAGS_input_folder <<
"/train.bin";
133 WriteToDB(train_file.str(), kCIFAR100TrainDataSize, 0, train_db.get());
134 stringstream test_file;
135 test_file << FLAGS_input_folder <<
"/test.bin";
136 WriteToDB(test_file.str(), kCIFAR100TestDataSize, 0, test_db.get());
142 int main(
int argc,
char** argv) {
144 caffe2::ConvertCIFAR();
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.