Caffe2 - C++ API
A deep learning, cross platform ML framework
video_input_op.cc
1 
17 #include <caffe2/video/video_input_op.h>
18 
19 namespace caffe2 {
20 
21 REGISTER_CPU_OPERATOR(VideoInput, VideoInputOp<CPUContext>);
22 
23 OPERATOR_SCHEMA(VideoInput)
24  .NumInputs(0, 1)
25  .NumOutputs(2, 4)
26  .TensorInferenceFunction(
27  [](const OperatorDef& def,
28  const vector<TensorShape>& /* unused */ /*in*/) {
29  ArgumentHelper helper(def);
30  int batch_size = helper.GetSingleArgument<int>("batch_size", 0);
31  int clip_per_video =
32  helper.GetSingleArgument<int>("clip_per_video", 1);
33  int crop_height = helper.GetSingleArgument<int>(
34  "crop_height", helper.GetSingleArgument<int>("crop_size", 0));
35  int crop_width = helper.GetSingleArgument<int>(
36  "crop_width", helper.GetSingleArgument<int>("crop_size", 0));
37  int length_rgb = helper.GetSingleArgument<int>("length_rgb", 0);
38  int channels_rgb = helper.GetSingleArgument<int>("channels_rgb", 3);
39  int length_of = helper.GetSingleArgument<int>("length_of", 0);
40  int channels_of = helper.GetSingleArgument<int>("channels_of", 2);
41 
42  // get the flags
43  bool get_rgb = helper.GetSingleArgument<bool>("get_rgb", true);
44  bool get_optical_flow =
45  helper.GetSingleArgument<bool>("get_optical_flow", false);
46  bool do_multi_label =
47  helper.GetSingleArgument<bool>("do_multi_label", false);
48  bool get_video_id =
49  helper.GetSingleArgument<bool>("get_video_id", false);
50 
51  int output_size = 1;
52  if (get_rgb) {
53  output_size++;
54  }
55  if (get_optical_flow) {
56  output_size++;
57  }
58  if (get_video_id) {
59  output_size++;
60  }
61 
62  int index = 0;
63  vector<TensorShape> out(output_size);
64  CHECK_GT(crop_height, 0);
65  CHECK_GT(crop_width, 0);
66  batch_size *= clip_per_video;
67  if (get_rgb) {
68  out[index++] = CreateTensorShape(
69  vector<int>{batch_size,
70  channels_rgb,
71  length_rgb,
72  crop_height,
73  crop_width},
74  TensorProto::FLOAT);
75  }
76  if (get_optical_flow) {
77  out[index++] = CreateTensorShape(
78  vector<int>{batch_size,
79  channels_of,
80  length_of,
81  crop_height,
82  crop_width},
83  TensorProto::FLOAT);
84  }
85  if (!do_multi_label) {
86  out[index++] = CreateTensorShape(
87  vector<int>{1, batch_size}, TensorProto::INT32);
88  } else {
89  int num_of_class = helper.GetSingleArgument<int>("num_of_class", 0);
90  out[index++] = CreateTensorShape(
91  vector<int>{batch_size, num_of_class}, TensorProto::INT32);
92  }
93  if (get_video_id) {
94  out[index] = CreateTensorShape(
95  vector<int>{1, batch_size}, TensorProto::INT32);
96  }
97 
98  return out;
99  });
100 
101 NO_GRADIENT(VideoInput);
102 
103 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.