Caffe2 - C++ API
A deep learning, cross platform ML framework
roi_align_gradient_op.cc
1 #include "roi_align_gradient_op.h"
2 
3 #include "caffe2/utils/eigen_utils.h"
4 #include "caffe2/utils/math.h"
5 
6 namespace caffe2 {
7 namespace {
8 
9 template <typename T>
10 void bilinear_interpolate_gradient(
11  const int height,
12  const int width,
13  T y,
14  T x,
15  T& w1,
16  T& w2,
17  T& w3,
18  T& w4,
19  int& x_low,
20  int& x_high,
21  int& y_low,
22  int& y_high,
23  const int /*index*/ /* index for debug only*/) {
24  // deal with cases that inverse elements are out of feature map boundary
25  if (y < -1.0 || y > height || x < -1.0 || x > width) {
26  // empty
27  w1 = w2 = w3 = w4 = 0.;
28  x_low = x_high = y_low = y_high = -1;
29  return;
30  }
31 
32  if (y <= 0) {
33  y = 0;
34  }
35  if (x <= 0) {
36  x = 0;
37  }
38 
39  y_low = (int)y;
40  x_low = (int)x;
41 
42  if (y_low >= height - 1) {
43  y_high = y_low = height - 1;
44  y = (T)y_low;
45  } else {
46  y_high = y_low + 1;
47  }
48 
49  if (x_low >= width - 1) {
50  x_high = x_low = width - 1;
51  x = (T)x_low;
52  } else {
53  x_high = x_low + 1;
54  }
55 
56  T ly = y - y_low;
57  T lx = x - x_low;
58  T hy = 1. - ly, hx = 1. - lx;
59 
60  // reference in forward
61  // T v1 = bottom_data[y_low * width + x_low];
62  // T v2 = bottom_data[y_low * width + x_high];
63  // T v3 = bottom_data[y_high * width + x_low];
64  // T v4 = bottom_data[y_high * width + x_high];
65  // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
66 
67  w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
68 
69  return;
70 }
71 
72 template <class T>
73 inline void add(const T& val, T* address) {
74  *address += val;
75 }
76 
77 template <typename T>
78 void ROIAlignBackwardFeature(
79  const int nthreads,
80  const T* top_diff,
81  const int /*num_rois*/,
82  const T& spatial_scale,
83  const int channels,
84  const int height,
85  const int width,
86  const int pooled_height,
87  const int pooled_width,
88  const int sampling_ratio,
89  T* bottom_diff,
90  const T* bottom_rois,
91  int rois_cols) {
92  DCHECK(rois_cols == 4 || rois_cols == 5);
93 
94  for (int index = 0; index < nthreads; index++) {
95  // (n, c, ph, pw) is an element in the pooled output
96  int pw = index % pooled_width;
97  int ph = (index / pooled_width) % pooled_height;
98  int c = (index / pooled_width / pooled_height) % channels;
99  int n = index / pooled_width / pooled_height / channels;
100 
101  const T* offset_bottom_rois = bottom_rois + n * rois_cols;
102  int roi_batch_ind = 0;
103  if (rois_cols == 5) {
104  roi_batch_ind = offset_bottom_rois[0];
105  offset_bottom_rois++;
106  }
107 
108  // Do not using rounding; this implementation detail is critical
109  T roi_start_w = offset_bottom_rois[0] * spatial_scale;
110  T roi_start_h = offset_bottom_rois[1] * spatial_scale;
111  T roi_end_w = offset_bottom_rois[2] * spatial_scale;
112  T roi_end_h = offset_bottom_rois[3] * spatial_scale;
113  // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
114  // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
115  // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
116  // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
117 
118  // Force malformed ROIs to be 1x1
119  T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
120  T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
121  T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
122  T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
123 
124  T* offset_bottom_diff =
125  bottom_diff + (roi_batch_ind * channels + c) * height * width;
126 
127  int top_offset = (n * channels + c) * pooled_height * pooled_width;
128  const T* offset_top_diff = top_diff + top_offset;
129  const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
130 
131  // We use roi_bin_grid to sample the grid and mimic integral
132  int roi_bin_grid_h = (sampling_ratio > 0)
133  ? sampling_ratio
134  : ceil(roi_height / pooled_height); // e.g., = 2
135  int roi_bin_grid_w =
136  (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
137 
138  // We do average (integral) pooling inside a bin
139  const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
140 
141  for (int iy = 0; iy < roi_bin_grid_h; iy++) {
142  const T y = roi_start_h + ph * bin_size_h +
143  static_cast<T>(iy + .5f) * bin_size_h /
144  static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
145  for (int ix = 0; ix < roi_bin_grid_w; ix++) {
146  const T x = roi_start_w + pw * bin_size_w +
147  static_cast<T>(ix + .5f) * bin_size_w /
148  static_cast<T>(roi_bin_grid_w);
149 
150  T w1, w2, w3, w4;
151  int x_low, x_high, y_low, y_high;
152 
153  bilinear_interpolate_gradient(
154  height,
155  width,
156  y,
157  x,
158  w1,
159  w2,
160  w3,
161  w4,
162  x_low,
163  x_high,
164  y_low,
165  y_high,
166  index);
167 
168  T g1 = top_diff_this_bin * w1 / count;
169  T g2 = top_diff_this_bin * w2 / count;
170  T g3 = top_diff_this_bin * w3 / count;
171  T g4 = top_diff_this_bin * w4 / count;
172 
173  if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
174  // atomic add is not needed for now since it is single threaded
175  add(static_cast<T>(g1), offset_bottom_diff + y_low * width + x_low);
176  add(static_cast<T>(g2), offset_bottom_diff + y_low * width + x_high);
177  add(static_cast<T>(g3), offset_bottom_diff + y_high * width + x_low);
178  add(static_cast<T>(g4), offset_bottom_diff + y_high * width + x_high);
179  } // if
180  } // ix
181  } // iy
182  } // for
183 } // ROIAlignBackward
184 
185 } // namespace
186 
187 template <>
188 bool RoIAlignGradientOp<float, CPUContext>::RunOnDevice() {
189  auto& X = Input(0); // Input data to pool
190  auto& R = Input(1); // RoIs
191  auto& dY = Input(2); // Gradient of net w.r.t. output of "forward" op
192  // (aka "gradOutput")
193 
194  CAFFE_ENFORCE_EQ(R.dim(), 2);
195  // if R has 5 columns, the first column is the index, otherwise 0
196  CAFFE_ENFORCE(R.dim32(1) == 4 || R.dim32(1) == 5);
197 
198  auto* dX = Output(
199  0,
200  X.sizes(),
201  at::dtype<float>()); // Gradient of net w.r.t. input to "forward" op (aka
202  // "gradInput")
203 
204  // Must zero-out dX before accumulating gradients
205  // (TODO): Kaiming - is this safe?
206  math::Set<float, CPUContext>(
207  dX->numel(), 0.f, dX->template mutable_data<float>(), &context_);
208 
209  if (dY.numel() > 0) { // Handle possibly empty gradient if there were no rois
210  ROIAlignBackwardFeature<float>(
211  dY.numel(),
212  dY.data<float>(),
213  R.dim32(0),
214  spatial_scale_,
215  X.dim32(1),
216  X.dim32(2),
217  X.dim32(3),
218  pooled_height_,
219  pooled_width_,
220  sampling_ratio_,
221  dX->template mutable_data<float>(),
222  R.data<float>(),
223  R.dim32(1));
224  }
225  return true;
226 }
227 
228 REGISTER_CPU_OPERATOR(RoIAlignGradient, RoIAlignGradientOp<float, CPUContext>);
229 
230 // Input: X, rois, dY (aka "gradOutput");
231 // Output: dX (aka "gradInput")
232 OPERATOR_SCHEMA(RoIAlignGradient)
233  .NumInputs(3)
234  .NumOutputs(1)
235  .Input(0, "X", "See RoIPoolF.")
236  .Input(1, "RoIs", "See RoIPoolF.")
237  .Input(2, "dY", "Gradient of forward output 0 (Y)")
238  .Output(0, "dX", "Gradient of forward input 0 (X)");
239 
240 namespace {
241 
242 class GetRoIAlignGradient : public GradientMakerBase {
243  using GradientMakerBase::GradientMakerBase;
244  vector<OperatorDef> GetGradientDefs() override {
245  return SingleGradientDef(
246  "RoIAlignGradient",
247  "",
248  vector<string>{I(0), I(1), GO(0)},
249  vector<string>{GI(0)});
250  }
251 };
252 
253 } // namespace
254 
255 REGISTER_GRADIENT(RoIAlign, GetRoIAlignGradient);
256 
257 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13