Caffe2 - C++ API
A deep learning, cross platform ML framework
flexible_top_k.cc
1 #include "caffe2/operators/flexible_top_k.h"
2 
3 #include "caffe2/proto/caffe2.pb.h"
4 
5 namespace caffe2 {
6 
7 namespace {
8 
9 template <typename T>
10 struct ValueCmp {
11  bool operator()(
12  const std::pair<T, TIndex>& lhs,
13  const std::pair<T, TIndex>& rhs) {
14  return (
15  lhs.first > rhs.first ||
16  (lhs.first == rhs.first && lhs.second < rhs.second));
17  }
18 };
19 
20 } // namespace
21 
22 template <typename T, class Context>
23 bool FlexibleTopKOp<T, Context>::RunOnDevice() {
24  auto& input = Input(0);
25  auto& k = Input(1);
26  auto* values = Output(0);
27  auto* indices = Output(1);
28 
29  const T* input_data = input.template data<T>();
30  const TIndex* k_data = k.template data<TIndex>();
31 
32  // get flatten shape of input
33  CAFFE_ENFORCE_GT(input.ndim(), 0);
34  vector<TIndex> input_dims = input.dims();
35  vector<TIndex> linear_shape = {
36  size_to_dim_(input_dims.size() - 1, input_dims), input_dims.back()};
37  CAFFE_ENFORCE_EQ(
38  linear_shape[0],
39  k.size(),
40  "first n-1 dims of input data and K does not match.");
41 
42  TIndex output_size = 0;
43  for (TIndex i = 0; i < linear_shape[0]; ++i) {
44  CAFFE_ENFORCE(
45  linear_shape[1] >= k_data[i],
46  "k should not be greater than last dim, error at index ",
47  i,
48  ", with value: ",
49  k_data[i]);
50  CAFFE_ENFORCE(
51  k_data[i] > 0,
52  "k should be greater than 0, error at index ",
53  i,
54  ", with value: ",
55  k_data[i]);
56  output_size += k_data[i];
57  }
58  values->Resize(output_size);
59  indices->Resize(output_size);
60  T* values_data = values->template mutable_data<T>();
61  TIndex* indices_data = indices->template mutable_data<TIndex>();
62 
63  TIndex output_offset = 0;
64  // Sort preserving indices
65  for (TIndex i = 0; i < linear_shape[0]; ++i) {
66  // Build a min-heap, the heap element is pair of (value, idx)
67  // the top of the heap is the smallest value
68  std::priority_queue<
69  std::pair<T, TIndex>,
70  std::vector<std::pair<T, TIndex>>,
71  ValueCmp<T>>
72  PQ;
73 
74  TIndex k_ = k_data[i];
75  for (TIndex j = 0; j < linear_shape[1]; ++j) {
76  const T value = input_data[i * linear_shape[1] + j];
77  if (PQ.size() < k_ || value > PQ.top().first) {
78  PQ.push(std::make_pair(value, j));
79  }
80  if (PQ.size() > k_) {
81  PQ.pop();
82  }
83  }
84  for (TIndex j = 0; j < k_; ++j) {
85  auto& pqElem = PQ.top();
86  values_data[output_offset + k_ - j - 1] = pqElem.first;
87  indices_data[output_offset + k_ - j - 1] = pqElem.second;
88  PQ.pop();
89  }
90  output_offset += k_;
91  }
92 
93  return true;
94 }
95 
96 template <typename T, class Context>
97 bool FlexibleTopKGradientOp<T, Context>::RunOnDevice() {
98  auto& original_input = Input(0);
99  auto& k = Input(1);
100  auto& values = Input(2);
101  auto& indices = Input(3);
102  auto* output = Output(0);
103 
104  const TIndex* k_data = k.template data<TIndex>();
105  const T* values_data = values.template data<T>();
106  const TIndex* indices_data = indices.template data<TIndex>();
107 
108  // Resize output tensors to be as orignial_input size and initialized with 0
109  CAFFE_ENFORCE_GT(original_input.ndim(), 0);
110  vector<TIndex> original_dims = original_input.dims();
111  output->Resize(original_dims);
112  T* output_data = output->template mutable_data<T>();
113  math::Set<T, Context>(
114  output->size(), static_cast<T>(0), output_data, &context_);
115 
116  TIndex index_offset = 0;
117  for (TIndex i = 0; i < k.size(); ++i) {
118  // offset of output_data
119  TIndex output_offset = i * original_dims.back();
120  for (TIndex j = 0; j < k_data[i]; ++j) {
121  TIndex index = indices_data[index_offset + j];
122  T value = values_data[index_offset + j];
123  output_data[output_offset + index] = value;
124  }
125  index_offset += k_data[i];
126  }
127 
128  return true;
129 }
130 
131 REGISTER_CPU_OPERATOR(FlexibleTopK, FlexibleTopKOp<float, CPUContext>);
132 REGISTER_CPU_OPERATOR(
133  FlexibleTopKGradient,
134  FlexibleTopKGradientOp<float, CPUContext>);
135 
136 OPERATOR_SCHEMA(FlexibleTopK)
137  .NumInputs(2)
138  .NumOutputs(2)
139  .SetDoc(R"DOC(
140 Given two tensors: X and K,
141 retrieve the top K[..., 1] elements from X on the last dimension.
142 X is an input tensor of shape [a_1, a_2, ..., a_n, r].
143 K is an input tensor of shape [a_1, a_2, ..., a_n, 1],
144 where for each element, r >= K[..., 1] > 0
145 Output two outputs:
146 -Flatten values tensor of shape [ \sum_i K[i, 1] ] which contains the values of
147  the top K[..., 1] elements along the last dimension
148 -Flatten indices tensor of shape [ \sum_i K[i, 1] ] which contains the indices
149  of the top K[..., 1] elements, flatten indices from the input tensor).
150 These two outputs should be used with the input K, so that we know which indices
151 in X are picked.
152 
153 Given two equivalent values, this operator uses the indices along the last dim-
154 ension as a tiebreaker. That is, the element with the lower index will appear
155 first.
156  )DOC")
157  .Input(0, "X", "Tensor of shape [a_1, a_2, ..., a_n, r]")
158  .Input(1, "K", "Tensor of shape [a_1, a_2, ..., a_n, 1]")
159  .Output(
160  0,
161  "Flatten values",
162  "Tensor of shape [ \\sum_i K[i, 1] ] containing"
163  " top K[..., 1] values from the input tensor")
164  .Output(
165  1,
166  "Flatten indices",
167  "Tensor of shape [ \\sum_i K[i, 1] ] containing the indices "
168  "into the flatten input");
169 
170 OPERATOR_SCHEMA(FlexibleTopKGradient).NumInputs(4).NumOutputs(1);
171 
172 class GetFlexibleTopKGradient : public GradientMakerBase {
173  using GradientMakerBase::GradientMakerBase;
174  vector<OperatorDef> GetGradientDefs() override {
175  return SingleGradientDef(
176  "FlexibleTopKGradient",
177  "",
178  vector<string>{I(0), I(1), GO(0), O(1)},
179  vector<string>{GI(0)});
180  }
181 };
182 
183 REGISTER_GRADIENT(FlexibleTopK, GetFlexibleTopKGradient);
184 
185 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.