Caffe2 - C++ API
A deep learning, cross platform ML framework
bisect_percentile_op.h
1 #ifndef CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_
2 #define CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/core/tensor.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <class Context>
13 class BisectPercentileOp final : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  template <class... Args>
17  explicit BisectPercentileOp(Args&&... args)
18  : Operator<Context>(std::forward<Args>(args)...),
19  pct_raw_(OperatorBase::GetRepeatedArgument<float>(
20  "percentile_raw",
21  vector<float>{})),
22  pct_mapping_(OperatorBase::GetRepeatedArgument<float>(
23  "percentile_mapping",
24  vector<float>{})),
25  pct_lower_(OperatorBase::GetRepeatedArgument<float>(
26  "percentile_lower",
27  vector<float>{})),
28  pct_upper_(OperatorBase::GetRepeatedArgument<float>(
29  "percentile_upper",
30  vector<float>{})),
31  pct_lens_(
32  OperatorBase::GetRepeatedArgument<int>("lengths", vector<int>{})) {
33  CAFFE_ENFORCE_EQ(
34  pct_raw_.size(),
35  pct_mapping_.size(),
36  "Feature (raw) data and percentile value dimension should match.");
37  CAFFE_ENFORCE_EQ(
38  pct_raw_.size(),
39  pct_lower_.size(),
40  "Feature (raw) data and lower bound dimension should match.");
41  CAFFE_ENFORCE_EQ(
42  pct_raw_.size(),
43  pct_upper_.size(),
44  "Feature (raw) data and upper bound dimension should match.");
45  n_features = pct_lens_.size();
46  index.reserve(n_features + 1);
47  index[0] = 0;
48  for (int i = 1; i <= n_features; ++i) {
49  index[i] = index[i - 1] + pct_lens_[i - 1];
50  }
51  CAFFE_ENFORCE_EQ(
52  index[n_features], // The sum of lengths_data
53  pct_raw_.size(),
54  "Sum of lengths should be equal to the total number of percentile "
55  "mapping data samples");
56  }
57 
58  bool RunOnDevice() override {
59  // Input
60  const auto& raw = Input(RAW);
61  CAFFE_ENFORCE_EQ(raw.dim(), 2);
62  const auto batch_size = raw.size(0);
63  const auto num_features = raw.size(1);
64  CAFFE_ENFORCE_EQ(num_features, pct_lens_.size());
65  const float* raw_data = raw.template data<float>();
66 
67  // Output
68 
69  auto* pct = Output(PCT, raw.sizes(), at::dtype<float>());
70  float* pct_output = pct->template mutable_data<float>();
71 
72  // Compute percentile for each raw feature value
73  int feature_start_index = 0;
74  int feature_length = 0;
75  int cur_index = 0;
76 
77  for (int i = 0; i < num_features; ++i) {
78  cur_index = i;
79  feature_start_index = index[i];
80  feature_length = pct_lens_[i];
81  for (int j = 0; j < batch_size; ++j) {
82  pct_output[cur_index] = compute_percentile(
83  pct_raw_.begin() + feature_start_index,
84  pct_mapping_.begin() + feature_start_index,
85  pct_lower_.begin() + feature_start_index,
86  pct_upper_.begin() + feature_start_index,
87  feature_length,
88  raw_data[cur_index]);
89  cur_index += num_features;
90  }
91  }
92  return true;
93  }
94 
95  protected:
96  INPUT_TAGS(RAW);
97  OUTPUT_TAGS(PCT);
98 
99  private:
100  int n_features;
101  vector<float> pct_raw_;
102  vector<float> pct_mapping_;
103  vector<float> pct_lower_;
104  vector<float> pct_upper_;
105  vector<int> pct_lens_;
106  vector<int> index;
107  vector<std::map<float, float>> fast_pct;
108 
109  const float kEPSILON = 1e-10;
110 
111  int binary_search(
112  const std::vector<float>::iterator& data,
113  int lo,
114  int hi,
115  float val) {
116  int mid;
117  bool low_cond, high_cond;
118 
119  while (lo < hi) {
120  mid = (lo + hi) >> 1;
121  low_cond = (data[mid] <= val);
122  high_cond = (val < data[mid + 1]);
123  if (low_cond && high_cond) {
124  return mid;
125  } else if (!low_cond) {
126  hi = mid - 1;
127  } else {
128  lo = mid + 1;
129  }
130  }
131  return lo;
132  }
133 
134  float compute_percentile(
135  const std::vector<float>::iterator& pct_raw_it,
136  const std::vector<float>::iterator& pct_mapping_it,
137  const std::vector<float>::iterator& pct_lower_it,
138  const std::vector<float>::iterator& pct_upper_it,
139  const int size,
140  const float val) {
141  // Corner cases where no interpolation is needed.
142  if (val < pct_raw_it[0]) {
143  return 0.;
144  }
145  if (val > pct_raw_it[size - 1]) {
146  return 1.;
147  }
148 
149  float result;
150  // Interpolation by binary search
151  const auto k = binary_search(pct_raw_it, 0, size - 1, val);
152 
153  if (pct_raw_it[k] == val) {
154  // Exact match
155  result = pct_mapping_it[k];
156  } else {
157  // interpolation
158  float w = (val - pct_raw_it[k]) /
159  (pct_raw_it[k + 1] - pct_raw_it[k] + kEPSILON);
160  result = (1 - w) * pct_upper_it[k] + w * pct_lower_it[k + 1];
161  }
162  return result;
163  }
164 };
165 
166 } // namespace caffe2
167 
168 #endif // CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13