| import torch |
| import unittest |
| import math |
| from contextlib import contextmanager |
| from itertools import product |
| |
| from torch.testing._internal.common_utils import \ |
| (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA) |
| from torch.testing._internal.common_device_type import \ |
| (instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA, precisionOverride, |
| skipCPUIfNoMkl, skipCUDAIfRocm, deviceCountAtLeast, onlyCUDA) |
| |
| if TEST_NUMPY: |
| import numpy as np |
| |
| |
| if TEST_LIBROSA: |
| import librosa |
| |
| # saves the torch.fft function that's clobbered by importing the torch.fft module |
| fft_fn = torch.fft |
| import torch.fft |
| |
| # Tests of functions related to Fourier analysis in the torch.fft namespace |
| class TestFFT(TestCase): |
| exact_dtype = True |
| |
| @skipCPUIfNoMkl |
| @skipCUDAIfRocm |
| def test_fft_function_clobbered(self, device): |
| t = torch.randn((100, 2), device=device) |
| eager_result = fft_fn(t, 1) |
| |
| def method_fn(t): |
| return t.fft(1) |
| scripted_method_fn = torch.jit.script(method_fn) |
| |
| self.assertEqual(scripted_method_fn(t), eager_result) |
| |
| with self.assertRaisesRegex(TypeError, "'module' object is not callable"): |
| torch.fft(t, 1) |
| |
| @skipCPUIfNoMkl |
| @skipCUDAIfRocm |
| @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') |
| @precisionOverride({torch.complex64: 1e-4}) |
| @dtypes(torch.complex64, torch.complex128) |
| def test_fft(self, device, dtype): |
| test_inputs = (torch.randn(67, device=device, dtype=dtype), |
| torch.randn(4029, device=device, dtype=dtype)) |
| |
| def fn(t): |
| return torch.fft.fft(t) |
| scripted_fn = torch.jit.script(fn) |
| |
| # TODO: revisit the following function if t.fft() becomes torch.fft.fft |
| # def method_fn(t): |
| # return t.fft() |
| # scripted_method_fn = torch.jit.script(method_fn) |
| |
| # TODO: revisit the following function if t.fft() becomes torch.fft.fft |
| # torch_fns = (torch.fft.fft, torch.Tensor.fft, scripted_fn, scripted_method_fn) |
| torch_fns = (torch.fft.fft, scripted_fn) |
| |
| for input in test_inputs: |
| expected = np.fft.fft(input.cpu().numpy()) |
| for fn in torch_fns: |
| actual = fn(input) |
| self.assertEqual(actual, expected, exact_dtype=(dtype is torch.complex128)) |
| |
| # Note: NumPy will throw a ValueError for an empty input |
| @skipCUDAIfRocm |
| @skipCPUIfNoMkl |
| @onlyOnCPUAndCUDA |
| @dtypes(torch.complex64, torch.complex128) |
| def test_empty_fft(self, device, dtype): |
| t = torch.empty(0, device=device, dtype=dtype) |
| |
| if self.device_type == 'cuda': |
| with self.assertRaisesRegex(RuntimeError, "cuFFT error"): |
| torch.fft.fft(t) |
| return |
| |
| # CPU (MKL) |
| with self.assertRaisesRegex(RuntimeError, "MKL FFT error"): |
| torch.fft.fft(t) |
| |
| @dtypes(torch.int64, torch.float32) |
| def test_fft_invalid_dtypes(self, device, dtype): |
| if dtype.is_floating_point: |
| t = torch.randn(64, device=device, dtype=dtype) |
| else: |
| t = torch.randint(-2, 2, (64,), device=device, dtype=dtype) |
| |
| with self.assertRaisesRegex(RuntimeError, "Expected a complex tensor"): |
| torch.fft.fft(t) |
| |
| # Legacy fft tests |
| def _test_fft_ifft_rfft_irfft(self, device, dtype): |
| def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x): |
| x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device)) |
| for normalized in (True, False): |
| res = x.fft(signal_ndim, normalized=normalized) |
| rec = res.ifft(signal_ndim, normalized=normalized) |
| self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='fft and ifft') |
| res = x.ifft(signal_ndim, normalized=normalized) |
| rec = res.fft(signal_ndim, normalized=normalized) |
| self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='ifft and fft') |
| |
| def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x): |
| x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device)) |
| signal_numel = 1 |
| signal_sizes = x.size()[-signal_ndim:] |
| for normalized, onesided in product((True, False), repeat=2): |
| res = x.rfft(signal_ndim, normalized=normalized, onesided=onesided) |
| if not onesided: # check Hermitian symmetry |
| def test_one_sample(res, test_num=10): |
| idxs_per_dim = [torch.LongTensor(test_num).random_(s).tolist() for s in signal_sizes] |
| for idx in zip(*idxs_per_dim): |
| reflected_idx = tuple((s - i) % s for i, s in zip(idx, res.size())) |
| idx_val = res.__getitem__(idx) |
| reflected_val = res.__getitem__(reflected_idx) |
| self.assertEqual(idx_val[0], reflected_val[0], msg='rfft hermitian symmetry on real part') |
| self.assertEqual(idx_val[1], -reflected_val[1], msg='rfft hermitian symmetry on imaginary part') |
| if len(sizes) == signal_ndim: |
| test_one_sample(res) |
| else: |
| output_non_batch_shape = res.size()[-(signal_ndim + 1):] |
| flatten_batch_res = res.view(-1, *output_non_batch_shape) |
| nb = flatten_batch_res.size(0) |
| test_idxs = torch.LongTensor(min(nb, 4)).random_(nb) |
| for test_idx in test_idxs.tolist(): |
| test_one_sample(flatten_batch_res[test_idx]) |
| # compare with C2C |
| xc = torch.stack([x, torch.zeros_like(x)], -1) |
| xc_res = xc.fft(signal_ndim, normalized=normalized) |
| self.assertEqual(res, xc_res) |
| test_input_signal_sizes = [signal_sizes] |
| rec = res.irfft(signal_ndim, normalized=normalized, |
| onesided=onesided, signal_sizes=signal_sizes) |
| self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='rfft and irfft') |
| if not onesided: # check that we can use C2C ifft |
| rec = res.ifft(signal_ndim, normalized=normalized) |
| self.assertEqual(x, rec.select(-1, 0), atol=1e-8, rtol=0, msg='twosided rfft and ifft real') |
| self.assertEqual(rec.select(-1, 1).abs().mean(), 0, atol=1e-8, |
| rtol=0, msg='twosided rfft and ifft imaginary') |
| |
| # contiguous case |
| _test_real((100,), 1) |
| _test_real((10, 1, 10, 100), 1) |
| _test_real((100, 100), 2) |
| _test_real((2, 2, 5, 80, 60), 2) |
| _test_real((50, 40, 70), 3) |
| _test_real((30, 1, 50, 25, 20), 3) |
| |
| _test_complex((100, 2), 1) |
| _test_complex((100, 100, 2), 1) |
| _test_complex((100, 100, 2), 2) |
| _test_complex((1, 20, 80, 60, 2), 2) |
| _test_complex((50, 40, 70, 2), 3) |
| _test_complex((6, 5, 50, 25, 20, 2), 3) |
| |
| # non-contiguous case |
| _test_real((165,), 1, lambda x: x.narrow(0, 25, 100)) # input is not aligned to complex type |
| _test_real((100, 100, 3), 1, lambda x: x[:, :, 0]) |
| _test_real((100, 100), 2, lambda x: x.t()) |
| _test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60]) |
| _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80]) |
| _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3)) |
| |
| _test_complex((2, 100), 1, lambda x: x.t()) |
| _test_complex((100, 2), 1, lambda x: x.expand(100, 100, 2)) |
| _test_complex((300, 200, 3), 2, lambda x: x[:100, :100, 1:]) # input is not aligned to complex type |
| _test_complex((20, 90, 110, 2), 2, lambda x: x[:, 5:85].narrow(2, 5, 100)) |
| _test_complex((40, 60, 3, 80, 2), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:]) |
| _test_complex((30, 55, 50, 22, 2), 3, lambda x: x[:, 3:53, 15:40, 1:21]) |
| |
| # non-contiguous with strides not representable as aligned with complex type |
| _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [3, 2, 1])) |
| _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [4, 2, 2])) |
| _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [4, 3, 1])) |
| _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [3, 3, 1])) |
| _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [4, 2, 2])) |
| _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [4, 3, 1])) |
| |
| @skipCUDAIfRocm |
| @skipCPUIfNoMkl |
| @onlyOnCPUAndCUDA |
| @dtypes(torch.double) |
| def test_fft_ifft_rfft_irfft(self, device, dtype): |
| self._test_fft_ifft_rfft_irfft(device, dtype) |
| |
| @deviceCountAtLeast(1) |
| @skipCUDAIfRocm |
| @onlyCUDA |
| @dtypes(torch.double) |
| def test_cufft_plan_cache(self, devices, dtype): |
| @contextmanager |
| def plan_cache_max_size(device, n): |
| if device is None: |
| plan_cache = torch.backends.cuda.cufft_plan_cache |
| else: |
| plan_cache = torch.backends.cuda.cufft_plan_cache[device] |
| original = plan_cache.max_size |
| plan_cache.max_size = n |
| yield |
| plan_cache.max_size = original |
| |
| with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)): |
| self._test_fft_ifft_rfft_irfft(devices[0], dtype) |
| |
| with plan_cache_max_size(devices[0], 0): |
| self._test_fft_ifft_rfft_irfft(devices[0], dtype) |
| |
| torch.backends.cuda.cufft_plan_cache.clear() |
| |
| # check that stll works after clearing cache |
| with plan_cache_max_size(devices[0], 10): |
| self._test_fft_ifft_rfft_irfft(devices[0], dtype) |
| |
| with self.assertRaisesRegex(RuntimeError, r"must be non-negative"): |
| torch.backends.cuda.cufft_plan_cache.max_size = -1 |
| |
| with self.assertRaisesRegex(RuntimeError, r"read-only property"): |
| torch.backends.cuda.cufft_plan_cache.size = -1 |
| |
| with self.assertRaisesRegex(RuntimeError, r"but got device with index"): |
| torch.backends.cuda.cufft_plan_cache[torch.cuda.device_count() + 10] |
| |
| # Multigpu tests |
| if len(devices) > 1: |
| # Test that different GPU has different cache |
| x0 = torch.randn(2, 3, 3, device=devices[0]) |
| x1 = x0.to(devices[1]) |
| self.assertEqual(x0.rfft(2), x1.rfft(2)) |
| # If a plan is used across different devices, the following line (or |
| # the assert above) would trigger illegal memory access. Other ways |
| # to trigger the error include |
| # (1) setting CUDA_LAUNCH_BLOCKING=1 (pytorch/pytorch#19224) and |
| # (2) printing a device 1 tensor. |
| x0.copy_(x1) |
| |
| # Test that un-indexed `torch.backends.cuda.cufft_plan_cache` uses current device |
| with plan_cache_max_size(devices[0], 10): |
| with plan_cache_max_size(devices[1], 11): |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11) |
| |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 |
| with torch.cuda.device(devices[1]): |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 |
| with torch.cuda.device(devices[0]): |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 |
| |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) |
| with torch.cuda.device(devices[1]): |
| with plan_cache_max_size(None, 11): # default is cuda:1 |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11) |
| |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 |
| with torch.cuda.device(devices[0]): |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 |
| |
| # passes on ROCm w/ python 2.7, fails w/ python 3.6 |
| @skipCUDAIfRocm |
| @skipCPUIfNoMkl |
| @dtypes(torch.double) |
| def test_stft(self, device, dtype): |
| if not TEST_LIBROSA: |
| raise unittest.SkipTest('librosa not found') |
| |
| def librosa_stft(x, n_fft, hop_length, win_length, window, center): |
| if window is None: |
| window = np.ones(n_fft if win_length is None else win_length) |
| else: |
| window = window.cpu().numpy() |
| input_1d = x.dim() == 1 |
| if input_1d: |
| x = x.view(1, -1) |
| result = [] |
| for xi in x: |
| ri = librosa.stft(xi.cpu().numpy(), n_fft, hop_length, win_length, window, center=center) |
| result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1))) |
| result = torch.stack(result, 0) |
| if input_1d: |
| result = result[0] |
| return result |
| |
| def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, |
| center=True, expected_error=None): |
| x = torch.randn(*sizes, dtype=dtype, device=device) |
| if win_sizes is not None: |
| window = torch.randn(*win_sizes, dtype=dtype, device=device) |
| else: |
| window = None |
| if expected_error is None: |
| result = x.stft(n_fft, hop_length, win_length, window, center=center) |
| # NB: librosa defaults to np.complex64 output, no matter what |
| # the input dtype |
| ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center) |
| self.assertEqual(result, ref_result, atol=7e-6, rtol=0, msg='stft comparison against librosa', exact_dtype=False) |
| else: |
| self.assertRaises(expected_error, |
| lambda: x.stft(n_fft, hop_length, win_length, window, center=center)) |
| |
| for center in [True, False]: |
| _test((10,), 7, center=center) |
| _test((10, 4000), 1024, center=center) |
| |
| _test((10,), 7, 2, center=center) |
| _test((10, 4000), 1024, 512, center=center) |
| |
| _test((10,), 7, 2, win_sizes=(7,), center=center) |
| _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center) |
| |
| # spectral oversample |
| _test((10,), 7, 2, win_length=5, center=center) |
| _test((10, 4000), 1024, 512, win_length=100, center=center) |
| |
| _test((10, 4, 2), 1, 1, expected_error=RuntimeError) |
| _test((10,), 11, 1, center=False, expected_error=RuntimeError) |
| _test((10,), -1, 1, expected_error=RuntimeError) |
| _test((10,), 3, win_length=5, expected_error=RuntimeError) |
| _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError) |
| _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError) |
| |
| @skipCUDAIfRocm |
| @skipCPUIfNoMkl |
| def test_fft_input_modification(self, device): |
| # FFT functions should not modify their input (gh-34551) |
| |
| signal = torch.ones((2, 2, 2), device=device) |
| signal_copy = signal.clone() |
| spectrum = signal.fft(2) |
| self.assertEqual(signal, signal_copy) |
| |
| spectrum_copy = spectrum.clone() |
| _ = torch.ifft(spectrum, 2) |
| self.assertEqual(spectrum, spectrum_copy) |
| |
| half_spectrum = torch.rfft(signal, 2) |
| self.assertEqual(signal, signal_copy) |
| |
| half_spectrum_copy = half_spectrum.clone() |
| _ = torch.irfft(half_spectrum_copy, 2, signal_sizes=(2, 2)) |
| self.assertEqual(half_spectrum, half_spectrum_copy) |
| |
| @onlyOnCPUAndCUDA |
| @skipCPUIfNoMkl |
| @dtypes(torch.double) |
| def test_istft_round_trip_simple_cases(self, device, dtype): |
| """stft -> istft should recover the original signale""" |
| def _test(input, n_fft, length): |
| stft = torch.stft(input, n_fft=n_fft) |
| inverse = torch.istft(stft, n_fft=n_fft, length=length) |
| self.assertEqual(input, inverse, exact_dtype=True) |
| |
| _test(torch.ones(4, dtype=dtype, device=device), 4, 4) |
| _test(torch.zeros(4, dtype=dtype, device=device), 4, 4) |
| |
| @onlyOnCPUAndCUDA |
| @skipCPUIfNoMkl |
| @dtypes(torch.double) |
| def test_istft_round_trip_various_params(self, device, dtype): |
| """stft -> istft should recover the original signale""" |
| def _test_istft_is_inverse_of_stft(stft_kwargs): |
| # generates a random sound signal for each tril and then does the stft/istft |
| # operation to check whether we can reconstruct signal |
| data_sizes = [(2, 20), (3, 15), (4, 10)] |
| num_trials = 100 |
| istft_kwargs = stft_kwargs.copy() |
| del istft_kwargs['pad_mode'] |
| for sizes in data_sizes: |
| for i in range(num_trials): |
| original = torch.randn(*sizes, dtype=dtype, device=device) |
| stft = torch.stft(original, **stft_kwargs) |
| inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) |
| |
| # trim the original for case when constructed signal is shorter than original |
| original = original[..., :inversed.size(-1)] |
| self.assertEqual( |
| inversed, original, msg='istft comparison against original', |
| atol=7e-6, rtol=0, exact_dtype=True) |
| |
| patterns = [ |
| # hann_window, centered, normalized, onesided |
| { |
| 'n_fft': 12, |
| 'hop_length': 4, |
| 'win_length': 12, |
| 'window': torch.hann_window(12, dtype=dtype, device=device), |
| 'center': True, |
| 'pad_mode': 'reflect', |
| 'normalized': True, |
| 'onesided': True, |
| }, |
| # hann_window, centered, not normalized, not onesided |
| { |
| 'n_fft': 12, |
| 'hop_length': 2, |
| 'win_length': 8, |
| 'window': torch.hann_window(8, dtype=dtype, device=device), |
| 'center': True, |
| 'pad_mode': 'reflect', |
| 'normalized': False, |
| 'onesided': False, |
| }, |
| # hamming_window, centered, normalized, not onesided |
| { |
| 'n_fft': 15, |
| 'hop_length': 3, |
| 'win_length': 11, |
| 'window': torch.hamming_window(11, dtype=dtype, device=device), |
| 'center': True, |
| 'pad_mode': 'constant', |
| 'normalized': True, |
| 'onesided': False, |
| }, |
| # hamming_window, not centered, not normalized, onesided |
| # window same size as n_fft |
| { |
| 'n_fft': 5, |
| 'hop_length': 2, |
| 'win_length': 5, |
| 'window': torch.hamming_window(5, dtype=dtype, device=device), |
| 'center': False, |
| 'pad_mode': 'constant', |
| 'normalized': False, |
| 'onesided': True, |
| }, |
| # hamming_window, not centered, not normalized, not onesided |
| # window same size as n_fft |
| { |
| 'n_fft': 3, |
| 'hop_length': 2, |
| 'win_length': 3, |
| 'window': torch.hamming_window(3, dtype=dtype, device=device), |
| 'center': False, |
| 'pad_mode': 'reflect', |
| 'normalized': False, |
| 'onesided': False, |
| }, |
| ] |
| for i, pattern in enumerate(patterns): |
| _test_istft_is_inverse_of_stft(pattern) |
| |
| @onlyOnCPUAndCUDA |
| def test_istft_throws(self, device): |
| """istft should throw exception for invalid parameters""" |
| stft = torch.zeros((3, 5, 2), device=device) |
| # the window is size 1 but it hops 20 so there is a gap which throw an error |
| self.assertRaises( |
| RuntimeError, torch.istft, stft, n_fft=4, |
| hop_length=20, win_length=1, window=torch.ones(1)) |
| # A window of zeros does not meet NOLA |
| invalid_window = torch.zeros(4, device=device) |
| self.assertRaises( |
| RuntimeError, torch.istft, stft, n_fft=4, win_length=4, window=invalid_window) |
| # Input cannot be empty |
| self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2) |
| self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2) |
| |
| @onlyOnCPUAndCUDA |
| @skipCUDAIfRocm |
| @skipCPUIfNoMkl |
| @dtypes(torch.double) |
| def test_istft_of_sine(self, device, dtype): |
| def _test(amplitude, L, n): |
| # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L |
| x = torch.arange(2 * L + 1, device=device, dtype=dtype) |
| original = amplitude * torch.sin(2 * math.pi / L * x * n) |
| # stft = torch.stft(original, L, hop_length=L, win_length=L, |
| # window=torch.ones(L), center=False, normalized=False) |
| stft = torch.zeros((L // 2 + 1, 2, 2), device=device, dtype=dtype) |
| stft_largest_val = (amplitude * L) / 2.0 |
| if n < stft.size(0): |
| stft[n, :, 1] = -stft_largest_val |
| |
| if 0 <= L - n < stft.size(0): |
| # symmetric about L // 2 |
| stft[L - n, :, 1] = stft_largest_val |
| |
| inverse = torch.istft( |
| stft, L, hop_length=L, win_length=L, |
| window=torch.ones(L, device=device, dtype=dtype), center=False, normalized=False) |
| # There is a larger error due to the scaling of amplitude |
| original = original[..., :inverse.size(-1)] |
| self.assertEqual(inverse, original, atol=1e-3, rtol=0) |
| |
| _test(amplitude=123, L=5, n=1) |
| _test(amplitude=150, L=5, n=2) |
| _test(amplitude=111, L=5, n=3) |
| _test(amplitude=160, L=7, n=4) |
| _test(amplitude=145, L=8, n=5) |
| _test(amplitude=80, L=9, n=6) |
| _test(amplitude=99, L=10, n=7) |
| |
| @onlyOnCPUAndCUDA |
| @skipCUDAIfRocm |
| @skipCPUIfNoMkl |
| @dtypes(torch.double) |
| def test_istft_linearity(self, device, dtype): |
| num_trials = 100 |
| |
| def _test(data_size, kwargs): |
| for i in range(num_trials): |
| tensor1 = torch.randn(data_size, device=device, dtype=dtype) |
| tensor2 = torch.randn(data_size, device=device, dtype=dtype) |
| a, b = torch.rand(2, dtype=dtype, device=device) |
| istft1 = torch.istft(tensor1, **kwargs) |
| istft2 = torch.istft(tensor2, **kwargs) |
| istft = a * istft1 + b * istft2 |
| estimate = torch.istft(a * tensor1 + b * tensor2, **kwargs) |
| self.assertEqual(istft, estimate, atol=1e-5, rtol=0) |
| patterns = [ |
| # hann_window, centered, normalized, onesided |
| ( |
| (2, 7, 7, 2), |
| { |
| 'n_fft': 12, |
| 'window': torch.hann_window(12, device=device, dtype=dtype), |
| 'center': True, |
| 'normalized': True, |
| 'onesided': True, |
| }, |
| ), |
| # hann_window, centered, not normalized, not onesided |
| ( |
| (2, 12, 7, 2), |
| { |
| 'n_fft': 12, |
| 'window': torch.hann_window(12, device=device, dtype=dtype), |
| 'center': True, |
| 'normalized': False, |
| 'onesided': False, |
| }, |
| ), |
| # hamming_window, centered, normalized, not onesided |
| ( |
| (2, 12, 7, 2), |
| { |
| 'n_fft': 12, |
| 'window': torch.hamming_window(12, device=device, dtype=dtype), |
| 'center': True, |
| 'normalized': True, |
| 'onesided': False, |
| }, |
| ), |
| # hamming_window, not centered, not normalized, onesided |
| ( |
| (2, 7, 3, 2), |
| { |
| 'n_fft': 12, |
| 'window': torch.hamming_window(12, device=device, dtype=dtype), |
| 'center': False, |
| 'normalized': False, |
| 'onesided': True, |
| }, |
| ) |
| ] |
| for data_size, kwargs in patterns: |
| _test(data_size, kwargs) |
| |
| @onlyOnCPUAndCUDA |
| @skipCPUIfNoMkl |
| @skipCUDAIfRocm |
| def test_batch_istft(self, device): |
| original = torch.tensor([ |
| [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]], |
| [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], |
| [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]] |
| ], device=device) |
| |
| single = original.repeat(1, 1, 1, 1) |
| multi = original.repeat(4, 1, 1, 1) |
| |
| i_original = torch.istft(original, n_fft=4, length=4) |
| i_single = torch.istft(single, n_fft=4, length=4) |
| i_multi = torch.istft(multi, n_fft=4, length=4) |
| |
| self.assertEqual(i_original.repeat(1, 1), i_single, atol=1e-6, rtol=0, exact_dtype=True) |
| self.assertEqual(i_original.repeat(4, 1), i_multi, atol=1e-6, rtol=0, exact_dtype=True) |
| |
| instantiate_device_type_tests(TestFFT, globals()) |
| |
| if __name__ == '__main__': |
| run_tests() |