Caffe2 - C++ API
A deep learning, cross platform ML framework
SpectralOps.cpp
1 #include <ATen/ATen.h>
2 #include <ATen/NativeFunctions.h>
3 #include <ATen/native/SpectralOpsUtils.h>
4 #include <ATen/Config.h>
5 
6 #if !AT_MKL_ENABLED()
7 
8 namespace at { namespace native {
9 
10 Tensor _fft_mkl(const Tensor& input, int64_t signal_ndim,
11  bool complex_input, bool complex_output,
12  bool inverse, IntArrayRef checked_signal_sizes,
13  bool normalized, bool onesided,
14  IntArrayRef output_sizes) {
15  AT_ERROR("fft: ATen not compiled with MKL support");
16 }
17 
18 }}
19 
20 #else // AT_MKL_ENABLED
21 
22 #include <ATen/ATen.h>
23 #include <ATen/Config.h>
24 #include <ATen/Dispatch.h>
25 #include <ATen/Utils.h>
26 #include <ATen/NativeFunctions.h>
27 
28 #include <algorithm>
29 #include <vector>
30 #include <numeric>
31 #include <cmath>
32 
33 #include <mkl_dfti.h>
34 #include <ATen/mkl/Exceptions.h>
35 #include <ATen/mkl/Descriptors.h>
36 #include <ATen/mkl/Limits.h>
37 
38 #ifdef _OPENMP
39 #include <omp.h>
40 #endif
41 
42 namespace at { namespace native {
43 
44 // In real-to-complex transform, MKL FFT only fills half of the values due to
45 // conjugate symmetry. See native/SpectralUtils.h for more details.
46 // The following structs are used to fill in the other half with symmetry in
47 // case of real-to-complex transform with onesided=False flag.
48 // See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h.
49 
50 template <typename scalar_t>
51 static inline void _fft_fill_with_conjugate_symmetry_slice(Tensor& output,
52  int64_t signal_ndim, int64_t size_last_dim,
53  int64_t start_last_dim_idx, int64_t i, int64_t num) {
54  scalar_t *data = output.data<scalar_t>();
55 
56  // A slice means a slice of last dimension (of size size_last_dim)
57 
58  // This function iterates through the slices to fill, i.e. to_slice_data
59  // (basically data_slices[i:i+num]), and keeps track of the slices it reads
60  // data from, i.e., from_slice_data, using from_slice_indices, a vector
61  // containing the index of the from_slice_data slice.
62 
63  // Compute the indices for the first from_slice_data
64  std::vector<int64_t> from_slice_indices(signal_ndim); // up to before last signal dim
65  int64_t remainder = i;
66  // set last signal dim values
67  int64_t from_slice_offset = 0;
68  for (int64_t d = signal_ndim - 1; d >= 0; d--) {
69  int64_t dim_size = output.size(d);
70  int64_t dim_idx = remainder % dim_size;
71  remainder = remainder / dim_size;
72  from_slice_indices[d] = dim_idx;
73  if (d == 0) {
74  from_slice_offset += dim_idx * output.stride(d);
75  } else if (dim_idx != 0) {
76  from_slice_offset += (dim_size - dim_idx) * output.stride(d);
77  }
78  }
79 
80  // First to_slice_data and from_slice_data
81  scalar_t *to_slice_data = data + i * size_last_dim * 2;
82  scalar_t *from_slice_data = data + from_slice_offset;
83 
84  while (num > 0) {
85  // Fill to_slice_data from values in from_slice_data
86  for (int64_t j = start_last_dim_idx; j < size_last_dim; j++) {
87  // multiply index by 2 because of the last complex dim has size 2
88  int64_t to_idx = j * 2;
89  int64_t from_idx = (size_last_dim - j) * 2;
90  to_slice_data[to_idx] = from_slice_data[from_idx];
91  to_slice_data[to_idx + 1] = -from_slice_data[from_idx + 1];
92  }
93  // Compute the next to_slice_data and from_slice_data slices
94  to_slice_data += size_last_dim * 2;
95  for (int64_t d = signal_ndim - 1; d >= 0; d--) {
96  // Compute the next index at this dimension using conjugate symmetry
97  // Break out of this loop if nothing carries over
98  from_slice_indices[d] = (from_slice_indices[d] + 1) % output.size(d);
99  if (d > 0) {
100  // At d > 0 nonbatch dim, to get next from_slice_data offset
101  // 1. if this dim idx becomes 1, will need to add (size - 1) * stride
102  // 2. otherwise, will need to subtract stride
103  if (from_slice_indices[d] == 0) {
104  // Substract. Carries over to previous dimension
105  from_slice_data -= output.stride(d);
106  } else if (from_slice_indices[d] == 1) {
107  // Dimension index becomes 1
108  // Doesn't carry over to previous dimension
109  from_slice_data += (output.size(d) - 1) * output.stride(d);
110  break;
111  } else {
112  // Substract. Doesn't carry over to previous dimension
113  from_slice_data -= output.stride(d);
114  break;
115  }
116  } else {
117  // At d = 0 nonbatch dim, it means that to_slice_data ise now at a the
118  // beginning of a data sample. It maps to itself by conjugate symmetry.
119  from_slice_data = to_slice_data;
120  }
121  }
122  num--;
123  }
124 }
125 
126 // input should be a contiguous batched tensor of same size as full (twosided)
127 // signals, but only contains half (onesided) of the values.
128 // This function modifies inplace.
129 static inline void _fft_fill_with_conjugate_symmetry_(Tensor& input,
130  int64_t signal_ndim, int64_t size_last_dim,
131  int64_t last_dim_start_slice) {
132  if (last_dim_start_slice >= size_last_dim) {
133  return;
134  }
135 
136  int64_t num = 1;
137  for (int64_t d = 0; d < signal_ndim; d++) {
138  num *= input.size(d);
139  }
140 #ifdef _OPENMP
141  if (num > 500) {
142  int nthreads = omp_get_num_threads();
143  int64_t num_slices_per_thread = num / nthreads + 1;
144  #pragma omp parallel
145  {
146  int tid = omp_get_thread_num();
147  int64_t start = tid * num_slices_per_thread;
148  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "_fft_fill_with_conjugate_symmetry", [&] {
149  _fft_fill_with_conjugate_symmetry_slice<scalar_t>(input, signal_ndim, size_last_dim,
150  last_dim_start_slice, start, std::min(num_slices_per_thread, num - start));
151  });
152  }
153  return;
154  }
155 #endif
156  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "_fft_fill_with_conjugate_symmetry", [&] {
157  _fft_fill_with_conjugate_symmetry_slice<scalar_t>(input, signal_ndim, size_last_dim,
158  last_dim_start_slice, 0, num);
159  });
160 }
161 
162 // MKL DFTI
163 Tensor _fft_mkl(const Tensor& self, int64_t signal_ndim,
164  bool complex_input, bool complex_output,
165  bool inverse, IntArrayRef checked_signal_sizes,
166  bool normalized, bool onesided,
167  IntArrayRef output_sizes) {
168  int64_t batch = self.size(0);
169  Tensor input = self;
170  // real/imag dimension must aligned when viewed as of complex type
171  if (complex_input) {
172  bool need_contiguous = input.stride(-1) != 1;
173  for (int64_t i = 0; !need_contiguous && i <= signal_ndim; i++) {
174  need_contiguous |= input.stride(i) % 2 != 0;
175  }
176  if (need_contiguous) {
177  input = input.contiguous();
178  }
179  }
180 
181  // check if we can use MKL because MKL_LONG is 32bit on some OS, e.g. Windows
182  // need to check input and output size and strides
183  // be careful about complex domain, where the stride needs to be divided by 2
184  // only need to test upper bound MKL_LONG_MAX as these values are non-negative
185  if (sizeof(MKL_LONG) < sizeof(int64_t)) {
186  bool need_contiguous = false;
187  int64_t inumel = 1 /* istride if we contiguous-fy */, onumel = 1;
188  int64_t isize, osize, istride, ostride;
189  for (int64_t i = signal_ndim; i >= 0; i--) {
190  isize = input.size(i);
191  osize = output_sizes[i];
192  istride = complex_input ? input.stride(i) >> 1 : input.stride(i);
193  ostride = onumel;
194  AT_CHECK(isize <= MKL_LONG_MAX && osize <= MKL_LONG_MAX && ostride <= MKL_LONG_MAX,
195  "MKL FFT: input signal numel exceeds allowed range [1 ~ ", MKL_LONG_MAX, "]");
196  if (!need_contiguous && istride > MKL_LONG_MAX) {
197  // If we didn't plan to contiguous-fy but the `istride` exceeds bound,
198  // check if we can stride (equal to `inumel`) get back within bound if
199  // we contiguous-fy. If so, then we need to always check `inumel`
200  // instead for the remaining iterations. The iterations before this are
201  // fine as `inumel` is non-decreasing.
202  need_contiguous = true;
203  }
204  AT_CHECK(!need_contiguous || inumel <= MKL_LONG_MAX,
205  "MKL FFT: input signal numel exceeds allowed range [1 ~ ", MKL_LONG_MAX, "]");
206  inumel *= isize;
207  onumel *= osize;
208  }
209  }
210  Tensor output = at::empty(output_sizes, input.options());
211 
212  // precision
213  DFTI_CONFIG_VALUE prec;
214  if (input.scalar_type() == ScalarType::Float) {
215  prec = DFTI_SINGLE;
216  } else if (input.scalar_type() == ScalarType::Double) {
217  prec = DFTI_DOUBLE;
218  } else {
219  std::ostringstream ss;
220  ss << "MKL FFT doesn't support tensor of type: "
221  << toString(input.scalar_type());
222  AT_ERROR(ss.str());
223  }
224  // signal type
225  DFTI_CONFIG_VALUE signal_type;
226  if (!inverse) {
227  signal_type = complex_input ? DFTI_COMPLEX : DFTI_REAL;
228  } else {
229  signal_type = complex_output ? DFTI_COMPLEX : DFTI_REAL;
230  }
231  // create descriptor with signal size
232  std::vector<MKL_LONG> mkl_signal_sizes(checked_signal_sizes.begin(), checked_signal_sizes.end());
233  DftiDescriptor descriptor;
234  descriptor.init(prec, signal_type, signal_ndim, mkl_signal_sizes.data());
235  // out of place FFT
236  MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
237  // batch mode
238  MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch));
239 
240  auto istrides = input.strides();
241  auto ostrides = output.strides();
242  // batch dim stride, i.e., dist between each data
243  MKL_LONG idist = complex_input ? istrides[0] >> 1 : istrides[0];
244  MKL_LONG odist = complex_output ? ostrides[0] >> 1 : ostrides[0];
245  MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
246  MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist));
247  // signal strides
248  // first val is offset, set to zero (ignored)
249  std::vector<MKL_LONG> mkl_istrides(1 + signal_ndim, 0), mkl_ostrides(1 + signal_ndim, 0);
250  for (int64_t i = 1; i <= signal_ndim; i++) {
251  mkl_istrides[i] = complex_input ? istrides[i] >> 1 : istrides[i];
252  mkl_ostrides[i] = complex_output ? ostrides[i] >> 1 : ostrides[i];
253  }
254  MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_istrides.data()));
255  MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_ostrides.data()));
256  // if conjugate domain of real is involved, set standard CCE storage type
257  // this will become default in MKL in future
258  if (!complex_input || !complex_output) {
259  MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
260  }
261  // rescale if needed by normalized flag or inverse transform
262  if (normalized || inverse) {
263  auto signal_numel = at::prod_intlist(checked_signal_sizes);
264  double double_scale;
265  if (normalized) {
266  double_scale = 1.0 / std::sqrt(static_cast<double>(signal_numel));
267  } else {
268  double_scale = 1.0 / static_cast<double>(signal_numel);
269  }
270  MKL_DFTI_CHECK(DftiSetValue(descriptor.get(),
271  inverse ? DFTI_BACKWARD_SCALE : DFTI_FORWARD_SCALE,
272  prec == DFTI_DOUBLE ? double_scale : static_cast<float>(double_scale)));
273  }
274  // finalize
275  MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get()));
276  // run
277  if (!inverse) {
278  MKL_DFTI_CHECK(DftiComputeForward(descriptor.get(), input.data_ptr(), output.data_ptr()));
279  } else {
280  MKL_DFTI_CHECK(DftiComputeBackward(descriptor.get(), input.data_ptr(), output.data_ptr()));
281  }
282  // now if needed, fill out the other half using Hermitian symmetry dim
283  if (!complex_input && complex_output && !onesided) {
284  auto size_last_signal_dim = checked_signal_sizes[signal_ndim - 1];
285  auto start_slice = infer_ft_real_to_complex_onesided_size(size_last_signal_dim);
286  _fft_fill_with_conjugate_symmetry_(output, signal_ndim, size_last_signal_dim, start_slice);
287  }
288  return output;
289 }
290 
291 }} // namespace at::native
292 
293 #endif
Flush-To-Zero and Denormals-Are-Zero mode.