Caffe2 - C++ API
A deep learning, cross platform ML framework
piecewise_linear_transform_op.h
1 #ifndef CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
2 #define CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <typename T, class Context>
10 class PiecewiseLinearTransformOp final : public Operator<Context> {
11  public:
12  USE_OPERATOR_CONTEXT_FUNCTIONS;
13 
14  template <class... Args>
15  explicit PiecewiseLinearTransformOp(Args&&... args)
16  : Operator<Context>(std::forward<Args>(args)...) {
17  binary_ = this->template GetSingleArgument<bool>("binary", false);
18 
19  // Retrieve transform params (i.e., the linear functions).
20  bounds_from_arg_ = this->template GetRepeatedArgument<T>("bounds");
21  slopes_from_arg_ = this->template GetRepeatedArgument<T>("slopes");
22  intercepts_from_arg_ = this->template GetRepeatedArgument<T>("intercepts");
23  transform_param_from_arg_ = CheckTransParamFromArg();
24  }
25 
26  bool RunOnDevice() override {
27  return binary_ ? TransformBinary() : TransformGeneral();
28  }
29 
30  private:
31  // num_func_per_group is the number of pieces of linear functions of
32  // each group.
33  // num_group: The number of groups of linear functions. Each group is for
34  // transforming one column of predictions.
35  void InferNumFunctionsPerGroup(
36  const int64_t num_bounds,
37  const int64_t num_slopes,
38  const int64_t num_intercepts,
39  int64_t* num_func_per_group,
40  int64_t* num_group) {
41  CAFFE_ENFORCE_EQ(num_slopes, num_intercepts);
42 
43  // This is based on the facts:
44  // 1. in each group, the num of bounds minus the num of slopes is 1;
45  // 2. each group has the same number of pieces.
46  *num_group = num_bounds - num_slopes;
47  CAFFE_ENFORCE_GT(*num_group, 0);
48  if (binary_) {
49  CAFFE_ENFORCE_EQ(*num_group, 1);
50  }
51  *num_func_per_group = num_slopes / *num_group;
52  CAFFE_ENFORCE_GT(*num_func_per_group, 0);
53  CAFFE_ENFORCE_EQ(num_slopes % *num_group, 0);
54  }
55 
56  bool CheckBoundsSorted(
57  const T* bounds,
58  const int64_t num_bounds_per_group,
59  const int64_t num_group) {
60  const T* start = bounds;
61  for (int64_t i = 0; i < num_group; i++) {
62  if (!std::is_sorted(start, start + num_bounds_per_group)) {
63  return false;
64  }
65  start += num_bounds_per_group;
66  }
67  return true;
68  }
69 
70  // Returns true if the transform params from arg are valid.
71  // Otherwise, we will assume the transform params will pass from Input blobs.
72  bool CheckTransParamFromArg() {
73  int good_param = 0;
74  good_param += bounds_from_arg_.size() > 0;
75  good_param += slopes_from_arg_.size() > 0;
76  good_param += intercepts_from_arg_.size() > 0;
77  CAFFE_ENFORCE(
78  good_param == 0 || good_param == 3,
79  "bounds, slopes, intercepts must be all set or all not set");
80  if (good_param == 3) {
81  int64_t num_func_per_group;
82  int64_t num_group;
83  InferNumFunctionsPerGroup(
84  bounds_from_arg_.size(),
85  slopes_from_arg_.size(),
86  intercepts_from_arg_.size(),
87  &num_func_per_group,
88  &num_group);
89  CAFFE_ENFORCE(
90  CheckBoundsSorted(
91  bounds_from_arg_.data(), num_func_per_group + 1, num_group),
92  "bounds must be sorted for each group");
93  }
94 
95  return good_param == 3;
96  }
97 
98  void setUpTensors(int64_t& num_func_per_group, int64_t& num_group, int64_t M);
99 
100  void GetTransParamData(
101  const T** bounds,
102  const T** slopes,
103  const T** intercepts,
104  int64_t* num_func_per_group,
105  int64_t* num_group) {
106  int64_t num_bounds;
107  int64_t num_slopes;
108  int64_t num_intercepts;
109 
110  if (transform_param_from_arg_) {
111  CAFFE_ENFORCE_EQ(InputSize(), 1);
112  *bounds = bounds_from_arg_.data();
113  *slopes = slopes_from_arg_.data();
114  *intercepts = intercepts_from_arg_.data();
115  num_bounds = bounds_from_arg_.size();
116  num_slopes = slopes_from_arg_.size();
117  num_intercepts = intercepts_from_arg_.size();
118  } else {
119  CAFFE_ENFORCE_EQ(InputSize(), 4);
120  auto& bounds_input = Input(BOUNDS);
121  auto& slopes_input = Input(SLOPES);
122  auto& intercepts_input = Input(INTERCEPTS);
123  *bounds = bounds_input.template data<T>();
124  *slopes = slopes_input.template data<T>();
125  *intercepts = intercepts_input.template data<T>();
126  num_bounds = bounds_input.numel();
127  num_slopes = slopes_input.numel();
128  num_intercepts = intercepts_input.numel();
129  }
130  InferNumFunctionsPerGroup(
131  num_bounds, num_slopes, num_intercepts, num_func_per_group, num_group);
132  }
133 
134  bool TransformGeneral() {
135  auto& X = Input(0);
136 
137  CAFFE_ENFORCE_EQ(X.dim(), 2);
138  int64_t N = X.dim32(0);
139  int64_t M = X.dim32(1);
140  auto* Y = Output(0, X.sizes(), at::dtype<T>());
141  const auto* Xdata = X.template data<T>();
142  T* Ydata = Y->template mutable_data<T>();
143 
144  const T* bounds;
145  const T* slopes;
146  const T* intercepts;
147  int64_t num_func_per_group;
148  int64_t num_group;
149  GetTransParamData(
150  &bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
151  CAFFE_ENFORCE_EQ(num_group, M);
152 
153  for (int64_t j = 0; j < M; ++j) {
154  const T* bounds_group = bounds + j * (num_func_per_group + 1);
155  const T* slopes_group = slopes + j * num_func_per_group;
156  const T* intercepts_group = intercepts + j * num_func_per_group;
157  for (int64_t i = 0; i < N; ++i) {
158  Ydata[i * M + j] = PiecewiseLinearTransform(
159  Xdata[i * M + j],
160  bounds_group,
161  slopes_group,
162  intercepts_group,
163  num_func_per_group);
164  }
165  }
166  return true;
167  }
168 
169  bool TransformBinary() {
170  auto& X = Input(PREDICTIONS);
171 
172  CAFFE_ENFORCE(X.dim() == 1 || X.dim() == 2);
173  int64_t N = X.dim32(0);
174  int64_t M = X.dim() == 2 ? X.dim32(1) : 1;
175  CAFFE_ENFORCE(
176  M == 1 || M == 2,
177  "If binary is set to true, the input must be Nx2 or Nx1 tensor");
178  auto* Y = Output(0, X.sizes(), at::dtype<T>());
179  const auto* Xdata = X.template data<T>();
180  T* Ydata = Y->template mutable_data<T>();
181 
182  const T* bounds;
183  const T* slopes;
184  const T* intercepts;
185  int64_t num_func_per_group;
186  int64_t num_group;
187  GetTransParamData(
188  &bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
189  CAFFE_ENFORCE_EQ(num_group, 1);
190 
191  if (M == 1) {
192  for (int64_t i = 0; i < N; ++i) {
193  Ydata[i] = PiecewiseLinearTransform(
194  Xdata[i], bounds, slopes, intercepts, num_func_per_group);
195  }
196  } else {
197  for (int64_t i = 0; i < N; ++i) {
198  Ydata[i * M + 1] = PiecewiseLinearTransform(
199  Xdata[i * M + 1], bounds, slopes, intercepts, num_func_per_group);
200  Ydata[i * M] = 1.0f - Ydata[i * M + 1];
201  }
202  }
203 
204  return true;
205  }
206 
207  T PiecewiseLinearTransform(
208  const T x,
209  const T* bounds,
210  const T* slopes,
211  const T* intercepts,
212  const int64_t num_func_per_group) {
213  T y = 0;
214  // deal with samples out of bounds
215  // make it the same as the upper/lower bound value
216  if (x <= bounds[0]) {
217  y = slopes[0] * bounds[0] + intercepts[0];
218  } else if (x >= bounds[num_func_per_group]) {
219  y = slopes[num_func_per_group - 1] * bounds[num_func_per_group] +
220  intercepts[num_func_per_group - 1];
221  } else {
222  auto low_bound =
223  std::lower_bound(bounds, bounds + num_func_per_group + 1, x);
224  int bounds_idx = low_bound - bounds - 1;
225  // compute the piecewise linear transformation as Y
226  y = slopes[bounds_idx] * x + intercepts[bounds_idx];
227  }
228  return y;
229  }
230 
231  private:
232  bool binary_;
233  vector<T> bounds_from_arg_;
234  vector<T> slopes_from_arg_;
235  vector<T> intercepts_from_arg_;
236 
237  Tensor bounds_device_{Context::GetDeviceType()};
238  Tensor intercepts_device_{Context::GetDeviceType()};
239  Tensor slopes_device_{Context::GetDeviceType()};
240  bool gpu_copied_ = false;
241 
242  // If true, the piecewise linear functions are passed through args,
243  // otherwise, they are passed through Input blobs.
244  bool transform_param_from_arg_;
245 
246  INPUT_TAGS(PREDICTIONS, BOUNDS, SLOPES, INTERCEPTS);
247 };
248 
249 } // namespace caffe2
250 
251 #endif // CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
Definition: any.cpp:108
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13