1 #include "caffe2/operators/swish_op.h" 6 #include "caffe2/core/types.h" 7 #include "caffe2/utils/eigen_utils.h" 8 #include "caffe2/utils/math.h" 14 bool SwishFunctor<CPUContext>::
15 operator()(
const int N,
const T* X,
T* Y, CPUContext* )
const {
16 ConstEigenVectorArrayMap<T> X_arr(X, N);
17 EigenVectorArrayMap<T>(Y, N) = X_arr / (
T(1) + (-X_arr).exp());
23 bool SwishGradientOp<CPUContext>::DoRunWithType() {
26 auto& DYin = Input(DY);
28 CAFFE_ENFORCE_EQ(Xin.numel(), Yin.numel());
29 CAFFE_ENFORCE_EQ(DYin.numel(), Yin.numel());
30 auto* DXout = Output(DX, Yin.sizes(), at::dtype<float>());
32 const float* Xdata = Xin.template data<float>();
33 const float* Ydata = Yin.template data<float>();
34 const float* dYdata = DYin.template data<float>();
35 float* dXdata = DXout->template mutable_data<float>();
37 EigenVectorArrayMap<float> dXvec(dXdata, DXout->numel());
38 ConstEigenVectorArrayMap<float> Xvec(Xdata, Xin.numel());
39 ConstEigenVectorArrayMap<float> Yvec(Ydata, Yin.numel());
40 ConstEigenVectorArrayMap<float> dYvec(dYdata, DYin.numel());
43 dXvec = dYvec * (Yvec + (
T(1) / (
T(1) + (-Xvec).exp())) * (
T(1) - Yvec));
47 REGISTER_CPU_OPERATOR(
52 SwishFunctor<CPUContext>>);
53 REGISTER_CPU_OPERATOR(SwishGradient, SwishGradientOp<CPUContext>);
56 OPERATOR_SCHEMA(Swish)
59 .IdenticalTypeAndShape()
61 Swish takes one input data (Tensor) and produces one output data 62 (Tensor) where the swish function, y = x / (1 + exp(-x)), is applied to the 65 .Input(0, "X",
"1D input tensor")
66 .Output(0,
"Y",
"1D output tensor");
68 OPERATOR_SCHEMA(SwishGradient)
71 .AllowInplace({{2, 0}})
73 SwishGradient takes X, Y and dY and uses this to update dX according to the 74 chain rule and derivatives of the swish function. 79 class GetSwishGradient :
public GradientMakerBase {
80 using GradientMakerBase::GradientMakerBase;
81 std::vector<OperatorDef> GetGradientDefs()
override {
82 return SingleGradientDef(
85 std::vector<std::string>{I(0), O(0), GO(0)},
86 std::vector<std::string>{GI(0)});
92 REGISTER_GRADIENT(Swish, GetSwishGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...