Caffe2 - C++ API
A deep learning, cross platform ML framework
generate_proposals_op.h
1 #ifndef CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
2 #define CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/eigen_utils.h"
7 #include "caffe2/utils/math.h"
8 
9 C10_DECLARE_CAFFE2_OPERATOR(GenerateProposals);
10 
11 namespace caffe2 {
12 
13 namespace utils {
14 
15 // A sub tensor view
16 // TODO: Remove???
17 template <class T>
19  public:
20  ConstTensorView(const T* data, const std::vector<int>& dims)
21  : data_(data), dims_(dims) {}
22 
23  int ndim() const {
24  return dims_.size();
25  }
26  const std::vector<int>& dims() const {
27  return dims_;
28  }
29  int dim(int i) const {
30  DCHECK_LE(i, dims_.size());
31  return dims_[i];
32  }
33  const T* data() const {
34  return data_;
35  }
36  size_t size() const {
37  return std::accumulate(
38  dims_.begin(), dims_.end(), 1, std::multiplies<size_t>());
39  }
40 
41  private:
42  const T* data_ = nullptr;
43  std::vector<int> dims_;
44 };
45 
46 // Generate a list of bounding box shapes for each pixel based on predefined
47 // bounding box shapes 'anchors'.
48 // anchors: predefined anchors, size(A, 4)
49 // Return: all_anchors_vec: (H * W, A * 4)
50 // Need to reshape to (H * W * A, 4) to match the format in python
51 CAFFE2_API ERMatXf ComputeAllAnchors(
52  const TensorCPU& anchors,
53  int height,
54  int width,
55  float feat_stride);
56 
57 // Like ComputeAllAnchors, but instead of computing anchors for every single
58 // spatial location, only computes anchors for the already sorted and filtered
59 // positions after NMS is applied to avoid unnecessary computation.
60 // `order` is a raveled array of sorted indices in (A, H, W) format.
61 CAFFE2_API ERArrXXf ComputeSortedAnchors(
62  const Eigen::Map<const ERArrXXf>& anchors,
63  int height,
64  int width,
65  float feat_stride,
66  const vector<int>& order);
67 
68 } // namespace utils
69 
70 // C++ implementation of GenerateProposalsOp
71 // Generate bounding box proposals for Faster RCNN. The propoasls are generated
72 // for a list of images based on image score 'score', bounding box
73 // regression result 'deltas' as well as predefined bounding box shapes
74 // 'anchors'. Greedy non-maximum suppression is applied to generate the
75 // final bounding boxes.
76 // Reference: facebookresearch/Detectron/detectron/ops/generate_proposals.py
77 template <class Context>
78 class GenerateProposalsOp final : public Operator<Context> {
79  public:
80  USE_OPERATOR_CONTEXT_FUNCTIONS;
81  template<class... Args>
82  explicit GenerateProposalsOp(Args&&... args)
83  : Operator<Context>(std::forward<Args>(args)...),
84  spatial_scale_(
85  this->template GetSingleArgument<float>("spatial_scale", 1.0 / 16)),
86  feat_stride_(1.0 / spatial_scale_),
87  rpn_pre_nms_topN_(
88  this->template GetSingleArgument<int>("pre_nms_topN", 6000)),
89  rpn_post_nms_topN_(
90  this->template GetSingleArgument<int>("post_nms_topN", 300)),
91  rpn_nms_thresh_(
92  this->template GetSingleArgument<float>("nms_thresh", 0.7f)),
93  rpn_min_size_(this->template GetSingleArgument<float>("min_size", 16)),
94  angle_bound_on_(
95  this->template GetSingleArgument<bool>("angle_bound_on", true)),
96  angle_bound_lo_(
97  this->template GetSingleArgument<int>("angle_bound_lo", -90)),
98  angle_bound_hi_(
99  this->template GetSingleArgument<int>("angle_bound_hi", 90)),
100  clip_angle_thresh_(
101  this->template GetSingleArgument<float>("clip_angle_thresh", 1.0)) {}
102 
103  ~GenerateProposalsOp() {}
104 
105  bool RunOnDevice() override;
106 
107  // Generate bounding box proposals for a given image
108  // im_info: [height, width, im_scale]
109  // all_anchors: (H * W * A, 4)
110  // bbox_deltas_tensor: (4 * A, H, W)
111  // scores_tensor: (A, H, W)
112  // out_boxes: (n, 5)
113  // out_probs: n
114  void ProposalsForOneImage(
115  const Eigen::Array3f& im_info,
116  const Eigen::Map<const ERArrXXf>& anchors,
117  const utils::ConstTensorView<float>& bbox_deltas_tensor,
118  const utils::ConstTensorView<float>& scores_tensor,
119  ERArrXXf* out_boxes,
120  EArrXf* out_probs) const;
121 
122  protected:
123  // spatial_scale_ must be declared before feat_stride_
124  float spatial_scale_{1.0};
125  float feat_stride_{1.0};
126 
127  // RPN_PRE_NMS_TOP_N
128  int rpn_pre_nms_topN_{6000};
129  // RPN_POST_NMS_TOP_N
130  int rpn_post_nms_topN_{300};
131  // RPN_NMS_THRESH
132  float rpn_nms_thresh_{0.7};
133  // RPN_MIN_SIZE
134  float rpn_min_size_{16};
135  // If set, for rotated boxes in RRPN, output angles are normalized to be
136  // within [angle_bound_lo, angle_bound_hi].
137  bool angle_bound_on_{true};
138  int angle_bound_lo_{-90};
139  int angle_bound_hi_{90};
140  // For RRPN, clip almost horizontal boxes within this threshold of
141  // tolerance for backward compatibility. Set to negative value for
142  // no clipping.
143  float clip_angle_thresh_{1.0};
144 
145  // Scratch space required by the CUDA version
146  // CUB buffers
147  Tensor dev_cub_sort_buffer_{Context::GetDeviceType()};
148  Tensor dev_cub_select_buffer_{Context::GetDeviceType()};
149  Tensor dev_image_offset_{Context::GetDeviceType()};
150  Tensor dev_conv_layer_indexes_{Context::GetDeviceType()};
151  Tensor dev_sorted_conv_layer_indexes_{Context::GetDeviceType()};
152  Tensor dev_sorted_scores_{Context::GetDeviceType()};
153  Tensor dev_boxes_{Context::GetDeviceType()};
154  Tensor dev_boxes_keep_flags_{Context::GetDeviceType()};
155 
156  // prenms proposals (raw proposals minus empty boxes)
157  Tensor dev_image_prenms_boxes_{Context::GetDeviceType()};
158  Tensor dev_image_prenms_scores_{Context::GetDeviceType()};
159  Tensor dev_prenms_nboxes_{Context::GetDeviceType()};
160  Tensor host_prenms_nboxes_{CPU};
161 
162  Tensor dev_image_boxes_keep_list_{Context::GetDeviceType()};
163 
164  // Tensors used by NMS
165  Tensor dev_nms_mask_{Context::GetDeviceType()};
166  Tensor host_nms_mask_{CPU};
167 
168  // Buffer for output
169  Tensor dev_postnms_rois_{Context::GetDeviceType()};
170  Tensor dev_postnms_rois_probs_{Context::GetDeviceType()};
171 };
172 
173 } // namespace caffe2
174 
175 #endif // CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13