2 #include <ATen/cpu/vec256/vec256.h> 4 namespace at {
namespace vec256 {
7 template <
typename scalar_t,
typename Op>
8 inline scalar_t vec_reduce_all(
10 vec256::Vec256<scalar_t> acc_vec,
12 using Vec = vec256::Vec256<scalar_t>;
13 scalar_t acc_arr[Vec::size()];
14 acc_vec.store(acc_arr);
15 for (int64_t i = 1; i < size; i++) {
16 scalar_t acc_arr_next[Vec::size()];
17 acc_arr_next[0] = acc_arr[i];
18 Vec acc_vec_next = Vec::loadu(acc_arr_next);
19 acc_vec = vec_fun(acc_vec, acc_vec_next);
21 acc_vec.store(acc_arr);
25 template <
typename scalar_t,
typename Op>
26 inline scalar_t reduce_all(
const Op& vec_fun, scalar_t* data, int64_t size) {
27 using Vec = vec256::Vec256<scalar_t>;
28 if (size < Vec::size())
29 return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
30 int64_t d = Vec::size();
31 Vec acc_vec = Vec::loadu(data);
32 for (; d < size - (size % Vec::size()); d += Vec::size()) {
33 Vec data_vec = Vec::loadu(data + d);
34 acc_vec = vec_fun(acc_vec, data_vec);
37 Vec data_vec = Vec::loadu(data + d, size - d);
38 acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
40 return vec_reduce_all(vec_fun, acc_vec, Vec::size());
43 template <
typename scalar_t,
typename MapOp,
typename ReduceOp>
44 inline scalar_t map_reduce_all(
46 const ReduceOp& red_fun,
49 using Vec = vec256::Vec256<scalar_t>;
50 if (size < Vec::size())
51 return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size);
52 int64_t d = Vec::size();
53 Vec acc_vec = map_fun(Vec::loadu(data));
54 for (; d < size - (size % Vec::size()); d += Vec::size()) {
55 Vec data_vec = Vec::loadu(data + d);
56 data_vec = map_fun(data_vec);
57 acc_vec = red_fun(acc_vec, data_vec);
60 Vec data_vec = Vec::loadu(data + d, size - d);
61 data_vec = map_fun(data_vec);
62 acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
64 return vec_reduce_all(red_fun, acc_vec, Vec::size());
67 template <
typename scalar_t,
typename MapOp,
typename ReduceOp>
68 inline scalar_t map2_reduce_all(
70 const ReduceOp& red_fun,
72 const scalar_t* data2,
74 using Vec = vec256::Vec256<scalar_t>;
75 if (size < Vec::size()) {
76 Vec data_vec = Vec::loadu(data, size);
77 Vec data2_vec = Vec::loadu(data2, size);
78 data_vec = map_fun(data_vec, data2_vec);
79 return vec_reduce_all(red_fun, data_vec, size);
81 int64_t d = Vec::size();
82 Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2));
83 for (; d < size - (size % Vec::size()); d += Vec::size()) {
84 Vec data_vec = Vec::loadu(data + d);
85 Vec data2_vec = Vec::loadu(data2 + d);
86 data_vec = map_fun(data_vec, data2_vec);
87 acc_vec = red_fun(acc_vec, data_vec);
90 Vec data_vec = Vec::loadu(data + d, size - d);
91 Vec data2_vec = Vec::loadu(data2 + d, size - d);
92 data_vec = map_fun(data_vec, data2_vec);
93 acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
95 return vec_reduce_all(red_fun, acc_vec, Vec::size());
98 template <
typename scalar_t,
typename Op>
101 scalar_t* output_data,
102 const scalar_t* input_data,
104 using Vec = vec256::Vec256<scalar_t>;
106 for (; d < size - (size % Vec::size()); d += Vec::size()) {
107 Vec output_vec = vec_fun(Vec::loadu(input_data + d));
108 output_vec.store(output_data + d);
111 Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d));
112 output_vec.store(output_data + d, size - d);
116 template <
typename scalar_t,
typename Op>
119 scalar_t* output_data,
120 scalar_t* input_data,
121 scalar_t* input_data2,
123 using Vec = vec256::Vec256<scalar_t>;
125 for (; d < size - (size % Vec::size()); d += Vec::size()) {
126 Vec data_vec = Vec::loadu(input_data + d);
127 Vec data_vec2 = Vec::loadu(input_data2 + d);
128 Vec output_vec = vec_fun(data_vec, data_vec2);
129 output_vec.store(output_data + d);
132 Vec data_vec = Vec::loadu(input_data + d, size - d);
133 Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
134 Vec output_vec = vec_fun(data_vec, data_vec2);
135 output_vec.store(output_data + d, size - d);
Flush-To-Zero and Denormals-Are-Zero mode.