3 #if defined(__AVX__) && !defined(__NVCC__) && \ 4 (defined(__x86_64__) || defined(_M_X64) || defined(__i386__)) 5 #define CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC 8 #include <c10/util/Half.h> 18 static inline void adagrad_update_base_inlined(
28 for (
auto i = 0; i < N; ++i) {
30 float hi = decay * h[i] + gi * gi;
32 nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
53 inline void adagrad_update_prefetch_inlined(
56 #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
65 #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
72 #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
79 #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
89 #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC 90 constexpr
int kSize = 8;
91 for (; i + kSize <= N; i += kSize) {
92 _mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
93 _mm_prefetch(reinterpret_cast<const char*>(&h_n[i]), _MM_HINT_T0);
94 _mm_prefetch(reinterpret_cast<const char*>(&nw_n[i]), _MM_HINT_T0);
95 _mm_prefetch(reinterpret_cast<const char*>(&nh_n[i]), _MM_HINT_T0);
97 __m256 gi = _mm256_loadu_ps(g + i);
98 __m256 hi = _mm256_loadu_ps(h + i);
99 __m256 wi = _mm256_loadu_ps(w + i);
101 __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi));
102 _mm256_storeu_ps(nh + i, nhi);
103 __m256 vtmp = _mm256_div_ps(
104 gi, _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
106 nw + i, _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp)));
110 adagrad_update_base_inlined(
111 N - i, w + i, g + i, h + i, nw + i, nh + i, 1.0f, epsilon, lr);
114 inline void rowwise_adagrad_update_inlined(
117 #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
126 #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
136 #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC 137 constexpr
int kSize = 8;
138 _mm_prefetch(reinterpret_cast<const char*>(h_n), _MM_HINT_T0);
139 __m256 partial_sum = _mm256_setzero_ps();
140 for (; i + kSize <= N; i += kSize) {
141 __m256 gi = _mm256_loadu_ps(g + i);
142 partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(gi, gi));
145 __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
146 __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
147 float final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
148 _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
150 float final_sum = 0.0f;
154 final_sum += g[i] * g[i];
158 float hi = *h = *h + final_sum;
159 float float_step = lr / (std::sqrt(hi) + epsilon);
162 #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC 163 __m256 step = _mm256_set1_ps(float_step);
165 for (i = 0; i + kSize <= N; i += kSize) {
166 _mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
168 __m256 gi = _mm256_loadu_ps(g + i);
169 __m256 wi = _mm256_loadu_ps(w + i);
171 _mm256_storeu_ps(w + i, _mm256_add_ps(wi, _mm256_mul_ps(gi, step)));
177 w[i] = w[i] + gi * float_step;
200 void adagrad_update_prefetch(
221 void adagrad_fp16_update_prefetch(
235 void rowwise_adagrad_update(
264 template <
typename SIndex>
268 std::uint64_t param_size,
272 const SIndex* indices,
278 #define SPARSE_ADAGRAD_SPECIALIZATION(SIndex, ISA) \ 279 int sparse_adagrad_##SIndex##__##ISA( \ 282 std::uint64_t param_size, \ 286 const SIndex* indices, \ 291 for (int i = 0; i < num_rows; ++i) { \ 292 std::uint64_t idx = indices[i]; \ 293 auto offsetI = i * block_size; \ 294 auto offsetIdx = idx * block_size; \ 296 if (block_size + offsetIdx > param_size) { \ 300 if (block_size == 1) { \ 302 float hi = nh[idx] = h[idx] + gi * gi; \ 303 nw[idx] = w[idx] + lr * gi / (std::sqrt(hi) + epsilon); \ 305 const int prefdist_T0 = 16; \ 306 int i_pref = (i < num_rows - prefdist_T0) ? i + prefdist_T0 : i; \ 307 std::uint64_t idx_pref = indices[i_pref]; \ 309 adagrad_update_prefetch__##ISA( \ 312 &w[idx_pref * block_size], \ 315 &h[idx_pref * block_size], \ 317 &nw[idx_pref * block_size], \ 319 &nh[idx_pref * block_size], \ 329 #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC 330 #undef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...