Caffe2 - C++ API
A deep learning, cross platform ML framework
optical_flow.cc
1 #include <caffe2/video/optical_flow.h>
2 
3 namespace caffe2 {
4 
5 void OpticalFlowExtractor(
6  const cv::Mat& prev_gray,
7  const cv::Mat& curr_gray,
8  const int flow_alg_type,
9  cv::Mat& flow) {
10 #if CV_MAJOR_VERSION >= 4
11  cv::Ptr<cv::DISOpticalFlow> tvl1 = cv::DISOpticalFlow::create();
12 #else
13  cv::Ptr<cv::DualTVL1OpticalFlow> tvl1 = cv::DualTVL1OpticalFlow::create();
14 #endif
15  switch (flow_alg_type) {
16  case FLowAlgType::FarnebackOpticalFlow:
17  cv::calcOpticalFlowFarneback(
18  prev_gray,
19  curr_gray,
20  flow,
21  std::sqrt(2) / 2.0,
22  5,
23  10,
24  2,
25  7,
26  1.5,
27  cv::OPTFLOW_FARNEBACK_GAUSSIAN);
28  break;
29  case FLowAlgType::DensePyrLKOpticalFlow:
30  LOG(ERROR) << "DensePyrLKOpticalFlow only has sparse version on CPU";
31  break;
32  case FLowAlgType::BroxOpticalFlow:
33  LOG(ERROR) << "BroxOpticalFlow on CPU is not available";
34  break;
35  case FLowAlgType::OpticalFlowDual_TVL1:
36  tvl1->calc(prev_gray, curr_gray, flow);
37  break;
38  default:
39  LOG(ERROR) << "Unsupported optical flow type " << flow_alg_type;
40  break;
41  }
42 }
43 
44 void MergeOpticalFlow(cv::Mat& prev_flow, const cv::Mat& curr_flow) {
45  const int rows = prev_flow.rows;
46  const int cols = prev_flow.cols;
47 
48  // merge two optical flows into one
49  for (int y = 0; y < rows; y++) {
50  for (int x = 0; x < cols; x++) {
51  cv::Point2f u = prev_flow.at<cv::Point2f>(y, x);
52  // get the new location
53  int x_new = std::min(cols - 1, std::max(0, cvRound(u.x + x)));
54  int y_new = std::min(rows - 1, std::max(0, cvRound(u.y + y)));
55  cv::Point2f u_new = curr_flow.at<cv::Point2f>(y_new, x_new);
56 
57  // update the flow
58  prev_flow.at<cv::Point2f>(y, x) += u_new;
59  }
60  }
61 }
62 
63 void MultiFrameOpticalFlowExtractor(
64  const std::vector<cv::Mat>& grays,
65  const int optical_flow_alg_type,
66  cv::Mat& flow) {
67  int num_frames = grays.size();
68  CAFFE_ENFORCE_GE(num_frames, 2, "need at least 2 frames!");
69 
70  // compute optical flow for every two frames
71  std::vector<cv::Mat> flows;
72  for (int i = 0; i < num_frames - 1; i++) {
73  cv::Mat tmp;
74  OpticalFlowExtractor(grays[i], grays[i + 1], optical_flow_alg_type, tmp);
75  flows.push_back(tmp);
76  }
77 
78  flows[0].copyTo(flow);
79  // aggregate optical flow across multiple frame
80  for (int i = 1; i < num_frames - 1; i++) {
81  MergeOpticalFlow(flow, flows[i]);
82  }
83 }
84 
85 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13