1 #include "roi_align_gradient_op.h" 3 #include "caffe2/utils/eigen_utils.h" 4 #include "caffe2/utils/math.h" 10 void bilinear_interpolate_gradient(
25 if (y < -1.0 || y > height || x < -1.0 || x > width) {
27 w1 = w2 = w3 = w4 = 0.;
28 x_low = x_high = y_low = y_high = -1;
42 if (y_low >= height - 1) {
43 y_high = y_low = height - 1;
49 if (x_low >= width - 1) {
50 x_high = x_low = width - 1;
58 T hy = 1. - ly, hx = 1. - lx;
67 w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
73 inline void add(
const T& val,
T* address) {
78 void ROIAlignBackwardFeature(
82 const T& spatial_scale,
86 const int pooled_height,
87 const int pooled_width,
88 const int sampling_ratio,
92 DCHECK(rois_cols == 4 || rois_cols == 5);
94 for (
int index = 0; index < nthreads; index++) {
96 int pw = index % pooled_width;
97 int ph = (index / pooled_width) % pooled_height;
98 int c = (index / pooled_width / pooled_height) % channels;
99 int n = index / pooled_width / pooled_height / channels;
101 const T* offset_bottom_rois = bottom_rois + n * rois_cols;
102 int roi_batch_ind = 0;
103 if (rois_cols == 5) {
104 roi_batch_ind = offset_bottom_rois[0];
105 offset_bottom_rois++;
109 T roi_start_w = offset_bottom_rois[0] * spatial_scale;
110 T roi_start_h = offset_bottom_rois[1] * spatial_scale;
111 T roi_end_w = offset_bottom_rois[2] * spatial_scale;
112 T roi_end_h = offset_bottom_rois[3] * spatial_scale;
119 T roi_width = std::max(roi_end_w - roi_start_w, (
T)1.);
120 T roi_height = std::max(roi_end_h - roi_start_h, (
T)1.);
121 T bin_size_h =
static_cast<T>(roi_height) / static_cast<T>(pooled_height);
122 T bin_size_w =
static_cast<T>(roi_width) / static_cast<T>(pooled_width);
124 T* offset_bottom_diff =
125 bottom_diff + (roi_batch_ind * channels + c) * height * width;
127 int top_offset = (n * channels + c) * pooled_height * pooled_width;
128 const T* offset_top_diff = top_diff + top_offset;
129 const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
132 int roi_bin_grid_h = (sampling_ratio > 0)
134 : ceil(roi_height / pooled_height);
136 (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
139 const T count = roi_bin_grid_h * roi_bin_grid_w;
141 for (
int iy = 0; iy < roi_bin_grid_h; iy++) {
142 const T y = roi_start_h + ph * bin_size_h +
143 static_cast<T>(iy + .5f) * bin_size_h /
144 static_cast<T>(roi_bin_grid_h);
145 for (
int ix = 0; ix < roi_bin_grid_w; ix++) {
146 const T x = roi_start_w + pw * bin_size_w +
147 static_cast<T>(ix + .5f) * bin_size_w /
148 static_cast<T>(roi_bin_grid_w);
151 int x_low, x_high, y_low, y_high;
153 bilinear_interpolate_gradient(
168 T g1 = top_diff_this_bin * w1 / count;
169 T g2 = top_diff_this_bin * w2 / count;
170 T g3 = top_diff_this_bin * w3 / count;
171 T g4 = top_diff_this_bin * w4 / count;
173 if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
175 add(static_cast<T>(g1), offset_bottom_diff + y_low * width + x_low);
176 add(static_cast<T>(g2), offset_bottom_diff + y_low * width + x_high);
177 add(static_cast<T>(g3), offset_bottom_diff + y_high * width + x_low);
178 add(static_cast<T>(g4), offset_bottom_diff + y_high * width + x_high);
188 bool RoIAlignGradientOp<float, CPUContext>::RunOnDevice() {
194 CAFFE_ENFORCE_EQ(R.dim(), 2);
196 CAFFE_ENFORCE(R.dim32(1) == 4 || R.dim32(1) == 5);
206 math::Set<float, CPUContext>(
207 dX->numel(), 0.f, dX->template mutable_data<float>(), &context_);
209 if (dY.numel() > 0) {
210 ROIAlignBackwardFeature<float>(
221 dX->template mutable_data<float>(),
228 REGISTER_CPU_OPERATOR(RoIAlignGradient, RoIAlignGradientOp<float, CPUContext>);
232 OPERATOR_SCHEMA(RoIAlignGradient)
235 .Input(0,
"X",
"See RoIPoolF.")
236 .Input(1,
"RoIs",
"See RoIPoolF.")
237 .Input(2,
"dY",
"Gradient of forward output 0 (Y)")
238 .Output(0,
"dX",
"Gradient of forward input 0 (X)");
242 class GetRoIAlignGradient :
public GradientMakerBase {
243 using GradientMakerBase::GradientMakerBase;
244 vector<OperatorDef> GetGradientDefs()
override {
245 return SingleGradientDef(
248 vector<string>{I(0), I(1), GO(0)},
249 vector<string>{GI(0)});
255 REGISTER_GRADIENT(RoIAlign, GetRoIAlignGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...