Caffe2 - C++ API
A deep learning, cross platform ML framework
generate_proposals_op.h
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 
3 #ifndef CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
4 #define CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/eigen_utils.h"
9 #include "caffe2/utils/math.h"
10 
11 namespace caffe2 {
12 
13 namespace utils {
14 
15 // A sub tensor view
16 template <class T>
18  public:
19  ConstTensorView(const T* data, const std::vector<int>& dims)
20  : data_(data), dims_(dims) {}
21 
22  int ndim() const {
23  return dims_.size();
24  }
25  const std::vector<int>& dims() const {
26  return dims_;
27  }
28  int dim(int i) const {
29  DCHECK_LE(i, dims_.size());
30  return dims_[i];
31  }
32  const T* data() const {
33  return data_;
34  }
35  size_t size() const {
36  return std::accumulate(
37  dims_.begin(), dims_.end(), 1, std::multiplies<size_t>());
38  }
39 
40  private:
41  const T* data_ = nullptr;
42  std::vector<int> dims_;
43 };
44 
45 // Generate a list of bounding box shapes for each pixel based on predefined
46 // bounding box shapes 'anchors'.
47 // anchors: predefined anchors, size(A, 4)
48 // Return: all_anchors_vec: (H * W, A * 4)
49 // Need to reshape to (H * W * A, 4) to match the format in python
50 ERMatXf ComputeAllAnchors(
51  const TensorCPU& anchors,
52  int height,
53  int width,
54  float feat_stride);
55 
56 } // namespace utils
57 
58 // C++ implementation of GenerateProposalsOp
59 // Generate bounding box proposals for Faster RCNN. The propoasls are generated
60 // for a list of images based on image score 'score', bounding box
61 // regression result 'deltas' as well as predefined bounding box shapes
62 // 'anchors'. Greedy non-maximum suppression is applied to generate the
63 // final bounding boxes.
64 // Reference: detectron/lib/ops/generate_proposals.py
65 template <class Context>
66 class GenerateProposalsOp final : public Operator<Context> {
67  public:
68  USE_OPERATOR_CONTEXT_FUNCTIONS;
69  GenerateProposalsOp(const OperatorDef& operator_def, Workspace* ws)
70  : Operator<Context>(operator_def, ws),
71  spatial_scale_(
72  OperatorBase::GetSingleArgument<float>("spatial_scale", 1.0 / 16)),
73  feat_stride_(1.0 / spatial_scale_),
74  rpn_pre_nms_topN_(
75  OperatorBase::GetSingleArgument<int>("pre_nms_topN", 6000)),
76  rpn_post_nms_topN_(
77  OperatorBase::GetSingleArgument<int>("post_nms_topN", 300)),
78  rpn_nms_thresh_(
79  OperatorBase::GetSingleArgument<float>("nms_thresh", 0.7f)),
80  rpn_min_size_(OperatorBase::GetSingleArgument<float>("min_size", 16)),
81  correct_transform_coords_(OperatorBase::GetSingleArgument<bool>(
82  "correct_transform_coords",
83  false)) {}
84 
86 
87  bool RunOnDevice() override;
88 
89  // Generate bounding box proposals for a given image
90  // im_info: [height, width, im_scale]
91  // all_anchors: (H * W * A, 4)
92  // bbox_deltas_tensor: (4 * A, H, W)
93  // scores_tensor: (A, H, W)
94  // out_boxes: (n, 5)
95  // out_probs: n
96  void ProposalsForOneImage(
97  const Eigen::Array3f& im_info,
98  const Eigen::Map<const ERMatXf>& all_anchors,
99  const utils::ConstTensorView<float>& bbox_deltas_tensor,
100  const utils::ConstTensorView<float>& scores_tensor,
101  ERArrXXf* out_boxes,
102  EArrXf* out_probs) const;
103 
104  protected:
105  // spatial_scale_ must be declared before feat_stride_
106  float spatial_scale_{1.0};
107  float feat_stride_{1.0};
108 
109  // RPN_PRE_NMS_TOP_N
110  int rpn_pre_nms_topN_{6000};
111  // RPN_POST_NMS_TOP_N
112  int rpn_post_nms_topN_{300};
113  // RPN_NMS_THRESH
114  float rpn_nms_thresh_{0.7};
115  // RPN_MIN_SIZE
116  float rpn_min_size_{16};
117  // Correct bounding box transform coordates, see bbox_transform() in boxes.py
118  // Set to true to match the detectron code, set to false for backward
119  // compatibility
120  bool correct_transform_coords_{false};
121 };
122 
123 } // namespace caffe2
124 
125 #endif // CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.