Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_top_k_op.cc
1 #include "caffe2/operators/lengths_top_k_op.h"
2 
3 namespace caffe2 {
4 
5 template <typename T, class Context>
6 bool LengthsTopKOp<T, Context>::RunOnDevice() {
7  auto& X = Input(X_IN);
8  auto& Y = Input(Y_IN);
9  int N = Y.dim32(0);
10  const T* X_data = X.template data<T>();
11  const int* input_len = Y.template data<int>();
12 
13  auto output_dims = std::vector<int64_t>({N, k_});
14  auto* output_topk_values = Output(TOPK_VALUES_OUT, output_dims, at::dtype<T>());
15  auto* output_topk_indices =
16  Output(TOPK_INDICES_OUT, output_dims, at::dtype<int>());
17  T* output_topk_values_data = output_topk_values->template mutable_data<T>();
18  int* output_topk_indices_data =
19  output_topk_indices->template mutable_data<int>();
20 
21  auto cmp = [](std::pair<T, int64_t>& lhs, std::pair<T, int64_t>& rhs) {
22  return lhs.first > rhs.first ||
23  (lhs.first == rhs.first && lhs.second < rhs.second);
24  };
25 
26  // Sort preserving indices
27  int next_index = 0;
28  for (int64_t i = 0; i < N; ++i) {
29  // Build a min-heap, the heap element is pair of (value, idx)
30  // the top of the heap is the smallest value
31  std::priority_queue<
32  std::pair<T, int64_t>,
33  std::vector<std::pair<T, int64_t>>,
34  decltype(cmp)>
35  p_queue(cmp);
36 
37  // Maintain the size of heap to be less or equal to k_, so the
38  // heap will hold the k_ largest values
39  for (int64_t j = 0; j < input_len[i]; ++j) {
40  const auto value = X_data[next_index++];
41  if (p_queue.size() < k_ || value > p_queue.top().first) {
42  p_queue.push(std::make_pair(value, j));
43  }
44  if (p_queue.size() > k_) {
45  p_queue.pop();
46  }
47  }
48 
49  int last_index = p_queue.size();
50  for (int64_t j = 0; j < k_; ++j) {
51  if (p_queue.size() > 0) {
52  auto& pqElem = p_queue.top();
53  output_topk_values_data[i * k_ + last_index - j - 1] = pqElem.first;
54  output_topk_indices_data[i * k_ + last_index - j - 1] = pqElem.second;
55  p_queue.pop();
56  } else {
57  output_topk_values_data[i * k_ + j] = 0;
58  output_topk_indices_data[i * k_ + j] = -1;
59  }
60  }
61  }
62 
63  return true;
64 }
65 
66 template <typename T, class Context>
67 bool LengthsTopKGradientOp<T, Context>::RunOnDevice() {
68  auto& input_len = Input(LENGTH_IN);
69  int N = input_len.numel();
70  auto& input_indices = Input(INDICES_IN);
71  CAFFE_ENFORCE_GE(input_indices.dim(), 2, "input dim must be >= 2");
72  CAFFE_ENFORCE_EQ(
73  input_indices.numel(), N * k_, "input_indices shape is not correct");
74  auto& input_topk = Input(DER_TOPK_IN);
75  CAFFE_ENFORCE_EQ(
76  input_topk.numel(), N * k_, "input_topk shape is not correct");
77 
78  const int* input_len_data = input_len.template data<int>();
79  const int* input_indices_data = input_indices.template data<int>();
80  const T* input_topk_data = input_topk.template data<T>();
81 
82  int num_indices = 0;
83  for (int i = 0; i < N; i++) {
84  num_indices += input_len_data[i];
85  }
86  auto* X_out = Output(DER_X_OUT, {num_indices}, at::dtype<T>());
87  T* X_out_data = X_out->template mutable_data<T>();
88  math::Set<T, Context>(num_indices, 0.0, X_out_data, &context_);
89 
90  int index_offset = 0;
91  for (int i = 0; i < N; i++) {
92  for (int j = 0; j < std::min(input_len_data[i], k_); j++) {
93  int cur_index = index_offset + input_indices_data[i * k_ + j];
94  CAFFE_ENFORCE_LT(
95  cur_index, num_indices, "cur_index should be less than num_indices");
96  X_out_data[cur_index] = input_topk_data[i * k_ + j];
97  }
98  index_offset += input_len_data[i];
99  }
100 
101  return true;
102 }
103 
104 REGISTER_CPU_OPERATOR(LengthsTopK, LengthsTopKOp<float, CPUContext>);
105 REGISTER_CPU_OPERATOR(
106  LengthsTopKGradient,
107  LengthsTopKGradientOp<float, CPUContext>);
108 OPERATOR_SCHEMA(LengthsTopK)
109  .NumInputs(2)
110  .NumOutputs(2)
111  .SetDoc(R"DOC(
112 Apply TopK to each segment of the input tensor, where segments are defined by
113 their LENGTHS, and concatenate them in an output tensor of
114 shape=(SIZE(LENGTHs), k). In case there's less than k values in a segment,
115 the output value will be padded by 0, and the corresponding output indices will
116 be padded by -1.
117 )DOC")
118  .Input(
119  0,
120  "DATA",
121  "Tensor of rank 1. First dimension must be equal to the sum of "
122  "lengths")
123  .Input(1, "LENGTHS", "Tensor of int32 lengths of rank 1")
124  .Output(
125  0,
126  "TopKValue",
127  "Output top k elements for each segment, with"
128  "shape=(SIZE(lengths), k)")
129  .Output(
130  1,
131  "TopKIndices",
132  "Output indices in DATA corresponding to value in TopKValue")
133  .Arg(
134  "k",
135  "the number of top values to return for each segment, if the number "
136  "of values is smaller than k, the values would be padded with 0 and "
137  "indices would be padded with -1.");
138 OPERATOR_SCHEMA(LengthsTopKGradient).NumInputs(3).NumOutputs(1);
139 
140 namespace {
141 
142 class GetLengthsTopKGradient : public GradientMakerBase {
143  using GradientMakerBase::GradientMakerBase;
144  vector<OperatorDef> GetGradientDefs() override {
145  return SingleGradientDef(
146  "LengthsTopKGradient",
147  "",
148  vector<string>{I(1), O(1), GO(0)},
149  vector<string>{GI(0)});
150  }
151 };
152 
153 } // namespace
154 
155 REGISTER_GRADIENT(LengthsTopK, GetLengthsTopKGradient);
156 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13