Caffe2 - C++ API
A deep learning, cross platform ML framework
flexible_top_k.cc
1 #include "caffe2/operators/flexible_top_k.h"
2 
3 #include "caffe2/proto/caffe2_pb.h"
4 
5 namespace caffe2 {
6 
7 namespace {
8 
9 template <typename T>
10 struct ValueCmp {
11  bool operator()(
12  const std::pair<T, int64_t>& lhs,
13  const std::pair<T, int64_t>& rhs) {
14  return (
15  lhs.first > rhs.first ||
16  (lhs.first == rhs.first && lhs.second < rhs.second));
17  }
18 };
19 
20 } // namespace
21 
22 template <typename T, class Context>
23 bool FlexibleTopKOp<T, Context>::RunOnDevice() {
24  auto& input = Input(0);
25  auto& k = Input(1);
26 
27  const T* input_data = input.template data<T>();
28  const int64_t* k_data = k.template data<int64_t>();
29 
30  // get flatten shape of input
31  CAFFE_ENFORCE_GT(input.dim(), 0);
32  vector<int64_t> input_dims = input.sizes().vec();
33  vector<int64_t> linear_shape = {
34  size_to_dim_(input_dims.size() - 1, input_dims), input_dims.back()};
35  CAFFE_ENFORCE_EQ(
36  linear_shape[0],
37  k.numel(),
38  "first n-1 dims of input data and K does not match.");
39 
40  int64_t output_size = 0;
41  for (int64_t i = 0; i < linear_shape[0]; ++i) {
42  CAFFE_ENFORCE(
43  linear_shape[1] >= k_data[i],
44  "k should not be greater than last dim, error at index ",
45  i,
46  ", with value: ",
47  k_data[i]);
48  CAFFE_ENFORCE(
49  k_data[i] > 0,
50  "k should be greater than 0, error at index ",
51  i,
52  ", with value: ",
53  k_data[i]);
54  output_size += k_data[i];
55  }
56  auto* values = Output(0, {output_size}, at::dtype<T>());
57  auto* indices = Output(1, {output_size}, at::dtype<int64_t>());
58  T* values_data = values->template mutable_data<T>();
59  int64_t* indices_data = indices->template mutable_data<int64_t>();
60 
61  int64_t output_offset = 0;
62  // Sort preserving indices
63  for (int64_t i = 0; i < linear_shape[0]; ++i) {
64  // Build a min-heap, the heap element is pair of (value, idx)
65  // the top of the heap is the smallest value
66  std::priority_queue<
67  std::pair<T, int64_t>,
68  std::vector<std::pair<T, int64_t>>,
69  ValueCmp<T>>
70  PQ;
71 
72  int64_t k_ = k_data[i];
73  for (int64_t j = 0; j < linear_shape[1]; ++j) {
74  const T value = input_data[i * linear_shape[1] + j];
75  if (PQ.size() < k_ || value > PQ.top().first) {
76  PQ.push(std::make_pair(value, j));
77  }
78  if (PQ.size() > k_) {
79  PQ.pop();
80  }
81  }
82  for (int64_t j = 0; j < k_; ++j) {
83  auto& pqElem = PQ.top();
84  values_data[output_offset + k_ - j - 1] = pqElem.first;
85  indices_data[output_offset + k_ - j - 1] = pqElem.second;
86  PQ.pop();
87  }
88  output_offset += k_;
89  }
90 
91  return true;
92 }
93 
94 template <typename T, class Context>
95 bool FlexibleTopKGradientOp<T, Context>::RunOnDevice() {
96  auto& original_input = Input(0);
97  auto& k = Input(1);
98  auto& values = Input(2);
99  auto& indices = Input(3);
100 
101  const int64_t* k_data = k.template data<int64_t>();
102  const T* values_data = values.template data<T>();
103  const int64_t* indices_data = indices.template data<int64_t>();
104 
105  // Resize output tensors to be as orignial_input size and initialized with 0
106  CAFFE_ENFORCE_GT(original_input.dim(), 0);
107  vector<int64_t> original_dims = original_input.sizes().vec();
108  auto* output = Output(0, original_dims, at::dtype<T>());
109  T* output_data = output->template mutable_data<T>();
110  math::Set<T, Context>(
111  output->numel(), static_cast<T>(0), output_data, &context_);
112 
113  int64_t index_offset = 0;
114  for (int64_t i = 0; i < k.numel(); ++i) {
115  // offset of output_data
116  int64_t output_offset = i * original_dims.back();
117  for (int64_t j = 0; j < k_data[i]; ++j) {
118  int64_t index = indices_data[index_offset + j];
119  T value = values_data[index_offset + j];
120  output_data[output_offset + index] = value;
121  }
122  index_offset += k_data[i];
123  }
124 
125  return true;
126 }
127 
128 REGISTER_CPU_OPERATOR(FlexibleTopK, FlexibleTopKOp<float, CPUContext>);
129 REGISTER_CPU_OPERATOR(
130  FlexibleTopKGradient,
131  FlexibleTopKGradientOp<float, CPUContext>);
132 
133 OPERATOR_SCHEMA(FlexibleTopK)
134  .NumInputs(2)
135  .NumOutputs(2)
136  .SetDoc(R"DOC(
137 Given two tensors: X and K,
138 retrieve the top K[..., 1] elements from X on the last dimension.
139 X is an input tensor of shape [a_1, a_2, ..., a_n, r].
140 K is an input tensor of shape [a_1, a_2, ..., a_n, 1],
141 where for each element, r >= K[..., 1] > 0
142 Output two outputs:
143 -Flatten values tensor of shape [ \sum_i K[i, 1] ] which contains the values of
144  the top K[..., 1] elements along the last dimension
145 -Flatten indices tensor of shape [ \sum_i K[i, 1] ] which contains the indices
146  of the top K[..., 1] elements, flatten indices from the input tensor).
147 These two outputs should be used with the input K, so that we know which indices
148 in X are picked.
149 
150 Given two equivalent values, this operator uses the indices along the last dim-
151 ension as a tiebreaker. That is, the element with the lower index will appear
152 first.
153  )DOC")
154  .Input(0, "X", "Tensor of shape [a_1, a_2, ..., a_n, r]")
155  .Input(1, "K", "Tensor of shape [a_1, a_2, ..., a_n, 1]")
156  .Output(
157  0,
158  "Flatten values",
159  "Tensor of shape [ \\sum_i K[i, 1] ] containing"
160  " top K[..., 1] values from the input tensor")
161  .Output(
162  1,
163  "Flatten indices",
164  "Tensor of shape [ \\sum_i K[i, 1] ] containing the indices "
165  "into the flatten input");
166 
167 OPERATOR_SCHEMA(FlexibleTopKGradient).NumInputs(4).NumOutputs(1);
168 
169 class GetFlexibleTopKGradient : public GradientMakerBase {
170  using GradientMakerBase::GradientMakerBase;
171  vector<OperatorDef> GetGradientDefs() override {
172  return SingleGradientDef(
173  "FlexibleTopKGradient",
174  "",
175  vector<string>{I(0), I(1), GO(0), O(1)},
176  vector<string>{GI(0)});
177  }
178 };
179 
180 REGISTER_GRADIENT(FlexibleTopK, GetFlexibleTopKGradient);
181 
182 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13