Caffe2 - C++ API
A deep learning, cross platform ML framework
video_input_op.cc
1 #include <caffe2/video/video_input_op.h>
2 
3 namespace caffe2 {
4 
5 REGISTER_CPU_OPERATOR(VideoInput, VideoInputOp<CPUContext>);
6 
7 OPERATOR_SCHEMA(VideoInput)
8  .NumInputs(0, 1)
9  .NumOutputs(2, 4)
10  .TensorInferenceFunction(
11  [](const OperatorDef& def,
12  const vector<TensorShape>& /* unused */ /*in*/) {
13  ArgumentHelper helper(def);
14  int batch_size = helper.GetSingleArgument<int>("batch_size", 0);
15  int clip_per_video =
16  helper.GetSingleArgument<int>("clip_per_video", 1);
17  int crop_height = helper.GetSingleArgument<int>(
18  "crop_height", helper.GetSingleArgument<int>("crop_size", 0));
19  int crop_width = helper.GetSingleArgument<int>(
20  "crop_width", helper.GetSingleArgument<int>("crop_size", 0));
21  int length_rgb = helper.GetSingleArgument<int>("length_rgb", 0);
22  int channels_rgb = helper.GetSingleArgument<int>("channels_rgb", 3);
23  int length_of = helper.GetSingleArgument<int>("length_of", 0);
24  int channels_of = helper.GetSingleArgument<int>("channels_of", 2);
25 
26  // get the flags
27  bool get_rgb = helper.GetSingleArgument<bool>("get_rgb", true);
28  bool get_optical_flow =
29  helper.GetSingleArgument<bool>("get_optical_flow", false);
30  bool do_multi_label =
31  helper.GetSingleArgument<bool>("do_multi_label", false);
32  bool get_video_id =
33  helper.GetSingleArgument<bool>("get_video_id", false);
34 
35  int output_size = 1;
36  if (get_rgb) {
37  output_size++;
38  }
39  if (get_optical_flow) {
40  output_size++;
41  }
42  if (get_video_id) {
43  output_size++;
44  }
45 
46  int index = 0;
47  vector<TensorShape> out(output_size);
48  CHECK_GT(crop_height, 0);
49  CHECK_GT(crop_width, 0);
50  batch_size *= clip_per_video;
51  if (get_rgb) {
52  out[index++] = CreateTensorShape(
53  vector<int>{batch_size,
54  channels_rgb,
55  length_rgb,
56  crop_height,
57  crop_width},
58  TensorProto::FLOAT);
59  }
60  if (get_optical_flow) {
61  out[index++] = CreateTensorShape(
62  vector<int>{batch_size,
63  channels_of,
64  length_of,
65  crop_height,
66  crop_width},
67  TensorProto::FLOAT);
68  }
69  if (!do_multi_label) {
70  out[index++] = CreateTensorShape(
71  vector<int>{1, batch_size}, TensorProto::INT32);
72  } else {
73  int num_of_class = helper.GetSingleArgument<int>("num_of_class", 0);
74  out[index++] = CreateTensorShape(
75  vector<int>{batch_size, num_of_class}, TensorProto::INT32);
76  }
77  if (get_video_id) {
78  out[index] = CreateTensorShape(
79  vector<int>{1, batch_size}, TensorProto::INT32);
80  }
81 
82  return out;
83  });
84 
85 NO_GRADIENT(VideoInput);
86 
87 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13