1 #include "caffe2/operators/flexible_top_k.h" 3 #include "caffe2/proto/caffe2_pb.h" 12 const std::pair<T, int64_t>& lhs,
13 const std::pair<T, int64_t>& rhs) {
15 lhs.first > rhs.first ||
16 (lhs.first == rhs.first && lhs.second < rhs.second));
22 template <
typename T,
class Context>
23 bool FlexibleTopKOp<T, Context>::RunOnDevice() {
24 auto& input = Input(0);
27 const T* input_data = input.template data<T>();
28 const int64_t* k_data = k.template data<int64_t>();
31 CAFFE_ENFORCE_GT(input.dim(), 0);
32 vector<int64_t> input_dims = input.sizes().vec();
33 vector<int64_t> linear_shape = {
34 size_to_dim_(input_dims.size() - 1, input_dims), input_dims.back()};
38 "first n-1 dims of input data and K does not match.");
40 int64_t output_size = 0;
41 for (int64_t i = 0; i < linear_shape[0]; ++i) {
43 linear_shape[1] >= k_data[i],
44 "k should not be greater than last dim, error at index ",
50 "k should be greater than 0, error at index ",
54 output_size += k_data[i];
56 auto* values = Output(0, {output_size}, at::dtype<T>());
57 auto* indices = Output(1, {output_size}, at::dtype<int64_t>());
58 T* values_data = values->template mutable_data<T>();
59 int64_t* indices_data = indices->template mutable_data<int64_t>();
61 int64_t output_offset = 0;
63 for (int64_t i = 0; i < linear_shape[0]; ++i) {
67 std::pair<T, int64_t>,
68 std::vector<std::pair<T, int64_t>>,
72 int64_t k_ = k_data[i];
73 for (int64_t j = 0; j < linear_shape[1]; ++j) {
74 const T value = input_data[i * linear_shape[1] + j];
75 if (PQ.size() < k_ || value > PQ.top().first) {
76 PQ.push(std::make_pair(value, j));
82 for (int64_t j = 0; j < k_; ++j) {
83 auto& pqElem = PQ.top();
84 values_data[output_offset + k_ - j - 1] = pqElem.first;
85 indices_data[output_offset + k_ - j - 1] = pqElem.second;
94 template <
typename T,
class Context>
95 bool FlexibleTopKGradientOp<T, Context>::RunOnDevice() {
96 auto& original_input = Input(0);
98 auto& values = Input(2);
99 auto& indices = Input(3);
101 const int64_t* k_data = k.template data<int64_t>();
102 const T* values_data = values.template data<T>();
103 const int64_t* indices_data = indices.template data<int64_t>();
106 CAFFE_ENFORCE_GT(original_input.dim(), 0);
107 vector<int64_t> original_dims = original_input.sizes().vec();
108 auto* output = Output(0, original_dims, at::dtype<T>());
109 T* output_data = output->template mutable_data<T>();
110 math::Set<T, Context>(
111 output->numel(),
static_cast<T>(0), output_data, &context_);
113 int64_t index_offset = 0;
114 for (int64_t i = 0; i < k.numel(); ++i) {
116 int64_t output_offset = i * original_dims.back();
117 for (int64_t j = 0; j < k_data[i]; ++j) {
118 int64_t index = indices_data[index_offset + j];
119 T value = values_data[index_offset + j];
120 output_data[output_offset + index] = value;
122 index_offset += k_data[i];
128 REGISTER_CPU_OPERATOR(FlexibleTopK, FlexibleTopKOp<float, CPUContext>);
129 REGISTER_CPU_OPERATOR(
130 FlexibleTopKGradient,
131 FlexibleTopKGradientOp<float, CPUContext>);
133 OPERATOR_SCHEMA(FlexibleTopK)
137 Given two tensors: X and K, 138 retrieve the top K[..., 1] elements from X on the last dimension. 139 X is an input tensor of shape [a_1, a_2, ..., a_n, r]. 140 K is an input tensor of shape [a_1, a_2, ..., a_n, 1], 141 where for each element, r >= K[..., 1] > 0 143 -Flatten values tensor of shape [ \sum_i K[i, 1] ] which contains the values of 144 the top K[..., 1] elements along the last dimension 145 -Flatten indices tensor of shape [ \sum_i K[i, 1] ] which contains the indices 146 of the top K[..., 1] elements, flatten indices from the input tensor). 147 These two outputs should be used with the input K, so that we know which indices 150 Given two equivalent values, this operator uses the indices along the last dim- 151 ension as a tiebreaker. That is, the element with the lower index will appear 154 .Input(0, "X",
"Tensor of shape [a_1, a_2, ..., a_n, r]")
155 .Input(1,
"K",
"Tensor of shape [a_1, a_2, ..., a_n, 1]")
159 "Tensor of shape [ \\sum_i K[i, 1] ] containing" 160 " top K[..., 1] values from the input tensor")
164 "Tensor of shape [ \\sum_i K[i, 1] ] containing the indices " 165 "into the flatten input");
167 OPERATOR_SCHEMA(FlexibleTopKGradient).NumInputs(4).NumOutputs(1);
169 class GetFlexibleTopKGradient :
public GradientMakerBase {
170 using GradientMakerBase::GradientMakerBase;
171 vector<OperatorDef> GetGradientDefs()
override {
172 return SingleGradientDef(
173 "FlexibleTopKGradient",
175 vector<string>{I(0), I(1), GO(0), O(1)},
176 vector<string>{GI(0)});
180 REGISTER_GRADIENT(FlexibleTopK, GetFlexibleTopKGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...