1 #include "caffe2/operators/lengths_top_k_op.h" 5 template <
typename T,
class Context>
6 bool LengthsTopKOp<T, Context>::RunOnDevice() {
10 const T* X_data = X.template data<T>();
11 const int* input_len = Y.template data<int>();
13 auto output_dims = std::vector<int64_t>({N, k_});
14 auto* output_topk_values = Output(TOPK_VALUES_OUT, output_dims, at::dtype<T>());
15 auto* output_topk_indices =
16 Output(TOPK_INDICES_OUT, output_dims, at::dtype<int>());
17 T* output_topk_values_data = output_topk_values->template mutable_data<T>();
18 int* output_topk_indices_data =
19 output_topk_indices->template mutable_data<int>();
21 auto cmp = [](std::pair<T, int64_t>& lhs, std::pair<T, int64_t>& rhs) {
22 return lhs.first > rhs.first ||
23 (lhs.first == rhs.first && lhs.second < rhs.second);
28 for (int64_t i = 0; i < N; ++i) {
32 std::pair<T, int64_t>,
33 std::vector<std::pair<T, int64_t>>,
39 for (int64_t j = 0; j < input_len[i]; ++j) {
40 const auto value = X_data[next_index++];
41 if (p_queue.size() < k_ || value > p_queue.top().first) {
42 p_queue.push(std::make_pair(value, j));
44 if (p_queue.size() > k_) {
49 int last_index = p_queue.size();
50 for (int64_t j = 0; j < k_; ++j) {
51 if (p_queue.size() > 0) {
52 auto& pqElem = p_queue.top();
53 output_topk_values_data[i * k_ + last_index - j - 1] = pqElem.first;
54 output_topk_indices_data[i * k_ + last_index - j - 1] = pqElem.second;
57 output_topk_values_data[i * k_ + j] = 0;
58 output_topk_indices_data[i * k_ + j] = -1;
66 template <
typename T,
class Context>
67 bool LengthsTopKGradientOp<T, Context>::RunOnDevice() {
68 auto& input_len = Input(LENGTH_IN);
69 int N = input_len.numel();
70 auto& input_indices = Input(INDICES_IN);
71 CAFFE_ENFORCE_GE(input_indices.dim(), 2,
"input dim must be >= 2");
73 input_indices.numel(), N * k_,
"input_indices shape is not correct");
74 auto& input_topk = Input(DER_TOPK_IN);
76 input_topk.numel(), N * k_,
"input_topk shape is not correct");
78 const int* input_len_data = input_len.template data<int>();
79 const int* input_indices_data = input_indices.template data<int>();
80 const T* input_topk_data = input_topk.template data<T>();
83 for (
int i = 0; i < N; i++) {
84 num_indices += input_len_data[i];
86 auto* X_out = Output(DER_X_OUT, {num_indices}, at::dtype<T>());
87 T* X_out_data = X_out->template mutable_data<T>();
88 math::Set<T, Context>(num_indices, 0.0, X_out_data, &context_);
91 for (
int i = 0; i < N; i++) {
92 for (
int j = 0; j < std::min(input_len_data[i], k_); j++) {
93 int cur_index = index_offset + input_indices_data[i * k_ + j];
95 cur_index, num_indices,
"cur_index should be less than num_indices");
96 X_out_data[cur_index] = input_topk_data[i * k_ + j];
98 index_offset += input_len_data[i];
104 REGISTER_CPU_OPERATOR(LengthsTopK, LengthsTopKOp<float, CPUContext>);
105 REGISTER_CPU_OPERATOR(
107 LengthsTopKGradientOp<float, CPUContext>);
108 OPERATOR_SCHEMA(LengthsTopK)
112 Apply TopK to each segment of the input tensor, where segments are defined by 113 their LENGTHS, and concatenate them in an output tensor of 114 shape=(SIZE(LENGTHs), k). In case there's less than k values in a segment, 115 the output value will be padded by 0, and the corresponding output indices will 121 "Tensor of rank 1. First dimension must be equal to the sum of " 123 .Input(1,
"LENGTHS",
"Tensor of int32 lengths of rank 1")
127 "Output top k elements for each segment, with" 128 "shape=(SIZE(lengths), k)")
132 "Output indices in DATA corresponding to value in TopKValue")
135 "the number of top values to return for each segment, if the number " 136 "of values is smaller than k, the values would be padded with 0 and " 137 "indices would be padded with -1.");
138 OPERATOR_SCHEMA(LengthsTopKGradient).NumInputs(3).NumOutputs(1);
142 class GetLengthsTopKGradient :
public GradientMakerBase {
143 using GradientMakerBase::GradientMakerBase;
144 vector<OperatorDef> GetGradientDefs()
override {
145 return SingleGradientDef(
146 "LengthsTopKGradient",
148 vector<string>{I(1), O(1), GO(0)},
149 vector<string>{GI(0)});
155 REGISTER_GRADIENT(LengthsTopK, GetLengthsTopKGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...