17 #include "caffe2/operators/stump_func_op.h" 22 bool StumpFuncOp<float, float, CPUContext>::RunOnDevice() {
24 const float* in_data = in.template data<float>();
26 auto* out = Output(0, in.sizes(), at::dtype<float>());
27 float* out_data = out->template mutable_data<float>();
28 for (
int i = 0; i < in.numel(); i++) {
29 out_data[i] = (in_data[i] <= threshold_) ? low_value_ : high_value_;
35 bool StumpFuncIndexOp<float, int64_t, CPUContext>::RunOnDevice() {
37 const float* in_data = in.template data<float>();
40 for (
int i = 0; i < in.numel(); i++) {
41 lo_cnt += (in_data[i] <= threshold_);
43 auto* out_lo = Output(0, {lo_cnt}, at::dtype<int64_t>());
44 auto* out_hi = Output(1, {in.numel() - lo_cnt}, at::dtype<int64_t>());
45 int64_t* lo_data = out_lo->template mutable_data<int64_t>();
46 int64_t* hi_data = out_hi->template mutable_data<int64_t>();
49 for (
int i = 0; i < in.numel(); i++) {
50 if (in_data[i] <= threshold_) {
59 REGISTER_CPU_OPERATOR(StumpFunc, StumpFuncOp<float, float, CPUContext>);
61 OPERATOR_SCHEMA(StumpFunc)
64 .Input(0,
"X",
"tensor of float")
65 .Output(0,
"Y",
"tensor of float")
67 Converts each input element into either high_ or low_value 68 based on the given threshold. 71 NO_GRADIENT(StumpFunc); 73 REGISTER_CPU_OPERATOR( 75 StumpFuncIndexOp<float, int64_t, CPUContext>); 77 OPERATOR_SCHEMA(StumpFuncIndex) 80 .Input(0, "X",
"tensor of float")
84 "tensor of int64 indices for elements below/equal threshold")
88 "tensor of int64 indices for elements above threshold")
90 Split the elemnts and return the indices based on the given threshold. 93 NO_GRADIENT(StumpFuncIndex); A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...