1 #include "heatmap_max_keypoint_op.h" 2 #include "caffe2/utils/eigen_utils.h" 9 HeatmapMaxKeypointOp<float, CPUContext>);
13 OPERATOR_SCHEMA(HeatmapMaxKeypoint).NumInputs(2).NumOutputs(1);
15 SHOULD_NOT_DO_GRADIENT(HeatmapMaxKeypoint);
26 const auto& heatmaps_in = Input(0);
27 const auto& bboxes_in = Input(1);
29 CAFFE_ENFORCE_EQ(heatmaps_in.dim(), 4);
30 const int N = heatmaps_in.dim32(0);
31 CAFFE_ENFORCE_EQ(heatmaps_in.dim32(0), N);
32 const int keypoint_count = heatmaps_in.dim32(1);
33 const int heatmap_size = heatmaps_in.dim32(2);
34 CAFFE_ENFORCE_GE(heatmap_size, 2);
35 CAFFE_ENFORCE_EQ(heatmaps_in.dim32(2), heatmaps_in.dim32(3));
37 CAFFE_ENFORCE_EQ(bboxes_in.dim(), 2);
38 CAFFE_ENFORCE_EQ(bboxes_in.dim32(0), N);
39 CAFFE_ENFORCE_GE(bboxes_in.dim32(1), 4);
42 Eigen::Map<const ERArrXXf> heatmaps(
43 heatmaps_in.data<
float>(),
44 heatmaps_in.dim32(0) * heatmaps_in.dim32(1),
45 heatmaps_in.dim32(2) * heatmaps_in.dim32(3));
46 Eigen::Map<const ERArrXXf> bboxes(
47 bboxes_in.data<
float>(), bboxes_in.dim32(0), bboxes_in.dim32(1));
51 heatmaps_in.dim32(0) * heatmaps_in.dim32(1),
52 heatmaps_in.dim32(2) * heatmaps_in.dim32(3));
53 if (should_output_softmax_) {
56 ERArrXXf heatmap_exp = heatmaps.exp();
57 for (
int r = 0; r < N * keypoint_count; r++) {
58 probs.row(r) = heatmap_exp.row(r) / heatmap_exp.row(r).sum();
63 auto* keypoints_out = Output(0, {N, 4, keypoint_count}, at::dtype<float>());
64 Eigen::Map<ERArrXXf> keypoints(
65 keypoints_out->mutable_data<
float>(), N, 4 * keypoint_count);
67 EArrXi maxIndices(N * keypoint_count);
70 EArrXf maxScores = heatmaps.rowwise().maxCoeff();
71 for (
int r = 0; r < N * keypoint_count; r++) {
72 float maxScore = maxScores[r];
73 for (
int c = 0; c < heatmap_size * heatmap_size; c++) {
74 if (heatmaps(r, c) == maxScore) {
82 for (
int k = 0; k < N; k++) {
84 float x0 = bboxes(k, 0);
85 float y0 = bboxes(k, 1);
86 float xLen = std::max(bboxes(k, 2) - bboxes(k, 0), 1.0f);
87 float yLen = std::max(bboxes(k, 3) - bboxes(k, 1), 1.0f);
90 for (
int j = 0; j < keypoint_count; j++) {
91 const int heatmap_index = k * keypoint_count + j;
92 const int maxIndex = maxIndices[heatmap_index];
93 const float maxScore = maxScores[heatmap_index];
94 const int maxY = maxIndex / heatmap_size;
95 const int maxX = maxIndex - heatmap_size * maxY;
97 assert(heatmaps(heatmap_index, maxIndex) == maxScore);
98 ERArrXXf fmax = ERArrXXf::Zero(3, 3);
102 for (
int y = -1; y <= 1; y++) {
103 for (
int x = -1; x <= 1; x++) {
104 int xx = x - 2 * (x + maxX >= heatmap_size) + 2 * (x + maxX < 0);
105 int yy = y - 2 * (y + maxY >= heatmap_size) + 2 * (y + maxY < 0);
106 assert((xx + maxX < heatmap_size) && (xx + maxX >= 0));
107 assert((yy + maxY < heatmap_size) && (yy + maxY >= 0));
108 const int coord_index = (yy + maxY) * heatmap_size + xx + maxX;
109 fmax(y + 1, x + 1) = heatmaps(heatmap_index, coord_index);
115 b << -(fmax(1, 2) - fmax(1, 0)) / 2, -(fmax(2, 1) - fmax(0, 1)) / 2;
117 A << fmax(1, 0) - 2 * fmax(1, 1) + fmax(1, 2),
118 (fmax(2, 2) - fmax(2, 0) - fmax(0, 2) + fmax(0, 0)) / 4,
119 (fmax(2, 2) - fmax(2, 0) - fmax(0, 2) + fmax(0, 0)) / 4,
120 fmax(0, 1) - 2 * fmax(1, 1) + fmax(2, 1);
123 const float div = A.determinant();
126 const float MAX_DELTA = 1.5;
127 if (std::abs(div) < 1e-4f) {
129 deltaScore = maxScore;
131 delta = A.ldlt().solve(b);
133 if (std::abs(delta(0)) > MAX_DELTA || std::abs(delta(1)) > MAX_DELTA) {
134 float larger_delta = std::max(std::abs(delta(0)), std::abs(delta(1)));
135 delta(0) = delta(0) / larger_delta * MAX_DELTA;
136 delta(1) = delta(1) / larger_delta * MAX_DELTA;
138 deltaScore = fmax(1, 1) - b.transpose() * delta +
139 1.0 / 2.0 * delta.transpose() * A * delta;
141 assert(std::abs(delta(0)) <= MAX_DELTA);
142 assert(std::abs(delta(1)) <= MAX_DELTA);
144 keypoints(k, 0 * keypoint_count + j) =
145 x0 + (0.5 + maxX + delta(0)) * xLen / heatmap_size;
146 keypoints(k, 1 * keypoint_count + j) =
147 y0 + (0.5 + maxY + delta(1)) * yLen / heatmap_size;
148 keypoints(k, 2 * keypoint_count + j) = deltaScore;
149 if (should_output_softmax_) {
150 keypoints(k, 3 * keypoint_count + j) = probs(heatmap_index, maxIndex);
152 keypoints(k, 3 * keypoint_count + j) = .0f;
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...