Caffe2 - C++ API
A deep learning, cross platform ML framework
box_with_nms_limit_op.cc
1 #include "box_with_nms_limit_op.h"
2 #include "caffe2/utils/eigen_utils.h"
3 #include "generate_proposals_op_util_nms.h"
4 
5 namespace caffe2 {
6 
7 template <>
8 bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
9  const auto& tscores = Input(0);
10  const auto& tboxes = Input(1);
11 
12  const int box_dim = rotated_ ? 5 : 4;
13 
14  // tscores: (num_boxes, num_classes), 0 for background
15  if (tscores.dim() == 4) {
16  CAFFE_ENFORCE_EQ(tscores.size(2), 1, tscores.size(2));
17  CAFFE_ENFORCE_EQ(tscores.size(3), 1, tscores.size(3));
18  } else {
19  CAFFE_ENFORCE_EQ(tscores.dim(), 2, tscores.dim());
20  }
21  CAFFE_ENFORCE(tscores.template IsType<float>(), tscores.dtype().name());
22  // tboxes: (num_boxes, num_classes * box_dim)
23  if (tboxes.dim() == 4) {
24  CAFFE_ENFORCE_EQ(tboxes.size(2), 1, tboxes.size(2));
25  CAFFE_ENFORCE_EQ(tboxes.size(3), 1, tboxes.size(3));
26  } else {
27  CAFFE_ENFORCE_EQ(tboxes.dim(), 2, tboxes.dim());
28  }
29  CAFFE_ENFORCE(tboxes.template IsType<float>(), tboxes.dtype().name());
30 
31  int N = tscores.size(0);
32  int num_classes = tscores.size(1);
33 
34  CAFFE_ENFORCE_EQ(N, tboxes.size(0));
35  CAFFE_ENFORCE_EQ(num_classes * box_dim, tboxes.size(1));
36 
37  int batch_size = 1;
38  vector<float> batch_splits_default(1, tscores.size(0));
39  const float* batch_splits_data = batch_splits_default.data();
40  if (InputSize() > 2) {
41  // tscores and tboxes have items from multiple images in a batch. Get the
42  // corresponding batch splits from input.
43  const auto& tbatch_splits = Input(2);
44  CAFFE_ENFORCE_EQ(tbatch_splits.dim(), 1);
45  batch_size = tbatch_splits.size(0);
46  batch_splits_data = tbatch_splits.data<float>();
47  }
48  Eigen::Map<const EArrXf> batch_splits(batch_splits_data, batch_size);
49  CAFFE_ENFORCE_EQ(batch_splits.sum(), N);
50 
51  auto* out_scores = Output(0, {0}, at::dtype<float>());
52  auto* out_boxes = Output(1, {0, box_dim}, at::dtype<float>());
53  auto* out_classes = Output(2, {0}, at::dtype<float>());
54 
55  Tensor* out_keeps = nullptr;
56  Tensor* out_keeps_size = nullptr;
57  if (OutputSize() > 4) {
58  out_keeps = Output(4);
59  out_keeps_size = Output(5);
60  out_keeps->Resize(0);
61  out_keeps_size->Resize(batch_size, num_classes);
62  }
63 
64  vector<int> total_keep_per_batch(batch_size);
65  int offset = 0;
66  for (int b = 0; b < batch_splits.size(); ++b) {
67  int num_boxes = batch_splits(b);
68  Eigen::Map<const ERArrXXf> scores(
69  tscores.data<float>() + offset * tscores.size(1),
70  num_boxes,
71  tscores.size(1));
72  Eigen::Map<const ERArrXXf> boxes(
73  tboxes.data<float>() + offset * tboxes.size(1),
74  num_boxes,
75  tboxes.size(1));
76 
77  // To store updated scores if SoftNMS is used
78  ERArrXXf soft_nms_scores(num_boxes, tscores.size(1));
79  vector<vector<int>> keeps(num_classes);
80 
81  // Perform nms to each class
82  // skip j = 0, because it's the background class
83  int total_keep_count = 0;
84  for (int j = 1; j < num_classes; j++) {
85  auto cur_scores = scores.col(j);
86  auto inds = utils::GetArrayIndices(cur_scores > score_thres_);
87  auto cur_boxes = boxes.block(0, j * box_dim, boxes.rows(), box_dim);
88 
89  if (soft_nms_enabled_) {
90  auto cur_soft_nms_scores = soft_nms_scores.col(j);
91  keeps[j] = utils::soft_nms_cpu(
92  &cur_soft_nms_scores,
93  cur_boxes,
94  cur_scores,
95  inds,
96  soft_nms_sigma_,
97  nms_thres_,
98  soft_nms_min_score_thres_,
99  soft_nms_method_);
100  } else {
101  std::sort(
102  inds.data(),
103  inds.data() + inds.size(),
104  [&cur_scores](int lhs, int rhs) {
105  return cur_scores(lhs) > cur_scores(rhs);
106  });
107  int keep_max = detections_per_im_ > 0 ? detections_per_im_ : -1;
108  keeps[j] =
109  utils::nms_cpu(cur_boxes, cur_scores, inds, nms_thres_, keep_max);
110  }
111  total_keep_count += keeps[j].size();
112  }
113 
114  if (soft_nms_enabled_) {
115  // Re-map scores to the updated SoftNMS scores
116  new (&scores) Eigen::Map<const ERArrXXf>(
117  soft_nms_scores.data(),
118  soft_nms_scores.rows(),
119  soft_nms_scores.cols());
120  }
121 
122  // Limit to max_per_image detections *over all classes*
123  if (detections_per_im_ > 0 && total_keep_count > detections_per_im_) {
124  // merge all scores (represented by indices) together and sort
125  auto get_all_scores_sorted = [&scores, &keeps, total_keep_count]() {
126  // flatten keeps[i][j] to [pair(i, keeps[i][j]), ...]
127  // first: class index (1 ~ keeps.size() - 1),
128  // second: values in keeps[first]
129  using KeepIndex = std::pair<int, int>;
130  vector<KeepIndex> ret(total_keep_count);
131 
132  int ret_idx = 0;
133  for (int i = 1; i < keeps.size(); i++) {
134  auto& cur_keep = keeps[i];
135  for (auto& ckv : cur_keep) {
136  ret[ret_idx++] = {i, ckv};
137  }
138  }
139 
140  std::sort(
141  ret.data(),
142  ret.data() + ret.size(),
143  [&scores](const KeepIndex& lhs, const KeepIndex& rhs) {
144  return scores(lhs.second, lhs.first) >
145  scores(rhs.second, rhs.first);
146  });
147 
148  return ret;
149  };
150 
151  // Pick the first `detections_per_im_` boxes with highest scores
152  auto all_scores_sorted = get_all_scores_sorted();
153  DCHECK_GT(all_scores_sorted.size(), detections_per_im_);
154 
155  // Reconstruct keeps from `all_scores_sorted`
156  for (auto& cur_keep : keeps) {
157  cur_keep.clear();
158  }
159  for (int i = 0; i < detections_per_im_; i++) {
160  DCHECK_GT(all_scores_sorted.size(), i);
161  auto& cur = all_scores_sorted[i];
162  keeps[cur.first].push_back(cur.second);
163  }
164  total_keep_count = detections_per_im_;
165  }
166  total_keep_per_batch[b] = total_keep_count;
167 
168  // Write results
169  int cur_start_idx = out_scores->size(0);
170  out_scores->Extend(total_keep_count, 50);
171  out_boxes->Extend(total_keep_count, 50);
172  out_classes->Extend(total_keep_count, 50);
173 
174  int cur_out_idx = 0;
175  for (int j = 1; j < num_classes; j++) {
176  auto cur_scores = scores.col(j);
177  auto cur_boxes = boxes.block(0, j * box_dim, boxes.rows(), box_dim);
178  auto& cur_keep = keeps[j];
179  Eigen::Map<EArrXf> cur_out_scores(
180  out_scores->template mutable_data<float>() + cur_start_idx +
181  cur_out_idx,
182  cur_keep.size());
183  Eigen::Map<ERArrXXf> cur_out_boxes(
184  out_boxes->mutable_data<float>() +
185  (cur_start_idx + cur_out_idx) * box_dim,
186  cur_keep.size(),
187  box_dim);
188  Eigen::Map<EArrXf> cur_out_classes(
189  out_classes->template mutable_data<float>() + cur_start_idx +
190  cur_out_idx,
191  cur_keep.size());
192 
193  utils::GetSubArray(
194  cur_scores, utils::AsEArrXt(cur_keep), &cur_out_scores);
195  utils::GetSubArrayRows(
196  cur_boxes, utils::AsEArrXt(cur_keep), &cur_out_boxes);
197  for (int k = 0; k < cur_keep.size(); k++) {
198  cur_out_classes[k] = static_cast<float>(j);
199  }
200 
201  cur_out_idx += cur_keep.size();
202  }
203 
204  if (out_keeps) {
205  out_keeps->Extend(total_keep_count, 50);
206 
207  Eigen::Map<EArrXi> out_keeps_arr(
208  out_keeps->template mutable_data<int>() + cur_start_idx,
209  total_keep_count);
210  Eigen::Map<EArrXi> cur_out_keeps_size(
211  out_keeps_size->template mutable_data<int>() + b * num_classes,
212  num_classes);
213 
214  cur_out_idx = 0;
215  for (int j = 0; j < num_classes; j++) {
216  out_keeps_arr.segment(cur_out_idx, keeps[j].size()) =
217  utils::AsEArrXt(keeps[j]);
218  cur_out_keeps_size[j] = keeps[j].size();
219  cur_out_idx += keeps[j].size();
220  }
221  }
222 
223  offset += num_boxes;
224  }
225 
226  if (OutputSize() > 3) {
227  auto* batch_splits_out = Output(3, {batch_size}, at::dtype<float>());
228  Eigen::Map<EArrXf> batch_splits_out_map(
229  batch_splits_out->template mutable_data<float>(), batch_size);
230  batch_splits_out_map =
231  Eigen::Map<const EArrXi>(total_keep_per_batch.data(), batch_size)
232  .cast<float>();
233  }
234 
235  return true;
236 }
237 
238 namespace {
239 
240 REGISTER_CPU_OPERATOR(BoxWithNMSLimit, BoxWithNMSLimitOp<CPUContext>);
241 
242 OPERATOR_SCHEMA(BoxWithNMSLimit)
243  .NumInputs(2, 3)
244  .NumOutputs(3, 6)
245  .SetDoc(R"DOC(
246 Apply NMS to each class (except background) and limit the number of
247 returned boxes.
248 )DOC")
249  .Arg("score_thresh", "(float) TEST.SCORE_THRESH")
250  .Arg("nms", "(float) TEST.NMS")
251  .Arg("detections_per_im", "(int) TEST.DEECTIONS_PER_IM")
252  .Arg("soft_nms_enabled", "(bool) TEST.SOFT_NMS.ENABLED")
253  .Arg("soft_nms_method", "(string) TEST.SOFT_NMS.METHOD")
254  .Arg("soft_nms_sigma", "(float) TEST.SOFT_NMS.SIGMA")
255  .Arg(
256  "soft_nms_min_score_thres",
257  "(float) Lower bound on updated scores to discard boxes")
258  .Arg(
259  "rotated",
260  "bool (default false). If true, then boxes (rois and deltas) include "
261  "angle info to handle rotation. The format will be "
262  "[ctr_x, ctr_y, width, height, angle (in degrees)].")
263  .Input(0, "scores", "Scores, size (count, num_classes)")
264  .Input(
265  1,
266  "boxes",
267  "Bounding box for each class, size (count, num_classes * 4). "
268  "For rotated boxes, this would have an additional angle (in degrees) "
269  "in the format [<optionaal_batch_id>, ctr_x, ctr_y, w, h, angle]. "
270  "Size: (count, num_classes * 5).")
271  .Input(
272  2,
273  "batch_splits",
274  "Tensor of shape (batch_size) with each element denoting the number "
275  "of RoIs/boxes belonging to the corresponding image in batch. "
276  "Sum should add up to total count of scores/boxes.")
277  .Output(0, "scores", "Filtered scores, size (n)")
278  .Output(
279  1,
280  "boxes",
281  "Filtered boxes, size (n, 4). "
282  "For rotated boxes, size (n, 5), format [ctr_x, ctr_y, w, h, angle].")
283  .Output(2, "classes", "Class id for each filtered score/box, size (n)")
284  .Output(
285  3,
286  "batch_splits",
287  "Output batch splits for scores/boxes after applying NMS")
288  .Output(4, "keeps", "Optional filtered indices, size (n)")
289  .Output(
290  5,
291  "keeps_size",
292  "Optional number of filtered indices per class, size (num_classes)");
293 
294 SHOULD_NOT_DO_GRADIENT(BoxWithNMSLimit);
295 
296 } // namespace
297 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13