Caffe2 - C++ API
A deep learning, cross platform ML framework
elu_op.cc
1 #include "caffe2/operators/elu_op.h"
2 
3 #include <algorithm>
4 #include <functional>
5 #include <string>
6 
7 #include "caffe2/utils/eigen_utils.h"
8 
9 namespace caffe2 {
10 
11 template <>
12 template <typename T>
13 bool EluFunctor<CPUContext>::
14 operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const {
15  ConstEigenVectorArrayMap<T> X_arr(X, N);
16  EigenVectorMap<T>(Y, N) =
17  (X_arr < 0).select(alpha * (X_arr.exp() - T(1)), X_arr);
18  return true;
19 }
20 
21 template <>
22 template <typename T>
23 bool EluGradientFunctor<CPUContext>::Forward(
24  const std::vector<int>& Y_dims,
25  const std::vector<int>& /* dY_dims */,
26  const T* Y,
27  const T* dY,
28  T* dX,
29  CPUContext* /* context */) const {
30  const int size = std::accumulate(
31  Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
32  ConstEigenVectorArrayMap<T> Y_arr(Y, size);
33  ConstEigenVectorArrayMap<T> dY_arr(dY, size);
34  EigenVectorArrayMap<T>(dX, size) =
35  (Y_arr < 0).select(dY_arr * (Y_arr + alpha), dY_arr);
36  return true;
37 }
38 
39 REGISTER_CPU_OPERATOR(
40  Elu,
41  UnaryElementwiseWithArgsOp<
42  TensorTypes<float>,
43  CPUContext,
44  EluFunctor<CPUContext>>);
45 REGISTER_CPU_GRADIENT_OPERATOR(
46  EluGradient,
47  BinaryElementwiseWithArgsOp<
48  TensorTypes<float>,
49  CPUContext,
50  EluGradientFunctor<CPUContext>>);
51 
52 // Input: X, output: Y
53 OPERATOR_SCHEMA(Elu)
54  .NumInputs(1)
55  .NumOutputs(1)
56  .AllowInplace({{0, 0}})
57  .IdenticalTypeAndShape()
58  .SetDoc(R"DOC(
59 
60 This op implements the exponential linear unit (ELU) activation function as described in [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)](https://arxiv.org/abs/1511.07289). The op takes an input tensor $X$ of arbitrary shape, computes the elementwise elu operation, and returns a vector $Y$ of the same shape as output. The alpha parameter may be passed as an argument, but defaults to 1. The elu operation is defined as
61 
62 $$y=f(x) =\begin{cases}\alpha(e^x-1) & x < 0 \\ x & otherwise\end{cases}$$
63 
64 Github Links:
65 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/elu_op.h
66 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/elu_op.cc
67 
68 <details>
69 
70 <summary> <b>Example</b> </summary>
71 
72 **Code**
73 
74 ```
75 workspace.ResetWorkspace()
76 
77 op = core.CreateOperator(
78  "Elu",
79  ["X"],
80  ["Y"],
81  alpha=1.1
82 )
83 
84 workspace.FeedBlob("X", np.random.randn(3, 3).astype(np.float32))
85 print("X:\n", workspace.FetchBlob("X"), "\n")
86 
87 workspace.RunOperatorOnce(op)
88 print("Y:\n", workspace.FetchBlob("Y"))
89 
90 ```
91 
92 **Result**
93 
94 ```
95 
96 X:
97  [[ 0.35339102 1.1860217 -0.10710736]
98  [-3.1173866 -0.1889988 -0.20330353]
99  [ 1.8525308 -0.368949 0.506277 ]]
100 
101 Y:
102  [[ 0.35339102 1.1860217 -0.11172786]
103  [-1.0513 -0.18943374 -0.20236646]
104  [ 1.8525308 -0.33939326 0.506277 ]]
105 
106 ```
107 
108 </details>
109 
110 )DOC")
111  .Input(0, "X", "1D input tensor of data to be operated on.")
112  .Output(0, "Y", "1D input tensor, calculated as described above.")
113  .Arg(
114  "alpha",
115  "*(type: float; default: 1.0)* Defines alpha parameter used in calculation.")
116  .InheritOnnxSchema();
117 
118 // Input: Y, dY, output: dX
119 GRADIENT_OPERATOR_SCHEMA(EluGradient)
120  .NumInputs(2)
121  .NumOutputs(1)
122  .AllowInplace({{1, 0}})
123  .SetDoc(R"DOC(
124 EluGradient takes both Y and dY and uses this to update dX according to the
125 chain rule and derivatives of the rectified linear function.
126 )DOC");
127 
128 namespace {
129 
130 class GetEluGradient : public GradientMakerBase {
131  using GradientMakerBase::GradientMakerBase;
132  std::vector<OperatorDef> GetGradientDefs() override {
133  return SingleGradientDef(
134  def_.type() + "Gradient",
135  "",
136  std::vector<std::string>{O(0), GO(0)},
137  std::vector<std::string>{GI(0)});
138  }
139 };
140 
141 } // namespace
142 
143 REGISTER_GRADIENT(Elu, GetEluGradient);
144 
145 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13