Caffe2 - C++ API A deep learning, cross platform ML framework
1 #pragma once
2
3 #if defined(__AVX__) && !defined(__NVCC__) && \
4  (defined(__x86_64__) || defined(_M_X64) || defined(__i386__))
6 #include <immintrin.h>
7 #endif
8 #include <c10/util/Half.h>
9
10 namespace caffe2 {
11
12 namespace internal {
13
14 // The following functions inside internal namespace are inlined because they
15 // are performance critical.
16
17 template <typename T>
19  int N,
20  const T* w,
21  const float* g,
22  const T* h,
23  T* nw,
24  T* nh,
25  float decay,
26  float epsilon,
27  float lr) {
28  for (auto i = 0; i < N; ++i) {
29  float gi = g[i];
30  float hi = decay * h[i] + gi * gi;
31  nh[i] = hi;
32  nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
33  }
34 }
35
36 // version with prefetching
37 // TODO(msmelyan)
38 // Crux of the computation is computing a / (sqrt(b) + epsilon),
39 // where a and b are vectors and epislon is very small (eg., 10^-5) and does not
40 // change. Today it's computed using two vector sqrt and vector divide simd
41 // instructions. It is slow. We can take advantage of existing fast vector
42 // VRSQRTPS instruction that computes approximate reciprocals of square roots
43 // of the vector. It is 6x faster than vsrt and vdiv combinations. Since the
44 // addition of epislon is just done to avoid division by zero, we approximate a
45 // / (sqrt(b) + epsilon) by a / (sqrt(b + sqrt(epsilon)) If we do that, we can
46 // use VRSQRTPS instead now. VRSQRTPS is not very accurate. Specifically, for
47 // the test on random numbers between 0.1 and 1 the absolute error was about
48 // 10^-3 compared to using slower but more accurate combination of vsqrt and
49 // vdiv. Extend Marat's function with more NR iterations to get more accuracy
50 // for training
51 // TODO(msmelyan)
52 // explore streaming stores, but need to have unique indices (deduplication)
54  int N,
55  const float* w,
57  const float* w_n, // prefetch ptr
58 #else
59  const float* /* unused */,
60 #endif
61
62  const float* g,
63
64  const float* h,
66  const float* h_n, // prefetch ptr
67 #else
68  const float* /* unused */,
69 #endif
70
71  float* nw,
73  float* nw_n, // prefetch ptr
74 #else
75  float* /* unused */,
76 #endif
77
78  float* nh,
80  float* nh_n, // prefetch ptr
81 #else
82  float* /* unused */,
83 #endif
84
85  float epsilon,
86  float lr) {
87  auto i = 0;
88
90  constexpr int kSize = 8;
91  for (; i + kSize <= N; i += kSize) {
92  _mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
93  _mm_prefetch(reinterpret_cast<const char*>(&h_n[i]), _MM_HINT_T0);
94  _mm_prefetch(reinterpret_cast<const char*>(&nw_n[i]), _MM_HINT_T0);
95  _mm_prefetch(reinterpret_cast<const char*>(&nh_n[i]), _MM_HINT_T0);
96
97  __m256 gi = _mm256_loadu_ps(g + i);
98  __m256 hi = _mm256_loadu_ps(h + i);
99  __m256 wi = _mm256_loadu_ps(w + i);
100
101  __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi));
102  _mm256_storeu_ps(nh + i, nhi);
103  __m256 vtmp = _mm256_div_ps(
105  _mm256_storeu_ps(
106  nw + i, _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp)));
107  }
108 #endif
109
111  N - i, w + i, g + i, h + i, nw + i, nh + i, 1.0f, epsilon, lr);
112 }
113
115  int N,
116  float* w,
118  float* w_n, // prefetch ptr
119 #else
120  float* /* unused */,
121 #endif
122
123  const float* g,
124
125  float* h,
127  float* h_n, // prefetch ptr
128 #else
129  float* /* unused */,
130 #endif
131
132  float epsilon,
133  float lr) {
134  auto i = 0;
135
137  constexpr int kSize = 8;
138  _mm_prefetch(reinterpret_cast<const char*>(h_n), _MM_HINT_T0);
139  __m256 partial_sum = _mm256_setzero_ps();
140  for (; i + kSize <= N; i += kSize) {
141  __m256 gi = _mm256_loadu_ps(g + i);
142  partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(gi, gi));
143  }
144  // Reduce sum to 1 value
145  __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
146  __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
147  float final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
148  _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
149 #else
150  float final_sum = 0.0f;
151 #endif
152
153  for (; i < N; ++i) {
154  final_sum += g[i] * g[i];
155  }
156  final_sum /= N;
157
158  float hi = *h = *h + final_sum;
159  float float_step = lr / (std::sqrt(hi) + epsilon);
160
161  i = 0;
163  __m256 step = _mm256_set1_ps(float_step);
164
165  for (i = 0; i + kSize <= N; i += kSize) {
166  _mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
167
168  __m256 gi = _mm256_loadu_ps(g + i);
169  __m256 wi = _mm256_loadu_ps(w + i);
170
171  _mm256_storeu_ps(w + i, _mm256_add_ps(wi, _mm256_mul_ps(gi, step)));
172  }
173 #endif
174
175  for (; i < N; ++i) {
176  float gi = g[i];
177  w[i] = w[i] + gi * float_step;
178  }
179 }
180
181 } // namespace internal
182
183 // version with prefetching
184 // TODO(msmelyan)
185 // Crux of the computation is computing a / (sqrt(b) + epsilon),
186 // where a and b are vectors and epislon is very small (eg., 10^-5) and does not
187 // change. Today it's computed using two vector sqrt and vector divide simd
188 // instructions. It is slow. We can take advantage of existing fast vector
189 // VRSQRTPS instruction that computes approximate reciprocals of square roots
190 // of the vector. It is 6x faster than vsrt and vdiv combinations. Since the
191 // addition of epislon is just done to avoid division by zero, we approximate a
192 // / (sqrt(b) + epsilon) by a / (sqrt(b + sqrt(epsilon)) If we do that, we can
193 // use VRSQRTPS instead now. VRSQRTPS is not very accurate. Specifically, for
194 // the test on random numbers between 0.1 and 1 the absolute error was about
195 // 10^-3 compared to using slower but more accurate combination of vsqrt and
196 // vdiv. Extend Marat's function with more NR iterations to get more accuracy
197 // for training
198 // TODO(msmelyan)
199 // explore streaming stores, but need to have inuque indices (deduplication)
201  int N,
202  const float* w,
203  const float* w_n, // prefetch ptr
204
205  const float* g,
206
207  const float* h,
208  const float* h_n, // prefetch ptr
209
210  float* nw,
211  float* nw_n, // prefetch ptr
212
213  float* nh,
214  float* nh_n, // prefetch ptr
215
216  float epsilon,
217  float lr);
218
219 // Version with prefetching for embeddings and
220 // momentum using fp16
222  int N,
223  const at::Half* w,
224  const at::Half* w_n, // prefetch ptr
225  const float* g,
226  const at::Half* h,
227  const at::Half* h_n, // prefetch ptr
228  at::Half* nw,
229  at::Half* nw_n, // prefetch ptr
230  at::Half* nh,
231  at::Half* nh_n, // prefetch ptr
232  float epsilon,
233  float lr);
234
236  int N,
237  float* w,
238  float* w_n, // prefetch ptr
239
240  const float* g,
241
242  float* h,
243  float* h_n, // prefetch ptr
244
245  float epsilon,
246  float lr);
247
248 // version without prefetching
250  int N,
251  const float* w,
252  const float* g,
253  const float* h,
254  float* nw,
255  float* nh,
256  float epsilon,
257  float decay,
258  float lr);
259
264 template <typename SIndex>
266  int num_rows, // number of rows reading
267  int block_size, // number of parameters per rows
268  std::uint64_t param_size, // total number of parameters
269  const float* w, // input parameters
270  const float* g, // input gradients
271  const float* h, // input momentums
272  const SIndex* indices, // indices of each row
273  float* nw, // output parameters
274  float* nh, // output momentums
275  float epsilon,
276  float lr);
277
280  int num_rows, \
281  int block_size, \
282  std::uint64_t param_size, \
283  const float* w, \
284  const float* g, \
285  const float* h, \
286  const SIndex* indices, \
287  float* nw, \
288  float* nh, \
289  float epsilon, \
290  float lr) { \
291  for (int i = 0; i < num_rows; ++i) { \
292  std::uint64_t idx = indices[i]; \
293  auto offsetI = i * block_size; \
294  auto offsetIdx = idx * block_size; \
295  \
296  if (block_size + offsetIdx > param_size) { \
297  return i; \
298  } \
299  \
300  if (block_size == 1) { \
301  float gi = g[i]; \
302  float hi = nh[idx] = h[idx] + gi * gi; \
303  nw[idx] = w[idx] + lr * gi / (std::sqrt(hi) + epsilon); \
304  } else { \
305  const int prefdist_T0 = 16; \
306  int i_pref = (i < num_rows - prefdist_T0) ? i + prefdist_T0 : i; \
307  std::uint64_t idx_pref = indices[i_pref]; \
308  \
310  block_size, \
311  w + offsetIdx, \
312  &w[idx_pref * block_size], \
313  g + offsetI, \
314  h + offsetIdx, \
315  &h[idx_pref * block_size], \
316  nw + offsetIdx, \
317  &nw[idx_pref * block_size], \
318  nh + offsetIdx, \
319  &nh[idx_pref * block_size], \
320  epsilon, \
321  lr); \
322  } \
323  } \
324  return num_rows; \
325  };
326
327 } // namespace caffe2
328