Caffe2 - C++ API
A deep learning, cross platform ML framework
video_input_op.h
1 #ifndef CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
2 #define CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
3 
4 #include <istream>
5 #include <ostream>
6 #include <random>
7 #include <string>
8 
9 #include <c10/core/thread_pool.h>
10 #include <caffe2/core/db.h>
11 #include <caffe2/core/logging.h>
12 #include <caffe2/operators/prefetch_op.h>
13 #include <caffe2/utils/math.h>
14 #include <caffe2/video/video_io.h>
15 
16 namespace caffe2 {
17 
18 template <class Context>
19 class VideoInputOp final : public PrefetchOperator<Context> {
20  public:
21  using OperatorBase::OutputSize;
24  explicit VideoInputOp(const OperatorDef& operator_def, Workspace* ws);
25  ~VideoInputOp() {
27  }
28 
29  // override methods
30  bool Prefetch() override;
31  bool CopyPrefetched() override;
32 
33  private:
34  void CheckParamsAndPrint();
35 
36  bool GetClipsAndLabelsFromDBValue(
37  const std::string& value,
38  int& height,
39  int& width,
40  std::vector<unsigned char*>& buffer_rgb,
41  int* label_data,
42  int* video_id_data);
43 
44  void DecodeAndTransform(
45  const std::string& value,
46  float* clip_rgb_data,
47  float* clip_of_data,
48  int* label_data,
49  int* video_id_data,
50  std::mt19937* randgen,
51  std::bernoulli_distribution* mirror_this_clip);
52 
53  const db::DBReader* reader_;
54  Tensor prefetched_clip_rgb_;
55  Tensor prefetched_clip_of_;
56  Tensor prefetched_label_;
57  Tensor prefetched_video_id_;
58  Tensor prefetched_clip_rgb_on_device_{Context::GetDeviceType()};
59  Tensor prefetched_clip_of_on_device_{Context::GetDeviceType()};
60  Tensor prefetched_label_on_device_{Context::GetDeviceType()};
61  Tensor prefetched_video_id_on_device_{Context::GetDeviceType()};
62  int batch_size_;
63  int clip_per_video_;
64  std::vector<float> mean_rgb_;
65  std::vector<float> inv_std_rgb_;
66  std::vector<float> mean_of_;
67  std::vector<float> inv_std_of_;
68  int channels_rgb_;
69  int channels_of_;
70  int crop_height_;
71  int crop_width_;
72  int scale_h_;
73  int scale_w_;
74  int height_min_;
75  int width_min_;
76  int length_rgb_;
77  int sampling_rate_rgb_;
78  bool color_jitter_;
79  float img_saturation_;
80  float img_brightness_;
81  float img_contrast_;
82  bool color_lighting_;
83  float color_lighting_std_;
84  std::vector<std::vector<float>> color_lighting_eigvecs_;
85  std::vector<float> color_lighting_eigvals_;
86  int num_of_required_frame_;
87  int length_of_;
88  int sampling_rate_of_;
89  int frame_gap_of_;
90  bool random_mirror_;
91  int num_of_class_;
92  bool use_local_file_;
93  bool random_crop_;
94  bool multi_crop_;
95  int multi_crop_count_;
96  int flow_data_type_;
97  int flow_alg_type_;
98  int decode_type_;
99  int video_res_type_;
100  bool do_flow_aggregation_;
101  bool get_rgb_;
102  bool get_optical_flow_;
103  bool get_video_id_;
104  bool do_multi_label_;
105 
106  // thread pool for parse + decode
107  int num_decode_threads_;
108  std::shared_ptr<TaskThreadPool> thread_pool_;
109 };
110 
111 template <class Context>
113  // check whether the input parameters are valid or not
114  CAFFE_ENFORCE_GT(batch_size_, 0, "Batch size should be positive.");
115  CAFFE_ENFORCE_GT(
116  clip_per_video_, 0, "Number of clips per video should be positive.");
117  CAFFE_ENFORCE_GT(crop_height_, 0, "Must provide the cropping height value.");
118  CAFFE_ENFORCE_GT(crop_width_, 0, "Must provide the cropping width value.");
119 
120  CAFFE_ENFORCE_GT(
121  num_of_required_frame_, 0, "Required number of frames must be positive.");
122 
123  if (video_res_type_ == VideoResType::USE_MINIMAL_WIDTH_HEIGHT) {
124  CAFFE_ENFORCE_GT(height_min_, 0, "Must provide the minimal height value.");
125  CAFFE_ENFORCE_GT(width_min_, 0, "Must provide the minimal width value.");
126  CAFFE_ENFORCE_GE(
127  height_min_,
128  crop_height_,
129  "The minimal height must be no smaller than the cropping height.");
130  CAFFE_ENFORCE_GE(
131  width_min_,
132  crop_width_,
133  "The minimal width must be no smaller than the cropping width.");
134  } else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
135  CAFFE_ENFORCE_GT(scale_h_, 0, "Must provide the scale height value.");
136  CAFFE_ENFORCE_GT(scale_w_, 0, "Must provide the scale width value.");
137  CAFFE_ENFORCE_GE(
138  scale_h_,
139  crop_height_,
140  "The scaled height must be no smaller than the cropping height.");
141  CAFFE_ENFORCE_GE(
142  scale_w_,
143  crop_width_,
144  "The scaled width must be no smaller than the cropping width.");
145  }
146 
147  if (get_rgb_) {
148  CAFFE_ENFORCE_GT(length_rgb_, 0, "Must provide rgb clip length.");
149  CAFFE_ENFORCE_GT(
150  sampling_rate_rgb_, 0, "4 frames for mc2; 2 frames for res3d.");
151  CAFFE_ENFORCE_EQ(
152  channels_rgb_, mean_rgb_.size(), "Number rgb channels is wrong!");
153  CAFFE_ENFORCE_EQ(
154  channels_rgb_, inv_std_rgb_.size(), "Number rgb channels is wrong!");
155  }
156 
157  if (get_optical_flow_) {
158  CAFFE_ENFORCE_GT(length_of_, 0, "Must provide optical flow clip length.");
159  CAFFE_ENFORCE_GT(
160  sampling_rate_of_, 0, "4 frames for mc2; 2 frames for res3d.");
161  CAFFE_ENFORCE_EQ(
162  channels_of_,
163  mean_of_.size(),
164  "Number of optical flow channels is wrong!");
165  CAFFE_ENFORCE_EQ(
166  channels_of_,
167  inv_std_of_.size(),
168  "Number of optical flow channels is wrong!");
169  }
170 
171  if (clip_per_video_ > 1) {
172  CAFFE_ENFORCE_EQ(
173  decode_type_,
174  DecodeType::DO_UNIFORM_SMP,
175  "Only uniformly sampling is supported when sampling multiple clips!");
176  }
177 
178  if (do_multi_label_) {
179  CAFFE_ENFORCE_GT(
180  num_of_class_,
181  0,
182  "Number of classes must be set when using multiple labels.");
183  }
184 
185  // print out the parameter settings
186  LOG(INFO) << "Creating a clip input op with the following setting: ";
187  LOG(INFO) << " Using " << num_decode_threads_ << " CPU threads;";
188  LOG(INFO) << " Outputting in batches of " << batch_size_ << " videos;";
189  LOG(INFO) << " Each video has " << clip_per_video_ << " clips;";
190  LOG(INFO) << " Scaling image to " << scale_h_ << "x" << scale_w_;
191  LOG(INFO) << " (Height, Width) is at least (" << height_min_ << ", "
192  << width_min_ << ")";
193  LOG(INFO) << " Cropping video frame to " << crop_height_ << "x"
194  << crop_width_ << (random_mirror_ ? " with " : " without ")
195  << "random mirroring;";
196  LOG(INFO) << " Using " << (random_crop_ ? "random" : "center") << " crop";
197  LOG(INFO) << " Is multi-cropping enabled: " << multi_crop_;
198 
199  if (get_rgb_) {
200  LOG(INFO) << " Using a clip of " << length_rgb_ << " rgb frames "
201  << "with " << channels_rgb_ << " channels "
202  << "and a sampling rate of 1:" << sampling_rate_rgb_;
203  LOG(INFO) << " RGB data augmentation. Color jittering: " << color_jitter_
204  << ". Color lighting: " << color_lighting_;
205  for (int i = 0; i < channels_rgb_; i++) {
206  LOG(INFO) << " RGB " << i << "-th channel mean: " << mean_rgb_[i]
207  << " std: " << 1.f / inv_std_rgb_[i];
208  }
209  }
210 
211  if (get_optical_flow_) {
212  LOG(INFO) << " Using a clip of " << length_of_ << " optical flow frames "
213  << "with " << channels_of_ << " channels "
214  << "and a sampling rate of 1:" << sampling_rate_of_
215  << " flow_data_type_: " << flow_data_type_
216  << " flow_alg_type_: " << flow_alg_type_;
217  for (int i = 0; i < channels_of_; i++) {
218  LOG(INFO) << " Optical flow" << i
219  << "-th channel mean: " << mean_of_[i]
220  << " std: " << 1.f / inv_std_of_[i];
221  }
222  }
223 
224  if (video_res_type_ == VideoResType::ORIGINAL_RES) {
225  LOG(INFO) << " Use original resolution";
226  } else if (video_res_type_ == VideoResType::USE_MINIMAL_WIDTH_HEIGHT) {
227  LOG(INFO) << " Resize with minimal size and keep aspect ratio";
228  } else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
229  LOG(INFO) << " Resize and ignore aspect ratio";
230  } else {
231  LOG(ERROR) << " Unknown video resolution type";
232  }
233 
234  if (decode_type_ == DecodeType::DO_TMP_JITTER) {
235  LOG(INFO) << " Do temporal jittering";
236  } else if (decode_type_ == DecodeType::USE_START_FRM) {
237  LOG(INFO) << " Use start_frm for decoding";
238  } else if (decode_type_ == DecodeType::DO_UNIFORM_SMP) {
239  LOG(INFO) << " Do uniformly sampling";
240  } else {
241  LOG(ERROR) << " Unknown video decoding type";
242  }
243 }
244 
245 template <class Context>
247  const OperatorDef& operator_def,
248  Workspace* ws)
249  : PrefetchOperator<Context>(operator_def, ws),
250  reader_(nullptr),
251  batch_size_(
252  OperatorBase::template GetSingleArgument<int>("batch_size", 0)),
253  clip_per_video_(
254  OperatorBase::template GetSingleArgument<int>("clip_per_video", 1)),
255  mean_rgb_(OperatorBase::template GetRepeatedArgument<float>(
256  "mean_rgb_per_channel",
257  {OperatorBase::template GetSingleArgument<float>("mean_rgb", 128.)})),
258  inv_std_rgb_(OperatorBase::template GetRepeatedArgument<float>(
259  "std_rgb_per_channel",
260  {OperatorBase::template GetSingleArgument<float>("std_rgb", 1.)})),
261  mean_of_(OperatorBase::template GetRepeatedArgument<float>(
262  "mean_of_per_channel",
263  {OperatorBase::template GetSingleArgument<float>("mean_of", 0.)})),
264  inv_std_of_(OperatorBase::template GetRepeatedArgument<float>(
265  "std_of_per_channel",
266  {OperatorBase::template GetSingleArgument<float>("std_of", 1.)})),
267  channels_rgb_(
268  OperatorBase::template GetSingleArgument<int>("channels_rgb", 3)),
269  channels_of_(
270  OperatorBase::template GetSingleArgument<int>("channels_of", 2)),
271  crop_height_(OperatorBase::template GetSingleArgument<int>(
272  "crop_height",
273  {OperatorBase::template GetSingleArgument<int>("crop_size", 0.)})),
274  crop_width_(OperatorBase::template GetSingleArgument<int>(
275  "crop_width",
276  {OperatorBase::template GetSingleArgument<int>("crop_size", 0.)})),
277  scale_h_(OperatorBase::template GetSingleArgument<int>("scale_h", 0)),
278  scale_w_(OperatorBase::template GetSingleArgument<int>("scale_w", 0)),
279  height_min_(OperatorBase::template GetSingleArgument<int>(
280  "height_min",
281  {OperatorBase::template GetSingleArgument<int>("short_edge", 0)})),
282  width_min_(OperatorBase::template GetSingleArgument<int>(
283  "width_min",
284  {OperatorBase::template GetSingleArgument<int>("short_edge", 0)})),
285  length_rgb_(
286  OperatorBase::template GetSingleArgument<int>("length_rgb", 0)),
287  sampling_rate_rgb_(OperatorBase::template GetSingleArgument<int>(
288  "sampling_rate_rgb",
289  1)),
290  color_jitter_(OperatorBase::template GetSingleArgument<bool>(
291  "color_jitter",
292  false)),
293  img_saturation_(OperatorBase::template GetSingleArgument<float>(
294  "img_saturation",
295  0.4)),
296  img_brightness_(OperatorBase::template GetSingleArgument<float>(
297  "img_brightness",
298  0.4)),
299  img_contrast_(
300  OperatorBase::template GetSingleArgument<float>("img_contrast", 0.4)),
301  color_lighting_(OperatorBase::template GetSingleArgument<bool>(
302  "color_lighting",
303  false)),
304  color_lighting_std_(OperatorBase::template GetSingleArgument<float>(
305  "color_lighting_std",
306  0.1)),
307  length_of_(OperatorBase::template GetSingleArgument<int>("length_of", 0)),
308  sampling_rate_of_(
309  OperatorBase::template GetSingleArgument<int>("sampling_rate_of", 1)),
310  frame_gap_of_(
311  OperatorBase::template GetSingleArgument<int>("frame_gap_of", 1)),
312  random_mirror_(OperatorBase::template GetSingleArgument<bool>(
313  "random_mirror",
314  true)),
315  num_of_class_(
316  OperatorBase::template GetSingleArgument<int>("num_of_class", 0)),
317  use_local_file_(OperatorBase::template GetSingleArgument<bool>(
318  "use_local_file",
319  false)),
320  random_crop_(
321  OperatorBase::template GetSingleArgument<bool>("random_crop", true)),
322  multi_crop_(
323  OperatorBase::template GetSingleArgument<bool>("multi_crop", false)),
324  flow_data_type_(
325  OperatorBase::template GetSingleArgument<int>("flow_data_type", 0)),
326  flow_alg_type_(
327  OperatorBase::template GetSingleArgument<int>("flow_alg_type", 0)),
328  decode_type_(
329  OperatorBase::template GetSingleArgument<int>("decode_type", 0)),
330  video_res_type_(
331  OperatorBase::template GetSingleArgument<int>("video_res_type", 0)),
332  do_flow_aggregation_(OperatorBase::template GetSingleArgument<bool>(
333  "do_flow_aggregation",
334  true)),
335  get_rgb_(OperatorBase::template GetSingleArgument<bool>("get_rgb", true)),
336  get_optical_flow_(OperatorBase::template GetSingleArgument<bool>(
337  "get_optical_flow",
338  false)),
339  get_video_id_(OperatorBase::template GetSingleArgument<bool>(
340  "get_video_id",
341  false)),
342  do_multi_label_(OperatorBase::template GetSingleArgument<bool>(
343  "do_multi_label",
344  false)),
345  num_decode_threads_(OperatorBase::template GetSingleArgument<int>(
346  "num_decode_threads",
347  4)),
348  thread_pool_(std::make_shared<TaskThreadPool>(num_decode_threads_)) {
349  // hard-coded PCA eigenvectors and eigenvalues, based on RBG channel order
350  color_lighting_eigvecs_.push_back(
351  std::vector<float>{-144.7125, 183.396, 102.2295});
352  color_lighting_eigvecs_.push_back(
353  std::vector<float>{-148.104, -1.1475, -207.57});
354  color_lighting_eigvecs_.push_back(
355  std::vector<float>{-148.818, -177.174, 107.1765});
356 
357  color_lighting_eigvals_ = std::vector<float>{0.2175, 0.0188, 0.0045};
358 
359  // multi-cropping for testing
360  multi_crop_count_ = 1;
361  if (multi_crop_) {
362  // we take left-top, central-top, right-top, left-bottom, central-bottom,
363  // right-bottom and central-central croppings as well as their mirrorings
364  // In total, 14 croppings
365  multi_crop_count_ = 14;
366  }
367 
368  num_of_required_frame_ = 0;
369 
370  // mean and std for normalizing different optical flow data type;
371  // Example statistics generated from SOA are shown below, and you may
372  // want to change them if you are running on a different dataset;
373 
374  // 7 channels: (flow_x, flow_y, flow_magitude, gray, Red, Green, Blue)
375  const std::vector<float> InputDataMean = {
376  0.0046635, 0.0046261, 0.963986, 102.976, 110.201, 100.64, 95.9966};
377  const std::vector<float> InputDataStd = {
378  0.972347, 0.755146, 1.43588, 55.3691, 58.1489, 56.4701, 55.3324};
379 
380  // if we need RGB as an input
381  if (get_rgb_) {
382  // how many frames we need for RGB
383  num_of_required_frame_ = std::max(
384  num_of_required_frame_, (length_rgb_ - 1) * sampling_rate_rgb_ + 1);
385 
386  channels_rgb_ = 3;
387 
388  if (mean_rgb_.size() != channels_rgb_ ||
389  inv_std_rgb_.size() != channels_rgb_) {
390  mean_rgb_.clear();
391  inv_std_rgb_.clear();
392  for (int i = 4; i < 7; i++) {
393  mean_rgb_.push_back(InputDataMean[i]);
394  inv_std_rgb_.push_back(1.f / InputDataStd[i]);
395  }
396  }
397  }
398 
399  // if we need optical flow as an input
400  if (get_optical_flow_) {
401  // how many frames we need for optical flow
402  num_of_required_frame_ = std::max(
403  num_of_required_frame_,
404  (length_of_ - 1) * sampling_rate_of_ + frame_gap_of_ + 1);
405 
406  if (mean_of_.size() != channels_of_ || inv_std_of_.size() != channels_of_) {
407  mean_of_.clear();
408  inv_std_of_.clear();
409 
410  // set the parameters for different input data types
411  switch (flow_data_type_) {
412  case FlowDataType::Flow2C:
413  channels_of_ = 2;
414  for (int i = 0; i < channels_of_; i++) {
415  mean_of_.push_back(InputDataMean[i]);
416  inv_std_of_.push_back(1.f / InputDataStd[i]);
417  }
418  break;
419 
420  case FlowDataType::Flow3C:
421  channels_of_ = 3;
422  for (int i = 0; i < channels_of_; i++) {
423  mean_of_.push_back(InputDataMean[i]);
424  inv_std_of_.push_back(1.f / InputDataStd[i]);
425  }
426  break;
427 
428  // early fusion with gray
429  case FlowDataType::FlowWithGray:
430  channels_of_ = 3;
431  for (int i = 0; i < 2; i++) {
432  mean_of_.push_back(InputDataMean[i]);
433  inv_std_of_.push_back(1.f / InputDataStd[i]);
434  }
435  mean_of_.push_back(InputDataMean[3]);
436  inv_std_of_.push_back(1.f / InputDataStd[3]);
437  break;
438 
439  // early fusion with RGB
440  case FlowDataType::FlowWithRGB:
441  channels_of_ = 5;
442  for (int i = 0; i < 2; i++) {
443  mean_of_.push_back(InputDataMean[i]);
444  inv_std_of_.push_back(1.f / InputDataStd[i]);
445  }
446  for (int i = 4; i < 7; i++) {
447  mean_of_.push_back(InputDataMean[i]);
448  inv_std_of_.push_back(1.f / InputDataStd[i]);
449  }
450  break;
451 
452  default:
453  LOG(ERROR) << "Unknown optical flow type " << flow_data_type_;
454  break;
455  }
456  }
457  }
458 
459  CheckParamsAndPrint();
460  // Always need a dbreader, even when using local video files
461  CAFFE_ENFORCE_GT(
462  operator_def.input_size(), 0, "Need to have a DBReader blob input");
463 
464  vector<int64_t> data_shape(5);
465  vector<int64_t> label_shape(2);
466 
467  // for RGB data
468  data_shape[0] = batch_size_ * clip_per_video_ * multi_crop_count_;
469  data_shape[1] = channels_rgb_;
470  data_shape[2] = length_rgb_;
471  data_shape[3] = crop_height_;
472  data_shape[4] = crop_width_;
473  ReinitializeTensor(&prefetched_clip_rgb_, data_shape, at::dtype<float>().device(CPU));
474 
475  // for optical flow data
476  data_shape[1] = channels_of_;
477  data_shape[2] = length_of_;
478  ReinitializeTensor(&prefetched_clip_of_, data_shape, at::dtype<float>().device(CPU));
479 
480  // If do_multi_label is used, output label is a binary vector
481  // of length num_of_class indicating which labels present
482  if (do_multi_label_) {
483  label_shape[0] = batch_size_ * clip_per_video_ * multi_crop_count_;
484  label_shape[1] = num_of_class_;
485  ReinitializeTensor(&prefetched_label_, label_shape, at::dtype<int>().device(CPU));
486  } else {
487  prefetched_label_.Resize(
488  vector<int64_t>(1, batch_size_ * clip_per_video_ * multi_crop_count_));
489  }
490 
491  ReinitializeTensor(&prefetched_video_id_, vector<int64_t>(1, batch_size_ * clip_per_video_ * multi_crop_count_), at::dtype<int>().device(CPU));
492 }
493 
494 template <class Context>
496  const std::string& value,
497  int& height,
498  int& width,
499  std::vector<unsigned char*>& buffer_rgb,
500  int* label_data,
501  int* video_id_data) {
502  TensorProtos protos;
503  int curr_proto_idx = 0;
504  CAFFE_ENFORCE(protos.ParseFromString(value));
505  const TensorProto& video_proto = protos.protos(curr_proto_idx++);
506  const TensorProto& label_proto = protos.protos(curr_proto_idx++);
507 
508  int start_frm = 0;
509  // start_frm is only valid when sampling 1 clip per video without
510  // temporal jitterring
511  if (decode_type_ == DecodeType::USE_START_FRM) {
512  CAFFE_ENFORCE_GE(
513  protos.protos_size(),
514  curr_proto_idx + 1,
515  "Start frm proto not provided");
516  const TensorProto& start_frm_proto = protos.protos(curr_proto_idx++);
517  start_frm = start_frm_proto.int32_data(0);
518  }
519 
520  if (get_video_id_) {
521  CAFFE_ENFORCE_GE(
522  protos.protos_size(), curr_proto_idx + 1, "Video Id not provided");
523  const TensorProto& video_id_proto = protos.protos(curr_proto_idx);
524  for (int i = 0; i < clip_per_video_ * multi_crop_count_; i++) {
525  video_id_data[i] = video_id_proto.int64_data(0);
526  }
527  }
528  // assign labels
529  if (!do_multi_label_) {
530  for (int i = 0; i < clip_per_video_ * multi_crop_count_; i++) {
531  label_data[i] = label_proto.int32_data(0);
532  }
533  } else {
534  // For multiple label case, output label is a binary vector
535  // where presented concepts are makred 1
536  memset(
537  label_data,
538  0,
539  sizeof(int) * num_of_class_ * multi_crop_count_ * clip_per_video_);
540  for (int i = 0; i < clip_per_video_; i++) {
541  for (int j = 0; j < multi_crop_count_; ++j) {
542  for (int k = 0; k < label_proto.int32_data_size(); k++) {
543  CAFFE_ENFORCE_LT(
544  label_proto.int32_data(k),
545  num_of_class_,
546  "Label should be less than the number of classes.");
547  label_data
548  [(i * multi_crop_count_ + j) * num_of_class_ +
549  label_proto.int32_data(k)] = 1;
550  }
551  }
552  }
553  }
554 
555  if (use_local_file_) {
556  CAFFE_ENFORCE_EQ(
557  video_proto.data_type(),
558  TensorProto::STRING,
559  "Database with a file_list is expected to be string data");
560  }
561 
562  // initializing the decoding params
563  Params params;
564  params.maximumOutputFrames_ = MAX_DECODING_FRAMES;
565  params.video_res_type_ = video_res_type_;
566  params.crop_height_ = crop_height_;
567  params.crop_width_ = crop_width_;
568  params.height_min_ = height_min_;
569  params.width_min_ = width_min_;
570  params.scale_w_ = scale_w_;
571  params.scale_h_ = scale_h_;
572  params.decode_type_ = decode_type_;
573  params.num_of_required_frame_ = num_of_required_frame_;
574 
575  char* video_buffer = nullptr; // for decoding from buffer
576  std::string video_filename; // for decoding from file
577  int encoded_size = 0;
578  if (video_proto.data_type() == TensorProto::STRING) {
579  const string& encoded_video_str = video_proto.string_data(0);
580  if (!use_local_file_) {
581  encoded_size = encoded_video_str.size();
582  video_buffer = const_cast<char*>(encoded_video_str.data());
583  } else {
584  video_filename = encoded_video_str;
585  }
586  } else if (video_proto.data_type() == TensorProto::BYTE) {
587  if (!use_local_file_) {
588  encoded_size = video_proto.byte_data().size();
589  video_buffer = const_cast<char*>(video_proto.byte_data().data());
590  } else {
591  // TODO: does this works?
592  video_filename = video_proto.string_data(0);
593  }
594  } else {
595  LOG(FATAL) << "Unknown video data type.";
596  }
597 
598  DecodeMultipleClipsFromVideo(
599  video_buffer,
600  video_filename,
601  encoded_size,
602  params,
603  start_frm,
604  clip_per_video_,
605  use_local_file_,
606  height,
607  width,
608  buffer_rgb);
609 
610  return true;
611 }
612 
613 template <class Context>
615  const std::string& value,
616  float* clip_rgb_data,
617  float* clip_of_data,
618  int* label_data,
619  int* video_id_data,
620  std::mt19937* randgen,
621  std::bernoulli_distribution* mirror_this_clip) {
622  std::vector<unsigned char*> buffer_rgb;
623  // get the video resolution after decoding
624  int height = 0;
625  int width = 0;
626  // Decode the video from memory or read from a local file
627  CHECK(GetClipsAndLabelsFromDBValue(
628  value, height, width, buffer_rgb, label_data, video_id_data));
629 
630  int clip_offset_rgb = multi_crop_count_ * channels_rgb_ * length_rgb_ *
631  crop_height_ * crop_width_;
632  int clip_crop_offset_of =
633  channels_of_ * length_of_ * crop_height_ * crop_width_;
634  int clip_offset_of = multi_crop_count_ * clip_crop_offset_of;
635  for (int i = 0; i < std::min(clip_per_video_, int(buffer_rgb.size())); i++) {
636  // get the rectangle for cropping
637  int h_off = 0;
638  int w_off = 0;
639  if (random_crop_) {
640  // using random crop for training
641  h_off =
642  std::uniform_int_distribution<>(0, height - crop_height_)(*randgen);
643  w_off = std::uniform_int_distribution<>(0, width - crop_width_)(*randgen);
644  } else {
645  // using center crop for testing
646  h_off = (height - crop_height_) / 2;
647  w_off = (width - crop_width_) / 2;
648  }
649  // cv::Rect rect(w_off, h_off, crop_width_, crop_height_);
650 
651  // Multi cropping: we take left-top, central-top, right-top, left-bottom,
652  // central-bottom, right-bottom and central-central croppings as well as
653  // their mirrorings. In total, 14 croppings
654  int multi_crop_w_off[7] = {0,
655  (width - crop_width_) / 2,
656  width - crop_width_,
657  (width - crop_width_) / 2,
658  0,
659  (width - crop_width_) / 2,
660  width - crop_width_};
661  int multi_crop_h_off[7] = {0,
662  0,
663  0,
664  (height - crop_height_) / 2,
665  height - crop_height_,
666  height - crop_height_,
667  height - crop_height_};
668 
669  // randomly mirror the image or not
670  bool mirror_me = random_mirror_ && (*mirror_this_clip)(*randgen);
671  if (get_rgb_ && clip_rgb_data) {
672  ClipTransformRGB(
673  buffer_rgb[i],
674  multi_crop_count_,
675  crop_height_,
676  crop_width_,
677  length_rgb_,
678  channels_rgb_,
679  sampling_rate_rgb_,
680  height,
681  width,
682  h_off,
683  w_off,
684  multi_crop_h_off,
685  multi_crop_w_off,
686  mirror_me,
687  color_jitter_,
688  img_saturation_,
689  img_brightness_,
690  img_contrast_,
691  color_lighting_,
692  color_lighting_std_,
693  color_lighting_eigvecs_,
694  color_lighting_eigvals_,
695  mean_rgb_,
696  inv_std_rgb_,
697  randgen,
698  clip_rgb_data + (i * clip_offset_rgb));
699  }
700  if (get_optical_flow_ && clip_of_data) {
701  cv::Rect rect;
702  for (int j = 0; j < multi_crop_count_; ++j) {
703  if (multi_crop_count_ == 1) {
704  rect = cv::Rect(w_off, h_off, crop_width_, crop_height_);
705  } else {
706  mirror_me = j / (multi_crop_count_ / 2);
707  int k = j % (multi_crop_count_ / 2);
708  rect = cv::Rect(
709  multi_crop_w_off[k],
710  multi_crop_h_off[k],
711  crop_width_,
712  crop_height_);
713  }
714  ClipTransformOpticalFlow(
715  buffer_rgb[i],
716  crop_height_,
717  crop_width_,
718  length_of_,
719  channels_of_,
720  sampling_rate_of_,
721  height,
722  width,
723  rect,
724  channels_rgb_,
725  mirror_me,
726  flow_alg_type_,
727  flow_data_type_,
728  frame_gap_of_,
729  do_flow_aggregation_,
730  mean_of_,
731  inv_std_of_,
732  clip_of_data + (i * clip_offset_of) + j * clip_crop_offset_of);
733  }
734  }
735  }
736 
737  if (buffer_rgb.size() > 0) {
738  for (int i = 0; i < buffer_rgb.size(); i++) {
739  unsigned char* buff = buffer_rgb[i];
740  delete[] buff;
741  }
742  }
743  buffer_rgb.clear();
744 }
745 
746 template <class Context>
748  // We will get the reader pointer from input.
749  // If we use local clips, db will store the list
750  reader_ = &OperatorBase::Input<db::DBReader>(0);
751 
752  // Call mutable_data() once to allocate the underlying memory.
753  prefetched_clip_rgb_.mutable_data<float>();
754  prefetched_clip_of_.mutable_data<float>();
755  prefetched_label_.mutable_data<int>();
756  prefetched_video_id_.mutable_data<int>();
757 
758  // Prefetching handled with a thread pool of "decode_threads" threads.
759  std::mt19937 meta_randgen(time(nullptr));
760  std::vector<std::mt19937> randgen_per_thread;
761  for (int i = 0; i < num_decode_threads_; ++i) {
762  randgen_per_thread.emplace_back(meta_randgen());
763  }
764 
765  std::bernoulli_distribution mirror_this_clip(0.5);
766  for (int item_id = 0; item_id < batch_size_; ++item_id) {
767  std::mt19937* randgen = &randgen_per_thread[item_id % num_decode_threads_];
768 
769  int frame_size = crop_height_ * crop_width_;
770  // get the clip data pointer for the item_id -th example
771  float* clip_rgb_data = prefetched_clip_rgb_.mutable_data<float>() +
772  frame_size * length_rgb_ * channels_rgb_ * item_id * clip_per_video_ *
773  multi_crop_count_;
774 
775  // get the optical flow data for the current clip
776  float* clip_of_data = prefetched_clip_of_.mutable_data<float>() +
777  frame_size * length_of_ * channels_of_ * item_id * clip_per_video_ *
778  multi_crop_count_;
779 
780  // get the label data pointer for the item_id -th example
781  int* label_data = prefetched_label_.mutable_data<int>() +
782  (do_multi_label_ ? num_of_class_ : 1) * item_id * clip_per_video_ *
783  multi_crop_count_;
784 
785  // get the video id data pointer for the item_id -th example
786  int* video_id_data = prefetched_video_id_.mutable_data<int>() +
787  item_id * clip_per_video_ * multi_crop_count_;
788 
789  std::string key, value;
790  // read data
791  reader_->Read(&key, &value);
792 
793  thread_pool_->run(std::bind(
795  this,
796  std::string(value),
797  clip_rgb_data,
798  clip_of_data,
799  label_data,
800  video_id_data,
801  randgen,
802  &mirror_this_clip));
803  } // for over the batch
804  thread_pool_->waitWorkComplete();
805 
806  // If the context is not CPUContext, we will need to do a copy in the
807  // prefetch function as well.
808  if (!std::is_same<Context, CPUContext>::value) {
809  if (get_rgb_) {
810  prefetched_clip_rgb_on_device_.CopyFrom(
811  prefetched_clip_rgb_, true /*async*/);
812  }
813  if (get_optical_flow_) {
814  prefetched_clip_of_on_device_.CopyFrom(
815  prefetched_clip_of_, true /*async*/);
816  }
817  prefetched_label_on_device_.CopyFrom(prefetched_label_, true /*async*/);
818  if (get_video_id_) {
819  prefetched_video_id_on_device_.CopyFrom(
820  prefetched_video_id_, true /*async*/);
821  }
822  }
823  return true;
824 }
825 
826 template <class Context>
828  int index = 0;
829  if (get_rgb_) {
830  auto* clip_rgb_output =
831  OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
832  if (std::is_same<Context, CPUContext>::value) {
833  clip_rgb_output->CopyFrom(prefetched_clip_rgb_, true /*async*/);
834  } else {
835  clip_rgb_output->CopyFrom(prefetched_clip_rgb_on_device_, true /*async*/);
836  }
837  }
838  if (get_optical_flow_) {
839  auto* clip_of_output =
840  OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
841  if (std::is_same<Context, CPUContext>::value) {
842  clip_of_output->CopyFrom(prefetched_clip_of_, true /*async*/);
843  } else {
844  clip_of_output->CopyFrom(prefetched_clip_of_on_device_, true /*async*/);
845  }
846  }
847  auto* label_output =
848  OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
849  if (std::is_same<Context, CPUContext>::value) {
850  label_output->CopyFrom(prefetched_label_, true /*async*/);
851  } else {
852  label_output->CopyFrom(prefetched_label_on_device_, true /*async*/);
853  }
854  if (get_video_id_) {
855  auto* video_id_output =
856  OperatorBase::Output<Tensor>(index, Context::GetDeviceType());
857  if (std::is_same<Context, CPUContext>::value) {
858  video_id_output->CopyFrom(prefetched_video_id_, true /*async*/);
859  } else {
860  video_id_output->CopyFrom(prefetched_video_id_on_device_, true /*async*/);
861  }
862  }
863  return true;
864 }
865 
866 } // namespace caffe2
867 
868 #endif // CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:228
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:144
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13