Caffe2 - C++ API
A deep learning, cross platform ML framework
thresholded_relu_op.cc
1 #include "caffe2/operators/thresholded_relu_op.h"
2 
3 #include "caffe2/utils/eigen_utils.h"
4 #include "caffe2/utils/math.h"
5 
6 namespace caffe2 {
7 
8 template <>
9 bool ThresholdedReluOp<float, CPUContext>::RunOnDevice() {
10  auto& X = Input(0);
11 
12  auto* Y = Output(0, X.sizes(), at::dtype<float>());
13 
14  ConstEigenVectorArrayMap<float> Xvec(X.data<float>(), X.numel());
15  EigenVectorArrayMap<float> Yvec(
16  Y->template mutable_data<float>(), Y->numel());
17  Yvec = (Xvec > alpha_).select(Xvec, 0.f);
18  /* Naive implementation
19  const float* Xdata = X.data<float>();
20  float* Ydata = Y->template mutable_data<float>();
21  for (int i = 0; i < X.size(); ++i) {
22  Xdata[i] -= alpha_;
23  Ydata[i] = std::max(Xdata[i], 0.0f);
24  }
25  */
26  return true;
27 }
28 
29 template <>
30 bool ThresholdedReluGradientOp<float, CPUContext>::RunOnDevice() {
31  auto& Y = Input(0);
32  auto& dY = Input(1);
33 
34  CAFFE_ENFORCE_EQ(dY.numel(), Y.numel());
35  auto* dX = Output(0, Y.sizes(), at::dtype<float>());
36 
37  const float* Ydata = Y.data<float>();
38  const float* dYdata = dY.data<float>();
39  float* dXdata = dX->template mutable_data<float>();
40  EigenVectorArrayMap<float> dXvec(dXdata, dX->numel());
41  ConstEigenVectorArrayMap<float> Yvec(Ydata, Y.numel());
42  ConstEigenVectorArrayMap<float> dYvec(dYdata, dY.numel());
43  dXvec = dYvec * Yvec.cwiseSign();
44  /* Non vectorized implementation
45  for (int i = 0; i < Y.size(); ++i) {
46  dXdata[i] = Ydata[i] > 0 ? dYdata[i] : 0;
47  }
48  */
49  return true;
50 }
51 
52 REGISTER_CPU_OPERATOR(ThresholdedRelu, ThresholdedReluOp<float, CPUContext>);
53 REGISTER_CPU_OPERATOR(
54  ThresholdedReluGradient,
55  ThresholdedReluGradientOp<float, CPUContext>);
56 
57 // Input: X, output: Y
58 OPERATOR_SCHEMA(ThresholdedRelu)
59  .NumInputs(1)
60  .NumOutputs(1)
61  .AllowInplace({{0, 0}})
62  .CostInferenceFunction(PointwiseCostInference<2>)
63  .IdenticalTypeAndShape()
64  .SetDoc(R"DOC(
65 ThresholdedRelu takes one input data (Tensor) and produces one output data
66 (Tensor) where the rectified linear function, y = x for x > alpha, y = 0
67 otherwise, is applied to the tensor elementwise.
68 )DOC")
69  .Arg("alpha", "(float) defaults to 1.0.")
70  .Input(0, "X", "1D input tensor")
71  .Output(0, "Y", "1D input tensor");
72 
73 // Input: Y, dY, output: dX
74 OPERATOR_SCHEMA(ThresholdedReluGradient)
75  .NumInputs(2)
76  .NumOutputs(1)
77  .AllowInplace({{1, 0}})
78  .SetDoc(R"DOC(
79 ThresholdedReluGradient takes both Y and dY and uses this to update dX
80 according to the chain rule and derivatives of the rectified linear function.
81 )DOC");
82 
83 class GetThresholdedReluGradient : public GradientMakerBase {
84  using GradientMakerBase::GradientMakerBase;
85  vector<OperatorDef> GetGradientDefs() override {
86  return SingleGradientDef(
87  def_.type() + "Gradient",
88  "",
89  vector<string>{O(0), GO(0)},
90  vector<string>{GI(0)});
91  }
92 };
93 REGISTER_GRADIENT(ThresholdedRelu, GetThresholdedReluGradient);
94 
95 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13