Caffe2 - C++ API
A deep learning, cross platform ML framework
box_with_nms_limit_op.cc
1 #include "box_with_nms_limit_op.h"
2 #include "caffe2/utils/eigen_utils.h"
3 #include "generate_proposals_op_util_nms.h"
4 
5 #ifdef CAFFE2_USE_MKL
6 #include "caffe2/mkl/operators/operator_fallback_mkl.h"
7 #endif // CAFFE2_USE_MKL
8 
9 namespace caffe2 {
10 
11 namespace {
12 
13 template <class Derived, class Func>
14 vector<int> filter_with_indices(
15  const Eigen::ArrayBase<Derived>& array,
16  const vector<int>& indices,
17  const Func& func) {
18  vector<int> ret;
19  for (auto& cur : indices) {
20  if (func(array[cur])) {
21  ret.push_back(cur);
22  }
23  }
24  return ret;
25 }
26 
27 } // namespace
28 
29 template <>
30 bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
31  const auto& tscores = Input(0);
32  const auto& tboxes = Input(1);
33  auto* out_scores = Output(0);
34  auto* out_boxes = Output(1);
35  auto* out_classes = Output(2);
36 
37  // tscores: (num_boxes, num_classes), 0 for background
38  if (tscores.ndim() == 4) {
39  CAFFE_ENFORCE_EQ(tscores.dim(2), 1, tscores.dim(2));
40  CAFFE_ENFORCE_EQ(tscores.dim(3), 1, tscores.dim(3));
41  } else {
42  CAFFE_ENFORCE_EQ(tscores.ndim(), 2, tscores.ndim());
43  }
44  CAFFE_ENFORCE(tscores.template IsType<float>(), tscores.meta().name());
45  // tboxes: (num_boxes, num_classes * 4)
46  if (tboxes.ndim() == 4) {
47  CAFFE_ENFORCE_EQ(tboxes.dim(2), 1, tboxes.dim(2));
48  CAFFE_ENFORCE_EQ(tboxes.dim(3), 1, tboxes.dim(3));
49  } else {
50  CAFFE_ENFORCE_EQ(tboxes.ndim(), 2, tboxes.ndim());
51  }
52  CAFFE_ENFORCE(tboxes.template IsType<float>(), tboxes.meta().name());
53 
54  int num_classes = tscores.dim(1);
55 
56  CAFFE_ENFORCE_EQ(tscores.dim(0), tboxes.dim(0));
57  CAFFE_ENFORCE_EQ(num_classes * 4, tboxes.dim(1));
58 
59  Eigen::Map<const ERArrXXf> scores(
60  tscores.data<float>(), tscores.dim(0), tscores.dim(1));
61  Eigen::Map<const ERArrXXf> boxes(
62  tboxes.data<float>(), tboxes.dim(0), tboxes.dim(1));
63 
64  // To store updated scores if SoftNMS is used
65  ERArrXXf soft_nms_scores(tscores.dim(0), tscores.dim(1));
66 
67  vector<vector<int>> keeps(num_classes);
68 
69  // Perform nms to each class
70  // skip j = 0, because it's the background class
71  int total_keep_count = 0;
72  for (int j = 1; j < num_classes; j++) {
73  auto cur_scores = scores.col(j);
74  auto inds = utils::GetArrayIndices(cur_scores > score_thres_);
75  auto cur_boxes = boxes.block(0, j * 4, boxes.rows(), 4);
76 
77  if (soft_nms_enabled_) {
78  auto out_scores = soft_nms_scores.col(j);
79  keeps[j] = utils::soft_nms_cpu(
80  &out_scores,
81  cur_boxes,
82  cur_scores,
83  inds,
84  soft_nms_sigma_,
85  nms_thres_,
86  soft_nms_min_score_thres_,
87  soft_nms_method_);
88  } else {
89  std::sort(
90  inds.data(),
91  inds.data() + inds.size(),
92  [&cur_scores](int lhs, int rhs) {
93  return cur_scores(lhs) > cur_scores(rhs);
94  });
95  keeps[j] = utils::nms_cpu(cur_boxes, cur_scores, inds, nms_thres_);
96  }
97  total_keep_count += keeps[j].size();
98  }
99 
100  if (soft_nms_enabled_) {
101  // Re-map scores to the updated SoftNMS scores
102  new (&scores) Eigen::Map<const ERArrXXf>(
103  soft_nms_scores.data(), soft_nms_scores.rows(), soft_nms_scores.cols());
104  }
105 
106  // Limit to max_per_image detections *over all classes*
107  if (detections_per_im_ > 0 && total_keep_count > detections_per_im_) {
108  // merge all scores together and sort
109  auto get_all_scores_sorted = [&scores, &keeps, total_keep_count]() {
110  EArrXf ret(total_keep_count);
111 
112  int ret_idx = 0;
113  for (int i = 1; i < keeps.size(); i++) {
114  auto& cur_keep = keeps[i];
115  auto cur_scores = scores.col(i);
116  auto cur_ret = ret.segment(ret_idx, cur_keep.size());
117  utils::GetSubArray(cur_scores, utils::AsEArrXt(keeps[i]), &cur_ret);
118  ret_idx += cur_keep.size();
119  }
120 
121  std::sort(ret.data(), ret.data() + ret.size());
122 
123  return ret;
124  };
125 
126  // Compute image thres based on all classes
127  auto all_scores_sorted = get_all_scores_sorted();
128  DCHECK_GT(all_scores_sorted.size(), detections_per_im_);
129  auto image_thresh =
130  all_scores_sorted[all_scores_sorted.size() - detections_per_im_];
131 
132  total_keep_count = 0;
133  // filter results with image_thresh
134  for (int j = 1; j < num_classes; j++) {
135  auto& cur_keep = keeps[j];
136  auto cur_scores = scores.col(j);
137  keeps[j] =
138  filter_with_indices(cur_scores, cur_keep, [&image_thresh](float sc) {
139  return sc >= image_thresh;
140  });
141  total_keep_count += keeps[j].size();
142  }
143  }
144 
145  // Write results
146  out_scores->Resize(total_keep_count);
147  out_boxes->Resize(total_keep_count, 4);
148  out_classes->Resize(total_keep_count);
149  int cur_out_idx = 0;
150  for (int j = 1; j < num_classes; j++) {
151  auto cur_scores = scores.col(j);
152  auto cur_boxes = boxes.block(0, j * 4, boxes.rows(), 4);
153  auto& cur_keep = keeps[j];
154  Eigen::Map<EArrXf> cur_out_scores(
155  out_scores->mutable_data<float>() + cur_out_idx, cur_keep.size());
156  Eigen::Map<ERArrXXf> cur_out_boxes(
157  out_boxes->mutable_data<float>() + cur_out_idx * 4, cur_keep.size(), 4);
158  Eigen::Map<EArrXf> cur_out_classes(
159  out_classes->mutable_data<float>() + cur_out_idx, cur_keep.size());
160 
161  utils::GetSubArray(cur_scores, utils::AsEArrXt(cur_keep), &cur_out_scores);
162  utils::GetSubArrayRows(
163  cur_boxes, utils::AsEArrXt(cur_keep), &cur_out_boxes);
164  for (int k = 0; k < cur_keep.size(); k++) {
165  cur_out_classes[k] = static_cast<float>(j);
166  }
167 
168  cur_out_idx += cur_keep.size();
169  }
170 
171  if (OutputSize() > 3) {
172  auto* out_keeps = Output(3);
173  auto* out_keeps_size = Output(4);
174  out_keeps->Resize(total_keep_count);
175  out_keeps_size->Resize(num_classes);
176 
177  Eigen::Map<EArrXi> cur_out_keeps_size(
178  out_keeps_size->mutable_data<int>(), num_classes);
179 
180  cur_out_idx = 0;
181  Eigen::Map<EArrXi> out_keeps_arr(
182  out_keeps->mutable_data<int>(), total_keep_count);
183  for (int j = 0; j < num_classes; j++) {
184  out_keeps_arr.segment(cur_out_idx, keeps[j].size()) =
185  utils::AsEArrXt(keeps[j]);
186 
187  cur_out_keeps_size[j] = keeps[j].size();
188  cur_out_idx += keeps[j].size();
189  }
190  }
191 
192  return true;
193 }
194 
195 namespace {
196 
197 REGISTER_CPU_OPERATOR(BoxWithNMSLimit, BoxWithNMSLimitOp<CPUContext>);
198 
199 #ifdef CAFFE2_HAS_MKL_DNN
200 REGISTER_MKL_OPERATOR(
201  BoxWithNMSLimit,
202  mkl::MKLFallbackOp<BoxWithNMSLimitOp<CPUContext>>);
203 #endif // CAFFE2_HAS_MKL_DNN
204 
205 OPERATOR_SCHEMA(BoxWithNMSLimit)
206  .NumInputs(2)
207  .NumOutputs(3, 5)
208  .SetDoc(R"DOC(
209 Apply NMS to each class (except background) and limit the number of
210 returned boxes.
211 )DOC")
212  .Arg("score_thres", "(float) TEST.SCORE_THRES")
213  .Arg("nms", "(float) TEST.NMS")
214  .Arg("detections_per_im", "(int) TEST.DEECTIONS_PER_IM")
215  .Arg("soft_nms_enabled", "(bool) TEST.SOFT_NMS.ENABLED")
216  .Arg("soft_nms_method", "(string) TEST.SOFT_NMS.METHOD")
217  .Arg("soft_nms_sigma", "(float) TEST.SOFT_NMS.SIGMA")
218  .Arg(
219  "soft_nms_min_score_thres",
220  "(float) Lower bound on updated scores to discard boxes")
221  .Input(0, "scores", "Scores, size (count, num_classes)")
222  .Input(
223  1,
224  "boxes",
225  "Bounding box for each class, size (count, num_classes * 4)")
226  .Output(0, "scores", "Filtered scores, size (n)")
227  .Output(1, "boxes", "Filtered boxes, size (n, 4)")
228  .Output(2, "classes", "Class id for each filtered score/box, size (n)")
229  .Output(3, "keeps", "Optional filtered indices, size (n)")
230  .Output(
231  4,
232  "keeps_size",
233  "Optional number of filtered indices per class, size (num_classes)");
234 
235 SHOULD_NOT_DO_GRADIENT(BoxWithNMSLimit);
236 
237 } // namespace
238 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.