1 #include "caffe2/perfkernels/adagrad.h" 2 #include "caffe2/perfkernels/cvtsh_ss_bugfix.h" 10 void adagrad_update__avx_f16c(
20 constexpr
size_t kSize = 8;
22 for (; i + kSize <= N; i += kSize) {
23 __m256 gi = _mm256_loadu_ps(g + i);
24 __m256 hi = _mm256_loadu_ps(h + i);
25 __m256 wi = _mm256_loadu_ps(w + i);
27 __m256 nhi = _mm256_add_ps(
28 _mm256_mul_ps(_mm256_set1_ps(decay), hi), _mm256_mul_ps(gi, gi));
29 _mm256_storeu_ps(nh + i, nhi);
30 __m256 vtmp = _mm256_div_ps(
31 gi, _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
33 nw + i, _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp)));
38 float hi = nh[i] = decay * h[i] + gi * gi;
39 nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
43 void adagrad_update_prefetch__avx_f16c(
61 internal::adagrad_update_prefetch_inlined(
62 N, w, w_n, g, h, h_n, nw, nw_n, nh, nh_n, epsilon, lr);
66 void adagrad_fp16_update_prefetch__avx_f16c(
79 constexpr
int kSize = 8;
81 for (; i + kSize <= N; i += kSize) {
82 _mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
83 _mm_prefetch(reinterpret_cast<const char*>(&h_n[i]), _MM_HINT_T0);
84 _mm_prefetch(reinterpret_cast<const char*>(&nw_n[i]), _MM_HINT_T0);
85 _mm_prefetch(reinterpret_cast<const char*>(&nh_n[i]), _MM_HINT_T0);
88 __m256 gi = _mm256_loadu_ps(g + i);
89 __m128i hhi = _mm_loadu_si128(reinterpret_cast<const __m128i*>(h + i));
90 __m256 hi = _mm256_cvtph_ps(hhi);
91 __m128i whi = _mm_loadu_si128(reinterpret_cast<const __m128i*>(w + i));
92 __m256 wi = _mm256_cvtph_ps(whi);
94 __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi));
95 __m128i nhhi = _mm256_cvtps_ph(nhi, 0);
96 _mm_storeu_si128(reinterpret_cast<__m128i*>(nh + i), nhhi);
98 __m256 vtmp = _mm256_div_ps(
99 gi, _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
100 __m256 nwi = _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp));
101 __m128i nhwi = _mm256_cvtps_ph(nwi, 0);
102 _mm_storeu_si128(reinterpret_cast<__m128i*>(nw + i), nhwi);
108 _cvtsh_ss(reinterpret_cast<const unsigned short*>(h)[i]) + gi * gi;
109 reinterpret_cast<unsigned short*
>(nh)[i] = _cvtss_sh(nhi, 0);
110 float nwi = _cvtsh_ss(reinterpret_cast<const unsigned short*>(w)[i]) +
111 lr * gi / (std::sqrt(nhi) + epsilon);
112 reinterpret_cast<unsigned short*
>(nw)[i] = _cvtss_sh(nwi, 0);
116 void rowwise_adagrad_update__avx_f16c(
128 internal::rowwise_adagrad_update_inlined(N, w, w_n, g, h, h_n, epsilon, lr);
131 SPARSE_ADAGRAD_SPECIALIZATION(int32_t, avx_f16c);
132 SPARSE_ADAGRAD_SPECIALIZATION(int64_t, avx_f16c);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...