1 #ifndef CAFFE2_OPERATORS_FIND_OP_H_ 2 #define CAFFE2_OPERATORS_FIND_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 8 #include <unordered_map> 12 template <
class Context>
15 template <
class... Args>
16 explicit FindOp(Args&&... args)
19 this->
template GetSingleArgument<int>(
"missing_value", -1)) {}
20 USE_OPERATOR_CONTEXT_FUNCTIONS;
29 bool DoRunWithType() {
31 auto& needles =
Input(1);
33 auto* res_indices = Output(0, needles.sizes(), at::dtype<T>());
35 const T* idx_data = idx.template data<T>();
36 const T* needles_data = needles.template data<T>();
37 T* res_data = res_indices->template mutable_data<T>();
38 auto idx_size = idx.numel();
43 if (needles.numel() < 16) {
45 for (
int i = 0; i < needles.numel(); i++) {
46 T x = needles_data[i];
47 T res =
static_cast<T>(missing_value_);
48 for (
int j = idx_size - 1; j >= 0; j--) {
49 if (idx_data[j] == x) {
58 std::unordered_map<T, int> idx_map;
59 for (
int j = 0; j < idx_size; j++) {
60 idx_map[idx_data[j]] = j;
62 for (
int i = 0; i < needles.numel(); i++) {
63 T x = needles_data[i];
64 auto it = idx_map.find(x);
65 res_data[i] = (it == idx_map.end() ? missing_value_ : it->second);
78 #endif // CAFFE2_OPERATORS_FIND_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...