Caffe2 - C++ API
A deep learning, cross platform ML framework
swish_op.cc
1 
17 #include "swish_op.h"
18 #include "caffe2/core/types.h"
19 #include "caffe2/operators/elementwise_op.h"
20 #include "caffe2/utils/math.h"
21 
22 namespace caffe2 {
24  template <typename T>
25  inline void
26  operator()(const int n, const T* x, T* y, CPUContext* /*device_context*/) {
27  ConstEigenVectorArrayMap<T> xM(x, n);
28  EigenVectorArrayMap<T>(y, n) = xM / (1. + (-xM).exp());
29  }
30 };
31 
32 template <>
33 template <typename T>
35  auto& Xin = Input(X);
36  auto& Yin = Input(Y);
37  auto& DYin = Input(DY);
38  auto* DXout = Output(DX);
39  CAFFE_ENFORCE_EQ(Xin.size(), Yin.size());
40  CAFFE_ENFORCE_EQ(DYin.size(), Yin.size());
41  DXout->ResizeLike(Yin);
42 
43  const float* Xdata = Xin.template data<float>();
44  const float* Ydata = Yin.template data<float>();
45  const float* dYdata = DYin.template data<float>();
46  float* dXdata = DXout->template mutable_data<float>();
47 
48  EigenVectorArrayMap<float> dXvec(dXdata, DXout->size());
49  ConstEigenVectorArrayMap<float> Xvec(Xdata, Xin.size());
50  ConstEigenVectorArrayMap<float> Yvec(Ydata, Yin.size());
51  ConstEigenVectorArrayMap<float> dYvec(dYdata, DYin.size());
52 
53  // dx = dy * (y + sigmoid(x)*(1-y))
54  dXvec = dYvec * (Yvec + (1. / (1. + (-Xvec).exp())) * (1. - Yvec));
55  return true;
56 }
57 
58 REGISTER_CPU_OPERATOR(
59  Swish,
62  CPUContext,
64 REGISTER_CPU_OPERATOR(SwishGradient, SwishGradientOp<CPUContext>);
65 
66 // Input: X, output: Y
67 OPERATOR_SCHEMA(Swish)
68  .NumInputs(1)
69  .NumOutputs(1)
70  .IdenticalTypeAndShape()
71  .SetDoc(R"DOC(
72 Swish takes one input data (Tensor<T>) and produces one output data
73 (Tensor<T>) where the swish function, y = x / (1 + exp(-x)), is applied to the
74 tensor elementwise.
75 )DOC")
76  .Input(0, "X", "1D input tensor")
77  .Output(0, "Y", "1D output tensor");
78 // Input: X, Y, dY, output: dX
79 OPERATOR_SCHEMA(SwishGradient)
80  .NumInputs(3)
81  .NumOutputs(1)
82  .AllowInplace({{2, 0}})
83  .SetDoc(R"DOC(
84 SwishGradient takes X, Y and dY and uses this to update dX according to the
85 chain rule and derivatives of the swish function.
86 )DOC");
87 
88 REGISTER_GRADIENT(Swish, GetSwishGradient);
89 } // namespace caffe2
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:82
Copyright (c) 2016-present, Facebook, Inc.