Caffe2 - C++ API
A deep learning, cross platform ML framework
typed_axpy_avx2.cc
1 #include "caffe2/perfkernels/cvtsh_ss_bugfix.h"
2 
3 #include <c10/util/Half.h>
4 #include <emmintrin.h>
5 #include <immintrin.h>
6 
7 namespace caffe2 {
8 
9 void TypedAxpyHalffloat__avx2_fma(
10  int N,
11  const float a,
12  const at::Half* x,
13  float* y) {
14  // if x does not start at the 16 byte boundary, we will process the first few.
15  // before we get to a real one.
16  while ((reinterpret_cast<unsigned long>(x) % 16) && N) {
17  *(y++) += _cvtsh_ss((*(x++)).x) * a;
18  --N;
19  }
20 
21  // From now on we can do vectorized additions using __m256, which is 8 floats,
22  // so we will vectorize every 8 element and then resort to cvtsh_ss.
23  __m256 mma = _mm256_set1_ps(a);
24  int current = 0;
25  const int bound = (N % 8) ? N - 8 : N;
26 
27  for (; current < bound; current += 8) {
28  __m128i mmx_16 =
29  _mm_loadu_si128(reinterpret_cast<const __m128i*>(x + current));
30  __m256 mmx_32 = _mm256_cvtph_ps(mmx_16);
31  __m256 mmy = _mm256_loadu_ps(y + current);
32  mmy = _mm256_fmadd_ps(mmx_32, mma, mmy);
33  _mm256_storeu_ps(y + current, mmy);
34  }
35 
36  if (bound != N) {
37  while (current < N) {
38  y[current] += _cvtsh_ss(x[current].x) * a;
39  ++current;
40  }
41  }
42 }
43 
44 void TypedAxpy_uint8_float__avx2_fma(
45  int N,
46  const float a,
47  const std::uint8_t* x,
48  float* y) {
49  // if x does not start at the 16 byte boundary, we will process the first few.
50  // before we get to a real one.
51  while ((reinterpret_cast<unsigned long>(x) % 16) && N) {
52  *(y++) += static_cast<float>(*(x++)) * a;
53  --N;
54  }
55 
56  // From now on we can do vectorized additions using __m256, which is 8 floats,
57  // so we will vectorize every 8 element and then resort to cvtsh_ss.
58  __m256 mma = _mm256_set1_ps(a);
59  int current = 0;
60  const int bound = (N % 8) ? N - 8 : N;
61 
62  for (; current < bound; current += 8) {
63  __m256i mmx_int32 = _mm256_cvtepi8_epi32(
64  _mm_loadu_si128(reinterpret_cast<const __m128i*>(x + current)));
65  __m256 mmx_fp32 = _mm256_cvtepi32_ps(mmx_int32);
66 
67  __m256 mmy = _mm256_loadu_ps(y + current);
68  mmy = _mm256_fmadd_ps(mmx_fp32, mma, mmy);
69  _mm256_storeu_ps(y + current, mmy);
70  }
71 
72  if (bound != N) {
73  while (current < N) {
74  y[current] += (float)(x[current]) * a;
75  ++current;
76  }
77  }
78 }
79 
80 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13