2 #include <ATen/NativeFunctions.h> 3 #include <ATen/native/SpectralOpsUtils.h> 4 #include <ATen/Config.h> 8 namespace at {
namespace native {
10 Tensor _fft_mkl(
const Tensor& input, int64_t signal_ndim,
11 bool complex_input,
bool complex_output,
12 bool inverse, IntArrayRef checked_signal_sizes,
13 bool normalized,
bool onesided,
14 IntArrayRef output_sizes) {
15 AT_ERROR(
"fft: ATen not compiled with MKL support");
20 #else // AT_MKL_ENABLED 22 #include <ATen/ATen.h> 23 #include <ATen/Config.h> 24 #include <ATen/Dispatch.h> 25 #include <ATen/Utils.h> 26 #include <ATen/NativeFunctions.h> 34 #include <ATen/mkl/Exceptions.h> 35 #include <ATen/mkl/Descriptors.h> 36 #include <ATen/mkl/Limits.h> 42 namespace at {
namespace native {
50 template <
typename scalar_t>
51 static inline void _fft_fill_with_conjugate_symmetry_slice(
Tensor& output,
52 int64_t signal_ndim, int64_t size_last_dim,
53 int64_t start_last_dim_idx, int64_t i, int64_t num) {
54 scalar_t *data = output.data<scalar_t>();
64 std::vector<int64_t> from_slice_indices(signal_ndim);
65 int64_t remainder = i;
67 int64_t from_slice_offset = 0;
68 for (int64_t d = signal_ndim - 1; d >= 0; d--) {
69 int64_t dim_size = output.size(d);
70 int64_t dim_idx = remainder % dim_size;
71 remainder = remainder / dim_size;
72 from_slice_indices[d] = dim_idx;
74 from_slice_offset += dim_idx * output.stride(d);
75 }
else if (dim_idx != 0) {
76 from_slice_offset += (dim_size - dim_idx) * output.stride(d);
81 scalar_t *to_slice_data = data + i * size_last_dim * 2;
82 scalar_t *from_slice_data = data + from_slice_offset;
86 for (int64_t j = start_last_dim_idx; j < size_last_dim; j++) {
88 int64_t to_idx = j * 2;
89 int64_t from_idx = (size_last_dim - j) * 2;
90 to_slice_data[to_idx] = from_slice_data[from_idx];
91 to_slice_data[to_idx + 1] = -from_slice_data[from_idx + 1];
94 to_slice_data += size_last_dim * 2;
95 for (int64_t d = signal_ndim - 1; d >= 0; d--) {
98 from_slice_indices[d] = (from_slice_indices[d] + 1) % output.size(d);
103 if (from_slice_indices[d] == 0) {
105 from_slice_data -= output.stride(d);
106 }
else if (from_slice_indices[d] == 1) {
109 from_slice_data += (output.size(d) - 1) * output.stride(d);
113 from_slice_data -= output.stride(d);
119 from_slice_data = to_slice_data;
129 static inline void _fft_fill_with_conjugate_symmetry_(
Tensor& input,
130 int64_t signal_ndim, int64_t size_last_dim,
131 int64_t last_dim_start_slice) {
132 if (last_dim_start_slice >= size_last_dim) {
137 for (int64_t d = 0; d < signal_ndim; d++) {
138 num *= input.size(d);
142 int nthreads = omp_get_num_threads();
143 int64_t num_slices_per_thread = num / nthreads + 1;
146 int tid = omp_get_thread_num();
147 int64_t start = tid * num_slices_per_thread;
148 AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
"_fft_fill_with_conjugate_symmetry", [&] {
149 _fft_fill_with_conjugate_symmetry_slice<scalar_t>(input, signal_ndim, size_last_dim,
150 last_dim_start_slice, start, std::min(num_slices_per_thread, num - start));
156 AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
"_fft_fill_with_conjugate_symmetry", [&] {
157 _fft_fill_with_conjugate_symmetry_slice<scalar_t>(input, signal_ndim, size_last_dim,
158 last_dim_start_slice, 0, num);
163 Tensor _fft_mkl(
const Tensor&
self, int64_t signal_ndim,
164 bool complex_input,
bool complex_output,
165 bool inverse, IntArrayRef checked_signal_sizes,
166 bool normalized,
bool onesided,
167 IntArrayRef output_sizes) {
168 int64_t batch =
self.size(0);
172 bool need_contiguous = input.stride(-1) != 1;
173 for (int64_t i = 0; !need_contiguous && i <= signal_ndim; i++) {
174 need_contiguous |= input.stride(i) % 2 != 0;
176 if (need_contiguous) {
177 input = input.contiguous();
185 if (
sizeof(MKL_LONG) <
sizeof(int64_t)) {
186 bool need_contiguous =
false;
187 int64_t inumel = 1 , onumel = 1;
188 int64_t isize, osize, istride, ostride;
189 for (int64_t i = signal_ndim; i >= 0; i--) {
190 isize = input.size(i);
191 osize = output_sizes[i];
192 istride = complex_input ? input.stride(i) >> 1 : input.stride(i);
194 AT_CHECK(isize <= MKL_LONG_MAX && osize <= MKL_LONG_MAX && ostride <= MKL_LONG_MAX,
195 "MKL FFT: input signal numel exceeds allowed range [1 ~ ", MKL_LONG_MAX,
"]");
196 if (!need_contiguous && istride > MKL_LONG_MAX) {
202 need_contiguous =
true;
204 AT_CHECK(!need_contiguous || inumel <= MKL_LONG_MAX,
205 "MKL FFT: input signal numel exceeds allowed range [1 ~ ", MKL_LONG_MAX,
"]");
210 Tensor output = at::empty(output_sizes, input.options());
213 DFTI_CONFIG_VALUE prec;
214 if (input.scalar_type() == ScalarType::Float) {
216 }
else if (input.scalar_type() == ScalarType::Double) {
219 std::ostringstream ss;
220 ss <<
"MKL FFT doesn't support tensor of type: " 221 << toString(input.scalar_type());
225 DFTI_CONFIG_VALUE signal_type;
227 signal_type = complex_input ? DFTI_COMPLEX : DFTI_REAL;
229 signal_type = complex_output ? DFTI_COMPLEX : DFTI_REAL;
232 std::vector<MKL_LONG> mkl_signal_sizes(checked_signal_sizes.begin(), checked_signal_sizes.end());
233 DftiDescriptor descriptor;
234 descriptor.init(prec, signal_type, signal_ndim, mkl_signal_sizes.data());
236 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
238 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch));
240 auto istrides = input.strides();
241 auto ostrides = output.strides();
243 MKL_LONG idist = complex_input ? istrides[0] >> 1 : istrides[0];
244 MKL_LONG odist = complex_output ? ostrides[0] >> 1 : ostrides[0];
245 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
246 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist));
249 std::vector<MKL_LONG> mkl_istrides(1 + signal_ndim, 0), mkl_ostrides(1 + signal_ndim, 0);
250 for (int64_t i = 1; i <= signal_ndim; i++) {
251 mkl_istrides[i] = complex_input ? istrides[i] >> 1 : istrides[i];
252 mkl_ostrides[i] = complex_output ? ostrides[i] >> 1 : ostrides[i];
254 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_istrides.data()));
255 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_ostrides.data()));
258 if (!complex_input || !complex_output) {
259 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
262 if (normalized || inverse) {
263 auto signal_numel = at::prod_intlist(checked_signal_sizes);
266 double_scale = 1.0 / std::sqrt(static_cast<double>(signal_numel));
268 double_scale = 1.0 /
static_cast<double>(signal_numel);
270 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(),
271 inverse ? DFTI_BACKWARD_SCALE : DFTI_FORWARD_SCALE,
272 prec == DFTI_DOUBLE ? double_scale :
static_cast<float>(double_scale)));
275 MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get()));
278 MKL_DFTI_CHECK(DftiComputeForward(descriptor.get(), input.data_ptr(), output.data_ptr()));
280 MKL_DFTI_CHECK(DftiComputeBackward(descriptor.get(), input.data_ptr(), output.data_ptr()));
283 if (!complex_input && complex_output && !onesided) {
284 auto size_last_signal_dim = checked_signal_sizes[signal_ndim - 1];
285 auto start_slice = infer_ft_real_to_complex_onesided_size(size_last_signal_dim);
286 _fft_fill_with_conjugate_symmetry_(output, signal_ndim, size_last_signal_dim, start_slice);
Flush-To-Zero and Denormals-Are-Zero mode.