30 #include <opencv2/opencv.hpp> 39 #include "caffe2/core/common.h" 40 #include "caffe2/core/db.h" 41 #include "caffe2/core/init.h" 42 #include "caffe2/proto/caffe2_pb.h" 43 #include "caffe2/core/logging.h" 48 "Randomly shuffle the order of images and their labels");
49 C10_DEFINE_string(input_folder,
"",
"The input image file name.");
53 "The text file containing the list of images.");
54 C10_DEFINE_string(output_db_name,
"",
"The output training leveldb name.");
55 C10_DEFINE_string(db,
"leveldb",
"The db type.");
59 "If set, we pre-read the images and store the raw buffer.");
60 C10_DEFINE_bool(color,
true,
"If set, load images in color.");
64 "If FLAGS_raw is set, scale the shorter edge to the given value.");
65 C10_DEFINE_bool(warp,
false,
"If warp is set, warp the images to square.");
69 "Number of image parsing and conversion threads.");
76 data_ = protos_.add_protos();
77 label_ = protos_.add_protos();
79 data_->set_data_type(TensorProto::BYTE);
86 data_->set_data_type(TensorProto::STRING);
88 data_->add_string_data(
"");
90 label_->set_data_type(TensorProto::INT32);
92 label_->add_int32_data(0);
96 if (thread_.joinable()) {
101 void queue(
const std::pair<std::string, int>& pair) {
106 thread_ = std::thread(&Converter::run,
this);
110 std::unique_lock<std::mutex> lock(mutex_);
111 while (out_.empty()) {
115 auto value = out_.front();
122 const auto& input_folder = FLAGS_input_folder;
123 std::unique_lock<std::mutex> lock(mutex_);
125 while (!in_.empty()) {
126 auto pair = in_.front();
130 label_->set_int32_data(0, pair.second);
134 std::ifstream image_file_stream(input_folder + pair.first);
135 if (!image_file_stream) {
136 LOG(ERROR) <<
"Cannot open " << input_folder << pair.first
139 data_->mutable_string_data(0)->assign(
140 std::istreambuf_iterator<char>(image_file_stream),
141 std::istreambuf_iterator<char>());
145 cv::Mat img = cv::imread(
146 input_folder + pair.first,
147 FLAGS_color ? cv::IMREAD_COLOR : cv::IMREAD_GRAYSCALE);
151 int scaled_width, scaled_height;
153 scaled_width = FLAGS_scale;
154 scaled_height = FLAGS_scale;
155 }
else if (img.rows > img.cols) {
156 scaled_width = FLAGS_scale;
157 scaled_height =
static_cast<float>(img.rows) * FLAGS_scale / img.cols;
159 scaled_height = FLAGS_scale;
160 scaled_width =
static_cast<float>(img.cols) * FLAGS_scale / img.rows;
165 cv::Size(scaled_width, scaled_height),
169 data_->set_dims(0, scaled_height);
170 data_->set_dims(1, scaled_width);
173 DCHECK(resized_img.isContinuous());
174 auto nbytes = resized_img.total() * resized_img.elemSize();
175 data_->set_byte_data(resized_img.ptr(), nbytes);
178 protos_.SerializeToString(&value);
182 while (!out_.empty()) {
191 TensorProtos protos_;
194 std::queue<std::pair<std::string, int>> in_;
195 std::queue<std::string> out_;
198 std::condition_variable cv_;
202 void ConvertImageDataset(
203 const string& input_folder,
204 const string& list_filename,
205 const string& output_db_name,
207 std::ifstream list_file(list_filename);
208 std::vector<std::pair<std::string, int> > lines;
209 std::string filename;
211 while (list_file >> filename >> file_label) {
212 lines.push_back(std::make_pair(filename, file_label));
216 LOG(INFO) <<
"Shuffling data";
217 std::shuffle(lines.begin(), lines.end(), std::default_random_engine(1701));
220 auto num_threads = FLAGS_num_threads;
221 if (num_threads < 1) {
222 num_threads = std::thread::hardware_concurrency();
225 LOG(INFO) <<
"Processing " << lines.size() <<
" images...";
226 LOG(INFO) <<
"Opening DB " << output_db_name;
228 auto db = db::CreateDB(FLAGS_db, output_db_name, db::NEW);
229 auto transaction = db->NewTransaction();
231 LOG(INFO) <<
"Using " << num_threads <<
" processing threads...";
232 std::vector<Converter> converters(num_threads);
235 for (
auto i = 0; i < lines.size(); i++) {
236 converters[i % converters.size()].queue(lines[i]);
240 for (
auto& converter : converters) {
244 constexpr
auto key_max_length = 256;
245 char key_cstr[key_max_length];
248 for (
auto i = 0; i < lines.size(); i++) {
250 auto value = converters[i % converters.size()].get();
253 auto key_len = snprintf(
254 key_cstr,
sizeof(key_cstr),
"%08d_%s", i, lines[i].first.c_str());
255 DCHECK_LE(key_len,
sizeof(key_cstr));
258 transaction->Put(
string(key_cstr), value);
260 if (++count % 1000 == 0) {
262 transaction->Commit();
263 LOG(INFO) <<
"Processed " << count <<
" files.";
268 transaction->Commit();
269 LOG(INFO) <<
"Processed " << count <<
" files.";
275 int main(
int argc,
char** argv) {
277 caffe2::ConvertImageDataset(
278 FLAGS_input_folder, FLAGS_list_file, FLAGS_output_db_name, FLAGS_shuffle);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.