1 #ifndef CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_ 2 #define CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 9 template <
typename T,
class Context>
12 USE_OPERATOR_CONTEXT_FUNCTIONS;
14 template <
class... Args>
17 binary_ = this->
template GetSingleArgument<bool>(
"binary",
false);
20 bounds_from_arg_ = this->
template GetRepeatedArgument<T>(
"bounds");
21 slopes_from_arg_ = this->
template GetRepeatedArgument<T>(
"slopes");
22 intercepts_from_arg_ = this->
template GetRepeatedArgument<T>(
"intercepts");
23 transform_param_from_arg_ = CheckTransParamFromArg();
26 bool RunOnDevice()
override {
27 return binary_ ? TransformBinary() : TransformGeneral();
35 void InferNumFunctionsPerGroup(
36 const int64_t num_bounds,
37 const int64_t num_slopes,
38 const int64_t num_intercepts,
39 int64_t* num_func_per_group,
41 CAFFE_ENFORCE_EQ(num_slopes, num_intercepts);
46 *num_group = num_bounds - num_slopes;
47 CAFFE_ENFORCE_GT(*num_group, 0);
49 CAFFE_ENFORCE_EQ(*num_group, 1);
51 *num_func_per_group = num_slopes / *num_group;
52 CAFFE_ENFORCE_GT(*num_func_per_group, 0);
53 CAFFE_ENFORCE_EQ(num_slopes % *num_group, 0);
56 bool CheckBoundsSorted(
58 const int64_t num_bounds_per_group,
59 const int64_t num_group) {
60 const T* start = bounds;
61 for (int64_t i = 0; i < num_group; i++) {
62 if (!std::is_sorted(start, start + num_bounds_per_group)) {
65 start += num_bounds_per_group;
72 bool CheckTransParamFromArg() {
74 good_param += bounds_from_arg_.size() > 0;
75 good_param += slopes_from_arg_.size() > 0;
76 good_param += intercepts_from_arg_.size() > 0;
78 good_param == 0 || good_param == 3,
79 "bounds, slopes, intercepts must be all set or all not set");
80 if (good_param == 3) {
81 int64_t num_func_per_group;
83 InferNumFunctionsPerGroup(
84 bounds_from_arg_.size(),
85 slopes_from_arg_.size(),
86 intercepts_from_arg_.size(),
91 bounds_from_arg_.data(), num_func_per_group + 1, num_group),
92 "bounds must be sorted for each group");
95 return good_param == 3;
98 void setUpTensors(int64_t& num_func_per_group, int64_t& num_group, int64_t
M);
100 void GetTransParamData(
103 const T** intercepts,
104 int64_t* num_func_per_group,
105 int64_t* num_group) {
108 int64_t num_intercepts;
110 if (transform_param_from_arg_) {
111 CAFFE_ENFORCE_EQ(InputSize(), 1);
112 *bounds = bounds_from_arg_.data();
113 *slopes = slopes_from_arg_.data();
114 *intercepts = intercepts_from_arg_.data();
115 num_bounds = bounds_from_arg_.size();
116 num_slopes = slopes_from_arg_.size();
117 num_intercepts = intercepts_from_arg_.size();
119 CAFFE_ENFORCE_EQ(InputSize(), 4);
120 auto& bounds_input =
Input(BOUNDS);
121 auto& slopes_input =
Input(SLOPES);
122 auto& intercepts_input =
Input(INTERCEPTS);
123 *bounds = bounds_input.template data<T>();
124 *slopes = slopes_input.template data<T>();
125 *intercepts = intercepts_input.template data<T>();
126 num_bounds = bounds_input.numel();
127 num_slopes = slopes_input.numel();
128 num_intercepts = intercepts_input.numel();
130 InferNumFunctionsPerGroup(
131 num_bounds, num_slopes, num_intercepts, num_func_per_group, num_group);
134 bool TransformGeneral() {
137 CAFFE_ENFORCE_EQ(X.dim(), 2);
138 int64_t N = X.dim32(0);
139 int64_t M = X.dim32(1);
140 auto* Y = Output(0, X.sizes(), at::dtype<T>());
141 const auto* Xdata = X.template data<T>();
142 T* Ydata = Y->template mutable_data<T>();
147 int64_t num_func_per_group;
150 &bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
151 CAFFE_ENFORCE_EQ(num_group, M);
153 for (int64_t j = 0; j < M; ++j) {
154 const T* bounds_group = bounds + j * (num_func_per_group + 1);
155 const T* slopes_group = slopes + j * num_func_per_group;
156 const T* intercepts_group = intercepts + j * num_func_per_group;
157 for (int64_t i = 0; i < N; ++i) {
158 Ydata[i * M + j] = PiecewiseLinearTransform(
169 bool TransformBinary() {
170 auto& X =
Input(PREDICTIONS);
172 CAFFE_ENFORCE(X.dim() == 1 || X.dim() == 2);
173 int64_t N = X.dim32(0);
174 int64_t M = X.dim() == 2 ? X.dim32(1) : 1;
177 "If binary is set to true, the input must be Nx2 or Nx1 tensor");
178 auto* Y = Output(0, X.sizes(), at::dtype<T>());
179 const auto* Xdata = X.template data<T>();
180 T* Ydata = Y->template mutable_data<T>();
185 int64_t num_func_per_group;
188 &bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
189 CAFFE_ENFORCE_EQ(num_group, 1);
192 for (int64_t i = 0; i < N; ++i) {
193 Ydata[i] = PiecewiseLinearTransform(
194 Xdata[i], bounds, slopes, intercepts, num_func_per_group);
197 for (int64_t i = 0; i < N; ++i) {
198 Ydata[i * M + 1] = PiecewiseLinearTransform(
199 Xdata[i * M + 1], bounds, slopes, intercepts, num_func_per_group);
200 Ydata[i * M] = 1.0f - Ydata[i * M + 1];
207 T PiecewiseLinearTransform(
212 const int64_t num_func_per_group) {
216 if (x <= bounds[0]) {
217 y = slopes[0] * bounds[0] + intercepts[0];
218 }
else if (x >= bounds[num_func_per_group]) {
219 y = slopes[num_func_per_group - 1] * bounds[num_func_per_group] +
220 intercepts[num_func_per_group - 1];
223 std::lower_bound(bounds, bounds + num_func_per_group + 1, x);
224 int bounds_idx = low_bound - bounds - 1;
226 y = slopes[bounds_idx] * x + intercepts[bounds_idx];
233 vector<T> bounds_from_arg_;
234 vector<T> slopes_from_arg_;
235 vector<T> intercepts_from_arg_;
237 Tensor bounds_device_{Context::GetDeviceType()};
238 Tensor intercepts_device_{Context::GetDeviceType()};
239 Tensor slopes_device_{Context::GetDeviceType()};
240 bool gpu_copied_ =
false;
244 bool transform_param_from_arg_;
246 INPUT_TAGS(PREDICTIONS, BOUNDS, SLOPES, INTERCEPTS);
251 #endif // CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...