Caffe2 - C++ API
A deep learning, cross platform ML framework
stump_func_op.cc
1 
17 #include "caffe2/operators/stump_func_op.h"
18 
19 namespace caffe2 {
20 
21 template <>
22 bool StumpFuncOp<float, float, CPUContext>::RunOnDevice() {
23  auto& in = Input(0);
24  const float* in_data = in.template data<float>();
25 
26  auto* out = Output(0, in.sizes(), at::dtype<float>());
27  float* out_data = out->template mutable_data<float>();
28  for (int i = 0; i < in.numel(); i++) {
29  out_data[i] = (in_data[i] <= threshold_) ? low_value_ : high_value_;
30  }
31  return true;
32 }
33 
34 template <>
35 bool StumpFuncIndexOp<float, int64_t, CPUContext>::RunOnDevice() {
36  auto& in = Input(0);
37  const float* in_data = in.template data<float>();
38 
39  int lo_cnt = 0;
40  for (int i = 0; i < in.numel(); i++) {
41  lo_cnt += (in_data[i] <= threshold_);
42  }
43  auto* out_lo = Output(0, {lo_cnt}, at::dtype<int64_t>());
44  auto* out_hi = Output(1, {in.numel() - lo_cnt}, at::dtype<int64_t>());
45  int64_t* lo_data = out_lo->template mutable_data<int64_t>();
46  int64_t* hi_data = out_hi->template mutable_data<int64_t>();
47  int lidx = 0;
48  int hidx = 0;
49  for (int i = 0; i < in.numel(); i++) {
50  if (in_data[i] <= threshold_) {
51  lo_data[lidx++] = i;
52  } else {
53  hi_data[hidx++] = i;
54  }
55  }
56  return true;
57 }
58 
59 REGISTER_CPU_OPERATOR(StumpFunc, StumpFuncOp<float, float, CPUContext>);
60 
61 OPERATOR_SCHEMA(StumpFunc)
62  .NumInputs(1)
63  .NumOutputs(1)
64  .Input(0, "X", "tensor of float")
65  .Output(0, "Y", "tensor of float")
66  .SetDoc(R"DOC(
67 Converts each input element into either high_ or low_value
68 based on the given threshold.
69 )DOC");
70 
71 NO_GRADIENT(StumpFunc);
72 
73 REGISTER_CPU_OPERATOR(
74  StumpFuncIndex,
75  StumpFuncIndexOp<float, int64_t, CPUContext>);
76 
77 OPERATOR_SCHEMA(StumpFuncIndex)
78  .NumInputs(1)
79  .NumOutputs(2)
80  .Input(0, "X", "tensor of float")
81  .Output(
82  0,
83  "Index_Low",
84  "tensor of int64 indices for elements below/equal threshold")
85  .Output(
86  1,
87  "Index_High",
88  "tensor of int64 indices for elements above threshold")
89  .SetDoc(R"DOC(
90 Split the elemnts and return the indices based on the given threshold.
91 )DOC");
92 
93 NO_GRADIENT(StumpFuncIndex);
94 
95 } // caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13