Caffe2 - C++ API
A deep learning, cross platform ML framework
video_input_op.h
1 
17 #ifndef CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
18 #define CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
19 
20 #include <istream>
21 #include <ostream>
22 #include <random>
23 #include <string>
24 
25 #include <caffe2/core/db.h>
26 #include <caffe2/core/logging.h>
27 #include <caffe2/operators/prefetch_op.h>
28 #include <caffe2/utils/math.h>
29 #include <caffe2/utils/thread_pool.h>
30 #include <caffe2/video/video_io.h>
31 
32 namespace caffe2 {
33 
34 template <class Context>
35 class VideoInputOp final : public PrefetchOperator<Context> {
36  public:
37  using OperatorBase::OutputSize;
40  explicit VideoInputOp(const OperatorDef& operator_def, Workspace* ws);
41  ~VideoInputOp() {
43  }
44 
45  // override methods
46  bool Prefetch() override;
47  bool CopyPrefetched() override;
48 
49  private:
50  void CheckParamsAndPrint();
51 
52  bool GetClipsAndLabelsFromDBValue(
53  const std::string& value,
54  int& height,
55  int& width,
56  std::vector<unsigned char*>& buffer_rgb,
57  int* label_data,
58  int* video_id_data);
59 
60  void DecodeAndTransform(
61  const std::string& value,
62  float* clip_rgb_data,
63  float* clip_of_data,
64  int* label_data,
65  int* video_id_data,
66  std::mt19937* randgen,
67  std::bernoulli_distribution* mirror_this_clip);
68 
69  const db::DBReader* reader_;
70  CPUContext cpu_context_;
71  TensorCPU prefetched_clip_rgb_;
72  TensorCPU prefetched_clip_of_;
73  TensorCPU prefetched_label_;
74  TensorCPU prefetched_video_id_;
75  Tensor<Context> prefetched_clip_rgb_on_device_;
76  Tensor<Context> prefetched_clip_of_on_device_;
77  Tensor<Context> prefetched_label_on_device_;
78  Tensor<Context> prefetched_video_id_on_device_;
79  int batch_size_;
80  int clip_per_video_;
81  std::vector<float> mean_rgb_;
82  std::vector<float> inv_std_rgb_;
83  std::vector<float> mean_of_;
84  std::vector<float> inv_std_of_;
85  int channels_rgb_;
86  int channels_of_;
87  int crop_height_;
88  int crop_width_;
89  int scale_h_;
90  int scale_w_;
91  int height_min_;
92  int width_min_;
93  int length_rgb_;
94  int sampling_rate_rgb_;
95  bool color_jitter_;
96  float img_saturation_;
97  float img_brightness_;
98  float img_contrast_;
99  bool color_lighting_;
100  float color_lighting_std_;
101  std::vector<std::vector<float>> color_lighting_eigvecs_;
102  std::vector<float> color_lighting_eigvals_;
103  int num_of_required_frame_;
104  int length_of_;
105  int sampling_rate_of_;
106  int frame_gap_of_;
107  bool random_mirror_;
108  int num_of_class_;
109  bool use_local_file_;
110  bool random_crop_;
111  bool multi_crop_;
112  int multi_crop_count_;
113  int flow_data_type_;
114  int flow_alg_type_;
115  int decode_type_;
116  int video_res_type_;
117  bool do_flow_aggregation_;
118  bool get_rgb_;
119  bool get_optical_flow_;
120  bool get_video_id_;
121  bool do_multi_label_;
122 
123  // thread pool for parse + decode
124  int num_decode_threads_;
125  std::shared_ptr<TaskThreadPool> thread_pool_;
126 };
127 
128 template <class Context>
130  // check whether the input parameters are valid or not
131  CAFFE_ENFORCE_GT(batch_size_, 0, "Batch size should be positive.");
132  CAFFE_ENFORCE_GT(
133  clip_per_video_, 0, "Number of clips per video should be positive.");
134  CAFFE_ENFORCE_GT(crop_height_, 0, "Must provide the cropping height value.");
135  CAFFE_ENFORCE_GT(crop_width_, 0, "Must provide the cropping width value.");
136 
137  CAFFE_ENFORCE_GT(
138  num_of_required_frame_, 0, "Required number of frames must be positive.");
139 
140  if (video_res_type_ == VideoResType::USE_MINIMAL_WIDTH_HEIGHT) {
141  CAFFE_ENFORCE_GT(height_min_, 0, "Must provide the minimal height value.");
142  CAFFE_ENFORCE_GT(width_min_, 0, "Must provide the minimal width value.");
143  CAFFE_ENFORCE_GE(
144  height_min_,
145  crop_height_,
146  "The minimal height must be no smaller than the cropping height.");
147  CAFFE_ENFORCE_GE(
148  width_min_,
149  crop_width_,
150  "The minimal width must be no smaller than the cropping width.");
151  } else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
152  CAFFE_ENFORCE_GT(scale_h_, 0, "Must provide the scale height value.");
153  CAFFE_ENFORCE_GT(scale_w_, 0, "Must provide the scale width value.");
154  CAFFE_ENFORCE_GE(
155  scale_h_,
156  crop_height_,
157  "The scaled height must be no smaller than the cropping height.");
158  CAFFE_ENFORCE_GE(
159  scale_w_,
160  crop_width_,
161  "The scaled width must be no smaller than the cropping width.");
162  }
163 
164  if (get_rgb_) {
165  CAFFE_ENFORCE_GT(length_rgb_, 0, "Must provide rgb clip length.");
166  CAFFE_ENFORCE_GT(
167  sampling_rate_rgb_, 0, "4 frames for mc2; 2 frames for res3d.");
168  CAFFE_ENFORCE_EQ(
169  channels_rgb_, mean_rgb_.size(), "Number rgb channels is wrong!");
170  CAFFE_ENFORCE_EQ(
171  channels_rgb_, inv_std_rgb_.size(), "Number rgb channels is wrong!");
172  }
173 
174  if (get_optical_flow_) {
175  CAFFE_ENFORCE_GT(length_of_, 0, "Must provide optical flow clip length.");
176  CAFFE_ENFORCE_GT(
177  sampling_rate_of_, 0, "4 frames for mc2; 2 frames for res3d.");
178  CAFFE_ENFORCE_EQ(
179  channels_of_,
180  mean_of_.size(),
181  "Number of optical flow channels is wrong!");
182  CAFFE_ENFORCE_EQ(
183  channels_of_,
184  inv_std_of_.size(),
185  "Number of optical flow channels is wrong!");
186  }
187 
188  if (clip_per_video_ > 1) {
189  CAFFE_ENFORCE_EQ(
190  decode_type_,
191  DecodeType::DO_UNIFORM_SMP,
192  "Only uniformly sampling is supported when sampling multiple clips!");
193  }
194 
195  if (do_multi_label_) {
196  CAFFE_ENFORCE_GT(
197  num_of_class_,
198  0,
199  "Number of classes must be set when using multiple labels.");
200  }
201 
202  // print out the parameter settings
203  LOG(INFO) << "Creating a clip input op with the following setting: ";
204  LOG(INFO) << " Using " << num_decode_threads_ << " CPU threads;";
205  LOG(INFO) << " Outputting in batches of " << batch_size_ << " videos;";
206  LOG(INFO) << " Each video has " << clip_per_video_ << " clips;";
207  LOG(INFO) << " Scaling image to " << scale_h_ << "x" << scale_w_;
208  LOG(INFO) << " (Height, Width) is at least (" << height_min_ << ", "
209  << width_min_ << ")";
210  LOG(INFO) << " Cropping video frame to " << crop_height_ << "x"
211  << crop_width_ << (random_mirror_ ? " with " : " without ")
212  << "random mirroring;";
213  LOG(INFO) << " Using " << (random_crop_ ? "random" : "center") << " crop";
214  LOG(INFO) << " Is multi-cropping enabled: " << multi_crop_;
215 
216  if (get_rgb_) {
217  LOG(INFO) << " Using a clip of " << length_rgb_ << " rgb frames "
218  << "with " << channels_rgb_ << " channels "
219  << "and a sampling rate of 1:" << sampling_rate_rgb_;
220  LOG(INFO) << " RGB data augmentation. Color jittering: " << color_jitter_
221  << ". Color lighting: " << color_lighting_;
222  for (int i = 0; i < channels_rgb_; i++) {
223  LOG(INFO) << " RGB " << i << "-th channel mean: " << mean_rgb_[i]
224  << " std: " << 1.f / inv_std_rgb_[i];
225  }
226  }
227 
228  if (get_optical_flow_) {
229  LOG(INFO) << " Using a clip of " << length_of_ << " optical flow frames "
230  << "with " << channels_of_ << " channels "
231  << "and a sampling rate of 1:" << sampling_rate_of_
232  << " flow_data_type_: " << flow_data_type_
233  << " flow_alg_type_: " << flow_alg_type_;
234  for (int i = 0; i < channels_of_; i++) {
235  LOG(INFO) << " Optical flow" << i
236  << "-th channel mean: " << mean_of_[i]
237  << " std: " << 1.f / inv_std_of_[i];
238  }
239  }
240 
241  if (video_res_type_ == VideoResType::ORIGINAL_RES) {
242  LOG(INFO) << " Use original resolution";
243  } else if (video_res_type_ == VideoResType::USE_MINIMAL_WIDTH_HEIGHT) {
244  LOG(INFO) << " Resize with minimal size and keep aspect ratio";
245  } else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
246  LOG(INFO) << " Resize and ignore aspect ratio";
247  } else {
248  LOG(ERROR) << " Unknown video resolution type";
249  }
250 
251  if (decode_type_ == DecodeType::DO_TMP_JITTER) {
252  LOG(INFO) << " Do temporal jittering";
253  } else if (decode_type_ == DecodeType::USE_START_FRM) {
254  LOG(INFO) << " Use start_frm for decoding";
255  } else if (decode_type_ == DecodeType::DO_UNIFORM_SMP) {
256  LOG(INFO) << " Do uniformly sampling";
257  } else {
258  LOG(ERROR) << " Unknown video decoding type";
259  }
260 }
261 
262 template <class Context>
264  const OperatorDef& operator_def,
265  Workspace* ws)
266  : PrefetchOperator<Context>(operator_def, ws),
267  reader_(nullptr),
268  batch_size_(
269  OperatorBase::template GetSingleArgument<int>("batch_size", 0)),
270  clip_per_video_(
271  OperatorBase::template GetSingleArgument<int>("clip_per_video", 1)),
272  mean_rgb_(OperatorBase::template GetRepeatedArgument<float>(
273  "mean_rgb_per_channel",
274  {OperatorBase::template GetSingleArgument<float>("mean_rgb", 128.)})),
275  inv_std_rgb_(OperatorBase::template GetRepeatedArgument<float>(
276  "std_rgb_per_channel",
277  {OperatorBase::template GetSingleArgument<float>("std_rgb", 1.)})),
278  mean_of_(OperatorBase::template GetRepeatedArgument<float>(
279  "mean_of_per_channel",
280  {OperatorBase::template GetSingleArgument<float>("mean_of", 0.)})),
281  inv_std_of_(OperatorBase::template GetRepeatedArgument<float>(
282  "std_of_per_channel",
283  {OperatorBase::template GetSingleArgument<float>("std_of", 1.)})),
284  channels_rgb_(
285  OperatorBase::template GetSingleArgument<int>("channels_rgb", 3)),
286  channels_of_(
287  OperatorBase::template GetSingleArgument<int>("channels_of", 2)),
288  crop_height_(OperatorBase::template GetSingleArgument<int>(
289  "crop_height",
290  {OperatorBase::template GetSingleArgument<int>("crop_size", 0.)})),
291  crop_width_(OperatorBase::template GetSingleArgument<int>(
292  "crop_width",
293  {OperatorBase::template GetSingleArgument<int>("crop_size", 0.)})),
294  scale_h_(OperatorBase::template GetSingleArgument<int>("scale_h", 0)),
295  scale_w_(OperatorBase::template GetSingleArgument<int>("scale_w", 0)),
296  height_min_(OperatorBase::template GetSingleArgument<int>(
297  "height_min",
298  {OperatorBase::template GetSingleArgument<int>("short_edge", 0)})),
299  width_min_(OperatorBase::template GetSingleArgument<int>(
300  "width_min",
301  {OperatorBase::template GetSingleArgument<int>("short_edge", 0)})),
302  length_rgb_(
303  OperatorBase::template GetSingleArgument<int>("length_rgb", 0)),
304  sampling_rate_rgb_(OperatorBase::template GetSingleArgument<int>(
305  "sampling_rate_rgb",
306  1)),
307  color_jitter_(OperatorBase::template GetSingleArgument<bool>(
308  "color_jitter",
309  false)),
310  img_saturation_(OperatorBase::template GetSingleArgument<float>(
311  "img_saturation",
312  0.4)),
313  img_brightness_(OperatorBase::template GetSingleArgument<float>(
314  "img_brightness",
315  0.4)),
316  img_contrast_(
317  OperatorBase::template GetSingleArgument<float>("img_contrast", 0.4)),
318  color_lighting_(OperatorBase::template GetSingleArgument<bool>(
319  "color_lighting",
320  false)),
321  color_lighting_std_(OperatorBase::template GetSingleArgument<float>(
322  "color_lighting_std",
323  0.1)),
324  length_of_(OperatorBase::template GetSingleArgument<int>("length_of", 0)),
325  sampling_rate_of_(
326  OperatorBase::template GetSingleArgument<int>("sampling_rate_of", 1)),
327  frame_gap_of_(
328  OperatorBase::template GetSingleArgument<int>("frame_gap_of", 1)),
329  random_mirror_(OperatorBase::template GetSingleArgument<bool>(
330  "random_mirror",
331  true)),
332  num_of_class_(
333  OperatorBase::template GetSingleArgument<int>("num_of_class", 0)),
334  use_local_file_(OperatorBase::template GetSingleArgument<bool>(
335  "use_local_file",
336  false)),
337  random_crop_(
338  OperatorBase::template GetSingleArgument<bool>("random_crop", true)),
339  multi_crop_(
340  OperatorBase::template GetSingleArgument<bool>("multi_crop", false)),
341  flow_data_type_(
342  OperatorBase::template GetSingleArgument<int>("flow_data_type", 0)),
343  flow_alg_type_(
344  OperatorBase::template GetSingleArgument<int>("flow_alg_type", 0)),
345  decode_type_(
346  OperatorBase::template GetSingleArgument<int>("decode_type", 0)),
347  video_res_type_(
348  OperatorBase::template GetSingleArgument<int>("video_res_type", 0)),
349  do_flow_aggregation_(OperatorBase::template GetSingleArgument<bool>(
350  "do_flow_aggregation",
351  true)),
352  get_rgb_(OperatorBase::template GetSingleArgument<bool>("get_rgb", true)),
353  get_optical_flow_(OperatorBase::template GetSingleArgument<bool>(
354  "get_optical_flow",
355  false)),
356  get_video_id_(OperatorBase::template GetSingleArgument<bool>(
357  "get_video_id",
358  false)),
359  do_multi_label_(OperatorBase::template GetSingleArgument<bool>(
360  "do_multi_label",
361  false)),
362  num_decode_threads_(OperatorBase::template GetSingleArgument<int>(
363  "num_decode_threads",
364  4)),
365  thread_pool_(std::make_shared<TaskThreadPool>(num_decode_threads_)) {
366  // hard-coded PCA eigenvectors and eigenvalues, based on RBG channel order
367  color_lighting_eigvecs_.push_back(
368  std::vector<float>{-144.7125, 183.396, 102.2295});
369  color_lighting_eigvecs_.push_back(
370  std::vector<float>{-148.104, -1.1475, -207.57});
371  color_lighting_eigvecs_.push_back(
372  std::vector<float>{-148.818, -177.174, 107.1765});
373 
374  color_lighting_eigvals_ = std::vector<float>{0.2175, 0.0188, 0.0045};
375 
376  // multi-cropping for testing
377  multi_crop_count_ = 1;
378  if (multi_crop_) {
379  // we take left-top, central-top, right-top, left-bottom, central-bottom,
380  // right-bottom and central-central croppings as well as their mirrorings
381  // In total, 14 croppings
382  multi_crop_count_ = 14;
383  }
384 
385  num_of_required_frame_ = 0;
386 
387  // mean and std for normalizing different optical flow data type;
388  // Example statistics generated from SOA are shown below, and you may
389  // want to change them if you are running on a different dataset;
390 
391  // 7 channels: (flow_x, flow_y, flow_magitude, gray, Red, Green, Blue)
392  // const std::vector<float> InputDataMean =
393  // {0.0046635, 0.0046261, 0.963986, 102.976, 110.201, 100.64, 95.9966};
394  // const std::vector<float> InputDataStd =
395  // {0.972347, 0.755146, 1.43588, 55.3691, 58.1489, 56.4701, 55.3324};
396 
397  // if we need RGB as an input
398  if (get_rgb_) {
399  // how many frames we need for RGB
400  num_of_required_frame_ = std::max(
401  num_of_required_frame_, (length_rgb_ - 1) * sampling_rate_rgb_ + 1);
402 
403  channels_rgb_ = 3;
404 
405  CAFFE_ENFORCE_EQ(
406  mean_rgb_.size(),
407  inv_std_rgb_.size(),
408  "The mean and std. vectors for RGB must be of the same size.");
409  if (mean_rgb_.size() == 1) {
410  mean_rgb_.resize(3, mean_rgb_[0]);
411  inv_std_rgb_.resize(3, inv_std_rgb_[0]);
412  }
413  CAFFE_ENFORCE_EQ(mean_rgb_.size(), 3, "RGB should have 3 channels");
414  for (int i = 0; i < 3; ++i) {
415  inv_std_rgb_[i] = 1.f / inv_std_rgb_[i];
416  }
417  }
418  // if we need optical flow as an input
419  if (get_optical_flow_) {
420  // how many frames we need for optical flow
421  num_of_required_frame_ = std::max(
422  num_of_required_frame_,
423  (length_of_ - 1) * sampling_rate_of_ + frame_gap_of_ + 1);
424 
425  CAFFE_ENFORCE_EQ(
426  mean_of_.size(),
427  inv_std_of_.size(),
428  "The mean and std. vectors for Optical Flow must be of the same size.");
429  // set the parameters for different input data types
430  switch (flow_data_type_) {
431  // (flow_x, flow_y)
432  case FlowDataType::Flow2C:
433  channels_of_ = 2;
434  break;
435  // (flow_x, flow_y, flow_magnitude)
436  case FlowDataType::Flow3C:
437  channels_of_ = 3;
438  break;
439  // early fusion with gray
440  // (flow_x, flow_y, gray)
441  case FlowDataType::FlowWithGray:
442  channels_of_ = 3;
443  break;
444  // early fusion with RGB
445  // (flow_x, flow_y, Red, Green, Blue)
446  case FlowDataType::FlowWithRGB:
447  channels_of_ = 5;
448  break;
449  default:
450  LOG(ERROR) << "Unknown optical flow type " << flow_data_type_;
451  break;
452  }
453  LOG(INFO) << "channels_of_: " << channels_of_;
454  if (mean_of_.size() == 1) {
455  mean_of_.resize(channels_of_, mean_of_[0]);
456  inv_std_of_.resize(channels_of_, inv_std_of_[0]);
457  }
458  for (int i = 0; i < channels_of_; ++i) {
459  inv_std_of_[i] = 1.f / inv_std_of_[i];
460  }
461  }
462 
463  CheckParamsAndPrint();
464  // Always need a dbreader, even when using local video files
465  CAFFE_ENFORCE_GT(
466  operator_def.input_size(), 0, "Need to have a DBReader blob input");
467 
468  vector<TIndex> data_shape(5);
469  vector<TIndex> label_shape(2);
470 
471  // for RGB data
472  data_shape[0] = batch_size_ * clip_per_video_ * multi_crop_count_;
473  data_shape[1] = channels_rgb_;
474  data_shape[2] = length_rgb_;
475  data_shape[3] = crop_height_;
476  data_shape[4] = crop_width_;
477  prefetched_clip_rgb_.Resize(data_shape);
478 
479  // for optical flow data
480  data_shape[1] = channels_of_;
481  data_shape[2] = length_of_;
482  prefetched_clip_of_.Resize(data_shape);
483 
484  // If do_multi_label is used, output label is a binary vector
485  // of length num_of_class indicating which labels present
486  if (do_multi_label_) {
487  label_shape[0] = batch_size_ * clip_per_video_ * multi_crop_count_;
488  label_shape[1] = num_of_class_;
489  prefetched_label_.Resize(label_shape);
490  } else {
491  prefetched_label_.Resize(
492  vector<TIndex>(1, batch_size_ * clip_per_video_ * multi_crop_count_));
493  }
494 
495  prefetched_video_id_.Resize(
496  vector<TIndex>(1, batch_size_ * clip_per_video_ * multi_crop_count_));
497 }
498 
499 template <class Context>
501  const std::string& value,
502  int& height,
503  int& width,
504  std::vector<unsigned char*>& buffer_rgb,
505  int* label_data,
506  int* video_id_data) {
507  TensorProtos protos;
508  int curr_proto_idx = 0;
509  CAFFE_ENFORCE(protos.ParseFromString(value));
510  const TensorProto& video_proto = protos.protos(curr_proto_idx++);
511  const TensorProto& label_proto = protos.protos(curr_proto_idx++);
512 
513  int start_frm = 0;
514  // start_frm is only valid when sampling 1 clip per video without
515  // temporal jitterring
516  if (decode_type_ == DecodeType::USE_START_FRM) {
517  CAFFE_ENFORCE_LT(
518  curr_proto_idx,
519  protos.protos_size(),
520  "No proto is found for starting frame");
521  const TensorProto& start_frm_proto = protos.protos(curr_proto_idx++);
522  start_frm = start_frm_proto.int32_data(0);
523  }
524  if (get_video_id_) {
525  CAFFE_ENFORCE_LT(
526  curr_proto_idx, protos.protos_size(), "No proto is found for video id");
527  const TensorProto& video_id_proto = protos.protos(curr_proto_idx);
528  for (int i = 0; i < clip_per_video_ * multi_crop_count_; i++) {
529  video_id_data[i] = video_id_proto.int64_data(0);
530  }
531  }
532  // assign labels
533  if (!do_multi_label_) {
534  for (int i = 0; i < clip_per_video_ * multi_crop_count_; i++) {
535  label_data[i] = label_proto.int32_data(0);
536  }
537  } else {
538  // For multiple label case, output label is a binary vector
539  // where presented concepts are makred 1
540  memset(
541  label_data,
542  0,
543  sizeof(int) * num_of_class_ * multi_crop_count_ * clip_per_video_);
544  for (int i = 0; i < clip_per_video_; i++) {
545  for (int j = 0; j < multi_crop_count_; ++j) {
546  for (int k = 0; k < label_proto.int32_data_size(); k++) {
547  label_data
548  [(i * multi_crop_count_ + j) * num_of_class_ +
549  label_proto.int32_data(k)] = 1;
550  }
551  }
552  }
553  }
554 
555  if (use_local_file_) {
556  CAFFE_ENFORCE_EQ(
557  video_proto.data_type(),
558  TensorProto::STRING,
559  "Database with a file_list is expected to be string data");
560  }
561 
562  // initializing the decoding params
563  Params params;
564  params.maximumOutputFrames_ = MAX_DECODING_FRAMES;
565  params.video_res_type_ = video_res_type_;
566  params.crop_height_ = crop_height_;
567  params.crop_width_ = crop_width_;
568  params.height_min_ = height_min_;
569  params.width_min_ = width_min_;
570  params.scale_w_ = scale_w_;
571  params.scale_h_ = scale_h_;
572  params.decode_type_ = decode_type_;
573  params.num_of_required_frame_ = num_of_required_frame_;
574 
575  char* video_buffer = nullptr; // for decoding from buffer
576  std::string video_filename; // for decoding from file
577  int encoded_size = 0;
578  if (video_proto.data_type() == TensorProto::STRING) {
579  const string& encoded_video_str = video_proto.string_data(0);
580  if (!use_local_file_) {
581  encoded_size = encoded_video_str.size();
582  video_buffer = const_cast<char*>(encoded_video_str.data());
583  } else {
584  video_filename = encoded_video_str;
585  }
586  } else if (video_proto.data_type() == TensorProto::BYTE) {
587  if (!use_local_file_) {
588  encoded_size = video_proto.byte_data().size();
589  video_buffer = const_cast<char*>(video_proto.byte_data().data());
590  } else {
591  // TODO: does this works?
592  video_filename = video_proto.string_data(0);
593  }
594  } else {
595  LOG(FATAL) << "Unknown video data type.";
596  }
597 
598  DecodeMultipleClipsFromVideo(
599  video_buffer,
600  video_filename,
601  encoded_size,
602  params,
603  start_frm,
604  clip_per_video_,
605  use_local_file_,
606  height,
607  width,
608  buffer_rgb);
609 
610  return true;
611 }
612 
613 template <class Context>
615  const std::string& value,
616  float* clip_rgb_data,
617  float* clip_of_data,
618  int* label_data,
619  int* video_id_data,
620  std::mt19937* randgen,
621  std::bernoulli_distribution* mirror_this_clip) {
622  std::vector<unsigned char*> buffer_rgb;
623  // get the video resolution after decoding
624  int height = 0;
625  int width = 0;
626  // Decode the video from memory or read from a local file
627  CHECK(GetClipsAndLabelsFromDBValue(
628  value, height, width, buffer_rgb, label_data, video_id_data));
629  int clip_offset_rgb = multi_crop_count_ * channels_rgb_ * length_rgb_ *
630  crop_height_ * crop_width_;
631  int clip_crop_offset_of =
632  channels_of_ * length_of_ * crop_height_ * crop_width_;
633  int clip_offset_of = multi_crop_count_ * clip_crop_offset_of;
634  for (int i = 0; i < std::min(clip_per_video_, int(buffer_rgb.size())); i++) {
635  // get the rectangle for cropping
636  int h_off = 0;
637  int w_off = 0;
638  if (random_crop_) {
639  // using random crop for training
640  h_off =
641  std::uniform_int_distribution<>(0, height - crop_height_)(*randgen);
642  w_off = std::uniform_int_distribution<>(0, width - crop_width_)(*randgen);
643  } else {
644  // using center crop for testing
645  h_off = (height - crop_height_) / 2;
646  w_off = (width - crop_width_) / 2;
647  }
648  // cv::Rect rect(w_off, h_off, crop_width_, crop_height_);
649 
650  // Multi cropping: we take left-top, central-top, right-top, left-bottom,
651  // central-bottom, right-bottom and central-central croppings as well as
652  // their mirrorings. In total, 14 croppings
653  int multi_crop_w_off[7] = {0,
654  (width - crop_width_) / 2,
655  width - crop_width_,
656  (width - crop_width_) / 2,
657  0,
658  (width - crop_width_) / 2,
659  width - crop_width_};
660  int multi_crop_h_off[7] = {0,
661  0,
662  0,
663  (height - crop_height_) / 2,
664  height - crop_height_,
665  height - crop_height_,
666  height - crop_height_};
667 
668  // randomly mirror the image or not
669  bool mirror_me = random_mirror_ && (*mirror_this_clip)(*randgen);
670  if (get_rgb_ && clip_rgb_data) {
671  ClipTransformRGB(
672  buffer_rgb[i],
673  multi_crop_count_,
674  crop_height_,
675  crop_width_,
676  length_rgb_,
677  channels_rgb_,
678  sampling_rate_rgb_,
679  height,
680  width,
681  h_off,
682  w_off,
683  multi_crop_h_off,
684  multi_crop_w_off,
685  mirror_me,
686  color_jitter_,
687  img_saturation_,
688  img_brightness_,
689  img_contrast_,
690  color_lighting_,
691  color_lighting_std_,
692  color_lighting_eigvecs_,
693  color_lighting_eigvals_,
694  mean_rgb_,
695  inv_std_rgb_,
696  randgen,
697  clip_rgb_data + (i * clip_offset_rgb));
698  }
699  if (get_optical_flow_ && clip_of_data) {
700  cv::Rect rect;
701  for (int j = 0; j < multi_crop_count_; ++j) {
702  if (multi_crop_count_ == 1) {
703  rect = cv::Rect(w_off, h_off, crop_width_, crop_height_);
704  } else {
705  mirror_me = j / (multi_crop_count_ / 2);
706  int k = j % (multi_crop_count_ / 2);
707  rect = cv::Rect(
708  multi_crop_w_off[k],
709  multi_crop_h_off[k],
710  crop_width_,
711  crop_height_);
712  }
713  ClipTransformOpticalFlow(
714  buffer_rgb[i],
715  crop_height_,
716  crop_width_,
717  length_of_,
718  channels_of_,
719  sampling_rate_of_,
720  height,
721  width,
722  rect,
723  channels_rgb_,
724  mirror_me,
725  flow_alg_type_,
726  flow_data_type_,
727  frame_gap_of_,
728  do_flow_aggregation_,
729  mean_of_,
730  inv_std_of_,
731  clip_of_data + (i * clip_offset_of) + j * clip_crop_offset_of);
732  }
733  }
734  }
735 
736  if (buffer_rgb.size() > 0) {
737  for (int i = 0; i < buffer_rgb.size(); i++) {
738  unsigned char* buff = buffer_rgb[i];
739  delete[] buff;
740  }
741  }
742  buffer_rgb.clear();
743 }
744 
745 template <class Context>
747  // We will get the reader pointer from input.
748  // If we use local clips, db will store the list
749  reader_ = &OperatorBase::Input<db::DBReader>(0);
750 
751  // Call mutable_data() once to allocate the underlying memory.
752  prefetched_clip_rgb_.mutable_data<float>();
753  prefetched_clip_of_.mutable_data<float>();
754  prefetched_label_.mutable_data<int>();
755  prefetched_video_id_.mutable_data<int>();
756 
757  // Prefetching handled with a thread pool of "decode_threads" threads.
758  std::mt19937 meta_randgen(time(nullptr));
759  std::vector<std::mt19937> randgen_per_thread;
760  for (int i = 0; i < num_decode_threads_; ++i) {
761  randgen_per_thread.emplace_back(meta_randgen());
762  }
763 
764  std::bernoulli_distribution mirror_this_clip(0.5);
765  for (int item_id = 0; item_id < batch_size_; ++item_id) {
766  std::mt19937* randgen = &randgen_per_thread[item_id % num_decode_threads_];
767 
768  int frame_size = crop_height_ * crop_width_;
769  // get the clip data pointer for the item_id -th example
770  float* clip_rgb_data = prefetched_clip_rgb_.mutable_data<float>() +
771  frame_size * length_rgb_ * channels_rgb_ * item_id * clip_per_video_ *
772  multi_crop_count_;
773 
774  // get the optical flow data for the current clip
775  float* clip_of_data = prefetched_clip_of_.mutable_data<float>() +
776  frame_size * length_of_ * channels_of_ * item_id * clip_per_video_ *
777  multi_crop_count_;
778 
779  // get the label data pointer for the item_id -th example
780  int* label_data = prefetched_label_.mutable_data<int>() +
781  (do_multi_label_ ? num_of_class_ : 1) * item_id * clip_per_video_ *
782  multi_crop_count_;
783 
784  // get the video id data pointer for the item_id -th example
785  int* video_id_data = prefetched_video_id_.mutable_data<int>() +
786  item_id * clip_per_video_ * multi_crop_count_;
787 
788  std::string key, value;
789  // read data
790  reader_->Read(&key, &value);
791 
792  thread_pool_->runTask(std::bind(
794  this,
795  std::string(value),
796  clip_rgb_data,
797  clip_of_data,
798  label_data,
799  video_id_data,
800  randgen,
801  &mirror_this_clip));
802  } // for over the batch
803  thread_pool_->waitWorkComplete();
804 
805  // If the context is not CPUContext, we will need to do a copy in the
806  // prefetch function as well.
807  if (!std::is_same<Context, CPUContext>::value) {
808  if (get_rgb_) {
809  prefetched_clip_rgb_on_device_.CopyFrom(prefetched_clip_rgb_, &context_);
810  }
811  if (get_optical_flow_) {
812  prefetched_clip_of_on_device_.CopyFrom(prefetched_clip_of_, &context_);
813  }
814  prefetched_label_on_device_.CopyFrom(prefetched_label_, &context_);
815  if (get_video_id_) {
816  prefetched_video_id_on_device_.CopyFrom(prefetched_video_id_, &context_);
817  }
818  }
819  return true;
820 }
821 
822 template <class Context>
824  int index = 0;
825  if (get_rgb_) {
826  auto* clip_rgb_output = OperatorBase::Output<Tensor<Context>>(index++);
827  if (std::is_same<Context, CPUContext>::value) {
828  clip_rgb_output->CopyFrom(prefetched_clip_rgb_, &context_);
829  } else {
830  clip_rgb_output->CopyFrom(prefetched_clip_rgb_on_device_, &context_);
831  }
832  }
833  if (get_optical_flow_) {
834  auto* clip_of_output = OperatorBase::Output<Tensor<Context>>(index++);
835  if (std::is_same<Context, CPUContext>::value) {
836  clip_of_output->CopyFrom(prefetched_clip_of_, &context_);
837  } else {
838  clip_of_output->CopyFrom(prefetched_clip_of_on_device_, &context_);
839  }
840  }
841  auto* label_output = OperatorBase::Output<Tensor<Context>>(index++);
842  if (std::is_same<Context, CPUContext>::value) {
843  label_output->CopyFrom(prefetched_label_, &context_);
844  } else {
845  label_output->CopyFrom(prefetched_label_on_device_, &context_);
846  }
847  if (get_video_id_) {
848  auto* video_id_output = OperatorBase::Output<Tensor<Context>>(index);
849  if (std::is_same<Context, CPUContext>::value) {
850  video_id_output->CopyFrom(prefetched_video_id_, &context_);
851  } else {
852  video_id_output->CopyFrom(prefetched_video_id_on_device_, &context_);
853  }
854  }
855  return true;
856 }
857 
858 } // namespace caffe2
859 
860 #endif // CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:238
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:160
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:82
T * mutable_data()
Returns a typed pointer of the underlying storage.
Definition: tensor.h:594
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:304
Copyright (c) 2016-present, Facebook, Inc.