3 #include <ATen/Config.h> 4 #include <ATen/Parallel.h> 5 #include <ATen/cpu/vec256/functional.h> 6 #include <ATen/cpu/vec256/vec256.h> 31 #include <type_traits> 33 #if AT_MKL_ENABLED() && !defined(__APPLE__) 41 #if defined(__AVX__) && defined(__GLIBC__) && __GLIBC_MINOR__ == 23 42 #define DL_RUNTIME_BUG(op, type) \ 43 volatile type x = (type)(1); \ 47 #define DL_RUNTIME_BUG(op, type) 54 using namespace vec256;
56 template <
typename scalar_t>
57 inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
58 parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
60 [](
const Vec256<scalar_t>& x) {
61 return Vec256<scalar_t>((scalar_t)(1)) / x.sqrt();
77 #define IMPLEMENT_VML_BUG(op) \ 78 template <typename scalar_t> \ 79 inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \ 80 DL_RUNTIME_BUG(op, scalar_t) \ 81 parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { \ 82 map([](const Vec256<scalar_t>& x) { return x.op(); }, \ 89 #define IMPLEMENT_VML(op) \ 90 template <typename scalar_t> \ 91 inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \ 92 parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { \ 93 map([](const Vec256<scalar_t>& x) { return x.op(); }, \ 100 IMPLEMENT_VML_BUG(abs)
101 IMPLEMENT_VML_BUG(acos)
102 IMPLEMENT_VML_BUG(asin)
103 IMPLEMENT_VML_BUG(atan)
104 IMPLEMENT_VML_BUG(ceil)
105 IMPLEMENT_VML_BUG(cos)
107 IMPLEMENT_VML_BUG(erf)
108 IMPLEMENT_VML_BUG(erfc)
109 IMPLEMENT_VML_BUG(exp)
110 IMPLEMENT_VML_BUG(expm1)
111 IMPLEMENT_VML_BUG(floor)
112 IMPLEMENT_VML(reciprocal)
113 IMPLEMENT_VML_BUG(log)
114 IMPLEMENT_VML_BUG(log10)
115 IMPLEMENT_VML_BUG(log1p)
116 IMPLEMENT_VML_BUG(log2)
118 IMPLEMENT_VML_BUG(sin)
120 IMPLEMENT_VML_BUG(sqrt)
121 IMPLEMENT_VML_BUG(round)
123 IMPLEMENT_VML_BUG(tan)
124 IMPLEMENT_VML_BUG(tanh)
125 IMPLEMENT_VML_BUG(trunc)
127 #if AT_MKL_ENABLED() && !defined(__APPLE__) 133 std::is_same<MKL_INT, int32_t>::value,
134 "MKL_INT is assumed to be int32_t");
135 #define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype) \ 137 inline void v##op(type * out, const type * in, int64_t size) { \ 138 int64_t max_mkl_ind = std::numeric_limits<MKL_INT>::max(); \ 139 if (size <= static_cast<int64_t>(max_mkl_ind)) { \ 140 vm##mkltype##mklop( \ 141 size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \ 144 int64_t chunks = size / max_mkl_ind; \ 145 int64_t rest = size % max_mkl_ind; \ 146 for (; ind < chunks; ind++) { \ 147 vm##mkltype##mklop( \ 149 in + ind * max_mkl_ind, \ 150 out + ind * max_mkl_ind, \ 151 VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \ 153 vm##mkltype##mklop( \ 155 in + ind * max_mkl_ind, \ 156 out + ind * max_mkl_ind, \ 157 VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \ 161 #define IMPLEMENT_VML_MKL(op, mklop) \ 162 IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \ 163 IMPLEMENT_VML_MKL_STUB(op, mklop, double, d) 167 IMPLEMENT_VML_MKL(abs, Abs)
168 IMPLEMENT_VML_MKL(acos, Acos)
169 IMPLEMENT_VML_MKL(asin, Asin)
170 IMPLEMENT_VML_MKL(atan, Atan)
171 IMPLEMENT_VML_MKL(cos, Cos)
173 IMPLEMENT_VML_MKL(erf, Erf)
174 IMPLEMENT_VML_MKL(erfc, Erfc)
175 IMPLEMENT_VML_MKL(exp, Exp)
176 IMPLEMENT_VML_MKL(expm1, Expm1)
177 IMPLEMENT_VML_MKL(log, Ln)
178 IMPLEMENT_VML_MKL(log10, Log10)
179 IMPLEMENT_VML_MKL(log1p, Log1p)
180 IMPLEMENT_VML_MKL(sin, Sin)
182 IMPLEMENT_VML_MKL(sqrt, Sqrt)
183 IMPLEMENT_VML_MKL(tan, Tan)
184 IMPLEMENT_VML_MKL(tanh, Tanh)
185 IMPLEMENT_VML_MKL(trunc, Trunc)
187 #if INTEL_MKL_VERSION >= 20180406 188 IMPLEMENT_VML_MKL(log2, Log2)
Flush-To-Zero and Denormals-Are-Zero mode.