Caffe2 - C++ API
A deep learning, cross platform ML framework
find_op.h
1 
17 #ifndef CAFFE2_OPERATORS_FIND_OP_H_
18 #define CAFFE2_OPERATORS_FIND_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 
24 #include <unordered_map>
25 
26 namespace caffe2 {
27 
28 template <class Context>
29 class FindOp final : public Operator<Context> {
30  public:
31  FindOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws),
33  missing_value_(
34  OperatorBase::GetSingleArgument<int>("missing_value", -1)) {}
35  USE_OPERATOR_CONTEXT_FUNCTIONS;
36  USE_DISPATCH_HELPER;
37 
38  bool RunOnDevice() {
39  return DispatchHelper<TensorTypes<int, long>>::call(this, Input(0));
40  }
41 
42  protected:
43  template <typename T>
44  bool DoRunWithType() {
45  auto& idx = Input(0);
46  auto& needles = Input(1);
47  auto* res_indices = Output(0);
48  res_indices->ResizeLike(needles);
49 
50  const T* idx_data = idx.template data<T>();
51  const T* needles_data = needles.template data<T>();
52  T* res_data = res_indices->template mutable_data<T>();
53  auto idx_size = idx.size();
54 
55  // Use an arbitrary cut-off for when to use brute-force
56  // search. For larger needle sizes we first put the
57  // index into a map
58  if (needles.size() < 16) {
59  // Brute force O(nm)
60  for (int i = 0; i < needles.size(); i++) {
61  T x = needles_data[i];
62  T res = static_cast<T>(missing_value_);
63  for (int j = idx_size - 1; j >= 0; j--) {
64  if (idx_data[j] == x) {
65  res = j;
66  break;
67  }
68  }
69  res_data[i] = res;
70  }
71  } else {
72  // O(n + m)
73  std::unordered_map<T, int> idx_map;
74  for (int j = 0; j < idx_size; j++) {
75  idx_map[idx_data[j]] = j;
76  }
77  for (int i = 0; i < needles.size(); i++) {
78  T x = needles_data[i];
79  auto it = idx_map.find(x);
80  res_data[i] = (it == idx_map.end() ? missing_value_ : it->second);
81  }
82  }
83 
84  return true;
85  }
86 
87  protected:
88  int missing_value_;
89 };
90 
91 } // namespace caffe2
92 
93 #endif // CAFFE2_OPERATORS_FIND_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.