Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_top_k_op.cc
1 
17 #include "caffe2/operators/lengths_top_k_op.h"
18 
19 namespace caffe2 {
20 
21 template <typename T, class Context>
22 bool LengthsTopKOp<T, Context>::RunOnDevice() {
23  auto& X = Input(X_IN);
24  auto& Y = Input(Y_IN);
25  int N = Y.dim32(0);
26  const T* X_data = X.template data<T>();
27  const int* input_len = Y.template data<int>();
28  auto* output_topk_values = Output(TOPK_VALUES_OUT);
29  auto* output_topk_indices = Output(TOPK_INDICES_OUT);
30 
31  output_topk_values->Resize(N * k_);
32  output_topk_indices->Resize(N * k_);
33  std::vector<int> output_dims = std::vector<int>({N, k_});
34  output_topk_values->Reshape(output_dims);
35  output_topk_indices->Reshape(output_dims);
36  T* output_topk_values_data = output_topk_values->template mutable_data<T>();
37  int* output_topk_indices_data =
38  output_topk_indices->template mutable_data<int>();
39 
40  auto cmp = [](std::pair<T, TIndex>& lhs, std::pair<T, TIndex>& rhs) {
41  return lhs.first > rhs.first ||
42  (lhs.first == rhs.first && lhs.second < rhs.second);
43  };
44 
45  // Sort preserving indices
46  int next_index = 0;
47  for (TIndex i = 0; i < N; ++i) {
48  // Build a min-heap, the heap element is pair of (value, idx)
49  // the top of the heap is the smallest value
50  std::priority_queue<
51  std::pair<T, TIndex>,
52  std::vector<std::pair<T, TIndex>>,
53  decltype(cmp)>
54  p_queue(cmp);
55 
56  // Maintain the size of heap to be less or equal to k_, so the
57  // heap will hold the k_ largest values
58  for (TIndex j = 0; j < input_len[i]; ++j) {
59  const auto value = X_data[next_index++];
60  if (p_queue.size() < k_ || value > p_queue.top().first) {
61  p_queue.push(std::make_pair(value, j));
62  }
63  if (p_queue.size() > k_) {
64  p_queue.pop();
65  }
66  }
67 
68  int last_index = p_queue.size();
69  for (TIndex j = 0; j < k_; ++j) {
70  if (p_queue.size() > 0) {
71  auto& pqElem = p_queue.top();
72  output_topk_values_data[i * k_ + last_index - j - 1] = pqElem.first;
73  output_topk_indices_data[i * k_ + last_index - j - 1] = pqElem.second;
74  p_queue.pop();
75  } else {
76  output_topk_values_data[i * k_ + j] = 0;
77  output_topk_indices_data[i * k_ + j] = -1;
78  }
79  }
80  }
81 
82  return true;
83 }
84 
85 template <typename T, class Context>
86 bool LengthsTopKGradientOp<T, Context>::RunOnDevice() {
87  auto& input_len = Input(LENGTH_IN);
88  int N = input_len.size();
89  auto& input_indices = Input(INDICES_IN);
90  CAFFE_ENFORCE_GE(input_indices.ndim(), 2, "input dim must be >= 2");
91  CAFFE_ENFORCE_EQ(
92  input_indices.size(), N * k_, "input_indices shape is not correct");
93  auto& input_topk = Input(DER_TOPK_IN);
94  CAFFE_ENFORCE_EQ(
95  input_topk.size(), N * k_, "input_topk shape is not correct");
96  auto* X_out = Output(DER_X_OUT);
97 
98  const int* input_len_data = input_len.template data<int>();
99  const int* input_indices_data = input_indices.template data<int>();
100  const T* input_topk_data = input_topk.template data<T>();
101 
102  int num_indices = 0;
103  for (int i = 0; i < N; i++) {
104  num_indices += input_len_data[i];
105  }
106  X_out->Resize(num_indices);
107  std::vector<int> output_dims = std::vector<int>({num_indices});
108  X_out->Reshape(output_dims);
109  T* X_out_data = X_out->template mutable_data<T>();
110  math::Set<T, Context>(num_indices, 0.0, X_out_data, &context_);
111 
112  int index_offset = 0;
113  for (int i = 0; i < N; i++) {
114  for (int j = 0; j < std::min(input_len_data[i], k_); j++) {
115  int cur_index = index_offset + input_indices_data[i * k_ + j];
116  CAFFE_ENFORCE_LT(
117  cur_index, num_indices, "cur_index should be less than num_indices");
118  X_out_data[cur_index] = input_topk_data[i * k_ + j];
119  }
120  index_offset += input_len_data[i];
121  }
122 
123  return true;
124 }
125 
126 REGISTER_CPU_OPERATOR(LengthsTopK, LengthsTopKOp<float, CPUContext>);
127 REGISTER_CPU_OPERATOR(
128  LengthsTopKGradient,
129  LengthsTopKGradientOp<float, CPUContext>);
130 OPERATOR_SCHEMA(LengthsTopK)
131  .NumInputs(2)
132  .NumOutputs(2)
133  .SetDoc(R"DOC(
134 Apply TopK to each segment of the input tensor, where segments are defined by
135 their LENGTHS, and concatenate them in an output tensor of
136 shape=(SIZE(LENGTHs), k). In case there's less than k values in a segment,
137 the output value will be padded by 0, and the corresponding output indices will
138 be padded by -1.
139 )DOC")
140  .Input(
141  0,
142  "DATA",
143  "Tensor of rank 1. First dimension must be equal to the sum of "
144  "lengths")
145  .Input(1, "LENGTHS", "Tensor of int32 lengths of rank 1")
146  .Output(
147  0,
148  "TopKValue",
149  "Output top k elements for each segment, with"
150  "shape=(SIZE(lengths), k)")
151  .Output(
152  1,
153  "TopKIndices",
154  "Output indices in DATA corresponding to value in TopKValue")
155  .Arg(
156  "k",
157  "the number of top values to return for each segment, if the number "
158  "of values is smaller than k, the values would be padded with 0 and "
159  "indices would be padded with -1.");
160 OPERATOR_SCHEMA(LengthsTopKGradient).NumInputs(3).NumOutputs(1);
161 
162 namespace {
163 
164 class GetLengthsTopKGradient : public GradientMakerBase {
165  using GradientMakerBase::GradientMakerBase;
166  vector<OperatorDef> GetGradientDefs() override {
167  return SingleGradientDef(
168  "LengthsTopKGradient",
169  "",
170  vector<string>{I(1), O(1), GO(0)},
171  vector<string>{GI(0)});
172  }
173 };
174 
175 } // namespace
176 
177 REGISTER_GRADIENT(LengthsTopK, GetLengthsTopKGradient);
178 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.