1 #include "caffe2/operators/thresholded_relu_op.h" 3 #include "caffe2/utils/eigen_utils.h" 4 #include "caffe2/utils/math.h" 9 bool ThresholdedReluOp<float, CPUContext>::RunOnDevice() {
12 auto* Y = Output(0, X.sizes(), at::dtype<float>());
14 ConstEigenVectorArrayMap<float> Xvec(X.data<
float>(), X.numel());
15 EigenVectorArrayMap<float> Yvec(
16 Y->template mutable_data<float>(), Y->numel());
17 Yvec = (Xvec > alpha_).select(Xvec, 0.f);
30 bool ThresholdedReluGradientOp<float, CPUContext>::RunOnDevice() {
34 CAFFE_ENFORCE_EQ(dY.numel(), Y.numel());
35 auto* dX = Output(0, Y.sizes(), at::dtype<float>());
37 const float* Ydata = Y.data<
float>();
38 const float* dYdata = dY.data<
float>();
39 float* dXdata = dX->template mutable_data<float>();
40 EigenVectorArrayMap<float> dXvec(dXdata, dX->numel());
41 ConstEigenVectorArrayMap<float> Yvec(Ydata, Y.numel());
42 ConstEigenVectorArrayMap<float> dYvec(dYdata, dY.numel());
43 dXvec = dYvec * Yvec.cwiseSign();
52 REGISTER_CPU_OPERATOR(ThresholdedRelu, ThresholdedReluOp<float, CPUContext>);
53 REGISTER_CPU_OPERATOR(
54 ThresholdedReluGradient,
55 ThresholdedReluGradientOp<float, CPUContext>);
58 OPERATOR_SCHEMA(ThresholdedRelu)
61 .AllowInplace({{0, 0}})
62 .CostInferenceFunction(PointwiseCostInference<2>)
63 .IdenticalTypeAndShape()
65 ThresholdedRelu takes one input data (Tensor) and produces one output data 66 (Tensor) where the rectified linear function, y = x for x > alpha, y = 0 67 otherwise, is applied to the tensor elementwise. 69 .Arg("alpha",
"(float) defaults to 1.0.")
70 .Input(0,
"X",
"1D input tensor")
71 .Output(0,
"Y",
"1D input tensor");
74 OPERATOR_SCHEMA(ThresholdedReluGradient)
77 .AllowInplace({{1, 0}})
79 ThresholdedReluGradient takes both Y and dY and uses this to update dX 80 according to the chain rule and derivatives of the rectified linear function. 83 class GetThresholdedReluGradient :
public GradientMakerBase {
84 using GradientMakerBase::GradientMakerBase;
85 vector<OperatorDef> GetGradientDefs()
override {
86 return SingleGradientDef(
87 def_.type() +
"Gradient",
89 vector<string>{O(0), GO(0)},
90 vector<string>{GI(0)});
93 REGISTER_GRADIENT(ThresholdedRelu, GetThresholdedReluGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...