1 #ifndef CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_ 2 #define CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/core/tensor.h" 8 #include "caffe2/utils/math.h" 12 template <
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 template <
class... Args>
19 pct_raw_(OperatorBase::GetRepeatedArgument<float>(
22 pct_mapping_(OperatorBase::GetRepeatedArgument<float>(
25 pct_lower_(OperatorBase::GetRepeatedArgument<float>(
28 pct_upper_(OperatorBase::GetRepeatedArgument<float>(
32 OperatorBase::GetRepeatedArgument<int>(
"lengths", vector<int>{})) {
36 "Feature (raw) data and percentile value dimension should match.");
40 "Feature (raw) data and lower bound dimension should match.");
44 "Feature (raw) data and upper bound dimension should match.");
45 n_features = pct_lens_.size();
46 index.reserve(n_features + 1);
48 for (
int i = 1; i <= n_features; ++i) {
49 index[i] = index[i - 1] + pct_lens_[i - 1];
54 "Sum of lengths should be equal to the total number of percentile " 55 "mapping data samples");
58 bool RunOnDevice()
override {
60 const auto& raw =
Input(RAW);
61 CAFFE_ENFORCE_EQ(raw.dim(), 2);
62 const auto batch_size = raw.size(0);
63 const auto num_features = raw.size(1);
64 CAFFE_ENFORCE_EQ(num_features, pct_lens_.size());
65 const float* raw_data = raw.template data<float>();
69 auto* pct = Output(PCT, raw.sizes(), at::dtype<float>());
70 float* pct_output = pct->template mutable_data<float>();
73 int feature_start_index = 0;
74 int feature_length = 0;
77 for (
int i = 0; i < num_features; ++i) {
79 feature_start_index = index[i];
80 feature_length = pct_lens_[i];
81 for (
int j = 0; j < batch_size; ++j) {
82 pct_output[cur_index] = compute_percentile(
83 pct_raw_.begin() + feature_start_index,
84 pct_mapping_.begin() + feature_start_index,
85 pct_lower_.begin() + feature_start_index,
86 pct_upper_.begin() + feature_start_index,
89 cur_index += num_features;
101 vector<float> pct_raw_;
102 vector<float> pct_mapping_;
103 vector<float> pct_lower_;
104 vector<float> pct_upper_;
105 vector<int> pct_lens_;
107 vector<std::map<float, float>> fast_pct;
109 const float kEPSILON = 1e-10;
112 const std::vector<float>::iterator& data,
117 bool low_cond, high_cond;
120 mid = (lo + hi) >> 1;
121 low_cond = (data[mid] <= val);
122 high_cond = (val < data[mid + 1]);
123 if (low_cond && high_cond) {
125 }
else if (!low_cond) {
134 float compute_percentile(
135 const std::vector<float>::iterator& pct_raw_it,
136 const std::vector<float>::iterator& pct_mapping_it,
137 const std::vector<float>::iterator& pct_lower_it,
138 const std::vector<float>::iterator& pct_upper_it,
142 if (val < pct_raw_it[0]) {
145 if (val > pct_raw_it[size - 1]) {
151 const auto k = binary_search(pct_raw_it, 0, size - 1, val);
153 if (pct_raw_it[k] == val) {
155 result = pct_mapping_it[k];
158 float w = (val - pct_raw_it[k]) /
159 (pct_raw_it[k + 1] - pct_raw_it[k] + kEPSILON);
160 result = (1 - w) * pct_upper_it[k] + w * pct_lower_it[k + 1];
168 #endif // CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...