1 #include "caffe2/operators/prelu_op.h" 2 #include "caffe2/utils/eigen_utils.h" 3 #include "caffe2/utils/math.h" 5 #include "caffe2/core/types.h" 6 #include "caffe2/utils/cpu_neon.h" 10 #if defined(__ARM_NEON__) || defined(__ARM_NEON) 13 void runNeonPrelu(
float* out,
const float* in,
int size,
float w) {
14 float32x4_t vZero = vdupq_n_f32(0.0f);
15 float32x4_t vW = vdupq_n_f32(w);
17 constexpr
int kVecSizeInFloat =
sizeof(float32x4_t) /
sizeof(
float);
19 if (size < kVecSizeInFloat) {
20 for (
int i = 0; i < size; ++i) {
22 out[i] = v > 0 ? v : v * w;
32 (((uintptr_t) in) % (
sizeof(float32x4_t))) /
sizeof(
float);
37 for (; i < prologue; ++i) {
39 out[i] = v > 0 ? v : v * w;
44 constexpr
int kUnroll = 6;
45 constexpr
int kFloatsPerLoop = kUnroll * kVecSizeInFloat;
47 int remainder = size - prologue;
48 int vectorizable = prologue + (remainder / kFloatsPerLoop) * kFloatsPerLoop;
50 for (; i < vectorizable; i += kFloatsPerLoop) {
51 float32x4_t v0 = vld1q_f32_aligned(in + i + 0);
52 float32x4_t v1 = vld1q_f32_aligned(in + i + 4);
53 float32x4_t v2 = vld1q_f32_aligned(in + i + 8);
54 float32x4_t v3 = vld1q_f32_aligned(in + i + 12);
55 float32x4_t v4 = vld1q_f32_aligned(in + i + 16);
56 float32x4_t v5 = vld1q_f32_aligned(in + i + 20);
58 uint32x4_t gz0 = vcgtq_f32(v0, vZero);
59 uint32x4_t gz1 = vcgtq_f32(v1, vZero);
60 uint32x4_t gz2 = vcgtq_f32(v2, vZero);
61 uint32x4_t gz3 = vcgtq_f32(v3, vZero);
62 uint32x4_t gz4 = vcgtq_f32(v4, vZero);
63 uint32x4_t gz5 = vcgtq_f32(v5, vZero);
65 float32x4_t v0neg = vmulq_f32(v0, vW);
66 float32x4_t v1neg = vmulq_f32(v1, vW);
67 float32x4_t v2neg = vmulq_f32(v2, vW);
68 float32x4_t v3neg = vmulq_f32(v3, vW);
69 float32x4_t v4neg = vmulq_f32(v4, vW);
70 float32x4_t v5neg = vmulq_f32(v5, vW);
73 v0 = vbslq_f32(gz0, v0, v0neg);
74 v1 = vbslq_f32(gz1, v1, v1neg);
75 v2 = vbslq_f32(gz2, v2, v2neg);
76 v3 = vbslq_f32(gz3, v3, v3neg);
77 v4 = vbslq_f32(gz4, v4, v4neg);
78 v5 = vbslq_f32(gz5, v5, v5neg);
80 vst1q_f32(out + i + 0, v0);
81 vst1q_f32(out + i + 4, v1);
82 vst1q_f32(out + i + 8, v2);
83 vst1q_f32(out + i + 12, v3);
84 vst1q_f32(out + i + 16, v4);
85 vst1q_f32(out + i + 20, v5);
88 for (; i < size; ++i) {
90 out[i] = v > 0 ? v : v * w;
95 #endif // defined(__ARM_NEON__) || defined(__ARM_NEON) 98 bool PReluOp<float, CPUContext>::RunOnDevice() {
99 const auto& X = Input(0);
100 const auto& W = Input(1);
102 auto* Y = Output(0, X.sizes(), at::dtype<float>());
103 const auto* Xdata = X.template data<float>();
104 const auto* Wdata = W.template data<float>();
105 auto* Ydata = Y->template mutable_data<float>();
107 const auto C = order_ == StorageOrder::NCHW ? X.size(1) : X.size(X.dim() - 1);
108 const auto C_shared = (W.numel() == 1);
111 CAFFE_ENFORCE_EQ(C, W.numel());
115 #if defined(__ARM_NEON__) || defined(__ARM_NEON) 117 runNeonPrelu(Ydata, Xdata, X.size(), Wdata[0]);
119 ConstEigenVectorMap<float> Xvec(Xdata, X.numel());
120 EigenVectorMap<float> Yvec(Ydata, Y->numel());
121 Yvec = Xvec.cwiseMax(0.f) + Xvec.cwiseMin(0.f) * Wdata[0];
122 #endif // defined(__ARM_NEON__) || defined(__ARM_NEON) 128 case StorageOrder::NCHW: {
129 const auto N = X.size(0);
130 const auto dim = X.size_from_dim(2);
132 #if defined(__ARM_NEON__) || defined(__ARM_NEON) 134 for (
int n = 0; n < N; ++n) {
135 for (
int c = 0; c < C; ++c) {
136 runNeonPrelu(Ydata + (n * C + c) * dim,
137 Xdata + (n * C + c) * dim,
143 for (
int n = 0; n < N; ++n) {
144 for (
int c = 0; c < C; ++c) {
145 ConstEigenVectorMap<float> Xvec(Xdata + nc * dim, dim);
146 EigenVectorMap<float>(Ydata + nc * dim, dim) =
147 Xvec.cwiseMax(0.f) + Xvec.cwiseMin(0.f) * Wdata[c];
154 case StorageOrder::NHWC: {
156 const auto NHW = X.numel() / C;
157 ConstEigenArrayMap<float> Xmat(Xdata, C, NHW);
158 ConstEigenVectorArrayMap<float> Wvec(Wdata, C);
159 EigenArrayMap<float> Ymat(Ydata, C, NHW);
160 Ymat = (Xmat > 0).select(Xmat, Xmat.colwise() * Wvec);
164 CAFFE_THROW(
"Unknown storage order: ", order_);
170 bool PReluGradientOp<float, CPUContext>::RunOnDevice() {
176 CAFFE_ENFORCE(&Y != &X,
"Cannot backpropagate through an in-place PReLU");
178 DCHECK_EQ(dY.numel(), Y.numel());
179 auto* dX = Output(0, Y.sizes(), at::dtype<float>());
180 auto* dW = Output(1, W.sizes(), at::dtype<float>());
182 const auto C = order_ == StorageOrder::NCHW ? X.size(1) : X.size(X.dim() - 1);
183 const auto C_shared = (W.numel() == 1);
185 const float* Ydata = Y.data<
float>();
186 const float* dYdata = dY.data<
float>();
187 const float* Xdata = X.data<
float>();
188 const float* Wdata = W.data<
float>();
189 float* dXdata = dX->template mutable_data<float>();
190 float* dWdata = dW->template mutable_data<float>();
194 case StorageOrder::NCHW: {
195 const auto dim = X.size_from_dim(2);
196 const auto div_factor = C_shared ? C : 1;
197 for (
auto c = 0; c < W.numel(); ++c) {
201 for (
int i = 0; i < Y.numel(); ++i) {
203 int c = (i / dim) % C / div_factor;
204 dWdata[c] += dYdata[i] * Xdata[i];
208 for (
int i = 0; i < Y.numel(); ++i) {
210 dXdata[i] = dYdata[i];
212 int c = (i / dim) % C / div_factor;
213 dXdata[i] = Wdata[c] * dYdata[i];
218 case StorageOrder::NHWC: {
219 const auto NHW = X.numel() / C;
220 ConstEigenVectorArrayMap<float> Wvec(Wdata, W.numel());
221 EigenVectorArrayMap<float> dWvec(dWdata, dW->numel());
223 ConstEigenArrayMap<float> Ymat(Ydata, C, NHW);
224 ConstEigenArrayMap<float> dYmat(dYdata, C, NHW);
225 ConstEigenArrayMap<float> Xmat(Xdata, C, NHW);
226 EigenArrayMap<float> dXmat(dXdata, C, NHW);
229 dXmat = (Xmat > 0).select(dYmat, dYmat * Wdata[0]);
237 dXmat = (Xmat > 0).select(dYmat, dYmat.colwise() * Wvec);
248 CAFFE_THROW(
"Unknown storage order: ", order_);
254 REGISTER_CPU_OPERATOR(PRelu, PReluOp<float, CPUContext>);
255 REGISTER_CPU_GRADIENT_OPERATOR(
257 PReluGradientOp<float, CPUContext>);
260 OPERATOR_SCHEMA(PRelu)
263 .AllowInplace({{0, 0}})
264 .IdenticalTypeAndShapeOfInput(0)
267 The *PRelu* op takes input data tensor $X$, an input slope tensor $slope$, and produces one output tensor $Y$ of the same shape as $X.$ The op performs the element wise *PRelu* operation, defined as 269 $$y=prelu(x) =\begin{cases}slope * x & x < 0\\x & otherwise\end{cases}$$ 271 Note, is slope is size 1, the value is shared across the channels, otherwise $X$ and $slope$ must be the same shape. See [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](https://arxiv.org/abs/1502.01852) for more information. 275 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/prelu_op.h 276 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/prelu_op.cc 281 <summary> <b>Example</b> </summary> 287 workspace.ResetWorkspace() 289 op = core.CreateOperator( 295 workspace.FeedBlob("X", np.random.randn(3, 3).astype(np.float32)) 296 print("X:\n", workspace.FetchBlob("X"), "\n") 298 workspace.FeedBlob("Slope", np.array([0.1]).astype(np.float32)) 299 print("Slope:\n", workspace.FetchBlob("Slope"), "\n") 301 workspace.RunOperatorOnce(op) 302 print("Y:\n", workspace.FetchBlob("Y")) 311 [[ 0.3957382 -0.19725518 -0.26991343] 312 [ 1.5513182 -0.27427664 -0.14584002] 313 [-0.4121164 0.9292345 0.96426094]] 319 [[ 0.3957382 -0.01972552 -0.02699134] 320 [ 1.5513182 -0.02742766 -0.014584 ] 321 [-0.04121164 0.9292345 0.96426094]] 329 .Input(0, "X",
"Input tensor of data to be operated on.")
333 "1D input slope tensor. If `Slope` is of size 1, the value is shared across different channels")
334 .Output(0,
"Y",
"Output tensor, with same shape as $X$.")
335 .InheritOnnxSchema();
338 GRADIENT_OPERATOR_SCHEMA(PReluGradient).NumInputs(4).NumOutputs(2).SetDoc(R
"DOC( 340 PReluGradient takes both Y and dY and uses this to update dX and dW according 341 to the chain rule and derivatives of the rectified linear function. 345 class GetPReluGradient :
public GradientMakerBase {
346 using GradientMakerBase::GradientMakerBase;
347 vector<OperatorDef> GetGradientDefs()
override {
348 return SingleGradientDef(
349 def_.type() +
"Gradient",
351 vector<string>{O(0), GO(0), I(0), I(1)},
352 vector<string>{GI(0), GI(1)});
355 REGISTER_GRADIENT(PRelu, GetPReluGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...