Caffe2 - C++ API
A deep learning, cross platform ML framework
make_image_db.cc
1 
17 // This script converts an image dataset to a database.
18 //
19 // caffe2::FLAGS_input_folder is the root folder that holds all the images
20 //
21 // caffe2::FLAGS_list_file is the path to a file containing a list of files
22 // and their labels, as follows:
23 //
24 // subfolder1/file1.JPEG 7
25 // subfolder1/file2.JPEG 7
26 // subfolder2/file1.JPEG 8
27 // ...
28 //
29 
30 #include <opencv2/opencv.hpp>
31 
32 #include <algorithm>
33 #include <fstream>
34 #include <queue>
35 #include <random>
36 #include <string>
37 #include <thread>
38 
39 #include "caffe2/core/common.h"
40 #include "caffe2/core/db.h"
41 #include "caffe2/core/init.h"
42 #include "caffe2/proto/caffe2.pb.h"
43 #include "caffe2/core/logging.h"
44 
45 CAFFE2_DEFINE_bool(shuffle, false,
46  "Randomly shuffle the order of images and their labels");
47 CAFFE2_DEFINE_string(input_folder, "", "The input image file name.");
48 CAFFE2_DEFINE_string(
49  list_file,
50  "",
51  "The text file containing the list of images.");
52 CAFFE2_DEFINE_string(output_db_name, "", "The output training leveldb name.");
53 CAFFE2_DEFINE_string(db, "leveldb", "The db type.");
54 CAFFE2_DEFINE_bool(raw, false,
55  "If set, we pre-read the images and store the raw buffer.");
56 CAFFE2_DEFINE_bool(color, true, "If set, load images in color.");
57 CAFFE2_DEFINE_int(
58  scale,
59  256,
60  "If caffe2::FLAGS_raw is set, scale the shorter edge to the given value.");
61 CAFFE2_DEFINE_bool(warp, false, "If warp is set, warp the images to square.");
62 CAFFE2_DEFINE_int(
63  num_threads,
64  -1,
65  "Number of image parsing and conversion threads.");
66 
67 namespace caffe2 {
68 
69 class Converter {
70  public:
71  explicit Converter() {
72  data_ = protos_.add_protos();
73  label_ = protos_.add_protos();
74  if (caffe2::FLAGS_raw) {
75  data_->set_data_type(TensorProto::BYTE);
76  data_->add_dims(0);
77  data_->add_dims(0);
78  if (caffe2::FLAGS_color) {
79  data_->add_dims(3);
80  }
81  } else {
82  data_->set_data_type(TensorProto::STRING);
83  data_->add_dims(1);
84  data_->add_string_data("");
85  }
86  label_->set_data_type(TensorProto::INT32);
87  label_->add_dims(1);
88  label_->add_int32_data(0);
89  }
90 
91  ~Converter() {
92  if (thread_.joinable()) {
93  thread_.join();
94  }
95  }
96 
97  void queue(const std::pair<std::string, int>& pair) {
98  in_.push(pair);
99  }
100 
101  void start() {
102  thread_ = std::thread(&Converter::run, this);
103  }
104 
105  std::string get() {
106  std::unique_lock<std::mutex> lock(mutex_);
107  while (out_.empty()) {
108  cv_.wait(lock);
109  }
110 
111  auto value = out_.front();
112  out_.pop();
113  cv_.notify_one();
114  return value;
115  }
116 
117  void run() {
118  const auto& input_folder = caffe2::FLAGS_input_folder;
119  std::unique_lock<std::mutex> lock(mutex_);
120  std::string value;
121  while (!in_.empty()) {
122  auto pair = in_.front();
123  in_.pop();
124  lock.unlock();
125 
126  label_->set_int32_data(0, pair.second);
127 
128  // Add raw file contents to DB if !raw
129  if (!caffe2::FLAGS_raw) {
130  std::ifstream image_file_stream(input_folder + pair.first);
131  if (!image_file_stream) {
132  LOG(ERROR) << "Cannot open " << input_folder << pair.first
133  << ". Skipping.";
134  } else {
135  data_->mutable_string_data(0)->assign(
136  std::istreambuf_iterator<char>(image_file_stream),
137  std::istreambuf_iterator<char>());
138  }
139  } else {
140  // Load image
141  cv::Mat img = cv::imread(
142  input_folder + pair.first,
143  caffe2::FLAGS_color ? CV_LOAD_IMAGE_COLOR
144  : CV_LOAD_IMAGE_GRAYSCALE);
145 
146  // Resize image
147  cv::Mat resized_img;
148  int scaled_width, scaled_height;
149  if (caffe2::FLAGS_warp) {
150  scaled_width = caffe2::FLAGS_scale;
151  scaled_height = caffe2::FLAGS_scale;
152  } else if (img.rows > img.cols) {
153  scaled_width = caffe2::FLAGS_scale;
154  scaled_height =
155  static_cast<float>(img.rows) * caffe2::FLAGS_scale / img.cols;
156  } else {
157  scaled_height = caffe2::FLAGS_scale;
158  scaled_width =
159  static_cast<float>(img.cols) * caffe2::FLAGS_scale / img.rows;
160  }
161  cv::resize(
162  img,
163  resized_img,
164  cv::Size(scaled_width, scaled_height),
165  0,
166  0,
167  cv::INTER_LINEAR);
168  data_->set_dims(0, scaled_height);
169  data_->set_dims(1, scaled_width);
170 
171  // Assert we don't have to deal with alignment
172  DCHECK(resized_img.isContinuous());
173  auto nbytes = resized_img.total() * resized_img.elemSize();
174  data_->set_byte_data(resized_img.ptr(), nbytes);
175  }
176 
177  protos_.SerializeToString(&value);
178 
179  // Add serialized proto to out queue or wait if it is not empty
180  lock.lock();
181  while (!out_.empty()) {
182  cv_.wait(lock);
183  }
184  out_.push(value);
185  cv_.notify_one();
186  }
187  }
188 
189  protected:
190  TensorProtos protos_;
191  TensorProto* data_;
192  TensorProto* label_;
193  std::queue<std::pair<std::string, int>> in_;
194  std::queue<std::string> out_;
195 
196  std::mutex mutex_;
197  std::condition_variable cv_;
198  std::thread thread_;
199 };
200 
201 void ConvertImageDataset(
202  const string& input_folder,
203  const string& list_filename,
204  const string& output_db_name,
205  const bool /*shuffle*/) {
206  std::ifstream list_file(list_filename);
207  std::vector<std::pair<std::string, int> > lines;
208  std::string filename;
209  int file_label;
210  while (list_file >> filename >> file_label) {
211  lines.push_back(std::make_pair(filename, file_label));
212  }
213 
214  if (caffe2::FLAGS_shuffle) {
215  LOG(INFO) << "Shuffling data";
216  std::shuffle(lines.begin(), lines.end(), std::default_random_engine(1701));
217  }
218 
219  auto num_threads = caffe2::FLAGS_num_threads;
220  if (num_threads < 1) {
221  num_threads = std::thread::hardware_concurrency();
222  }
223 
224  LOG(INFO) << "Processing " << lines.size() << " images...";
225  LOG(INFO) << "Opening DB " << output_db_name;
226 
227  auto db = db::CreateDB(caffe2::FLAGS_db, output_db_name, db::NEW);
228  auto transaction = db->NewTransaction();
229 
230  LOG(INFO) << "Using " << num_threads << " processing threads...";
231  std::vector<Converter> converters(num_threads);
232 
233  // Queue entries across converters
234  for (auto i = 0; i < lines.size(); i++) {
235  converters[i % converters.size()].queue(lines[i]);
236  }
237 
238  // Start all converters
239  for (auto& converter : converters) {
240  converter.start();
241  }
242 
243  constexpr auto key_max_length = 256;
244  char key_cstr[key_max_length];
245  string value;
246  int count = 0;
247  for (auto i = 0; i < lines.size(); i++) {
248  // Get serialized proto for this entry
249  auto value = converters[i % converters.size()].get();
250 
251  // Synthesize key for this entry
252  auto key_len = snprintf(
253  key_cstr, sizeof(key_cstr), "%08d_%s", i, lines[i].first.c_str());
254  DCHECK_LE(key_len, sizeof(key_cstr));
255 
256  // Put in db
257  transaction->Put(string(key_cstr), value);
258 
259  if (++count % 1000 == 0) {
260  // Commit the current writes.
261  transaction->Commit();
262  LOG(INFO) << "Processed " << count << " files.";
263  }
264  }
265 
266  // Commit final transaction
267  transaction->Commit();
268  LOG(INFO) << "Processed " << count << " files.";
269 }
270 
271 } // namespace caffe2
272 
273 
274 int main(int argc, char** argv) {
275  caffe2::GlobalInit(&argc, &argv);
276  caffe2::ConvertImageDataset(
277  caffe2::FLAGS_input_folder, caffe2::FLAGS_list_file,
278  caffe2::FLAGS_output_db_name, caffe2::FLAGS_shuffle);
279  return 0;
280 }
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.
Definition: init.cc:34
Copyright (c) 2016-present, Facebook, Inc.