1 #ifndef CAFFE2_OPERATORS_UTILS_BOXES_H_ 2 #define CAFFE2_OPERATORS_UTILS_BOXES_H_ 4 #include "caffe2/utils/eigen_utils.h" 5 #include "caffe2/utils/math.h" 15 const float BBOX_XFORM_CLIP_DEFAULT = log(1000.0 / 16.0);
16 const float PI = 3.14159265358979323846;
36 template <
class Derived1,
class Derived2>
37 EArrXXt<typename Derived1::Scalar> bbox_transform_upright(
38 const Eigen::ArrayBase<Derived1>& boxes,
39 const Eigen::ArrayBase<Derived2>& deltas,
40 const std::vector<typename Derived2::Scalar>& weights =
41 std::vector<typename Derived2::Scalar>{1.0, 1.0, 1.0, 1.0},
42 const float bbox_xform_clip = BBOX_XFORM_CLIP_DEFAULT) {
43 using T =
typename Derived1::Scalar;
44 using EArrXX = EArrXXt<T>;
45 using EArrX = EArrXt<T>;
47 if (boxes.rows() == 0) {
48 return EArrXX::Zero(
T(0), deltas.cols());
51 CAFFE_ENFORCE_EQ(boxes.rows(), deltas.rows());
52 CAFFE_ENFORCE_EQ(boxes.cols(), 4);
53 CAFFE_ENFORCE_EQ(deltas.cols(), 4);
55 EArrX widths = boxes.col(2) - boxes.col(0) +
T(1.0);
56 EArrX heights = boxes.col(3) - boxes.col(1) +
T(1.0);
57 auto ctr_x = boxes.col(0) +
T(0.5) * widths;
58 auto ctr_y = boxes.col(1) +
T(0.5) * heights;
60 auto dx = deltas.col(0).template cast<T>() / weights[0];
61 auto dy = deltas.col(1).template cast<T>() / weights[1];
63 (deltas.col(2).template cast<T>() / weights[2]).cwiseMin(bbox_xform_clip);
65 (deltas.col(3).template cast<T>() / weights[3]).cwiseMin(bbox_xform_clip);
67 EArrX pred_ctr_x = dx * widths + ctr_x;
68 EArrX pred_ctr_y = dy * heights + ctr_y;
69 EArrX pred_w = dw.exp() * widths;
70 EArrX pred_h = dh.exp() * heights;
72 EArrXX pred_boxes = EArrXX::Zero(deltas.rows(), deltas.cols());
74 pred_boxes.col(0) = pred_ctr_x -
T(0.5) * pred_w;
76 pred_boxes.col(1) = pred_ctr_y -
T(0.5) * pred_h;
78 pred_boxes.col(2) = pred_ctr_x +
T(0.5) * pred_w -
T(1.0);
80 pred_boxes.col(3) = pred_ctr_y +
T(0.5) * pred_h -
T(1.0);
95 template <
class Derived1,
class Derived2>
96 EArrXXt<typename Derived1::Scalar> bbox_transform_rotated(
97 const Eigen::ArrayBase<Derived1>& boxes,
98 const Eigen::ArrayBase<Derived2>& deltas,
99 const std::vector<typename Derived2::Scalar>& weights =
100 std::vector<typename Derived2::Scalar>{1.0, 1.0, 1.0, 1.0},
101 const float bbox_xform_clip = BBOX_XFORM_CLIP_DEFAULT,
102 const bool angle_bound_on =
true,
103 const int angle_bound_lo = -90,
104 const int angle_bound_hi = 90) {
105 using T =
typename Derived1::Scalar;
106 using EArrXX = EArrXXt<T>;
108 if (boxes.rows() == 0) {
109 return EArrXX::Zero(
T(0), deltas.cols());
112 CAFFE_ENFORCE_EQ(boxes.rows(), deltas.rows());
113 CAFFE_ENFORCE_EQ(boxes.cols(), 5);
114 CAFFE_ENFORCE_EQ(deltas.cols(), 5);
116 const auto& ctr_x = boxes.col(0);
117 const auto& ctr_y = boxes.col(1);
118 const auto& widths = boxes.col(2);
119 const auto& heights = boxes.col(3);
120 const auto& angles = boxes.col(4);
122 auto dx = deltas.col(0).template cast<T>() / weights[0];
123 auto dy = deltas.col(1).template cast<T>() / weights[1];
125 (deltas.col(2).template cast<T>() / weights[2]).cwiseMin(bbox_xform_clip);
127 (deltas.col(3).template cast<T>() / weights[3]).cwiseMin(bbox_xform_clip);
129 auto da = deltas.col(4).template cast<T>() * 180.0 / PI;
131 EArrXX pred_boxes = EArrXX::Zero(deltas.rows(), deltas.cols());
133 pred_boxes.col(0) = dx * widths + ctr_x;
135 pred_boxes.col(1) = dy * heights + ctr_y;
137 pred_boxes.col(2) = dw.exp() * widths;
139 pred_boxes.col(3) = dh.exp() * heights;
141 pred_boxes.col(4) = da + angles;
143 if (angle_bound_on) {
147 const int period = angle_bound_hi - angle_bound_lo;
148 CAFFE_ENFORCE(period > 0 && period % 180 == 0);
149 auto angles = pred_boxes.col(4);
150 for (
int i = 0; i < angles.size(); ++i) {
151 if (angles[i] < angle_bound_lo) {
152 angles[i] +=
T(period);
153 }
else if (angles[i] > angle_bound_hi) {
154 angles[i] -=
T(period);
162 template <
class Derived1,
class Derived2>
163 EArrXXt<typename Derived1::Scalar> bbox_transform(
164 const Eigen::ArrayBase<Derived1>& boxes,
165 const Eigen::ArrayBase<Derived2>& deltas,
166 const std::vector<typename Derived2::Scalar>& weights =
167 std::vector<typename Derived2::Scalar>{1.0, 1.0, 1.0, 1.0},
168 const float bbox_xform_clip = BBOX_XFORM_CLIP_DEFAULT,
169 const bool angle_bound_on =
true,
170 const int angle_bound_lo = -90,
171 const int angle_bound_hi = 90) {
172 CAFFE_ENFORCE(boxes.cols() == 4 || boxes.cols() == 5);
173 if (boxes.cols() == 4) {
175 return bbox_transform_upright(boxes, deltas, weights, bbox_xform_clip);
178 return bbox_transform_rotated(
189 template <
class Derived>
190 EArrXXt<typename Derived::Scalar> bbox_xyxy_to_ctrwh(
191 const Eigen::ArrayBase<Derived>& boxes) {
192 CAFFE_ENFORCE_EQ(boxes.cols(), 4);
194 const auto& x1 = boxes.col(0);
195 const auto& y1 = boxes.col(1);
196 const auto& x2 = boxes.col(2);
197 const auto& y2 = boxes.col(3);
199 EArrXXt<typename Derived::Scalar> ret(boxes.rows(), 4);
200 ret.col(0) = (x1 + x2) / 2.0;
201 ret.col(1) = (y1 + y2) / 2.0;
202 ret.col(2) = x2 - x1 + 1.0;
203 ret.col(3) = y2 - y1 + 1.0;
207 template <
class Derived>
208 EArrXXt<typename Derived::Scalar> bbox_ctrwh_to_xyxy(
209 const Eigen::ArrayBase<Derived>& boxes) {
210 CAFFE_ENFORCE_EQ(boxes.cols(), 4);
212 const auto& x_ctr = boxes.col(0);
213 const auto& y_ctr = boxes.col(1);
214 const auto& w = boxes.col(2);
215 const auto& h = boxes.col(3);
217 EArrXXt<typename Derived::Scalar> ret(boxes.rows(), 4);
218 ret.col(0) = x_ctr - (w - 1) / 2.0;
219 ret.col(1) = y_ctr - (h - 1) / 2.0;
220 ret.col(2) = x_ctr + (w - 1) / 2.0;
221 ret.col(3) = y_ctr + (h - 1) / 2.0;
227 template <
class Derived>
228 EArrXXt<typename Derived::Scalar> clip_boxes_upright(
229 const Eigen::ArrayBase<Derived>& boxes,
232 CAFFE_ENFORCE(boxes.cols() == 4);
234 EArrXXt<typename Derived::Scalar> ret(boxes.rows(), boxes.cols());
237 ret.col(0) = boxes.col(0).cwiseMin(width - 1).cwiseMax(0);
239 ret.col(1) = boxes.col(1).cwiseMin(height - 1).cwiseMax(0);
241 ret.col(2) = boxes.col(2).cwiseMin(width - 1).cwiseMax(0);
243 ret.col(3) = boxes.col(3).cwiseMin(height - 1).cwiseMax(0);
261 template <
class Derived>
262 EArrXXt<typename Derived::Scalar> clip_boxes_rotated(
263 const Eigen::ArrayBase<Derived>& boxes,
266 float angle_thresh = 1.0) {
267 CAFFE_ENFORCE(boxes.cols() == 5);
269 const auto& angles = boxes.col(4);
272 EArrXXt<typename Derived::Scalar> upright_boxes;
273 const auto& indices = GetArrayIndices(angles.abs() <= angle_thresh);
274 GetSubArrayRows(boxes, AsEArrXt(indices), &upright_boxes);
277 const auto& upright_boxes_xyxy =
278 bbox_ctrwh_to_xyxy(upright_boxes.leftCols(4));
279 const auto& clipped_upright_boxes_xyxy =
280 clip_boxes_upright(upright_boxes_xyxy, height, width);
283 upright_boxes.block(0, 0, upright_boxes.rows(), 4) =
284 bbox_xyxy_to_ctrwh(clipped_upright_boxes_xyxy);
286 EArrXXt<typename Derived::Scalar> ret(boxes.rows(), boxes.cols());
288 for (
int i = 0; i < upright_boxes.rows(); ++i) {
289 ret.row(indices[i]) = upright_boxes.row(i);
295 template <
class Derived>
296 EArrXXt<typename Derived::Scalar> clip_boxes(
297 const Eigen::ArrayBase<Derived>& boxes,
300 float angle_thresh = 1.0) {
301 CAFFE_ENFORCE(boxes.cols() == 4 || boxes.cols() == 5);
302 if (boxes.cols() == 4) {
304 return clip_boxes_upright(boxes, height, width);
307 return clip_boxes_rotated(boxes, height, width, angle_thresh);
315 template <
class Derived>
316 std::vector<int> filter_boxes_upright(
317 const Eigen::ArrayBase<Derived>& boxes,
319 const Eigen::Array3f& im_info) {
320 CAFFE_ENFORCE_EQ(boxes.cols(), 4);
323 min_size *= im_info[2];
325 using T =
typename Derived::Scalar;
326 using EArrX = EArrXt<T>;
328 EArrX ws = boxes.col(2) - boxes.col(0) +
T(1);
329 EArrX hs = boxes.col(3) - boxes.col(1) +
T(1);
330 EArrX x_ctr = boxes.col(0) + ws /
T(2);
331 EArrX y_ctr = boxes.col(1) + hs /
T(2);
333 EArrXb keep = (ws >= min_size) && (hs >= min_size) &&
334 (x_ctr <
T(im_info[1])) && (y_ctr <
T(im_info[0]));
336 return GetArrayIndices(keep);
344 template <
class Derived>
345 std::vector<int> filter_boxes_rotated(
346 const Eigen::ArrayBase<Derived>& boxes,
348 const Eigen::Array3f& im_info) {
349 CAFFE_ENFORCE_EQ(boxes.cols(), 5);
352 min_size *= im_info[2];
354 using T =
typename Derived::Scalar;
356 const auto& x_ctr = boxes.col(0);
357 const auto& y_ctr = boxes.col(1);
358 const auto& ws = boxes.col(2);
359 const auto& hs = boxes.col(3);
361 EArrXb keep = (ws >= min_size) && (hs >= min_size) &&
362 (x_ctr <
T(im_info[1])) && (y_ctr <
T(im_info[0]));
364 return GetArrayIndices(keep);
367 template <
class Derived>
368 std::vector<int> filter_boxes(
369 const Eigen::ArrayBase<Derived>& boxes,
371 const Eigen::Array3f& im_info) {
372 CAFFE_ENFORCE(boxes.cols() == 4 || boxes.cols() == 5);
373 if (boxes.cols() == 4) {
375 return filter_boxes_upright(boxes, min_size, im_info);
378 return filter_boxes_rotated(boxes, min_size, im_info);
385 #endif // CAFFE2_OPERATORS_UTILS_BOXES_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...