Caffe2 - C++ API
A deep learning, cross platform ML framework
image_input_op.h
1 
2 #ifndef CAFFE2_IMAGE_IMAGE_INPUT_OP_H_
3 #define CAFFE2_IMAGE_IMAGE_INPUT_OP_H_
4 
5 #include <opencv2/opencv.hpp>
6 
7 #include <iostream>
8 #include <algorithm>
9 
10 #include "c10/core/thread_pool.h"
11 #include "caffe2/core/common.h"
12 #include "caffe2/core/db.h"
13 #include "caffe2/image/transform_gpu.h"
14 #include "caffe2/operators/prefetch_op.h"
15 #include "caffe2/proto/caffe2_legacy.pb.h"
16 #include "caffe2/utils/cast.h"
17 #include "caffe2/utils/math.h"
18 
19 namespace caffe2 {
20 
21 class CUDAContext;
22 
23 template <class Context>
24 class ImageInputOp final
25  : public PrefetchOperator<Context> {
26  // SINGLE_LABEL: single integer label for multi-class classification
27  // MULTI_LABEL_SPARSE: sparse active label indices for multi-label classification
28  // MULTI_LABEL_DENSE: dense label embedding vector for label embedding regression
29  // MULTI_LABEL_WEIGHTED_SPARSE: sparse active label indices with per-label weights
30  // for multi-label classification
31  // SINGLE_LABEL_WEIGHTED: single integer label for multi-class classification with weighted sampling
32  // EMBEDDING_LABEL: an array of floating numbers representing dense embedding.
33  // It is useful for model distillation
34  enum LABEL_TYPE {
35  SINGLE_LABEL = 0,
36  MULTI_LABEL_SPARSE = 1,
37  MULTI_LABEL_DENSE = 2,
38  MULTI_LABEL_WEIGHTED_SPARSE = 3,
39  SINGLE_LABEL_WEIGHTED = 4,
40  EMBEDDING_LABEL = 5,
41  };
42 
43  // INCEPTION_STYLE: Random crop with size 8% - 100% image area and aspect
44  // ratio in [3/4, 4/3]. Reference: GoogleNet paper
45  enum SCALE_JITTER_TYPE {
46  NO_SCALE_JITTER = 0,
47  INCEPTION_STYLE = 1
48  // TODO(zyan3): ResNet-style random scale jitter
49  };
50 
51  public:
52  using OperatorBase::OutputSize;
55  explicit ImageInputOp(const OperatorDef& operator_def,
56  Workspace* ws);
57  ~ImageInputOp() {
59  }
60 
61  bool Prefetch() override;
62  bool CopyPrefetched() override;
63 
64  private:
65  using BoundingBox = struct {
66  bool valid;
67  int ymin;
68  int xmin;
69  int height;
70  int width;
71  };
72 
73  // Structure to store per-image information
74  // This can be modified by the DecodeAnd* so needs
75  // to be privatized per launch.
76  using PerImageArg = struct {
77  BoundingBox bounding_params;
78  };
79 
80  bool GetImageAndLabelAndInfoFromDBValue(
81  const string& value, cv::Mat* img, PerImageArg& info, int item_id,
82  std::mt19937* randgen);
83  void DecodeAndTransform(
84  const std::string& value, float *image_data, int item_id,
85  const int channels, std::size_t thread_index);
86  void DecodeAndTransposeOnly(
87  const std::string& value, uint8_t *image_data, int item_id,
88  const int channels, std::size_t thread_index);
89  bool ApplyTransformOnGPU(
90  const std::vector<std::int64_t>& dims,
91  const c10::Device& type);
92 
93  unique_ptr<db::DBReader> owned_reader_;
94  const db::DBReader* reader_;
95  Tensor prefetched_image_;
96  Tensor prefetched_label_;
97  vector<Tensor> prefetched_additional_outputs_;
98  Tensor prefetched_image_on_device_;
99  Tensor prefetched_label_on_device_;
100  vector<Tensor> prefetched_additional_outputs_on_device_;
101  // Default parameters for images
102  PerImageArg default_arg_;
103  int batch_size_;
104  LABEL_TYPE label_type_;
105  int num_labels_;
106 
107  bool color_;
108  bool color_jitter_;
109  float img_saturation_;
110  float img_brightness_;
111  float img_contrast_;
112  bool color_lighting_;
113  float color_lighting_std_;
114  std::vector<std::vector<float>> color_lighting_eigvecs_;
115  std::vector<float> color_lighting_eigvals_;
116  SCALE_JITTER_TYPE scale_jitter_type_;
117  int scale_;
118  // Minsize is similar to scale except that it will only
119  // force the image to scale up if it is too small. In other words,
120  // it ensures that both dimensions of the image are at least minsize_
121  int minsize_;
122  bool warp_;
123  int crop_;
124  std::vector<float> mean_;
125  std::vector<float> std_;
126  Tensor mean_gpu_;
127  Tensor std_gpu_;
128  bool mirror_;
129  bool is_test_;
130  bool use_caffe_datum_;
131  bool gpu_transform_;
132  bool mean_std_copied_ = false;
133 
134  // thread pool for parse + decode
135  int num_decode_threads_;
136  int additional_inputs_offset_;
137  int additional_inputs_count_;
138  std::vector<int> additional_output_sizes_;
139  std::shared_ptr<TaskThreadPool> thread_pool_;
140 
141  // Output type for GPU transform path
142  TensorProto_DataType output_type_;
143 
144  // random minsize
145  vector<int> random_scale_;
146  bool random_scaling_;
147 
148  // Working variables
149  std::vector<std::mt19937> randgen_per_thread_;
150 
151  // number of exceptions produced by opencv while reading image data
152  std::atomic<long> num_decode_errors_in_batch_{0};
153  // opencv exceptions tolerance
154  float max_decode_error_ratio_;
155 };
156 
157 template <class Context>
159  const OperatorDef& operator_def,
160  Workspace* ws)
161  : PrefetchOperator<Context>(operator_def, ws),
162  reader_(nullptr),
163  batch_size_(
164  OperatorBase::template GetSingleArgument<int>("batch_size", 0)),
165  label_type_(static_cast<LABEL_TYPE>(
166  OperatorBase::template GetSingleArgument<int>("label_type", 0))),
167  num_labels_(
168  OperatorBase::template GetSingleArgument<int>("num_labels", 0)),
169  color_(OperatorBase::template GetSingleArgument<int>("color", 1)),
170  color_jitter_(
171  OperatorBase::template GetSingleArgument<int>("color_jitter", 0)),
172  img_saturation_(OperatorBase::template GetSingleArgument<float>(
173  "img_saturation",
174  0.4)),
175  img_brightness_(OperatorBase::template GetSingleArgument<float>(
176  "img_brightness",
177  0.4)),
178  img_contrast_(
179  OperatorBase::template GetSingleArgument<float>("img_contrast", 0.4)),
180  color_lighting_(
181  OperatorBase::template GetSingleArgument<int>("color_lighting", 0)),
182  color_lighting_std_(OperatorBase::template GetSingleArgument<float>(
183  "color_lighting_std",
184  0.1)),
185  scale_jitter_type_(static_cast<SCALE_JITTER_TYPE>(
186  OperatorBase::template GetSingleArgument<int>(
187  "scale_jitter_type",
188  0))),
189  scale_(OperatorBase::template GetSingleArgument<int>("scale", -1)),
190  minsize_(OperatorBase::template GetSingleArgument<int>("minsize", -1)),
191  warp_(OperatorBase::template GetSingleArgument<int>("warp", 0)),
192  crop_(OperatorBase::template GetSingleArgument<int>("crop", -1)),
193  mirror_(OperatorBase::template GetSingleArgument<int>("mirror", 0)),
194  is_test_(OperatorBase::template GetSingleArgument<int>(
195  OpSchema::Arg_IsTest,
196  0)),
197  use_caffe_datum_(
198  OperatorBase::template GetSingleArgument<int>("use_caffe_datum", 0)),
199  gpu_transform_(OperatorBase::template GetSingleArgument<int>(
200  "use_gpu_transform",
201  0)),
202  num_decode_threads_(
203  OperatorBase::template GetSingleArgument<int>("decode_threads", 4)),
204  additional_output_sizes_(OperatorBase::template GetRepeatedArgument<int>(
205  "output_sizes", {})),
206  thread_pool_(std::make_shared<TaskThreadPool>(num_decode_threads_)),
207  // output type only supported with CUDA and use_gpu_transform for now
208  output_type_(
209  cast::GetCastDataType(ArgumentHelper(operator_def), "output_type")),
210  random_scale_(OperatorBase::template GetRepeatedArgument<int>(
211  "random_scale",
212  {-1, -1})),
213  max_decode_error_ratio_(OperatorBase::template GetSingleArgument<float>(
214  "max_decode_error_ratio",
215  1.0)) {
216  if ((random_scale_[0] == -1) || (random_scale_[1] == -1)) {
217  random_scaling_ = false;
218  } else {
219  random_scaling_ = true;
220  minsize_ = random_scale_[0];
221  }
222 
223  mean_ = OperatorBase::template GetRepeatedArgument<float>(
224  "mean_per_channel",
225  {OperatorBase::template GetSingleArgument<float>("mean", 0.)});
226 
227  std_ = OperatorBase::template GetRepeatedArgument<float>(
228  "std_per_channel",
229  {OperatorBase::template GetSingleArgument<float>("std", 1.)});
230 
231  if (additional_output_sizes_.size() == 0) {
232  additional_output_sizes_ = std::vector<int>(OutputSize() - 2, 1);
233  } else {
234  CAFFE_ENFORCE(
235  additional_output_sizes_.size() == OutputSize() - 2,
236  "If the output sizes are specified, they must be specified for all "
237  "additional outputs");
238  }
239  additional_inputs_count_ = OutputSize() - 2;
240 
241  default_arg_.bounding_params = {
242  false,
243  OperatorBase::template GetSingleArgument<int>("bounding_ymin", -1),
244  OperatorBase::template GetSingleArgument<int>("bounding_xmin", -1),
245  OperatorBase::template GetSingleArgument<int>("bounding_height", -1),
246  OperatorBase::template GetSingleArgument<int>("bounding_width", -1),
247  };
248 
249  if (operator_def.input_size() == 0) {
250  LOG(ERROR) << "You are using an old ImageInputOp format that creates "
251  "a local db reader. Consider moving to the new style "
252  "that takes in a DBReader blob instead.";
253  string db_name =
254  OperatorBase::template GetSingleArgument<string>("db", "");
255  CAFFE_ENFORCE_GT(db_name.size(), 0, "Must specify a db name.");
256  owned_reader_.reset(new db::DBReader(
257  OperatorBase::template GetSingleArgument<string>(
258  "db_type", "leveldb"),
259  db_name));
260  reader_ = owned_reader_.get();
261  }
262 
263  // hard-coded PCA eigenvectors and eigenvalues, based on RBG channel order
264  color_lighting_eigvecs_.push_back(
265  std::vector<float>{-144.7125f, 183.396f, 102.2295f});
266  color_lighting_eigvecs_.push_back(
267  std::vector<float>{-148.104f, -1.1475f, -207.57f});
268  color_lighting_eigvecs_.push_back(
269  std::vector<float>{-148.818f, -177.174f, 107.1765f});
270 
271  color_lighting_eigvals_ = std::vector<float>{0.2175f, 0.0188f, 0.0045f};
272 
273  CAFFE_ENFORCE_GT(batch_size_, 0, "Batch size should be nonnegative.");
274  if (use_caffe_datum_) {
275  CAFFE_ENFORCE(label_type_ == SINGLE_LABEL || label_type_ == SINGLE_LABEL_WEIGHTED,
276  "Caffe datum only supports single integer label");
277  }
278  if (label_type_ != SINGLE_LABEL && label_type_ != SINGLE_LABEL_WEIGHTED) {
279  CAFFE_ENFORCE_GT(num_labels_, 0,
280  "Number of labels must be set for using either sparse label indices or dense label embedding.");
281  }
282  if (label_type_ == MULTI_LABEL_WEIGHTED_SPARSE ||
283  label_type_ == SINGLE_LABEL_WEIGHTED) {
284  additional_inputs_offset_ = 3;
285  } else {
286  additional_inputs_offset_ = 2;
287  }
288  CAFFE_ENFORCE((scale_ > 0) != (minsize_ > 0),
289  "Must provide one and only one of scaling or minsize");
290  CAFFE_ENFORCE_GT(crop_, 0, "Must provide the cropping value.");
291  CAFFE_ENFORCE_GE(
292  scale_ > 0 ? scale_ : minsize_,
293  crop_, "The scale/minsize value must be no smaller than the crop value.");
294 
295  CAFFE_ENFORCE_EQ(
296  mean_.size(),
297  std_.size(),
298  "The mean and std. dev vectors must be of the same size.");
299  CAFFE_ENFORCE(mean_.size() == 1 || mean_.size() == 3,
300  "The mean and std. dev vectors must be of size 1 or 3");
301  CAFFE_ENFORCE(
302  !use_caffe_datum_ || OutputSize() == 2,
303  "There can only be 2 outputs if the Caffe datum format is used");
304 
305  CAFFE_ENFORCE(random_scale_.size() == 2,
306  "Must provide [scale_min, scale_max]");
307  CAFFE_ENFORCE_GE(random_scale_[1], random_scale_[0],
308  "random scale must provide a range [min, max]");
309 
310  if (default_arg_.bounding_params.ymin < 0
311  || default_arg_.bounding_params.xmin < 0
312  || default_arg_.bounding_params.height < 0
313  || default_arg_.bounding_params.width < 0) {
314  default_arg_.bounding_params.valid = false;
315  } else {
316  default_arg_.bounding_params.valid = true;
317  }
318 
319  if (mean_.size() == 1) {
320  // We are going to extend to 3 using the first value
321  mean_.resize(3, mean_[0]);
322  std_.resize(3, std_[0]);
323  }
324 
325  LOG(INFO) << "Creating an image input op with the following setting: ";
326  LOG(INFO) << " Using " << num_decode_threads_ << " CPU threads;";
327  if (gpu_transform_) {
328  LOG(INFO) << " Performing transformation on GPU";
329  }
330  LOG(INFO) << " Outputting in batches of " << batch_size_ << " images;";
331  LOG(INFO) << " Treating input image as "
332  << (color_ ? "color " : "grayscale ") << "image;";
333  if (default_arg_.bounding_params.valid) {
334  LOG(INFO) << " Applying a default bounding box of Y ["
335  << default_arg_.bounding_params.ymin << "; "
336  << default_arg_.bounding_params.ymin +
337  default_arg_.bounding_params.height
338  << ") x X ["
339  << default_arg_.bounding_params.xmin << "; "
340  << default_arg_.bounding_params.xmin +
341  default_arg_.bounding_params.width
342  << ")";
343  }
344  if (scale_ > 0 && !random_scaling_) {
345  LOG(INFO) << " Scaling image to " << scale_
346  << (warp_ ? " with " : " without ") << "warping;";
347  } else {
348  if (random_scaling_) {
349  // randomly set min_size_ for each image
350  LOG(INFO) << " Randomly scaling shortest side between "
351  << random_scale_[0] << " and "
352  << random_scale_[1];
353  } else {
354  // Here, minsize_ > 0
355  LOG(INFO) << " Ensuring minimum image size of " << minsize_
356  << (warp_ ? " with " : " without ") << "warping;";
357  }
358  }
359  LOG(INFO) << " " << (is_test_ ? "Central" : "Random")
360  << " cropping image to " << crop_
361  << (mirror_ ? " with " : " without ") << "random mirroring;";
362  LOG(INFO) << "Label Type: " << label_type_;
363  LOG(INFO) << "Num Labels: " << num_labels_;
364 
365  auto mit = mean_.begin();
366  auto sit = std_.begin();
367 
368  for (int i = 0;
369  mit != mean_.end() && sit != std_.end();
370  ++mit, ++sit, ++i) {
371  LOG(INFO) << " Default [Channel " << i << "] Subtract mean " << *mit
372  << " and divide by std " << *sit << ".";
373  // We actually will use the inverse of std, so inverse it here
374  *sit = 1.f / *sit;
375  }
376  LOG(INFO) << " Outputting images as "
377  << OperatorBase::template GetSingleArgument<string>("output_type", "unknown") << ".";
378 
379  std::mt19937 meta_randgen(time(nullptr));
380  for (int i = 0; i < num_decode_threads_; ++i) {
381  randgen_per_thread_.emplace_back(meta_randgen());
382  }
384  &prefetched_image_,
385  {int64_t(batch_size_),
386  int64_t(crop_),
387  int64_t(crop_),
388  int64_t(color_ ? 3 : 1)},
389  at::dtype<uint8_t>().device(CPU));
390  std::vector<int64_t> sizes;
391  if (label_type_ != SINGLE_LABEL && label_type_ != SINGLE_LABEL_WEIGHTED) {
392  sizes = std::vector<int64_t>{int64_t(batch_size_), int64_t(num_labels_)};
393  } else {
394  sizes = std::vector<int64_t>{batch_size_};
395  }
396  // data type for prefetched_label_ is actually not known here..
398  &prefetched_label_,
399  sizes,
400  at::dtype<int>().device(CPU));
401 
402  for (int i = 0; i < additional_output_sizes_.size(); ++i) {
403  prefetched_additional_outputs_on_device_.emplace_back();
404  prefetched_additional_outputs_.emplace_back();
405  }
406 
407 }
408 
409 // Inception-stype scale jittering
410 template <class Context>
411 bool RandomSizedCropping(
412  cv::Mat* img,
413  const int crop,
414  std::mt19937* randgen
415 ) {
416  cv::Mat scaled_img;
417  bool inception_scale_jitter = false;
418  int im_height = img->rows, im_width = img->cols;
419  int area = im_height * im_width;
420  std::uniform_real_distribution<> area_dis(0.08, 1.0);
421  std::uniform_real_distribution<> aspect_ratio_dis(3.0 / 4.0, 4.0 / 3.0);
422 
423  cv::Mat cropping;
424  for (int i = 0; i < 10; ++i) {
425  int target_area = int(ceil(area_dis(*randgen) * area));
426  float aspect_ratio = aspect_ratio_dis(*randgen);
427  int nh = floor(std::sqrt(((float)target_area / aspect_ratio)));
428  int nw = floor(std::sqrt(((float)target_area * aspect_ratio)));
429  if (nh >= 1 && nh <= im_height && nw >=1 && nw <= im_width) {
430  int height_offset = std::uniform_int_distribution<>(
431  0, im_height - nh)(*randgen);
432  int width_offset = std::uniform_int_distribution<>(
433  0,im_width - nw)(*randgen);
434  cv::Rect ROI(width_offset, height_offset, nw, nh);
435  cropping = (*img)(ROI);
436  cv::resize(
437  cropping,
438  scaled_img,
439  cv::Size(crop, crop),
440  0,
441  0,
442  cv::INTER_AREA);
443  *img = scaled_img;
444  inception_scale_jitter = true;
445  break;
446  }
447  }
448  return inception_scale_jitter;
449 }
450 
451 template <class Context>
453  const string& value,
454  cv::Mat* img,
455  PerImageArg& info,
456  int item_id,
457  std::mt19937* randgen) {
458  //
459  // recommend using --caffe2_use_fatal_for_enforce=1 when using ImageInputOp
460  // as this function runs on a worker thread and the exceptions from
461  // CAFFE_ENFORCE are silently dropped by the thread worker functions
462  //
463  cv::Mat src;
464 
465  // Use the default information for images
466  info = default_arg_;
467  if (use_caffe_datum_) {
468  // The input is a caffe datum format.
469  CaffeDatum datum;
470  CAFFE_ENFORCE(datum.ParseFromString(value));
471 
472  prefetched_label_.mutable_data<int>()[item_id] = datum.label();
473  if (datum.encoded()) {
474  // encoded image in datum.
475  // count the number of exceptions from opencv imdecode
476  try {
477  src = cv::imdecode(
478  cv::Mat(
479  1,
480  datum.data().size(),
481  CV_8UC1,
482  const_cast<char*>(datum.data().data())),
483  color_ ? cv::IMREAD_COLOR : cv::IMREAD_GRAYSCALE);
484  if (src.rows == 0 || src.cols == 0) {
485  num_decode_errors_in_batch_++;
486  src = cv::Mat::zeros(cv::Size(224, 224), CV_8UC3);
487  }
488  } catch (cv::Exception& e) {
489  num_decode_errors_in_batch_++;
490  src = cv::Mat::zeros(cv::Size(224, 224), CV_8UC3);
491  }
492  } else {
493  // Raw image in datum.
494  CAFFE_ENFORCE(datum.channels() == 3 || datum.channels() == 1);
495 
496  int src_c = datum.channels();
497  src.create(
498  datum.height(), datum.width(), (src_c == 3) ? CV_8UC3 : CV_8UC1);
499 
500  if (src_c == 1) {
501  memcpy(src.ptr<uchar>(0), datum.data().data(), datum.data().size());
502  } else {
503  // Datum stores things in CHW order, let's do HWC for images to make
504  // things more consistent with conventional image storage.
505  for (int c = 0; c < 3; ++c) {
506  const char* datum_buffer =
507  datum.data().data() + datum.height() * datum.width() * c;
508  uchar* ptr = src.ptr<uchar>(0) + c;
509  for (int h = 0; h < datum.height(); ++h) {
510  for (int w = 0; w < datum.width(); ++w) {
511  *ptr = *(datum_buffer++);
512  ptr += 3;
513  }
514  }
515  }
516  }
517  }
518  } else {
519  // The input is a caffe2 format.
520  TensorProtos protos;
521  CAFFE_ENFORCE(protos.ParseFromString(value));
522  const TensorProto& image_proto = protos.protos(0);
523  const TensorProto& label_proto = protos.protos(1);
524  // add handle protos
525  vector<TensorProto> additional_output_protos;
526  int start = additional_inputs_offset_;
527  int end = start + additional_inputs_count_;
528  for (int i = start; i < end; ++i) {
529  additional_output_protos.push_back(protos.protos(i));
530  }
531 
532  if (protos.protos_size() == end + 1) {
533  // We have bounding box information
534  const TensorProto& bounding_proto = protos.protos(end);
535  DCHECK_EQ(bounding_proto.data_type(), TensorProto::INT32);
536  DCHECK_EQ(bounding_proto.int32_data_size(), 4);
537  info.bounding_params.valid = true;
538  info.bounding_params.ymin = bounding_proto.int32_data(0);
539  info.bounding_params.xmin = bounding_proto.int32_data(1);
540  info.bounding_params.height = bounding_proto.int32_data(2);
541  info.bounding_params.width = bounding_proto.int32_data(3);
542  }
543 
544  if (image_proto.data_type() == TensorProto::STRING) {
545  // encoded image string.
546  DCHECK_EQ(image_proto.string_data_size(), 1);
547  const string& encoded_image_str = image_proto.string_data(0);
548  int encoded_size = encoded_image_str.size();
549  // We use a cv::Mat to wrap the encoded str so we do not need a copy.
550  // count the number of exceptions from opencv imdecode
551  try {
552  src = cv::imdecode(
553  cv::Mat(
554  1,
555  &encoded_size,
556  CV_8UC1,
557  const_cast<char*>(encoded_image_str.data())),
558  color_ ? cv::IMREAD_COLOR : cv::IMREAD_GRAYSCALE);
559  if (src.rows == 0 || src.cols == 0) {
560  num_decode_errors_in_batch_++;
561  src = cv::Mat::zeros(cv::Size(224, 224), CV_8UC3);
562  }
563  } catch (cv::Exception& e) {
564  num_decode_errors_in_batch_++;
565  src = cv::Mat::zeros(cv::Size(224, 224), CV_8UC3);
566  }
567  } else if (image_proto.data_type() == TensorProto::BYTE) {
568  // raw image content.
569  int src_c = (image_proto.dims_size() == 3) ? image_proto.dims(2) : 1;
570  CAFFE_ENFORCE(src_c == 3 || src_c == 1);
571 
572  src.create(
573  image_proto.dims(0),
574  image_proto.dims(1),
575  (src_c == 3) ? CV_8UC3 : CV_8UC1);
576  memcpy(
577  src.ptr<uchar>(0),
578  image_proto.byte_data().data(),
579  image_proto.byte_data().size());
580  } else {
581  LOG(FATAL) << "Unknown image data type.";
582  }
583 
584  // TODO: if image decoding was unsuccessful, set label to 0
585  if (label_proto.data_type() == TensorProto::FLOAT) {
586  if (label_type_ == SINGLE_LABEL || label_type_ == SINGLE_LABEL_WEIGHTED) {
587  DCHECK_EQ(label_proto.float_data_size(), 1);
588  prefetched_label_.mutable_data<float>()[item_id] =
589  label_proto.float_data(0);
590  } else if (label_type_ == MULTI_LABEL_SPARSE) {
591  float* label_data =
592  prefetched_label_.mutable_data<float>() + item_id * num_labels_;
593  memset(label_data, 0, sizeof(float) * num_labels_);
594  for (int i = 0; i < label_proto.float_data_size(); ++i) {
595  label_data[(int)label_proto.float_data(i)] = 1.0;
596  }
597  } else if (label_type_ == MULTI_LABEL_WEIGHTED_SPARSE) {
598  const TensorProto& weight_proto = protos.protos(2);
599  float* label_data =
600  prefetched_label_.mutable_data<float>() + item_id * num_labels_;
601  memset(label_data, 0, sizeof(float) * num_labels_);
602  for (int i = 0; i < label_proto.float_data_size(); ++i) {
603  label_data[(int)label_proto.float_data(i)] =
604  weight_proto.float_data(i);
605  }
606  } else if (
607  label_type_ == MULTI_LABEL_DENSE || label_type_ == EMBEDDING_LABEL) {
608  CAFFE_ENFORCE(label_proto.float_data_size() == num_labels_);
609  float* label_data =
610  prefetched_label_.mutable_data<float>() + item_id * num_labels_;
611  for (int i = 0; i < label_proto.float_data_size(); ++i) {
612  label_data[i] = label_proto.float_data(i);
613  }
614  } else {
615  LOG(ERROR) << "Unknown label type:" << label_type_;
616  }
617  } else if (label_proto.data_type() == TensorProto::INT32) {
618  if (label_type_ == SINGLE_LABEL || label_type_ == SINGLE_LABEL_WEIGHTED) {
619  DCHECK_EQ(label_proto.int32_data_size(), 1);
620  prefetched_label_.mutable_data<int>()[item_id] =
621  label_proto.int32_data(0);
622  } else if (label_type_ == MULTI_LABEL_SPARSE) {
623  int* label_data =
624  prefetched_label_.mutable_data<int>() + item_id * num_labels_;
625  memset(label_data, 0, sizeof(int) * num_labels_);
626  for (int i = 0; i < label_proto.int32_data_size(); ++i) {
627  label_data[label_proto.int32_data(i)] = 1;
628  }
629  } else if (label_type_ == MULTI_LABEL_WEIGHTED_SPARSE) {
630  const TensorProto& weight_proto = protos.protos(2);
631  float* label_data =
632  prefetched_label_.mutable_data<float>() + item_id * num_labels_;
633  memset(label_data, 0, sizeof(float) * num_labels_);
634  for (int i = 0; i < label_proto.int32_data_size(); ++i) {
635  label_data[label_proto.int32_data(i)] = weight_proto.float_data(i);
636  }
637  } else if (
638  label_type_ == MULTI_LABEL_DENSE || label_type_ == EMBEDDING_LABEL) {
639  CAFFE_ENFORCE(label_proto.int32_data_size() == num_labels_);
640  int* label_data =
641  prefetched_label_.mutable_data<int>() + item_id * num_labels_;
642  for (int i = 0; i < label_proto.int32_data_size(); ++i) {
643  label_data[i] = label_proto.int32_data(i);
644  }
645  } else {
646  LOG(ERROR) << "Unknown label type:" << label_type_;
647  }
648  } else {
649  LOG(FATAL) << "Unsupported label data type.";
650  }
651 
652  for (int i = 0; i < additional_output_protos.size(); ++i) {
653  auto additional_output_proto = additional_output_protos[i];
654  if (additional_output_proto.data_type() == TensorProto::FLOAT) {
655  float* additional_output =
656  prefetched_additional_outputs_[i].template mutable_data<float>() +
657  item_id * additional_output_proto.float_data_size();
658 
659  for (int j = 0; j < additional_output_proto.float_data_size(); ++j) {
660  additional_output[j] = additional_output_proto.float_data(j);
661  }
662  } else if (additional_output_proto.data_type() == TensorProto::INT32) {
663  int* additional_output =
664  prefetched_additional_outputs_[i].template mutable_data<int>() +
665  item_id * additional_output_proto.int32_data_size();
666 
667  for (int j = 0; j < additional_output_proto.int32_data_size(); ++j) {
668  additional_output[j] = additional_output_proto.int32_data(j);
669  }
670  } else if (additional_output_proto.data_type() == TensorProto::INT64) {
671  int64_t* additional_output =
672  prefetched_additional_outputs_[i].template mutable_data<int64_t>() +
673  item_id * additional_output_proto.int64_data_size();
674 
675  for (int j = 0; j < additional_output_proto.int64_data_size(); ++j) {
676  additional_output[j] = additional_output_proto.int64_data(j);
677  }
678  } else if (additional_output_proto.data_type() == TensorProto::UINT8) {
679  uint8_t* additional_output =
680  prefetched_additional_outputs_[i].template mutable_data<uint8_t>() +
681  item_id * additional_output_proto.int32_data_size();
682 
683  for (int j = 0; j < additional_output_proto.int32_data_size(); ++j) {
684  additional_output[j] =
685  static_cast<uint8_t>(additional_output_proto.int32_data(j));
686  }
687  } else {
688  LOG(FATAL) << "Unsupported output type.";
689  }
690  }
691  }
692 
693  //
694  // convert source to the color format requested from Op
695  //
696  int out_c = color_ ? 3 : 1;
697  if (out_c == src.channels()) {
698  *img = src;
699  } else {
700  cv::cvtColor(src, *img, (out_c == 1) ? cv::COLOR_BGR2GRAY : cv::COLOR_GRAY2BGR);
701  }
702 
703  // Note(Yangqing): I believe that the mat should be created continuous.
704  CAFFE_ENFORCE(img->isContinuous());
705 
706  // Sanity check now that we decoded everything
707 
708  // Ensure that the bounding box is legit
709  if (info.bounding_params.valid
710  && (src.rows < info.bounding_params.ymin + info.bounding_params.height
711  || src.cols < info.bounding_params.xmin + info.bounding_params.width
712  )) {
713  info.bounding_params.valid = false;
714  }
715 
716  // Apply the bounding box if requested
717  if (info.bounding_params.valid) {
718  // If we reach here, we know the parameters are sane
719  cv::Rect bounding_box(info.bounding_params.xmin, info.bounding_params.ymin,
720  info.bounding_params.width, info.bounding_params.height);
721  *img = (*img)(bounding_box);
722 
723  /*
724  LOG(INFO) << "Did bounding with ymin:"
725  << info.bounding_params.ymin << " xmin:" << info.bounding_params.xmin
726  << " height:" << info.bounding_params.height
727  << " width:" << info.bounding_params.width << "\n";
728  LOG(INFO) << "Bounded matrix: " << img;
729  */
730  } else {
731  // LOG(INFO) << "No bounding\n";
732  }
733 
734  cv::Mat scaled_img;
735  bool inception_scale_jitter = false;
736  if (scale_jitter_type_ == INCEPTION_STYLE) {
737  if (!is_test_) {
738  // Inception-stype scale jittering is only used for training
739  inception_scale_jitter = RandomSizedCropping<Context>(img, crop_, randgen);
740  // if a random crop is still not found, do simple random cropping later
741  }
742  }
743 
744  if ((scale_jitter_type_ == NO_SCALE_JITTER) ||
745  (scale_jitter_type_ == INCEPTION_STYLE && !inception_scale_jitter)) {
746  int scaled_width, scaled_height;
747  int scale_to_use = scale_ > 0 ? scale_ : minsize_;
748 
749  // set the random minsize
750  if (random_scaling_) {
751  scale_to_use = std::uniform_int_distribution<>(random_scale_[0],
752  random_scale_[1])(*randgen);
753  }
754 
755  if (warp_) {
756  scaled_width = scale_to_use;
757  scaled_height = scale_to_use;
758  } else if (img->rows > img->cols) {
759  scaled_width = scale_to_use;
760  scaled_height =
761  static_cast<float>(img->rows) * scale_to_use / img->cols;
762  } else {
763  scaled_height = scale_to_use;
764  scaled_width =
765  static_cast<float>(img->cols) * scale_to_use / img->rows;
766  }
767  if ((scale_ > 0 &&
768  (scaled_height != img->rows || scaled_width != img->cols))
769  || (scaled_height > img->rows || scaled_width > img->cols)) {
770  // We rescale in all cases if we are using scale_
771  // but only to make the image bigger if using minsize_
772  /*
773  LOG(INFO) << "Scaling to " << scaled_width << " x " << scaled_height
774  << " From " << img->cols << " x " << img->rows;
775  */
776  cv::resize(
777  *img,
778  scaled_img,
779  cv::Size(scaled_width, scaled_height),
780  0,
781  0,
782  cv::INTER_AREA);
783  *img = scaled_img;
784  }
785  }
786 
787  // TODO(Yangqing): return false if any error happens.
788  return true;
789 }
790 
791 // assume HWC order and color channels BGR
792 template <class Context>
793 void Saturation(
794  float* img,
795  const int img_size,
796  const float alpha_rand,
797  std::mt19937* randgen
798 ) {
799  float alpha = 1.0f +
800  std::uniform_real_distribution<float>(-alpha_rand, alpha_rand)(*randgen);
801  // BGR to Gray scale image: R -> 0.299, G -> 0.587, B -> 0.114
802  int p = 0;
803  for (int h = 0; h < img_size; ++h) {
804  for (int w = 0; w < img_size; ++w) {
805  float gray_color = img[3 * p] * 0.114f + img[3 * p + 1] * 0.587f +
806  img[3 * p + 2] * 0.299f;
807  for (int c = 0; c < 3; ++c) {
808  img[3 * p + c] = img[3 * p + c] * alpha + gray_color * (1.0f - alpha);
809  }
810  p++;
811  }
812  }
813 }
814 
815 // assume HWC order and color channels BGR
816 template <class Context>
817 void Brightness(
818  float* img,
819  const int img_size,
820  const float alpha_rand,
821  std::mt19937* randgen
822 ) {
823  float alpha = 1.0f +
824  std::uniform_real_distribution<float>(-alpha_rand, alpha_rand)(*randgen);
825  int p = 0;
826  for (int h = 0; h < img_size; ++h) {
827  for (int w = 0; w < img_size; ++w) {
828  for (int c = 0; c < 3; ++c) {
829  img[p++] *= alpha;
830  }
831  }
832  }
833 }
834 
835 // assume HWC order and color channels BGR
836 template <class Context>
837 void Contrast(
838  float* img,
839  const int img_size,
840  const float alpha_rand,
841  std::mt19937* randgen
842 ){
843  float gray_mean = 0;
844  int p = 0;
845  for (int h = 0; h < img_size; ++h) {
846  for (int w = 0; w < img_size; ++w) {
847  // BGR to Gray scale image: R -> 0.299, G -> 0.587, B -> 0.114
848  gray_mean += img[3 * p] * 0.114f + img[3 * p + 1] * 0.587f +
849  img[3 * p + 2] * 0.299f;
850  p++;
851  }
852  }
853  gray_mean /= (img_size * img_size);
854 
855  float alpha = 1.0f +
856  std::uniform_real_distribution<float>(-alpha_rand, alpha_rand)(*randgen);
857  p = 0;
858  for (int h = 0; h < img_size; ++h) {
859  for (int w = 0; w < img_size; ++w) {
860  for (int c = 0; c < 3; ++c) {
861  img[p] = img[p] * alpha + gray_mean * (1.0f - alpha);
862  p++;
863  }
864  }
865  }
866 }
867 
868 // assume HWC order and color channels BGR
869 template <class Context>
870 void ColorJitter(
871  float* img,
872  const int img_size,
873  const float saturation,
874  const float brightness,
875  const float contrast,
876  std::mt19937* randgen
877 ) {
878  std::srand (unsigned(std::time(0)));
879  std::vector<int> jitter_order{0, 1, 2};
880  // obtain a time-based seed:
881  unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
882  std::shuffle(jitter_order.begin(), jitter_order.end(),
883  std::default_random_engine(seed));
884 
885  for (int i = 0; i < 3; ++i) {
886  if (jitter_order[i] == 0) {
887  Saturation<Context>(img, img_size, saturation, randgen);
888  } else if (jitter_order[i] == 1) {
889  Brightness<Context>(img, img_size, brightness, randgen);
890  } else {
891  Contrast<Context>(img, img_size, contrast, randgen);
892  }
893  }
894 }
895 
896 // assume HWC order and color channels BGR
897 template <class Context>
898 void ColorLighting(
899  float* img,
900  const int img_size,
901  const float alpha_std,
902  const std::vector<std::vector<float>>& eigvecs,
903  const std::vector<float>& eigvals,
904  std::mt19937* randgen
905 ) {
906  std::normal_distribution<float> d(0, alpha_std);
907  std::vector<float> alphas(3);
908  for (int i = 0; i < 3; ++i) {
909  alphas[i] = d(*randgen);
910  }
911 
912  std::vector<float> delta_rgb(3, 0.0);
913  for (int i = 0; i < 3; ++i) {
914  for (int j = 0; j < 3; ++j) {
915  delta_rgb[i] += eigvecs[i][j] * eigvals[j] * alphas[j];
916  }
917  }
918 
919  int p = 0;
920  for (int h = 0; h < img_size; ++h) {
921  for (int w = 0; w < img_size; ++w) {
922  for (int c = 0; c < 3; ++c) {
923  img[p++] += delta_rgb[2 - c];
924  }
925  }
926  }
927 
928 }
929 
930 // assume HWC order and color channels BGR
931 // mean subtraction and scaling.
932 template <class Context>
933 void ColorNormalization(
934  float* img,
935  const int img_size,
936  const int channels,
937  const std::vector<float>& mean,
938  const std::vector<float>& std
939 ) {
940  int p = 0;
941  for (int h = 0; h < img_size; ++h) {
942  for (int w = 0; w < img_size; ++w) {
943  for (int c = 0; c < channels; ++c) {
944  img[p] = (img[p] - mean[c]) * std[c];
945  p++;
946  }
947  }
948  }
949 }
950 
951 // Factored out image transformation
952 template <class Context>
953 void TransformImage(
954  const cv::Mat& scaled_img,
955  const int channels,
956  float* image_data,
957  const bool color_jitter,
958  const float saturation,
959  const float brightness,
960  const float contrast,
961  const bool color_lighting,
962  const float color_lighting_std,
963  const std::vector<std::vector<float>>& color_lighting_eigvecs,
964  const std::vector<float>& color_lighting_eigvals,
965  const int crop,
966  const bool mirror,
967  const std::vector<float>& mean,
968  const std::vector<float>& std,
969  std::mt19937* randgen,
970  std::bernoulli_distribution* mirror_this_image,
971  bool is_test = false) {
972  CAFFE_ENFORCE_GE(
973  scaled_img.rows, crop, "Image height must be bigger than crop.");
974  CAFFE_ENFORCE_GE(
975  scaled_img.cols, crop, "Image width must be bigger than crop.");
976 
977  // find the cropped region, and copy it to the destination matrix
978  int width_offset, height_offset;
979  if (is_test) {
980  width_offset = (scaled_img.cols - crop) / 2;
981  height_offset = (scaled_img.rows - crop) / 2;
982  } else {
983  width_offset =
984  std::uniform_int_distribution<>(0, scaled_img.cols - crop)(*randgen);
985  height_offset =
986  std::uniform_int_distribution<>(0, scaled_img.rows - crop)(*randgen);
987  }
988 
989  float* image_data_ptr = image_data;
990  if (!is_test && mirror && (*mirror_this_image)(*randgen)) {
991  // Copy mirrored image.
992  for (int h = height_offset; h < height_offset + crop; ++h) {
993  for (int w = width_offset + crop - 1; w >= width_offset; --w) {
994  const uint8_t* cv_data = scaled_img.ptr(h) + w * channels;
995  for (int c = 0; c < channels; ++c) {
996  *(image_data_ptr++) = static_cast<float>(cv_data[c]);
997  }
998  }
999  }
1000  } else {
1001  // Copy normally.
1002  for (int h = height_offset; h < height_offset + crop; ++h) {
1003  for (int w = width_offset; w < width_offset + crop; ++w) {
1004  const uint8_t* cv_data = scaled_img.ptr(h) + w * channels;
1005  for (int c = 0; c < channels; ++c) {
1006  *(image_data_ptr++) = static_cast<float>(cv_data[c]);
1007  }
1008  }
1009  }
1010  }
1011 
1012  if (color_jitter && channels == 3 && !is_test) {
1013  ColorJitter<Context>(image_data, crop, saturation, brightness, contrast,
1014  randgen);
1015  }
1016  if (color_lighting && channels == 3 && !is_test) {
1017  ColorLighting<Context>(image_data, crop, color_lighting_std,
1018  color_lighting_eigvecs, color_lighting_eigvals, randgen);
1019  }
1020 
1021  // Color normalization
1022  // Mean subtraction and scaling.
1023  ColorNormalization<Context>(image_data, crop, channels, mean, std);
1024 }
1025 
1026 // Only crop / transose the image
1027 // leave in uint8_t dataType
1028 template <class Context>
1029 void CropTransposeImage(const cv::Mat& scaled_img, const int channels,
1030  uint8_t *cropped_data, const int crop,
1031  const bool mirror, std::mt19937 *randgen,
1032  std::bernoulli_distribution *mirror_this_image,
1033  bool is_test = false) {
1034  CAFFE_ENFORCE_GE(
1035  scaled_img.rows, crop, "Image height must be bigger than crop.");
1036  CAFFE_ENFORCE_GE(
1037  scaled_img.cols, crop, "Image width must be bigger than crop.");
1038 
1039  // find the cropped region, and copy it to the destination matrix
1040  int width_offset, height_offset;
1041  if (is_test) {
1042  width_offset = (scaled_img.cols - crop) / 2;
1043  height_offset = (scaled_img.rows - crop) / 2;
1044  } else {
1045  width_offset =
1046  std::uniform_int_distribution<>(0, scaled_img.cols - crop)(*randgen);
1047  height_offset =
1048  std::uniform_int_distribution<>(0, scaled_img.rows - crop)(*randgen);
1049  }
1050 
1051  if (mirror && (*mirror_this_image)(*randgen)) {
1052  // Copy mirrored image.
1053  for (int h = height_offset; h < height_offset + crop; ++h) {
1054  for (int w = width_offset + crop - 1; w >= width_offset; --w) {
1055  const uint8_t* cv_data = scaled_img.ptr(h) + w*channels;
1056  for (int c = 0; c < channels; ++c) {
1057  *(cropped_data++) = cv_data[c];
1058  }
1059  }
1060  }
1061  } else {
1062  // Copy normally.
1063  for (int h = height_offset; h < height_offset + crop; ++h) {
1064  for (int w = width_offset; w < width_offset + crop; ++w) {
1065  const uint8_t* cv_data = scaled_img.ptr(h) + w*channels;
1066  for (int c = 0; c < channels; ++c) {
1067  *(cropped_data++) = cv_data[c];
1068  }
1069  }
1070  }
1071  }
1072 }
1073 
1074 // Parse datum, decode image, perform transform
1075 // Intended as entry point for binding to thread pool
1076 template <class Context>
1078  const std::string& value, float *image_data, int item_id,
1079  const int channels, std::size_t thread_index) {
1080 
1081  CAFFE_ENFORCE((int)thread_index < num_decode_threads_);
1082 
1083  std::bernoulli_distribution mirror_this_image(0.5f);
1084  std::mt19937* randgen = &(randgen_per_thread_[thread_index]);
1085 
1086  cv::Mat img;
1087  // Decode the image
1088  PerImageArg info;
1089  CHECK(
1090  GetImageAndLabelAndInfoFromDBValue(value, &img, info, item_id, randgen));
1091  // Factor out the image transformation
1092  TransformImage<Context>(img, channels, image_data,
1093  color_jitter_, img_saturation_, img_brightness_, img_contrast_,
1094  color_lighting_, color_lighting_std_, color_lighting_eigvecs_,
1095  color_lighting_eigvals_, crop_, mirror_, mean_, std_,
1096  randgen, &mirror_this_image, is_test_);
1097 }
1098 
1099 template <class Context>
1101  const std::string& value, uint8_t *image_data, int item_id,
1102  const int channels, std::size_t thread_index) {
1103 
1104  CAFFE_ENFORCE((int)thread_index < num_decode_threads_);
1105 
1106  std::bernoulli_distribution mirror_this_image(0.5f);
1107  std::mt19937* randgen = &(randgen_per_thread_[thread_index]);
1108 
1109  cv::Mat img;
1110  // Decode the image
1111  PerImageArg info;
1112  CHECK(
1113  GetImageAndLabelAndInfoFromDBValue(value, &img, info, item_id, randgen));
1114 
1115  // Factor out the image transformation
1116  CropTransposeImage<Context>(img, channels, image_data, crop_, mirror_,
1117  randgen, &mirror_this_image, is_test_);
1118 }
1119 
1120 
1121 template <class Context>
1123  if (!owned_reader_.get()) {
1124  // if we are not owning the reader, we will get the reader pointer from
1125  // input. Otherwise the constructor should have already set the reader
1126  // pointer.
1127  reader_ = &OperatorBase::Input<db::DBReader>(0);
1128  }
1129  const int channels = color_ ? 3 : 1;
1130  // Call mutable_data() once to allocate the underlying memory.
1131  if (gpu_transform_) {
1132  // we'll transfer up in int8, then convert later
1133  prefetched_image_.mutable_data<uint8_t>();
1134  } else {
1135  prefetched_image_.mutable_data<float>();
1136  }
1137 
1138  prefetched_label_.mutable_data<int>();
1139  // Prefetching handled with a thread pool of "decode_threads" threads.
1140 
1141  for (int item_id = 0; item_id < batch_size_; ++item_id) {
1142  std::string key, value;
1143  cv::Mat img;
1144 
1145  // read data
1146  reader_->Read(&key, &value);
1147 
1148  // determine label type based on first item
1149  if( item_id == 0 ) {
1150  if( use_caffe_datum_ ) {
1151  prefetched_label_.mutable_data<int>();
1152  } else {
1153  TensorProtos protos;
1154  CAFFE_ENFORCE(protos.ParseFromString(value));
1155  TensorProto_DataType labeldt = protos.protos(1).data_type();
1156  if( labeldt == TensorProto::INT32 ) {
1157  prefetched_label_.mutable_data<int>();
1158  } else if ( labeldt == TensorProto::FLOAT) {
1159  prefetched_label_.mutable_data<float>();
1160  } else {
1161  LOG(FATAL) << "Unsupported label type.";
1162  }
1163 
1164  for (int i = 0; i < additional_inputs_count_; ++i) {
1165  int index = additional_inputs_offset_ + i;
1166  TensorProto additional_output_proto = protos.protos(index);
1167  auto sizes = std::vector<int64_t>({batch_size_, additional_output_sizes_[i]});
1168  if (additional_output_proto.data_type() == TensorProto::FLOAT) {
1169  prefetched_additional_outputs_[i] =
1170  caffe2::empty(sizes, at::dtype<float>().device(CPU));
1171  } else if (
1172  additional_output_proto.data_type() == TensorProto::INT32) {
1173  prefetched_additional_outputs_[i] =
1174  caffe2::empty(sizes, at::dtype<int>().device(CPU));
1175  } else if (
1176  additional_output_proto.data_type() == TensorProto::INT64) {
1177  prefetched_additional_outputs_[i] =
1178  caffe2::empty(sizes, at::dtype<int64_t>().device(CPU));
1179  } else if (
1180  additional_output_proto.data_type() == TensorProto::UINT8) {
1181  prefetched_additional_outputs_[i] =
1182  caffe2::empty(sizes, at::dtype<uint8_t>().device(CPU));
1183  } else {
1184  LOG(FATAL) << "Unsupported output type.";
1185  }
1186  }
1187  }
1188  }
1189 
1190  // launch into thread pool for processing
1191  // TODO: support color jitter and color lighting in gpu_transform
1192  if (gpu_transform_) {
1193  // output of decode will still be int8
1194  uint8_t* image_data = prefetched_image_.mutable_data<uint8_t>() +
1195  crop_ * crop_ * channels * item_id;
1196  thread_pool_->runTaskWithID(std::bind(
1198  this,
1199  std::string(value),
1200  image_data,
1201  item_id,
1202  channels,
1203  std::placeholders::_1));
1204  } else {
1205  float* image_data = prefetched_image_.mutable_data<float>() +
1206  crop_ * crop_ * channels * item_id;
1207  thread_pool_->runTaskWithID(std::bind(
1209  this,
1210  std::string(value),
1211  image_data,
1212  item_id,
1213  channels,
1214  std::placeholders::_1));
1215  }
1216  }
1217  thread_pool_->waitWorkComplete();
1218 
1219  // we allow to get at most max_decode_error_ratio from
1220  // opencv imdecode until raising a runtime exception
1221  if ((float)num_decode_errors_in_batch_ / batch_size_ >
1222  max_decode_error_ratio_) {
1223  throw std::runtime_error(
1224  "max_decode_error_ratio exceeded " +
1225  c10::to_string(max_decode_error_ratio_));
1226  }
1227 
1228  // If the context is not CPUContext, we will need to do a copy in the
1229  // prefetch function as well.
1230  auto device = at::device(Context::GetDeviceType());
1231  if (!std::is_same<Context, CPUContext>::value) {
1232  // do sync copies
1233  ReinitializeAndCopyFrom(
1234  &prefetched_image_on_device_, device, prefetched_image_);
1235  ReinitializeAndCopyFrom(
1236  &prefetched_label_on_device_, device, prefetched_label_);
1237 
1238  for (int i = 0; i < prefetched_additional_outputs_on_device_.size(); ++i) {
1239  ReinitializeAndCopyFrom(
1240  &prefetched_additional_outputs_on_device_[i],
1241  device,
1242  prefetched_additional_outputs_[i]);
1243  }
1244  }
1245 
1246  num_decode_errors_in_batch_ = 0;
1247 
1248  return true;
1249 }
1250 
1251 template <class Context>
1253  auto type = Device(Context::GetDeviceType());
1254  auto options = at::device(type);
1255 
1256  // Note(jiayq): The if statement below should be optimized away by the
1257  // compiler since std::is_same is a constexpr.
1258  if (std::is_same<Context, CPUContext>::value) {
1259  OperatorBase::OutputTensorCopyFrom(
1260  0, options, prefetched_image_, /* async */ true);
1261  OperatorBase::OutputTensorCopyFrom(
1262  1, options, prefetched_label_, /* async */ true);
1263 
1264  for (int i = 2; i < OutputSize(); ++i) {
1265  OperatorBase::OutputTensorCopyFrom(
1266  i, options, prefetched_additional_outputs_[i - 2], /* async */ true);
1267  }
1268  } else {
1269  // TODO: support color jitter and color lighting in gpu_transform
1270  if (gpu_transform_) {
1271  if (!mean_std_copied_) {
1273  &mean_gpu_,
1274  {static_cast<int64_t>(mean_.size())},
1275  at::dtype<float>().device(Context::GetDeviceType()));
1277  &std_gpu_,
1278  {static_cast<int64_t>(std_.size())},
1279  at::dtype<float>().device(Context::GetDeviceType()));
1280 
1281  context_.template CopyFromCPU<float>(
1282  mean_.size(),
1283  mean_.data(),
1284  mean_gpu_.template mutable_data<float>());
1285  context_.template CopyFromCPU<float>(
1286  std_.size(), std_.data(), std_gpu_.template mutable_data<float>());
1287  mean_std_copied_ = true;
1288  }
1289  const auto& X = prefetched_image_on_device_;
1290  // data comes in as NHWC
1291  const int N = X.dim32(0), C = X.dim32(3), H = X.dim32(1), W = X.dim32(2);
1292  // data goes out as NCHW
1293  auto dims = std::vector<int64_t>{N, C, H, W};
1294  if (!ApplyTransformOnGPU(dims, type)) {
1295  return false;
1296  }
1297 
1298  } else {
1299  OperatorBase::OutputTensorCopyFrom(
1300  0, type, prefetched_image_on_device_, /* async */ true);
1301  }
1302  OperatorBase::OutputTensorCopyFrom(
1303  1, type, prefetched_label_on_device_, /* async */ true);
1304 
1305  for (int i = 2; i < OutputSize(); ++i) {
1306  OperatorBase::OutputTensorCopyFrom(
1307  i,
1308  type,
1309  prefetched_additional_outputs_on_device_[i - 2],
1310  /* async */ true);
1311  }
1312  }
1313  return true;
1314 }
1315 } // namespace caffe2
1316 
1317 #endif // CAFFE2_IMAGE_IMAGE_INPUT_OP_H_
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:228
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:144
A helper class to index into arguments.
Definition: proto_utils.h:200
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64