Caffe2 - C++ API
A deep learning, cross platform ML framework
piecewise_linear_transform_op.h
1 
17 #ifndef CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
18 #define CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 
23 namespace caffe2 {
24 
25 template <typename T, class Context>
26 class PiecewiseLinearTransformOp final : public Operator<Context> {
27  public:
28  USE_OPERATOR_CONTEXT_FUNCTIONS;
29 
30  PiecewiseLinearTransformOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws) {
32  binary_ = OperatorBase::GetSingleArgument<bool>("binary", false);
33 
34  // Retrieve transform params (i.e., the linear functions).
35  bounds_from_arg_ = OperatorBase::GetRepeatedArgument<T>("bounds");
36  slopes_from_arg_ = OperatorBase::GetRepeatedArgument<T>("slopes");
37  intercepts_from_arg_ = OperatorBase::GetRepeatedArgument<T>("intercepts");
38  transform_param_from_arg_ = CheckTransParamFromArg();
39  }
40 
41  bool RunOnDevice() override {
42  return binary_ ? TransformBinary() : TransformGeneral();
43  }
44 
45  private:
46  // num_func_per_group is the number of pieces of linear functions of
47  // each group.
48  // num_group: The number of groups of linear functions. Each group is for
49  // transforming one column of predictions.
50  void InferNumFunctionsPerGroup(
51  const TIndex num_bounds,
52  const TIndex num_slopes,
53  const TIndex num_intercepts,
54  TIndex* num_func_per_group,
55  TIndex* num_group) {
56  CAFFE_ENFORCE_EQ(num_slopes, num_intercepts);
57 
58  // This is based on the facts:
59  // 1. in each group, the num of bounds minus the num of slopes is 1;
60  // 2. each group has the same number of pieces.
61  *num_group = num_bounds - num_slopes;
62  CAFFE_ENFORCE_GT(*num_group, 0);
63  if (binary_) {
64  CAFFE_ENFORCE_EQ(*num_group, 1);
65  }
66  *num_func_per_group = num_slopes / *num_group;
67  CAFFE_ENFORCE_GT(*num_func_per_group, 0);
68  CAFFE_ENFORCE_EQ(num_slopes % *num_group, 0);
69  }
70 
71  bool CheckBoundsSorted(
72  const T* bounds,
73  const TIndex num_bounds_per_group,
74  const TIndex num_group) {
75  const T* start = bounds;
76  for (TIndex i = 0; i < num_group; i++) {
77  if (!std::is_sorted(start, start + num_bounds_per_group)) {
78  return false;
79  }
80  start += num_bounds_per_group;
81  }
82  return true;
83  }
84 
85  // Returns true if the transform params from arg are valid.
86  // Otherwise, we will assume the transform params will pass from Input blobs.
87  bool CheckTransParamFromArg() {
88  int good_param = 0;
89  good_param += bounds_from_arg_.size() > 0;
90  good_param += slopes_from_arg_.size() > 0;
91  good_param += intercepts_from_arg_.size() > 0;
92  CAFFE_ENFORCE(
93  good_param == 0 || good_param == 3,
94  "bounds, slopes, intercepts must be all set or all not set");
95  if (good_param == 3) {
96  TIndex num_func_per_group;
97  TIndex num_group;
98  InferNumFunctionsPerGroup(
99  bounds_from_arg_.size(),
100  slopes_from_arg_.size(),
101  intercepts_from_arg_.size(),
102  &num_func_per_group,
103  &num_group);
104  CAFFE_ENFORCE(
105  CheckBoundsSorted(
106  bounds_from_arg_.data(), num_func_per_group + 1, num_group),
107  "bounds must be sorted for each group");
108  }
109 
110  return good_param == 3;
111  }
112 
113  void setUpTensors(TIndex& num_func_per_group, TIndex& num_group, TIndex M);
114 
115  void GetTransParamData(
116  const T** bounds,
117  const T** slopes,
118  const T** intercepts,
119  TIndex* num_func_per_group,
120  TIndex* num_group) {
121  TIndex num_bounds;
122  TIndex num_slopes;
123  TIndex num_intercepts;
124 
125  if (transform_param_from_arg_) {
126  CAFFE_ENFORCE_EQ(InputSize(), 1);
127  *bounds = bounds_from_arg_.data();
128  *slopes = slopes_from_arg_.data();
129  *intercepts = intercepts_from_arg_.data();
130  num_bounds = bounds_from_arg_.size();
131  num_slopes = slopes_from_arg_.size();
132  num_intercepts = intercepts_from_arg_.size();
133  } else {
134  CAFFE_ENFORCE_EQ(InputSize(), 4);
135  auto& bounds_input = Input(BOUNDS);
136  auto& slopes_input = Input(SLOPES);
137  auto& intercepts_input = Input(INTERCEPTS);
138  *bounds = bounds_input.template data<T>();
139  *slopes = slopes_input.template data<T>();
140  *intercepts = intercepts_input.template data<T>();
141  num_bounds = bounds_input.size();
142  num_slopes = slopes_input.size();
143  num_intercepts = intercepts_input.size();
144  }
145  InferNumFunctionsPerGroup(
146  num_bounds, num_slopes, num_intercepts, num_func_per_group, num_group);
147  }
148 
149  bool TransformGeneral() {
150  auto& X = Input(0);
151  auto* Y = Output(0);
152  CAFFE_ENFORCE_EQ(X.ndim(), 2);
153  TIndex N = X.dim32(0);
154  TIndex M = X.dim32(1);
155  Y->ResizeLike(X);
156  const auto* Xdata = X.template data<T>();
157  T* Ydata = Y->template mutable_data<T>();
158 
159  const T* bounds;
160  const T* slopes;
161  const T* intercepts;
162  TIndex num_func_per_group;
163  TIndex num_group;
164  GetTransParamData(
165  &bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
166  CAFFE_ENFORCE_EQ(num_group, M);
167 
168  for (TIndex j = 0; j < M; ++j) {
169  const T* bounds_group = bounds + j * (num_func_per_group + 1);
170  const T* slopes_group = slopes + j * num_func_per_group;
171  const T* intercepts_group = intercepts + j * num_func_per_group;
172  for (TIndex i = 0; i < N; ++i) {
173  Ydata[i * M + j] = PiecewiseLinearTransform(
174  Xdata[i * M + j],
175  bounds_group,
176  slopes_group,
177  intercepts_group,
178  num_func_per_group);
179  }
180  }
181  return true;
182  }
183 
184  bool TransformBinary() {
185  auto& X = Input(PREDICTIONS);
186  auto* Y = Output(0);
187  CAFFE_ENFORCE(X.ndim() == 1 || X.ndim() == 2);
188  TIndex N = X.dim32(0);
189  TIndex M = X.ndim() == 2 ? X.dim32(1) : 1;
190  CAFFE_ENFORCE(
191  M == 1 || M == 2,
192  "If binary is set to true, the input must be Nx2 or Nx1 tensor");
193  Y->ResizeLike(X);
194  const auto* Xdata = X.template data<T>();
195  T* Ydata = Y->template mutable_data<T>();
196 
197  const T* bounds;
198  const T* slopes;
199  const T* intercepts;
200  TIndex num_func_per_group;
201  TIndex num_group;
202  GetTransParamData(
203  &bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
204  CAFFE_ENFORCE_EQ(num_group, 1);
205 
206  if (M == 1) {
207  for (TIndex i = 0; i < N; ++i) {
208  Ydata[i] = PiecewiseLinearTransform(
209  Xdata[i], bounds, slopes, intercepts, num_func_per_group);
210  }
211  } else {
212  for (TIndex i = 0; i < N; ++i) {
213  Ydata[i * M + 1] = PiecewiseLinearTransform(
214  Xdata[i * M + 1], bounds, slopes, intercepts, num_func_per_group);
215  Ydata[i * M] = 1.0f - Ydata[i * M + 1];
216  }
217  }
218 
219  return true;
220  }
221 
222  T PiecewiseLinearTransform(
223  const T x,
224  const T* bounds,
225  const T* slopes,
226  const T* intercepts,
227  const TIndex num_func_per_group) {
228  T y = 0;
229  // deal with samples out of bounds
230  // make it the same as the upper/lower bound value
231  if (x <= bounds[0]) {
232  y = slopes[0] * bounds[0] + intercepts[0];
233  } else if (x >= bounds[num_func_per_group]) {
234  y = slopes[num_func_per_group - 1] * bounds[num_func_per_group] +
235  intercepts[num_func_per_group - 1];
236  } else {
237  auto low_bound =
238  std::lower_bound(bounds, bounds + num_func_per_group + 1, x);
239  int bounds_idx = low_bound - bounds - 1;
240  // compute the piecewise linear transformation as Y
241  y = slopes[bounds_idx] * x + intercepts[bounds_idx];
242  }
243  return y;
244  }
245 
246  private:
247  bool binary_;
248  vector<T> bounds_from_arg_;
249  vector<T> slopes_from_arg_;
250  vector<T> intercepts_from_arg_;
251 
252  Tensor<Context> bounds_device_;
253  Tensor<Context> intercepts_device_;
254  Tensor<Context> slopes_device_;
255  bool gpu_copied_ = false;
256 
257  // If true, the piecewise linear functions are passed through args,
258  // otherwise, they are passed through Input blobs.
259  bool transform_param_from_arg_;
260 
261  INPUT_TAGS(PREDICTIONS, BOUNDS, SLOPES, INTERCEPTS);
262 };
263 
264 } // namespace caffe2
265 
266 #endif // CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.