Caffe2 - C++ API
A deep learning, cross platform ML framework
typed_axpy_avx2.cc
1 
17 #include "caffe2/core/types.h"
18 #include "caffe2/perfkernels/cvtsh_ss_bugfix.h"
19 #include "caffe2/perfkernels/typed_axpy.h"
20 #include "caffe2/utils/math.h"
21 
22 #include <emmintrin.h>
23 #include <immintrin.h>
24 
25 namespace caffe2 {
26 
27 void TypedAxpy_float16_float__avx2_fma(
28  int N,
29  const float a,
30  const float16* x,
31  float* y) {
32  // if x does not start at the 16 byte boundary, we will process the first few.
33  // before we get to a real one.
34  while (((unsigned long)x % 16) && N) {
35  *(y++) += _cvtsh_ss((*(x++)).x) * a;
36  --N;
37  }
38 
39  // From now on we can do vectorized additions using __m256, which is 8 floats,
40  // so we will vectorize every 8 element and then resort to cvtsh_ss.
41  __m256 mma = _mm256_set1_ps(a);
42  int current = 0;
43  const int bound = (N % 8) ? N - 8 : N;
44 
45  for (; current < bound; current += 8) {
46  __m128i mmx_16 =
47  _mm_loadu_si128(reinterpret_cast<const __m128i*>(x + current));
48  __m256 mmx_32 = _mm256_cvtph_ps(mmx_16);
49  __m256 mmy = _mm256_loadu_ps(y + current);
50  mmy = _mm256_fmadd_ps(mmx_32, mma, mmy);
51  _mm256_storeu_ps(y + current, mmy);
52  }
53 
54  if (bound != N) {
55  while (current < N) {
56  y[current] += _cvtsh_ss(x[current].x) * a;
57  ++current;
58  }
59  }
60 }
61 
62 void TypedAxpy_uint8_float__avx2_fma(
63  int N,
64  const float a,
65  const std::uint8_t* x,
66  float* y) {
67  // if x does not start at the 16 byte boundary, we will process the first few.
68  // before we get to a real one.
69  while (((unsigned long)x % 16) && N) {
70  *(y++) += (float)(*(x++)) * a;
71  --N;
72  }
73 
74  // From now on we can do vectorized additions using __m256, which is 8 floats,
75  // so we will vectorize every 8 element and then resort to cvtsh_ss.
76  __m256 mma = _mm256_set1_ps(a);
77  int current = 0;
78  const int bound = (N % 8) ? N - 8 : N;
79 
80  for (; current < bound; current += 8) {
81  __m256i mmx_int32 = _mm256_cvtepi8_epi32(
82  _mm_loadu_si128(reinterpret_cast<const __m128i*>(x + current)));
83  __m256 mmx_fp32 = _mm256_cvtepi32_ps(mmx_int32);
84 
85  __m256 mmy = _mm256_loadu_ps(y + current);
86  mmy = _mm256_fmadd_ps(mmx_fp32, mma, mmy);
87  _mm256_storeu_ps(y + current, mmy);
88  }
89 
90  if (bound != N) {
91  while (current < N) {
92  y[current] += (float)(x[current]) * a;
93  ++current;
94  }
95  }
96 }
97 
98 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.