Caffe2 - C++ API
A deep learning, cross platform ML framework
adagrad_avx.cc
1 #include "caffe2/perfkernels/adagrad.h"
2 #include "caffe2/perfkernels/cvtsh_ss_bugfix.h"
3 
4 #include <emmintrin.h>
5 #include <immintrin.h>
6 
7 namespace caffe2 {
8 
9 // version without prefetching
10 void adagrad_update__avx_f16c(
11  int N,
12  const float* w,
13  const float* g,
14  const float* h,
15  float* nw,
16  float* nh,
17  float epsilon,
18  float decay,
19  float lr) {
20  constexpr size_t kSize = 8;
21  auto i = 0;
22  for (; i + kSize <= N; i += kSize) {
23  __m256 gi = _mm256_loadu_ps(g + i);
24  __m256 hi = _mm256_loadu_ps(h + i);
25  __m256 wi = _mm256_loadu_ps(w + i);
26 
27  __m256 nhi = _mm256_add_ps(
28  _mm256_mul_ps(_mm256_set1_ps(decay), hi), _mm256_mul_ps(gi, gi));
29  _mm256_storeu_ps(nh + i, nhi);
30  __m256 vtmp = _mm256_div_ps(
31  gi, _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
32  _mm256_storeu_ps(
33  nw + i, _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp)));
34  }
35 
36  for (; i < N; ++i) {
37  float gi = g[i];
38  float hi = nh[i] = decay * h[i] + gi * gi;
39  nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
40  }
41 }
42 
43 void adagrad_update_prefetch__avx_f16c(
44  int N,
45  const float* w,
46  const float* w_n, // prefetch ptr
47 
48  const float* g,
49 
50  const float* h,
51  const float* h_n, // prefetch ptr
52 
53  float* nw,
54  float* nw_n, // prefetch ptr
55 
56  float* nh,
57  float* nh_n, // prefetch ptr
58 
59  float epsilon,
60  float lr) {
61  internal::adagrad_update_prefetch_inlined(
62  N, w, w_n, g, h, h_n, nw, nw_n, nh, nh_n, epsilon, lr);
63 }
64 
65 // Compute adagrad sparse, assumes embedding and momentum are at::Half
66 void adagrad_fp16_update_prefetch__avx_f16c(
67  int N,
68  const at::Half* w,
69  const at::Half* w_n, // prefetch ptr
70  const float* g,
71  const at::Half* h,
72  const at::Half* h_n, // prefetch ptr
73  at::Half* nw,
74  at::Half* nw_n, // prefetch ptr
75  at::Half* nh,
76  at::Half* nh_n, // prefetch ptr
77  float epsilon,
78  float lr) {
79  constexpr int kSize = 8;
80  auto i = 0;
81  for (; i + kSize <= N; i += kSize) {
82  _mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
83  _mm_prefetch(reinterpret_cast<const char*>(&h_n[i]), _MM_HINT_T0);
84  _mm_prefetch(reinterpret_cast<const char*>(&nw_n[i]), _MM_HINT_T0);
85  _mm_prefetch(reinterpret_cast<const char*>(&nh_n[i]), _MM_HINT_T0);
86 
87  // only convert momentum and embedding, gradient is fp32
88  __m256 gi = _mm256_loadu_ps(g + i);
89  __m128i hhi = _mm_loadu_si128(reinterpret_cast<const __m128i*>(h + i));
90  __m256 hi = _mm256_cvtph_ps(hhi);
91  __m128i whi = _mm_loadu_si128(reinterpret_cast<const __m128i*>(w + i));
92  __m256 wi = _mm256_cvtph_ps(whi);
93 
94  __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi));
95  __m128i nhhi = _mm256_cvtps_ph(nhi, 0);
96  _mm_storeu_si128(reinterpret_cast<__m128i*>(nh + i), nhhi);
97 
98  __m256 vtmp = _mm256_div_ps(
99  gi, _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
100  __m256 nwi = _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp));
101  __m128i nhwi = _mm256_cvtps_ph(nwi, 0);
102  _mm_storeu_si128(reinterpret_cast<__m128i*>(nw + i), nhwi);
103  }
104 
105  for (; i < N; ++i) {
106  float gi = g[i];
107  float nhi =
108  _cvtsh_ss(reinterpret_cast<const unsigned short*>(h)[i]) + gi * gi;
109  reinterpret_cast<unsigned short*>(nh)[i] = _cvtss_sh(nhi, 0);
110  float nwi = _cvtsh_ss(reinterpret_cast<const unsigned short*>(w)[i]) +
111  lr * gi / (std::sqrt(nhi) + epsilon);
112  reinterpret_cast<unsigned short*>(nw)[i] = _cvtss_sh(nwi, 0);
113  }
114 }
115 
116 void rowwise_adagrad_update__avx_f16c(
117  int N,
118  float* w,
119  float* w_n, // prefetch ptr
120 
121  const float* g,
122 
123  float* h,
124  float* h_n, // prefetch ptr
125 
126  float epsilon,
127  float lr) {
128  internal::rowwise_adagrad_update_inlined(N, w, w_n, g, h, h_n, epsilon, lr);
129 }
130 
131 SPARSE_ADAGRAD_SPECIALIZATION(int32_t, avx_f16c);
132 SPARSE_ADAGRAD_SPECIALIZATION(int64_t, avx_f16c);
133 
134 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13