1 #include "caffe2/operators/top_k.h"
3 #include <algorithm>
4 #include <functional>
5 #include <queue>
6 #include <utility>
7 #include <vector>
9 #include "caffe2/proto/caffe2_pb.h"
10 #include "caffe2/utils/math.h"
12 namespace caffe2 {
14 namespace {
16 template <typename T>
17 struct ValueComp {
18  bool operator()(
19  const std::pair<T, int64_t>& lhs,
20  const std::pair<T, int64_t>& rhs) const {
21  return lhs.first > rhs.first ||
22  (lhs.first == rhs.first && lhs.second < rhs.second);
23  }
24 };
26 template <typename T>
27 void GetTopK(
28  const T* input,
29  const int64_t n,
30  const int64_t k,
31  const int64_t src_offset,
32  const int64_t dst_offset,
33  const int64_t stride,
34  T* values,
35  int64_t* indices,
36  int64_t* flatten_indices) {
37  const T* src_ptr = input + src_offset;
38  std::vector<std::pair<T, int64_t>> heap_data;
39  heap_data.reserve(k);
40  for (int64_t i = 0; i < k && i < n; ++i) {
41  heap_data.emplace_back(*src_ptr, i);
42  src_ptr += stride;
43  }
44  std::priority_queue<
45  std::pair<T, int64_t>,
46  std::vector<std::pair<T, int64_t>>,
47  ValueComp<T>>
48  pq(ValueComp<T>(), std::move(heap_data));
49  for (int64_t i = k; i < n; ++i) {
50  if ( < *src_ptr) {
51  pq.pop();
52  pq.emplace(*src_ptr, i);
53  }
54  src_ptr += stride;
55  }
56  int64_t dst_pos = dst_offset + (std::min(k, n) - 1) * stride;
57  while (!pq.empty()) {
58  const auto& item =;
59  values[dst_pos] = item.first;
60  indices[dst_pos] = item.second;
61  if (flatten_indices != nullptr) {
62  flatten_indices[dst_pos] = src_offset + item.second * stride;
63  }
64  pq.pop();
65  dst_pos -= stride;
66  }
67 }
69 template <typename T>
70 void SetTopKGradient(
71  const T* values,
72  const int64_t* indices,
73  const int k,
74  const int64_t src_offset,
75  const int64_t dst_offset,
76  const int64_t stride,
77  T* gradient) {
78  int64_t src_pos = src_offset;
79  for (int i = 0; i < k; ++i) {
80  if (indices[src_pos] < 0) {
81  continue;
82  }
83  gradient[dst_offset + indices[src_pos] * stride] = values[src_pos];
84  src_pos += stride;
85  }
86 }
88 } // namespace
90 template <typename T, class Context>
91 bool TopKOp<T, Context>::RunOnDevice() {
92  const auto& input = Input(0);
93  auto* values = Output(0);
94  auto* indices = Output(1);
95  auto* flatten_indices = OutputSize() > 2 ? Output(2) : nullptr;
97  at::IntArrayRef input_dims = input.sizes();
98  if (axis_ == -1) {
99  axis_ = input_dims.size() - 1;
100  }
101  CAFFE_ENFORCE_GE(axis_, 0);
102  CAFFE_ENFORCE_LT(axis_, input_dims.size());
104  std::vector<int64_t> output_dims = input_dims.vec();
105  output_dims[axis_] = k_;
106  values->Resize(output_dims);
107  indices->Resize(output_dims);
108  if (flatten_indices != nullptr) {
109  flatten_indices->Resize(indices->numel());
110  }
111  const T* input_data = input.template data<T>();
112  T* values_data = values->template mutable_data<T>();
113  int64_t* indices_data = indices->template mutable_data<int64_t>();
114  int64_t* flatten_indices_data = flatten_indices == nullptr
115  ? nullptr
116  : flatten_indices->template mutable_data<int64_t>();
117  // init values as the default value
118  math::Set<T, Context>(values->numel(), T(0), values_data, &context_);
119  math::Set<int64_t, Context>(
120  indices->numel(), int64_t(-1), indices_data, &context_);
121  if (flatten_indices_data != nullptr) {
122  math::Set<int64_t, Context>(
123  flatten_indices->numel(), int64_t(-1), flatten_indices_data, &context_);
124  }
126  const int64_t prev_size = std::accumulate(
127  input_dims.cbegin(),
128  input_dims.cbegin() + axis_,
129  int64_t(1),
130  std::multiplies<int64_t>());
131  const int64_t next_size = std::accumulate(
132  input_dims.cbegin() + axis_ + 1,
133  input_dims.cend(),
134  int64_t(1),
135  std::multiplies<int64_t>());
136  const int64_t src_offset_stride = input_dims[axis_] * next_size;
137  const int64_t dst_offset_stride = k_ * next_size;
138  int64_t src_offset = 0;
139  int64_t dst_offset = 0;
140  for (int64_t i = 0; i < prev_size; ++i) {
141  for (int64_t j = 0; j < next_size; ++j) {
142  GetTopK(
143  input_data,
144  input_dims[axis_],
145  k_,
146  src_offset + j,
147  dst_offset + j,
148  next_size,
149  values_data,
150  indices_data,
151  flatten_indices_data);
152  }
153  src_offset += src_offset_stride;
154  dst_offset += dst_offset_stride;
155  }
156  return true;
157 }
159 template <typename T, class Context>
160 bool TopKGradientOp<T, Context>::RunOnDevice() {
161  const auto& values = Input(0);
162  const auto& indices = Input(1);
163  const auto& original_input = Input(2);
164  auto* output = Output(0);
165  at::IntArrayRef values_dims = values.sizes();
166  at::IntArrayRef origin_dims = original_input.sizes();
167  CAFFE_ENFORCE_EQ(values_dims.size(), origin_dims.size());
168  output->Resize(origin_dims);
169  const T* values_data = values.template data<T>();
170  const int64_t* indices_data = indices.template data<int64_t>();
171  T* output_data = output->template mutable_data<T>();
172  if (axis_ == -1) {
173  axis_ = values_dims.size() - 1;
174  }
175  const int k = values_dims[axis_];
176  math::Set<T, Context>(output->numel(), T(0), output_data, &context_);
177  const int64_t prev_size = std::accumulate(
178  values_dims.cbegin(),
179  values_dims.cbegin() + axis_,
180  int64_t(1),
181  std::multiplies<int64_t>());
182  const int64_t next_size = std::accumulate(
183  values_dims.cbegin() + axis_ + 1,
184  values_dims.cend(),
185  int64_t(1),
186  std::multiplies<int64_t>());
187  const int64_t src_offset_stride = k * next_size;
188  const int64_t dst_offset_stride = origin_dims[axis_] * next_size;
189  int64_t src_offset = 0;
190  int64_t dst_offset = 0;
191  for (int64_t i = 0; i < prev_size; ++i) {
192  for (int64_t j = 0; j < next_size; ++j) {
193  SetTopKGradient(
194  values_data,
195  indices_data,
196  k,
197  src_offset + j,
198  dst_offset + j,
199  next_size,
200  output_data);
201  }
202  src_offset += src_offset_stride;
203  dst_offset += dst_offset_stride;
204  }
205  return true;
206 }
208 REGISTER_CPU_OPERATOR(TopK, TopKOp<float, CPUContext>);
209 REGISTER_CPU_OPERATOR(TopKGradient, TopKGradientOp<float, CPUContext>);
212  .NumInputs(1)
213  .NumOutputs(2, 3)
214  .TensorInferenceFunction([](const OperatorDef& def,
215  const vector<TensorShape>& in) {
216  vector<TensorShape> out = {in[0], in[0]};
217  ArgumentHelper helper(def);
218  auto k = helper.GetSingleArgument("k", -1);
219  auto dims_size = in[0].dims_size();
220  out[0].set_dims(dims_size - 1, k);
221  out[1].set_dims(dims_size - 1, k);
222  out[1].set_data_type(TensorProto_DataType_INT32);
223  if (def.output_size() > 2) {
224  TensorShape flatten_indices_shape;
225  flatten_indices_shape.set_data_type(TensorProto_DataType_INT32);
226  flatten_indices_shape.add_dims(
227  std::accumulate(
228  in[0].dims().begin(),
229  in[0].dims().end() - 1,
230  1,
231  std::multiplies<long>()) *
232  k);
233  out.push_back(flatten_indices_shape);
234  }
235  return out;
236  })
237  .SetDoc(R"DOC(
238 Retrieve the top-K elements of the last dimension. Given an input tensor of shape $(a_1, a_2, ..., a_n, r)$ and integer argument `k`, return up to three outputs:
240 1. Value tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the values of the top k elements along the last dimension
241 2. Index tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the indices of the top k elements (original indices from the input tensor).
242 3. [OPTIONAL] Flattened index tensor of shape $(a_1 * a_2 * ... * a_n * k,)$.
244 Given two equivalent values, this operator uses the indices along the last dimension as a tiebreaker. That is, the element with the lower index will appear first.
246 Github Links:
247 -
250 <details>
252 <summary> <b>Example</b> </summary>
254 **Code**
256 ```
258 workspace.ResetWorkspace()
260 op = core.CreateOperator(
261  "TopK",
262  ["X"],
263  ["Values", "Indices", "Flattened_indices"],
264  k=2
265 )
267 workspace.FeedBlob("X", np.random.randint(10, size=(3,3,3)).astype(np.float32))
268 print("X:", workspace.FetchBlob("X"))
269 workspace.RunOperatorOnce(op)
270 print("Values:", workspace.FetchBlob("Values"))
271 print("Indices:", workspace.FetchBlob("Indices"))
272 print("Flattened_indices:", workspace.FetchBlob("Flattened_indices"))
274 ```
276 **Result**
278 ```
280 X:
281 [[[6. 7. 0.]
282  [8. 7. 7.]
283  [1. 5. 6.]]
285  [[0. 6. 1.]
286  [2. 8. 4.]
287  [1. 2. 9.]]
289  [[4. 3. 7.]
290  [0. 1. 7.]
291  [0. 1. 8.]]]
292 Values:
293 [[[7. 6.]
294  [8. 7.]
295  [6. 5.]]
297  [[6. 1.]
298  [8. 4.]
299  [9. 2.]]
301  [[7. 4.]
302  [7. 1.]
303  [8. 1.]]]
304 Indices:
305 [[[1 0]
306  [0 1]
307  [2 1]]
309  [[1 2]
310  [1 2]
311  [2 1]]
313  [[2 0]
314  [2 1]
315  [2 1]]]
316 Flattened_indices: [ 1 0 3 4 8 7 10 11 13 14 17 16 20 18 23 22 26 25]
318 ```
320 </details>
322  )DOC")
323  .Input(
324  0,
325  "X",
326  "(*Tensor`<float>`*): input tensor of shape $(a_1, a_2, ..., a_n, r)$")
327  .Output(
328  0,
329  "Values",
330  "(*Tensor`<float>`*): output tensor of shape $(a_1, a_2, ..., a_n, k)$")
331  .Output(
332  1,
333  "Indices",
334  "(*Tensor`<int>`*): tensor of indices of shape $(a_1, a_2, ..., a_n, k)$; indices values refer to each element's index in the last dimension of the `X` input tensor")
335  .Output(
336  2,
337  "Flattened_indices",
338  "(*Tensor`<int>`*): tensor of indices of shape $(a_1 * a_2 * ... * a_n * k,)$; indices values refer to each element's index in the flattened input tensor `X`")
339  .Arg("k", "(*int*): number of top elements to retrieve");
341 OPERATOR_SCHEMA(TopKGradient).NumInputs(3).NumOutputs(1);
344  using GradientMakerBase::GradientMakerBase;
345  vector<OperatorDef> GetGradientDefs() override {
346  return SingleGradientDef(
347  "TopKGradient",
348  "",
349  vector<string>{GO(0), O(1), I(0)},
350  vector<string>{GI(0)});
351  }
352 };
354 REGISTER_GRADIENT(TopK, GetTopKGradient);
356 } // namespace caffe2
