Caffe2 - C++ API
A deep learning, cross platform ML framework
vml.h
1 #pragma once
2 
3 #include <ATen/Config.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/cpu/vec256/functional.h>
6 #include <ATen/cpu/vec256/vec256.h>
7 
8 // This header implements various unary operations using a MKL VML style
9 // interface.
10 
11 // It implements various functions with a simple interface
12 // For example it enables the user to call vsin(float* out, const float* in,
13 // size) This functions takes a pointer to a contious output array of floats and
14 // a constant input array. It will then apply sin to each value in in the input
15 // array and write the result into the output array. out and in may point to the
16 // same memory, i.e. this fully supports in-place operations. These functions
17 // also implement their own parallelization, so take precautions when calling
18 // these from threaded functions.
19 
20 // When MKL is available it will call into MKL's VML library similar to NumPy
21 // If MKL is not available it will use SLEEF.
22 
23 // This file might be compiled under AVX or AVX2 when called from e.g.
24 // UnaryOpsKernel.cpp
25 
26 #include <algorithm>
27 #include <cstddef>
28 #include <cstdint>
29 #include <cstring>
30 #include <iostream>
31 #include <type_traits>
32 
33 #if AT_MKL_ENABLED() && !defined(__APPLE__)
34 #include <mkl.h>
35 #endif
36 
37 // [Note SSE-AVX transitions]
38 // There is a bug in Glibc2.23
39 // https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280. Calling zeroall
40 // when using AVX/AVX2 code resolves this.
41 #if defined(__AVX__) && defined(__GLIBC__) && __GLIBC_MINOR__ == 23
42 #define DL_RUNTIME_BUG(op, type) \
43  volatile type x = (type)(1); \
44  x = std::op(x); \
45  _mm256_zeroall();
46 #else
47 #define DL_RUNTIME_BUG(op, type)
48 #endif
49 
50 namespace at {
51 namespace vml {
52 namespace {
53 
54 using namespace vec256;
55 
56 template <typename scalar_t>
57 inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
58  parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
59  map(
60  [](const Vec256<scalar_t>& x) {
61  return Vec256<scalar_t>((scalar_t)(1)) / x.sqrt();
62  },
63  out + begin,
64  in + begin,
65  end - begin);
66  });
67 }
68 
69 // NB: We ignore numerical errors by convention and leave them to the user
70 
71 // We unfortunately need to duplicate code here to deal with the SSE-AVX
72 // transition bug (see [Note SSE-AVX transitions]). As soon as we can expect
73 // users to use a version of glibc newer than 2.23 we will be able to ditch
74 // this. This duplication is also necessary since not all functions (e.g. rsqrt)
75 // might be part of cmath.
76 
77 #define IMPLEMENT_VML_BUG(op) \
78  template <typename scalar_t> \
79  inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \
80  DL_RUNTIME_BUG(op, scalar_t) \
81  parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { \
82  map([](const Vec256<scalar_t>& x) { return x.op(); }, \
83  out + begin, \
84  in + begin, \
85  end - begin); \
86  }); \
87  }
88 
89 #define IMPLEMENT_VML(op) \
90  template <typename scalar_t> \
91  inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \
92  parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { \
93  map([](const Vec256<scalar_t>& x) { return x.op(); }, \
94  out + begin, \
95  in + begin, \
96  end - begin); \
97  }); \
98  }
99 
100 IMPLEMENT_VML_BUG(abs)
101 IMPLEMENT_VML_BUG(acos)
102 IMPLEMENT_VML_BUG(asin)
103 IMPLEMENT_VML_BUG(atan)
104 IMPLEMENT_VML_BUG(ceil)
105 IMPLEMENT_VML_BUG(cos)
106 // IMPLEMENT_VML_BUG(cosh)
107 IMPLEMENT_VML_BUG(erf)
108 IMPLEMENT_VML_BUG(erfc)
109 IMPLEMENT_VML_BUG(exp)
110 IMPLEMENT_VML_BUG(expm1)
111 IMPLEMENT_VML_BUG(floor)
112 IMPLEMENT_VML(reciprocal)
113 IMPLEMENT_VML_BUG(log)
114 IMPLEMENT_VML_BUG(log10)
115 IMPLEMENT_VML_BUG(log1p)
116 IMPLEMENT_VML_BUG(log2)
117 IMPLEMENT_VML(neg)
118 IMPLEMENT_VML_BUG(sin)
119 // IMPLEMENT_VML_BUG(sinh)
120 IMPLEMENT_VML_BUG(sqrt)
121 IMPLEMENT_VML_BUG(round)
122 IMPLEMENT_VML(rsqrt)
123 IMPLEMENT_VML_BUG(tan)
124 IMPLEMENT_VML_BUG(tanh)
125 IMPLEMENT_VML_BUG(trunc)
126 
127 #if AT_MKL_ENABLED() && !defined(__APPLE__)
128 
129 // NB: LP64 MKL is the most commonly used and thus we assume it here. That means
130 // we need to expect MKL_INT to be of type int, which implies int32_t in most
131 // cases.
132 static_assert(
133  std::is_same<MKL_INT, int32_t>::value,
134  "MKL_INT is assumed to be int32_t");
135 #define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype) \
136  template <> \
137  inline void v##op(type * out, const type * in, int64_t size) { \
138  int64_t max_mkl_ind = std::numeric_limits<MKL_INT>::max(); \
139  if (size <= static_cast<int64_t>(max_mkl_ind)) { \
140  vm##mkltype##mklop( \
141  size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
142  } else { \
143  MKL_INT ind = 0; \
144  int64_t chunks = size / max_mkl_ind; \
145  int64_t rest = size % max_mkl_ind; \
146  for (; ind < chunks; ind++) { \
147  vm##mkltype##mklop( \
148  max_mkl_ind, \
149  in + ind * max_mkl_ind, \
150  out + ind * max_mkl_ind, \
151  VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
152  } \
153  vm##mkltype##mklop( \
154  rest, \
155  in + ind * max_mkl_ind, \
156  out + ind * max_mkl_ind, \
157  VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
158  } \
159  }
160 
161 #define IMPLEMENT_VML_MKL(op, mklop) \
162  IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \
163  IMPLEMENT_VML_MKL_STUB(op, mklop, double, d)
164 
165 // NB: abs, cosh and sinh were temporarily disabled due to issues with Apple clang
166 
167 IMPLEMENT_VML_MKL(abs, Abs)
168 IMPLEMENT_VML_MKL(acos, Acos)
169 IMPLEMENT_VML_MKL(asin, Asin)
170 IMPLEMENT_VML_MKL(atan, Atan)
171 IMPLEMENT_VML_MKL(cos, Cos)
172 // IMPLEMENT_VML_MKL(cosh, Cosh)
173 IMPLEMENT_VML_MKL(erf, Erf)
174 IMPLEMENT_VML_MKL(erfc, Erfc)
175 IMPLEMENT_VML_MKL(exp, Exp)
176 IMPLEMENT_VML_MKL(expm1, Expm1)
177 IMPLEMENT_VML_MKL(log, Ln)
178 IMPLEMENT_VML_MKL(log10, Log10)
179 IMPLEMENT_VML_MKL(log1p, Log1p)
180 IMPLEMENT_VML_MKL(sin, Sin)
181 // IMPLEMENT_VML_MKL(sinh, Sinh)
182 IMPLEMENT_VML_MKL(sqrt, Sqrt)
183 IMPLEMENT_VML_MKL(tan, Tan)
184 IMPLEMENT_VML_MKL(tanh, Tanh)
185 IMPLEMENT_VML_MKL(trunc, Trunc)
186 
187 #if INTEL_MKL_VERSION >= 20180406
188 IMPLEMENT_VML_MKL(log2, Log2)
189 #endif
190 
191 #endif
192 
193 } // namespace
194 } // namespace vml
195 } // namespace at
Flush-To-Zero and Denormals-Are-Zero mode.