Caffe2 - C++ API
A deep learning, cross platform ML framework
make_image_db.cc
1 
17 // This script converts an image dataset to a database.
18 //
19 // FLAGS_input_folder is the root folder that holds all the images
20 //
21 // FLAGS_list_file is the path to a file containing a list of files
22 // and their labels, as follows:
23 //
24 // subfolder1/file1.JPEG 7
25 // subfolder1/file2.JPEG 7
26 // subfolder2/file1.JPEG 8
27 // ...
28 //
29 
30 #include <opencv2/opencv.hpp>
31 
32 #include <algorithm>
33 #include <fstream>
34 #include <queue>
35 #include <random>
36 #include <string>
37 #include <thread>
38 
39 #include "caffe2/core/common.h"
40 #include "caffe2/core/db.h"
41 #include "caffe2/core/init.h"
42 #include "caffe2/proto/caffe2_pb.h"
43 #include "caffe2/core/logging.h"
44 
45 C10_DEFINE_bool(
46  shuffle,
47  false,
48  "Randomly shuffle the order of images and their labels");
49 C10_DEFINE_string(input_folder, "", "The input image file name.");
50 C10_DEFINE_string(
51  list_file,
52  "",
53  "The text file containing the list of images.");
54 C10_DEFINE_string(output_db_name, "", "The output training leveldb name.");
55 C10_DEFINE_string(db, "leveldb", "The db type.");
56 C10_DEFINE_bool(
57  raw,
58  false,
59  "If set, we pre-read the images and store the raw buffer.");
60 C10_DEFINE_bool(color, true, "If set, load images in color.");
61 C10_DEFINE_int(
62  scale,
63  256,
64  "If FLAGS_raw is set, scale the shorter edge to the given value.");
65 C10_DEFINE_bool(warp, false, "If warp is set, warp the images to square.");
66 C10_DEFINE_int(
67  num_threads,
68  -1,
69  "Number of image parsing and conversion threads.");
70 
71 namespace caffe2 {
72 
73 class Converter {
74  public:
75  explicit Converter() {
76  data_ = protos_.add_protos();
77  label_ = protos_.add_protos();
78  if (FLAGS_raw) {
79  data_->set_data_type(TensorProto::BYTE);
80  data_->add_dims(0);
81  data_->add_dims(0);
82  if (FLAGS_color) {
83  data_->add_dims(3);
84  }
85  } else {
86  data_->set_data_type(TensorProto::STRING);
87  data_->add_dims(1);
88  data_->add_string_data("");
89  }
90  label_->set_data_type(TensorProto::INT32);
91  label_->add_dims(1);
92  label_->add_int32_data(0);
93  }
94 
95  ~Converter() {
96  if (thread_.joinable()) {
97  thread_.join();
98  }
99  }
100 
101  void queue(const std::pair<std::string, int>& pair) {
102  in_.push(pair);
103  }
104 
105  void start() {
106  thread_ = std::thread(&Converter::run, this);
107  }
108 
109  std::string get() {
110  std::unique_lock<std::mutex> lock(mutex_);
111  while (out_.empty()) {
112  cv_.wait(lock);
113  }
114 
115  auto value = out_.front();
116  out_.pop();
117  cv_.notify_one();
118  return value;
119  }
120 
121  void run() {
122  const auto& input_folder = FLAGS_input_folder;
123  std::unique_lock<std::mutex> lock(mutex_);
124  std::string value;
125  while (!in_.empty()) {
126  auto pair = in_.front();
127  in_.pop();
128  lock.unlock();
129 
130  label_->set_int32_data(0, pair.second);
131 
132  // Add raw file contents to DB if !raw
133  if (!FLAGS_raw) {
134  std::ifstream image_file_stream(input_folder + pair.first);
135  if (!image_file_stream) {
136  LOG(ERROR) << "Cannot open " << input_folder << pair.first
137  << ". Skipping.";
138  } else {
139  data_->mutable_string_data(0)->assign(
140  std::istreambuf_iterator<char>(image_file_stream),
141  std::istreambuf_iterator<char>());
142  }
143  } else {
144  // Load image
145  cv::Mat img = cv::imread(
146  input_folder + pair.first,
147  FLAGS_color ? cv::IMREAD_COLOR : cv::IMREAD_GRAYSCALE);
148 
149  // Resize image
150  cv::Mat resized_img;
151  int scaled_width, scaled_height;
152  if (FLAGS_warp) {
153  scaled_width = FLAGS_scale;
154  scaled_height = FLAGS_scale;
155  } else if (img.rows > img.cols) {
156  scaled_width = FLAGS_scale;
157  scaled_height = static_cast<float>(img.rows) * FLAGS_scale / img.cols;
158  } else {
159  scaled_height = FLAGS_scale;
160  scaled_width = static_cast<float>(img.cols) * FLAGS_scale / img.rows;
161  }
162  cv::resize(
163  img,
164  resized_img,
165  cv::Size(scaled_width, scaled_height),
166  0,
167  0,
168  cv::INTER_LINEAR);
169  data_->set_dims(0, scaled_height);
170  data_->set_dims(1, scaled_width);
171 
172  // Assert we don't have to deal with alignment
173  DCHECK(resized_img.isContinuous());
174  auto nbytes = resized_img.total() * resized_img.elemSize();
175  data_->set_byte_data(resized_img.ptr(), nbytes);
176  }
177 
178  protos_.SerializeToString(&value);
179 
180  // Add serialized proto to out queue or wait if it is not empty
181  lock.lock();
182  while (!out_.empty()) {
183  cv_.wait(lock);
184  }
185  out_.push(value);
186  cv_.notify_one();
187  }
188  }
189 
190  protected:
191  TensorProtos protos_;
192  TensorProto* data_;
193  TensorProto* label_;
194  std::queue<std::pair<std::string, int>> in_;
195  std::queue<std::string> out_;
196 
197  std::mutex mutex_;
198  std::condition_variable cv_;
199  std::thread thread_;
200 };
201 
202 void ConvertImageDataset(
203  const string& input_folder,
204  const string& list_filename,
205  const string& output_db_name,
206  const bool /*shuffle*/) {
207  std::ifstream list_file(list_filename);
208  std::vector<std::pair<std::string, int> > lines;
209  std::string filename;
210  int file_label;
211  while (list_file >> filename >> file_label) {
212  lines.push_back(std::make_pair(filename, file_label));
213  }
214 
215  if (FLAGS_shuffle) {
216  LOG(INFO) << "Shuffling data";
217  std::shuffle(lines.begin(), lines.end(), std::default_random_engine(1701));
218  }
219 
220  auto num_threads = FLAGS_num_threads;
221  if (num_threads < 1) {
222  num_threads = std::thread::hardware_concurrency();
223  }
224 
225  LOG(INFO) << "Processing " << lines.size() << " images...";
226  LOG(INFO) << "Opening DB " << output_db_name;
227 
228  auto db = db::CreateDB(FLAGS_db, output_db_name, db::NEW);
229  auto transaction = db->NewTransaction();
230 
231  LOG(INFO) << "Using " << num_threads << " processing threads...";
232  std::vector<Converter> converters(num_threads);
233 
234  // Queue entries across converters
235  for (auto i = 0; i < lines.size(); i++) {
236  converters[i % converters.size()].queue(lines[i]);
237  }
238 
239  // Start all converters
240  for (auto& converter : converters) {
241  converter.start();
242  }
243 
244  constexpr auto key_max_length = 256;
245  char key_cstr[key_max_length];
246  string value;
247  int count = 0;
248  for (auto i = 0; i < lines.size(); i++) {
249  // Get serialized proto for this entry
250  auto value = converters[i % converters.size()].get();
251 
252  // Synthesize key for this entry
253  auto key_len = snprintf(
254  key_cstr, sizeof(key_cstr), "%08d_%s", i, lines[i].first.c_str());
255  DCHECK_LE(key_len, sizeof(key_cstr));
256 
257  // Put in db
258  transaction->Put(string(key_cstr), value);
259 
260  if (++count % 1000 == 0) {
261  // Commit the current writes.
262  transaction->Commit();
263  LOG(INFO) << "Processed " << count << " files.";
264  }
265  }
266 
267  // Commit final transaction
268  transaction->Commit();
269  LOG(INFO) << "Processed " << count << " files.";
270 }
271 
272 } // namespace caffe2
273 
274 
275 int main(int argc, char** argv) {
276  caffe2::GlobalInit(&argc, &argv);
277  caffe2::ConvertImageDataset(
278  FLAGS_input_folder, FLAGS_list_file, FLAGS_output_db_name, FLAGS_shuffle);
279  return 0;
280 }
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.
Definition: init.cc:44