1 #include "caffe2/operators/top_k.h" 9 #include "caffe2/proto/caffe2_pb.h" 10 #include "caffe2/utils/math.h" 19 const std::pair<T, int64_t>& lhs,
20 const std::pair<T, int64_t>& rhs)
const {
21 return lhs.first > rhs.first ||
22 (lhs.first == rhs.first && lhs.second < rhs.second);
31 const int64_t src_offset,
32 const int64_t dst_offset,
36 int64_t* flatten_indices) {
37 const T* src_ptr = input + src_offset;
38 std::vector<std::pair<T, int64_t>> heap_data;
40 for (int64_t i = 0; i < k && i < n; ++i) {
41 heap_data.emplace_back(*src_ptr, i);
45 std::pair<T, int64_t>,
46 std::vector<std::pair<T, int64_t>>,
48 pq(ValueComp<T>(), std::move(heap_data));
49 for (int64_t i = k; i < n; ++i) {
50 if (pq.top().first < *src_ptr) {
52 pq.emplace(*src_ptr, i);
56 int64_t dst_pos = dst_offset + (std::min(k, n) - 1) * stride;
58 const auto& item = pq.top();
59 values[dst_pos] = item.first;
60 indices[dst_pos] = item.second;
61 if (flatten_indices !=
nullptr) {
62 flatten_indices[dst_pos] = src_offset + item.second * stride;
72 const int64_t* indices,
74 const int64_t src_offset,
75 const int64_t dst_offset,
78 int64_t src_pos = src_offset;
79 for (
int i = 0; i < k; ++i) {
80 if (indices[src_pos] < 0) {
83 gradient[dst_offset + indices[src_pos] * stride] = values[src_pos];
90 template <
typename T,
class Context>
91 bool TopKOp<T, Context>::RunOnDevice() {
92 const auto& input = Input(0);
93 auto* values = Output(0);
94 auto* indices = Output(1);
95 auto* flatten_indices = OutputSize() > 2 ? Output(2) : nullptr;
99 axis_ = input_dims.
size() - 1;
101 CAFFE_ENFORCE_GE(axis_, 0);
102 CAFFE_ENFORCE_LT(axis_, input_dims.
size());
104 std::vector<int64_t> output_dims = input_dims.vec();
105 output_dims[axis_] = k_;
106 values->Resize(output_dims);
107 indices->Resize(output_dims);
108 if (flatten_indices !=
nullptr) {
109 flatten_indices->Resize(indices->numel());
111 const T* input_data = input.template data<T>();
112 T* values_data = values->template mutable_data<T>();
113 int64_t* indices_data = indices->template mutable_data<int64_t>();
114 int64_t* flatten_indices_data = flatten_indices ==
nullptr 116 : flatten_indices->template mutable_data<int64_t>();
118 math::Set<T, Context>(values->numel(),
T(0), values_data, &context_);
119 math::Set<int64_t, Context>(
120 indices->numel(), int64_t(-1), indices_data, &context_);
121 if (flatten_indices_data !=
nullptr) {
122 math::Set<int64_t, Context>(
123 flatten_indices->numel(), int64_t(-1), flatten_indices_data, &context_);
126 const int64_t prev_size = std::accumulate(
128 input_dims.cbegin() + axis_,
130 std::multiplies<int64_t>());
131 const int64_t next_size = std::accumulate(
132 input_dims.cbegin() + axis_ + 1,
135 std::multiplies<int64_t>());
136 const int64_t src_offset_stride = input_dims[axis_] * next_size;
137 const int64_t dst_offset_stride = k_ * next_size;
138 int64_t src_offset = 0;
139 int64_t dst_offset = 0;
140 for (int64_t i = 0; i < prev_size; ++i) {
141 for (int64_t j = 0; j < next_size; ++j) {
151 flatten_indices_data);
153 src_offset += src_offset_stride;
154 dst_offset += dst_offset_stride;
159 template <
typename T,
class Context>
160 bool TopKGradientOp<T, Context>::RunOnDevice() {
161 const auto& values = Input(0);
162 const auto& indices = Input(1);
163 const auto& original_input = Input(2);
164 auto* output = Output(0);
167 CAFFE_ENFORCE_EQ(values_dims.
size(), origin_dims.
size());
168 output->Resize(origin_dims);
169 const T* values_data = values.template data<T>();
170 const int64_t* indices_data = indices.template data<int64_t>();
171 T* output_data = output->template mutable_data<T>();
173 axis_ = values_dims.
size() - 1;
175 const int k = values_dims[axis_];
176 math::Set<T, Context>(output->numel(),
T(0), output_data, &context_);
177 const int64_t prev_size = std::accumulate(
178 values_dims.cbegin(),
179 values_dims.cbegin() + axis_,
181 std::multiplies<int64_t>());
182 const int64_t next_size = std::accumulate(
183 values_dims.cbegin() + axis_ + 1,
186 std::multiplies<int64_t>());
187 const int64_t src_offset_stride = k * next_size;
188 const int64_t dst_offset_stride = origin_dims[axis_] * next_size;
189 int64_t src_offset = 0;
190 int64_t dst_offset = 0;
191 for (int64_t i = 0; i < prev_size; ++i) {
192 for (int64_t j = 0; j < next_size; ++j) {
202 src_offset += src_offset_stride;
203 dst_offset += dst_offset_stride;
208 REGISTER_CPU_OPERATOR(TopK, TopKOp<float, CPUContext>);
209 REGISTER_CPU_OPERATOR(TopKGradient, TopKGradientOp<float, CPUContext>);
211 OPERATOR_SCHEMA(TopK)
214 .TensorInferenceFunction([](
const OperatorDef& def,
215 const vector<TensorShape>& in) {
216 vector<TensorShape> out = {in[0], in[0]};
217 ArgumentHelper helper(def);
218 auto k = helper.GetSingleArgument(
"k", -1);
219 auto dims_size = in[0].dims_size();
220 out[0].set_dims(dims_size - 1, k);
221 out[1].set_dims(dims_size - 1, k);
222 out[1].set_data_type(TensorProto_DataType_INT32);
223 if (def.output_size() > 2) {
224 TensorShape flatten_indices_shape;
225 flatten_indices_shape.set_data_type(TensorProto_DataType_INT32);
226 flatten_indices_shape.add_dims(
228 in[0].dims().begin(),
229 in[0].dims().end() - 1,
231 std::multiplies<long>()) *
233 out.push_back(flatten_indices_shape);
238 Retrieve the top-K elements of the last dimension. Given an input tensor of shape $(a_1, a_2, ..., a_n, r)$ and integer argument `k`, return up to three outputs: 240 1. Value tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the values of the top k elements along the last dimension 241 2. Index tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the indices of the top k elements (original indices from the input tensor). 242 3. [OPTIONAL] Flattened index tensor of shape $(a_1 * a_2 * ... * a_n * k,)$. 244 Given two equivalent values, this operator uses the indices along the last dimension as a tiebreaker. That is, the element with the lower index will appear first. 247 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/top_k.cc 252 <summary> <b>Example</b> </summary> 258 workspace.ResetWorkspace() 260 op = core.CreateOperator( 263 ["Values", "Indices", "Flattened_indices"], 267 workspace.FeedBlob("X", np.random.randint(10, size=(3,3,3)).astype(np.float32)) 268 print("X:", workspace.FetchBlob("X")) 269 workspace.RunOperatorOnce(op) 270 print("Values:", workspace.FetchBlob("Values")) 271 print("Indices:", workspace.FetchBlob("Indices")) 272 print("Flattened_indices:", workspace.FetchBlob("Flattened_indices")) 316 Flattened_indices: [ 1 0 3 4 8 7 10 11 13 14 17 16 20 18 23 22 26 25] 326 "(*Tensor`<float>`*): input tensor of shape $(a_1, a_2, ..., a_n, r)$")
330 "(*Tensor`<float>`*): output tensor of shape $(a_1, a_2, ..., a_n, k)$")
334 "(*Tensor`<int>`*): tensor of indices of shape $(a_1, a_2, ..., a_n, k)$; indices values refer to each element's index in the last dimension of the `X` input tensor")
338 "(*Tensor`<int>`*): tensor of indices of shape $(a_1 * a_2 * ... * a_n * k,)$; indices values refer to each element's index in the flattened input tensor `X`")
339 .Arg(
"k",
"(*int*): number of top elements to retrieve");
341 OPERATOR_SCHEMA(TopKGradient).NumInputs(3).NumOutputs(1);
344 using GradientMakerBase::GradientMakerBase;
345 vector<OperatorDef> GetGradientDefs()
override {
346 return SingleGradientDef(
349 vector<string>{GO(0), O(1), I(0)},
350 vector<string>{GI(0)});
constexpr size_t size() const
size - Get the array size.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...