Caffe2 - C++ API
A deep learning, cross platform ML framework
roi_align_op.cc
1 #include "roi_align_op.h"
2 
3 #include "caffe2/utils/eigen_utils.h"
4 #include "caffe2/utils/math.h"
5 
6 namespace caffe2 {
7 namespace {
8 
9 template <typename T>
10 struct PreCalc {
11  int pos1;
12  int pos2;
13  int pos3;
14  int pos4;
15  T w1;
16  T w2;
17  T w3;
18  T w4;
19 };
20 
21 template <typename T>
22 void pre_calc_for_bilinear_interpolate(
23  const int height,
24  const int width,
25  const int pooled_height,
26  const int pooled_width,
27  const int iy_upper,
28  const int ix_upper,
29  T roi_start_h,
30  T roi_start_w,
31  T bin_size_h,
32  T bin_size_w,
33  int roi_bin_grid_h,
34  int roi_bin_grid_w,
35  std::vector<PreCalc<T>>& pre_calc) {
36  int pre_calc_index = 0;
37  for (int ph = 0; ph < pooled_height; ph++) {
38  for (int pw = 0; pw < pooled_width; pw++) {
39  for (int iy = 0; iy < iy_upper; iy++) {
40  const T yy = roi_start_h + ph * bin_size_h +
41  static_cast<T>(iy + .5f) * bin_size_h /
42  static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
43  for (int ix = 0; ix < ix_upper; ix++) {
44  const T xx = roi_start_w + pw * bin_size_w +
45  static_cast<T>(ix + .5f) * bin_size_w /
46  static_cast<T>(roi_bin_grid_w);
47 
48  T x = xx;
49  T y = yy;
50  // deal with: inverse elements are out of feature map boundary
51  if (y < -1.0 || y > height || x < -1.0 || x > width) {
52  // empty
53  PreCalc<T> pc;
54  pc.pos1 = 0;
55  pc.pos2 = 0;
56  pc.pos3 = 0;
57  pc.pos4 = 0;
58  pc.w1 = 0;
59  pc.w2 = 0;
60  pc.w3 = 0;
61  pc.w4 = 0;
62  pre_calc[pre_calc_index] = pc;
63  pre_calc_index += 1;
64  continue;
65  }
66 
67  if (y <= 0) {
68  y = 0;
69  }
70  if (x <= 0) {
71  x = 0;
72  }
73 
74  int y_low = (int)y;
75  int x_low = (int)x;
76  int y_high;
77  int x_high;
78 
79  if (y_low >= height - 1) {
80  y_high = y_low = height - 1;
81  y = (T)y_low;
82  } else {
83  y_high = y_low + 1;
84  }
85 
86  if (x_low >= width - 1) {
87  x_high = x_low = width - 1;
88  x = (T)x_low;
89  } else {
90  x_high = x_low + 1;
91  }
92 
93  T ly = y - y_low;
94  T lx = x - x_low;
95  T hy = 1. - ly, hx = 1. - lx;
96  T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
97 
98  // save weights and indeces
99  PreCalc<T> pc;
100  pc.pos1 = y_low * width + x_low;
101  pc.pos2 = y_low * width + x_high;
102  pc.pos3 = y_high * width + x_low;
103  pc.pos4 = y_high * width + x_high;
104  pc.w1 = w1;
105  pc.w2 = w2;
106  pc.w3 = w3;
107  pc.w4 = w4;
108  pre_calc[pre_calc_index] = pc;
109 
110  pre_calc_index += 1;
111  }
112  }
113  }
114  }
115 }
116 
117 template <typename T>
118 void ROIAlignForward(
119  const int nthreads,
120  const T* bottom_data,
121  const T& spatial_scale,
122  const int channels,
123  const int height,
124  const int width,
125  const int pooled_height,
126  const int pooled_width,
127  const int sampling_ratio,
128  const T* bottom_rois,
129  int roi_cols,
130  T* top_data,
131  StorageOrder order) {
132  DCHECK(roi_cols == 4 || roi_cols == 5);
133 
134  int n_rois = nthreads / channels / pooled_width / pooled_height;
135 
136 #ifdef _OPENMP
137 #pragma omp parallel for
138 #endif
139  for (int n = 0; n < n_rois; n++) {
140  int index_n = n * channels * pooled_width * pooled_height;
141 
142  // roi could have 4 or 5 columns
143  const T* offset_bottom_rois = bottom_rois + n * roi_cols;
144  int roi_batch_ind = 0;
145  if (roi_cols == 5) {
146  roi_batch_ind = offset_bottom_rois[0];
147  offset_bottom_rois++;
148  }
149 
150  // Do not using rounding; this implementation detail is critical
151  T roi_start_w = offset_bottom_rois[0] * spatial_scale;
152  T roi_start_h = offset_bottom_rois[1] * spatial_scale;
153  T roi_end_w = offset_bottom_rois[2] * spatial_scale;
154  T roi_end_h = offset_bottom_rois[3] * spatial_scale;
155  // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
156  // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
157  // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
158  // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
159 
160  // Force malformed ROIs to be 1x1
161  T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
162  T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
163  T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
164  T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
165 
166  // We use roi_bin_grid to sample the grid and mimic integral
167  int roi_bin_grid_h = (sampling_ratio > 0)
168  ? sampling_ratio
169  : ceil(roi_height / pooled_height); // e.g., = 2
170  int roi_bin_grid_w =
171  (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
172 
173  // We do average (integral) pooling inside a bin
174  const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
175 
176  // we want to precalculate indeces and weights shared by all chanels,
177  // this is the key point of optimiation
178  std::vector<PreCalc<T>> pre_calc(
179  roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
180  pre_calc_for_bilinear_interpolate(
181  height,
182  width,
183  pooled_height,
184  pooled_width,
185  roi_bin_grid_h,
186  roi_bin_grid_w,
187  roi_start_h,
188  roi_start_w,
189  bin_size_h,
190  bin_size_w,
191  roi_bin_grid_h,
192  roi_bin_grid_w,
193  pre_calc);
194 
195  if (order == StorageOrder::NCHW) {
196  for (int c = 0; c < channels; c++) {
197  int index_n_c = index_n + c * pooled_width * pooled_height;
198  const T* offset_bottom_data =
199  bottom_data + (roi_batch_ind * channels + c) * height * width;
200  int pre_calc_index = 0;
201 
202  for (int ph = 0; ph < pooled_height; ph++) {
203  for (int pw = 0; pw < pooled_width; pw++) {
204  int index = index_n_c + ph * pooled_width + pw;
205 
206  T output_val = 0.;
207  for (int iy = 0; iy < roi_bin_grid_h; iy++) {
208  for (int ix = 0; ix < roi_bin_grid_w; ix++) {
209  PreCalc<T> pc = pre_calc[pre_calc_index];
210  output_val += pc.w1 * offset_bottom_data[pc.pos1] +
211  pc.w2 * offset_bottom_data[pc.pos2] +
212  pc.w3 * offset_bottom_data[pc.pos3] +
213  pc.w4 * offset_bottom_data[pc.pos4];
214 
215  pre_calc_index += 1;
216  }
217  }
218  output_val /= count;
219 
220  top_data[index] = output_val;
221  } // for pw
222  } // for ph
223  } // for c
224  } // if nchw
225 
226  if (order == StorageOrder::NHWC) {
227  const T* offset_bottom_data =
228  bottom_data + roi_batch_ind * channels * height * width;
229  int pre_calc_index = 0;
230 
231  for (int ph = 0; ph < pooled_height; ph++) {
232  for (int pw = 0; pw < pooled_width; pw++) {
233  EVecXf output_vals = EVecXf::Zero(channels);
234 
235  for (int iy = 0; iy < roi_bin_grid_h; iy++) {
236  for (int ix = 0; ix < roi_bin_grid_w; ix++) {
237  PreCalc<T> pc = pre_calc[pre_calc_index];
238 
239  ConstEigenVectorMap<T> data_1(
240  offset_bottom_data + channels * pc.pos1, channels);
241  ConstEigenVectorMap<T> data_2(
242  offset_bottom_data + channels * pc.pos2, channels);
243  ConstEigenVectorMap<T> data_3(
244  offset_bottom_data + channels * pc.pos3, channels);
245  ConstEigenVectorMap<T> data_4(
246  offset_bottom_data + channels * pc.pos4, channels);
247 
248  output_vals += pc.w1 * data_1 + pc.w2 * data_2 + pc.w3 * data_3 +
249  pc.w4 * data_4;
250 
251  pre_calc_index += 1;
252  }
253  }
254  output_vals /= count;
255 
256  int index_nhw = index_n + (ph * pooled_width + pw) * channels;
257  std::memcpy(
258  top_data + index_nhw, output_vals.data(), channels * sizeof(T));
259  } // for pw
260  } // for ph
261  } // if nhwc
262 
263  } // for n
264 }
265 
266 } // namespace
267 
268 template <>
269 bool RoIAlignOp<float, CPUContext>::RunOnDevice() {
270  auto& X = Input(0); // Input data to pool, NCHW
271  auto& R = Input(1); // RoIs
272 
273  if (R.numel() == 0) {
274  std::vector<int64_t> sizes;
275  // Handle empty rois
276  if (order_ == StorageOrder::NCHW) {
277  sizes = {0, X.dim32(1), pooled_height_, pooled_width_};
278  } else if (order_ == StorageOrder::NHWC) {
279  sizes = {0, pooled_height_, pooled_width_, X.dim32(3)};
280  }
281  // Output Tensor is inititalized with proper sizes and data type
282  Output(0, sizes, at::dtype<float>());
283  return true;
284  }
285 
286  CAFFE_ENFORCE_EQ(R.dim(), 2);
287  // if R has 5 columns, the first column is the index, otherwise 0
288  CAFFE_ENFORCE(R.dim32(1) == 4 || R.dim32(1) == 5);
289 
290  assert(sampling_ratio_ >= 0);
291 
292  if (order_ == StorageOrder::NCHW) {
293  auto* Y = Output(
294  0,
295  {R.dim32(0), X.dim32(1), pooled_height_, pooled_width_},
296  at::dtype<float>()); // RoI pooled data
297  int output_size = Y->numel();
298  ROIAlignForward<float>(
299  output_size,
300  X.data<float>(),
301  spatial_scale_,
302  X.dim32(1),
303  X.dim32(2),
304  X.dim32(3),
305  pooled_height_,
306  pooled_width_,
307  sampling_ratio_,
308  R.data<float>(),
309  R.dim32(1),
310  Y->template mutable_data<float>(),
311  order_);
312  } else if (order_ == StorageOrder::NHWC) {
313  auto* Y = Output(
314  0,
315  {R.dim32(0), pooled_height_, pooled_width_, X.dim32(3)},
316  at::dtype<float>()); // RoI pooled data
317  int output_size = Y->numel();
318  ROIAlignForward<float>(
319  output_size,
320  X.data<float>(),
321  spatial_scale_,
322  X.dim32(3),
323  X.dim32(1),
324  X.dim32(2),
325  pooled_height_,
326  pooled_width_,
327  sampling_ratio_,
328  R.data<float>(),
329  R.dim32(1),
330  Y->template mutable_data<float>(),
331  order_);
332  }
333 
334  return true;
335 }
336 
337 REGISTER_CPU_OPERATOR(RoIAlign, RoIAlignOp<float, CPUContext>);
338 
339 // Input: X, rois; Output: Y
340 OPERATOR_SCHEMA(RoIAlign)
341  .NumInputs(2)
342  .NumOutputs(1)
343  .SetDoc(R"DOC(
344 Region of Interest (RoI) align operation as used in Mask R-CNN.
345 )DOC")
346  .Arg(
347  "spatial_scale",
348  "(float) default 1.0; Spatial scale of the input feature map X "
349  "relative to the input image. E.g., 0.0625 if X has a stride of 16 "
350  "w.r.t. the input image.")
351  .Arg("pooled_h", "(int) default 1; Pooled output Y's height.")
352  .Arg("pooled_w", "(int) default 1; Pooled output Y's width.")
353  .Arg(
354  "sampling_ratio",
355  "(int) default -1; number of sampling points in the interpolation grid "
356  "used to compute the output value of each pooled output bin. If > 0, "
357  "then exactly sampling_ratio x sampling_ratio grid points are used. If "
358  "<= 0, then an adaptive number of grid points are used (computed as "
359  "ceil(roi_width / pooled_w), and likewise for height).")
360  .Input(0, "X", "4D feature map input of shape (N, C, H, W).")
361  .Input(
362  1,
363  "RoIs",
364  "2D input of shape (R, 4 or 5) specifying R RoIs "
365  "representing: batch index in [0, N - 1], x1, y1, x2, y2. The RoI "
366  "coordinates are in the coordinate system of the input image. For "
367  "inputs corresponding to a single image, batch index can be excluded "
368  "to have just 4 columns.")
369  .Output(
370  0,
371  "Y",
372  "4D output of shape (R, C, pooled_h, pooled_w). The r-th batch element "
373  "is a pooled feature map cooresponding to the r-th RoI.");
374 
375 } // namespace caffe2
376 
378 
379 C10_REGISTER_CAFFE2_OPERATOR_CPU(
380  RoIAlign,
381  (std::vector<c10::Argument>{
382  c10::Argument("features"),
383  c10::Argument("rois"),
384  c10::Argument("order", StringType::get()),
385  c10::Argument("spatial_scale", FloatType::get()),
386  c10::Argument("pooled_h", IntType::get()),
387  c10::Argument("pooled_w", IntType::get()),
388  c10::Argument("sampling_ratio", IntType::get()),
389  }),
390  (std::vector<c10::Argument>{
391  c10::Argument("pooled_features"),
392  }),
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13