Caffe2 - C++ API
A deep learning, cross platform ML framework
swish_op.cc
1 #include "caffe2/operators/swish_op.h"
2 
3 #include <string>
4 #include <vector>
5 
6 #include "caffe2/core/types.h"
7 #include "caffe2/utils/eigen_utils.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <>
13 template <typename T>
14 bool SwishFunctor<CPUContext>::
15 operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const {
16  ConstEigenVectorArrayMap<T> X_arr(X, N);
17  EigenVectorArrayMap<T>(Y, N) = X_arr / (T(1) + (-X_arr).exp());
18  return true;
19 }
20 
21 template <>
22 template <typename T>
23 bool SwishGradientOp<CPUContext>::DoRunWithType() {
24  auto& Xin = Input(X);
25  auto& Yin = Input(Y);
26  auto& DYin = Input(DY);
27 
28  CAFFE_ENFORCE_EQ(Xin.numel(), Yin.numel());
29  CAFFE_ENFORCE_EQ(DYin.numel(), Yin.numel());
30  auto* DXout = Output(DX, Yin.sizes(), at::dtype<float>());
31 
32  const float* Xdata = Xin.template data<float>();
33  const float* Ydata = Yin.template data<float>();
34  const float* dYdata = DYin.template data<float>();
35  float* dXdata = DXout->template mutable_data<float>();
36 
37  EigenVectorArrayMap<float> dXvec(dXdata, DXout->numel());
38  ConstEigenVectorArrayMap<float> Xvec(Xdata, Xin.numel());
39  ConstEigenVectorArrayMap<float> Yvec(Ydata, Yin.numel());
40  ConstEigenVectorArrayMap<float> dYvec(dYdata, DYin.numel());
41 
42  // dx = dy * (y + sigmoid(x)*(1-y))
43  dXvec = dYvec * (Yvec + (T(1) / (T(1) + (-Xvec).exp())) * (T(1) - Yvec));
44  return true;
45 }
46 
47 REGISTER_CPU_OPERATOR(
48  Swish,
49  UnaryElementwiseOp<
50  TensorTypes<float>,
51  CPUContext,
52  SwishFunctor<CPUContext>>);
53 REGISTER_CPU_OPERATOR(SwishGradient, SwishGradientOp<CPUContext>);
54 
55 // Input: X, output: Y
56 OPERATOR_SCHEMA(Swish)
57  .NumInputs(1)
58  .NumOutputs(1)
59  .IdenticalTypeAndShape()
60  .SetDoc(R"DOC(
61 Swish takes one input data (Tensor) and produces one output data
62 (Tensor) where the swish function, y = x / (1 + exp(-x)), is applied to the
63 tensor elementwise.
64 )DOC")
65  .Input(0, "X", "1D input tensor")
66  .Output(0, "Y", "1D output tensor");
67 // Input: X, Y, dY, output: dX
68 OPERATOR_SCHEMA(SwishGradient)
69  .NumInputs(3)
70  .NumOutputs(1)
71  .AllowInplace({{2, 0}})
72  .SetDoc(R"DOC(
73 SwishGradient takes X, Y and dY and uses this to update dX according to the
74 chain rule and derivatives of the swish function.
75 )DOC");
76 
77 namespace {
78 
79 class GetSwishGradient : public GradientMakerBase {
80  using GradientMakerBase::GradientMakerBase;
81  std::vector<OperatorDef> GetGradientDefs() override {
82  return SingleGradientDef(
83  "SwishGradient",
84  "",
85  std::vector<std::string>{I(0), O(0), GO(0)},
86  std::vector<std::string>{GI(0)});
87  }
88 };
89 
90 } // namespace
91 
92 REGISTER_GRADIENT(Swish, GetSwishGradient);
93 
94 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13