1 #include "caffe2/operators/relu_n_op.h" 7 #include "caffe2/utils/eigen_utils.h" 13 bool ReluNFunctor<CPUContext>::
14 operator()(
const int N,
const T* X,
T* Y, CPUContext* )
const {
15 EigenVectorMap<T>(Y, N) =
16 ConstEigenVectorMap<T>(X, N).cwiseMax(
T(0)).cwiseMin(
T(n));
22 bool ReluNGradientFunctor<CPUContext>::Forward(
23 const std::vector<int>& Y_dims,
24 const std::vector<int>& ,
29 const int size = std::accumulate(
30 Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
31 ConstEigenVectorArrayMap<T> Y_arr(Y, size);
32 EigenVectorArrayMap<T>(dX, size) =
33 (Y_arr >
T(0) && Y_arr <
T(n))
34 .select(ConstEigenVectorArrayMap<T>(dY, size),
T(0));
40 OpSchema::Cost CostInferenceForReluN(
41 const OperatorDef& def,
42 const vector<TensorShape>& in) {
43 struct OpSchema::Cost cost = PointwiseCostInference<2>(def, in);
44 cost.params_bytes = 0;
50 REGISTER_CPU_OPERATOR(
52 UnaryElementwiseWithArgsOp<
55 ReluNFunctor<CPUContext>>);
56 REGISTER_CPU_OPERATOR(
58 BinaryElementwiseWithArgsOp<
61 ReluNGradientFunctor<CPUContext>>);
64 OPERATOR_SCHEMA(ReluN)
67 .Arg(
"n",
"the cap of output")
68 .AllowInplace({{0, 0}})
69 .CostInferenceFunction(CostInferenceForReluN)
70 .IdenticalTypeAndShape()
72 Relu takes one input data (Tensor) and produces one output data 73 (Tensor) where the rectified linear function, y = min(max(0, x), n), 74 is applied to the tensor elementwise. 76 .Input(0, "X",
"1D input tensor")
77 .Output(0,
"Y",
"1D input tensor");
80 OPERATOR_SCHEMA(ReluNGradient)
83 .Arg(
"n",
"the cap of forward op output")
84 .AllowInplace({{1, 0}})
86 ReluGradient takes both Y and dY and uses this to update dX according to the 87 chain rule and derivatives of the rectified linear function. 92 class GetReluNGradient :
public GradientMakerBase {
93 using GradientMakerBase::GradientMakerBase;
94 std::vector<OperatorDef> GetGradientDefs()
override {
95 return SingleGradientDef(
96 def_.type() +
"Gradient",
98 std::vector<std::string>{O(0), GO(0)},
99 std::vector<std::string>{GI(0)});
105 REGISTER_GRADIENT(ReluN, GetReluNGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...