blob: 4e8a2ba1cd2cf90b01999cb09aa0ef3fd912e9d9 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fft operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_spectral_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import spectral_ops_test_util
from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.platform import test
VALID_FFT_RANKS = (1, 2, 3)
class BaseFFTOpsTest(test.TestCase):
def _compare(self, x, rank, fft_length=None, use_placeholder=False,
rtol=1e-4, atol=1e-4):
self._compareForward(x, rank, fft_length, use_placeholder, rtol, atol)
self._compareBackward(x, rank, fft_length, use_placeholder, rtol, atol)
def _compareForward(self, x, rank, fft_length=None, use_placeholder=False,
rtol=1e-4, atol=1e-4):
x_np = self._npFFT(x, rank, fft_length)
if use_placeholder:
x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
x_tf = self._tfFFT(x_ph, rank, fft_length, feed_dict={x_ph: x})
else:
x_tf = self._tfFFT(x, rank, fft_length)
self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol)
def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False,
rtol=1e-4, atol=1e-4):
x_np = self._npIFFT(x, rank, fft_length)
if use_placeholder:
x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
x_tf = self._tfIFFT(x_ph, rank, fft_length, feed_dict={x_ph: x})
else:
x_tf = self._tfIFFT(x, rank, fft_length)
self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol)
def _checkMemoryFail(self, x, rank):
config = config_pb2.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 1e-2
with self.cached_session(config=config, force_gpu=True):
self._tfFFT(x, rank, fft_length=None)
def _checkGradComplex(self, func, x, y, result_is_complex=True,
rtol=1e-2, atol=1e-2):
with self.cached_session(use_gpu=True):
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
# func is a forward or inverse, real or complex, batched or unbatched FFT
# function with a complex input.
z = func(math_ops.complex(inx, iny))
# loss = sum(|z|^2)
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
((x_jacob_t, x_jacob_n),
(y_jacob_t, y_jacob_n)) = gradient_checker.compute_gradient(
[inx, iny], [list(x.shape), list(y.shape)],
loss, [1],
x_init_value=[x, y],
delta=1e-2)
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
self.assertAllClose(y_jacob_t, y_jacob_n, rtol=rtol, atol=atol)
def _checkGradReal(self, func, x, rtol=1e-2, atol=1e-2):
with self.cached_session(use_gpu=True):
inx = ops.convert_to_tensor(x)
# func is a forward RFFT function (batched or unbatched).
z = func(inx)
# loss = sum(|z|^2)
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
x_jacob_t, x_jacob_n = test.compute_gradient(
inx, list(x.shape), loss, [1], x_init_value=x, delta=1e-2)
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
class FFTOpsTest(BaseFFTOpsTest):
def _tfFFT(self, x, rank, fft_length=None, feed_dict=None):
# fft_length unused for complex FFTs.
with self.cached_session(use_gpu=True) as sess:
return sess.run(self._tfFFTForRank(rank)(x), feed_dict=feed_dict)
def _tfIFFT(self, x, rank, fft_length=None, feed_dict=None):
# fft_length unused for complex FFTs.
with self.cached_session(use_gpu=True) as sess:
return sess.run(self._tfIFFTForRank(rank)(x), feed_dict=feed_dict)
def _npFFT(self, x, rank, fft_length=None):
if rank == 1:
return np.fft.fft2(x, s=fft_length, axes=(-1,))
elif rank == 2:
return np.fft.fft2(x, s=fft_length, axes=(-2, -1))
elif rank == 3:
return np.fft.fft2(x, s=fft_length, axes=(-3, -2, -1))
else:
raise ValueError("invalid rank")
def _npIFFT(self, x, rank, fft_length=None):
if rank == 1:
return np.fft.ifft2(x, s=fft_length, axes=(-1,))
elif rank == 2:
return np.fft.ifft2(x, s=fft_length, axes=(-2, -1))
elif rank == 3:
return np.fft.ifft2(x, s=fft_length, axes=(-3, -2, -1))
else:
raise ValueError("invalid rank")
def _tfFFTForRank(self, rank):
if rank == 1:
return fft_ops.fft
elif rank == 2:
return fft_ops.fft2d
elif rank == 3:
return fft_ops.fft3d
else:
raise ValueError("invalid rank")
def _tfIFFTForRank(self, rank):
if rank == 1:
return fft_ops.ifft
elif rank == 2:
return fft_ops.ifft2d
elif rank == 3:
return fft_ops.ifft3d
else:
raise ValueError("invalid rank")
@test_util.run_deprecated_v1
def testEmpty(self):
with spectral_ops_test_util.fft_kernel_label_map():
for np_type in (np.complex64, np.complex128):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
x = np.zeros((0,) * dims).astype(np_type)
self.assertEqual(x.shape, self._tfFFT(x, rank).shape)
self.assertEqual(x.shape, self._tfIFFT(x, rank).shape)
@test_util.run_deprecated_v1
def testBasic(self):
with spectral_ops_test_util.fft_kernel_label_map():
for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-8)):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
self._compare(
np.mod(np.arange(np.power(4, dims)), 10).reshape(
(4,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
def testLargeBatch(self):
if test.is_gpu_available(cuda_only=True):
rank = 1
for dims in xrange(rank, rank + 3):
for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-5)):
self._compare(
np.mod(np.arange(np.power(128, dims)), 10).reshape(
(128,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
# TODO(yangzihao): Disable before we can figure out a way to
# properly test memory fail for large batch fft.
# def testLargeBatchMemoryFail(self):
# if test.is_gpu_available(cuda_only=True):
# rank = 1
# for dims in xrange(rank, rank + 3):
# self._checkMemoryFail(
# np.mod(np.arange(np.power(128, dims)), 64).reshape(
# (128,) * dims).astype(np.complex64), rank)
@test_util.run_deprecated_v1
def testBasicPlaceholder(self):
with spectral_ops_test_util.fft_kernel_label_map():
for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-8)):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
self._compare(
np.mod(np.arange(np.power(4, dims)), 10).reshape(
(4,) * dims).astype(np_type),
rank, use_placeholder=True, rtol=tol, atol=tol)
@test_util.run_deprecated_v1
def testRandom(self):
with spectral_ops_test_util.fft_kernel_label_map():
for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 5e-6)):
def gen(shape):
n = np.prod(shape)
re = np.random.uniform(size=n)
im = np.random.uniform(size=n)
return (re + im * 1j).reshape(shape)
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
self._compare(gen((4,) * dims).astype(np_type), rank,
rtol=tol, atol=tol)
@test_util.run_deprecated_v1
def testRandom1D(self):
with spectral_ops_test_util.fft_kernel_label_map():
for np_type in (np.complex64, np.complex128):
has_gpu = test.is_gpu_available(cuda_only=True)
tol = {(np.complex64, True): 1e-4,
(np.complex64, False): 1e-2,
(np.complex128, True): 1e-4,
(np.complex128, False): 1e-2}[(np_type, has_gpu)]
def gen(shape):
n = np.prod(shape)
re = np.random.uniform(size=n)
im = np.random.uniform(size=n)
return (re + im * 1j).reshape(shape)
# Check a variety of power-of-2 FFT sizes.
for dim in (128, 256, 512, 1024):
self._compare(gen((dim,)).astype(np_type), 1, rtol=tol, atol=tol)
# Check a variety of non-power-of-2 FFT sizes.
for dim in (127, 255, 511, 1023):
self._compare(gen((dim,)).astype(np_type), 1, rtol=tol, atol=tol)
@test_util.run_deprecated_v1
def testError(self):
for rank in VALID_FFT_RANKS:
for dims in xrange(0, rank):
x = np.zeros((1,) * dims).astype(np.complex64)
with self.assertRaisesWithPredicateMatch(
ValueError, "Shape must be .*rank {}.*".format(rank)):
self._tfFFT(x, rank)
with self.assertRaisesWithPredicateMatch(
ValueError, "Shape must be .*rank {}.*".format(rank)):
self._tfIFFT(x, rank)
@test_util.run_deprecated_v1
def testGrad_Simple(self):
with spectral_ops_test_util.fft_kernel_label_map():
for np_type, tol in ((np.float32, 1e-4), (np.float64, 1e-10)):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 2):
re = np.ones(shape=(4,) * dims, dtype=np_type) / 10.0
im = np.zeros(shape=(4,) * dims, dtype=np_type)
self._checkGradComplex(self._tfFFTForRank(rank), re, im,
rtol=tol, atol=tol)
self._checkGradComplex(self._tfIFFTForRank(rank), re, im,
rtol=tol, atol=tol)
@test_util.run_deprecated_v1
def testGrad_Random(self):
with spectral_ops_test_util.fft_kernel_label_map():
for np_type, tol in ((np.float32, 1e-2), (np.float64, 1e-10)):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 2):
re = np.random.rand(*((3,) * dims)).astype(np_type) * 2 - 1
im = np.random.rand(*((3,) * dims)).astype(np_type) * 2 - 1
self._checkGradComplex(self._tfFFTForRank(rank), re, im,
rtol=tol, atol=tol)
self._checkGradComplex(self._tfIFFTForRank(rank), re, im,
rtol=tol, atol=tol)
class RFFTOpsTest(BaseFFTOpsTest):
def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False):
super(RFFTOpsTest, self)._compareBackward(x, rank, fft_length,
use_placeholder)
def _tfFFT(self, x, rank, fft_length=None, feed_dict=None):
with self.cached_session(use_gpu=True) as sess:
return sess.run(
self._tfFFTForRank(rank)(x, fft_length), feed_dict=feed_dict)
def _tfIFFT(self, x, rank, fft_length=None, feed_dict=None):
with self.cached_session(use_gpu=True) as sess:
return sess.run(
self._tfIFFTForRank(rank)(x, fft_length), feed_dict=feed_dict)
def _npFFT(self, x, rank, fft_length=None):
if rank == 1:
return np.fft.rfft2(x, s=fft_length, axes=(-1,))
elif rank == 2:
return np.fft.rfft2(x, s=fft_length, axes=(-2, -1))
elif rank == 3:
return np.fft.rfft2(x, s=fft_length, axes=(-3, -2, -1))
else:
raise ValueError("invalid rank")
def _npIFFT(self, x, rank, fft_length=None):
if rank == 1:
return np.fft.irfft2(x, s=fft_length, axes=(-1,))
elif rank == 2:
return np.fft.irfft2(x, s=fft_length, axes=(-2, -1))
elif rank == 3:
return np.fft.irfft2(x, s=fft_length, axes=(-3, -2, -1))
else:
raise ValueError("invalid rank")
def _tfFFTForRank(self, rank):
if rank == 1:
return fft_ops.rfft
elif rank == 2:
return fft_ops.rfft2d
elif rank == 3:
return fft_ops.rfft3d
else:
raise ValueError("invalid rank")
def _tfIFFTForRank(self, rank):
if rank == 1:
return fft_ops.irfft
elif rank == 2:
return fft_ops.irfft2d
elif rank == 3:
return fft_ops.irfft3d
else:
raise ValueError("invalid rank")
@test_util.run_deprecated_v1
def testEmpty(self):
with spectral_ops_test_util.fft_kernel_label_map():
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
x = np.zeros((0,) * dims).astype(np.float32)
self.assertEqual(x.shape, self._tfFFT(x, rank).shape)
x = np.zeros((0,) * dims).astype(np.complex64)
self.assertEqual(x.shape, self._tfIFFT(x, rank).shape)
@test_util.run_deprecated_v1
def testBasic(self):
with spectral_ops_test_util.fft_kernel_label_map():
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
inner_dim = size // 2 + 1
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
(size,) * dims)
self._compareForward(r2c.astype(np.float32), rank, (size,) * rank)
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
10).reshape((size,) * (dims - 1) + (inner_dim,))
self._compareBackward(
c2r.astype(np.complex64), rank, (size,) * rank)
def testLargeBatch(self):
if test.is_gpu_available(cuda_only=True):
rank = 1
for dims in xrange(rank, rank + 3):
for size in (64, 128):
inner_dim = size // 2 + 1
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
(size,) * dims)
self._compareForward(r2c.astype(np.float32), rank, (size,) * rank)
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
10).reshape((size,) * (dims - 1) + (inner_dim,))
self._compareBackward(c2r.astype(np.complex64), rank, (size,) * rank)
@test_util.run_deprecated_v1
def testBasicPlaceholder(self):
with spectral_ops_test_util.fft_kernel_label_map():
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
inner_dim = size // 2 + 1
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
(size,) * dims)
self._compareForward(
r2c.astype(np.float32),
rank, (size,) * rank,
use_placeholder=True)
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
10).reshape((size,) * (dims - 1) + (inner_dim,))
self._compareBackward(
c2r.astype(np.complex64),
rank, (size,) * rank,
use_placeholder=True)
@test_util.run_deprecated_v1
def testFftLength(self):
if test.is_gpu_available(cuda_only=True):
with spectral_ops_test_util.fft_kernel_label_map():
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
inner_dim = size // 2 + 1
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
(size,) * dims)
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
10).reshape((size,) * (dims - 1) + (inner_dim,))
# Test truncation (FFT size < dimensions).
fft_length = (size - 2,) * rank
self._compareForward(r2c.astype(np.float32), rank, fft_length)
self._compareBackward(c2r.astype(np.complex64), rank, fft_length)
# Confirm it works with unknown shapes as well.
self._compareForward(
r2c.astype(np.float32),
rank,
fft_length,
use_placeholder=True)
self._compareBackward(
c2r.astype(np.complex64),
rank,
fft_length,
use_placeholder=True)
# Test padding (FFT size > dimensions).
fft_length = (size + 2,) * rank
self._compareForward(r2c.astype(np.float32), rank, fft_length)
self._compareBackward(c2r.astype(np.complex64), rank, fft_length)
# Confirm it works with unknown shapes as well.
self._compareForward(
r2c.astype(np.float32),
rank,
fft_length,
use_placeholder=True)
self._compareBackward(
c2r.astype(np.complex64),
rank,
fft_length,
use_placeholder=True)
@test_util.run_deprecated_v1
def testRandom(self):
with spectral_ops_test_util.fft_kernel_label_map():
def gen_real(shape):
n = np.prod(shape)
re = np.random.uniform(size=n)
ret = re.reshape(shape)
return ret
def gen_complex(shape):
n = np.prod(shape)
re = np.random.uniform(size=n)
im = np.random.uniform(size=n)
ret = (re + im * 1j).reshape(shape)
return ret
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
inner_dim = size // 2 + 1
self._compareForward(gen_real((size,) * dims), rank, (size,) * rank)
complex_dims = (size,) * (dims - 1) + (inner_dim,)
self._compareBackward(
gen_complex(complex_dims), rank, (size,) * rank)
@test_util.run_deprecated_v1
def testError(self):
with spectral_ops_test_util.fft_kernel_label_map():
for rank in VALID_FFT_RANKS:
for dims in xrange(0, rank):
x = np.zeros((1,) * dims).astype(np.complex64)
with self.assertRaisesWithPredicateMatch(
ValueError, "Shape .* must have rank at least {}".format(rank)):
self._tfFFT(x, rank)
with self.assertRaisesWithPredicateMatch(
ValueError, "Shape .* must have rank at least {}".format(rank)):
self._tfIFFT(x, rank)
for dims in xrange(rank, rank + 2):
x = np.zeros((1,) * rank)
# Test non-rank-1 fft_length produces an error.
fft_length = np.zeros((1, 1)).astype(np.int32)
with self.assertRaisesWithPredicateMatch(ValueError,
"Shape .* must have rank 1"):
self._tfFFT(x, rank, fft_length)
with self.assertRaisesWithPredicateMatch(ValueError,
"Shape .* must have rank 1"):
self._tfIFFT(x, rank, fft_length)
# Test wrong fft_length length.
fft_length = np.zeros((rank + 1,)).astype(np.int32)
with self.assertRaisesWithPredicateMatch(
ValueError, "Dimension must be .*but is {}.*".format(rank + 1)):
self._tfFFT(x, rank, fft_length)
with self.assertRaisesWithPredicateMatch(
ValueError, "Dimension must be .*but is {}.*".format(rank + 1)):
self._tfIFFT(x, rank, fft_length)
# Test that calling the kernel directly without padding to fft_length
# produces an error.
rffts_for_rank = {
1: [gen_spectral_ops.rfft, gen_spectral_ops.irfft],
2: [gen_spectral_ops.rfft2d, gen_spectral_ops.irfft2d],
3: [gen_spectral_ops.rfft3d, gen_spectral_ops.irfft3d]
}
rfft_fn, irfft_fn = rffts_for_rank[rank]
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
"Input dimension .* must have length of at least 6 but got: 5"):
x = np.zeros((5,) * rank).astype(np.float32)
fft_length = [6] * rank
with self.cached_session():
self.evaluate(rfft_fn(x, fft_length))
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
"Input dimension .* must have length of at least .* but got: 3"):
x = np.zeros((3,) * rank).astype(np.complex64)
fft_length = [6] * rank
with self.cached_session():
self.evaluate(irfft_fn(x, fft_length))
@test_util.run_deprecated_v1
def testGrad_Simple(self):
with spectral_ops_test_util.fft_kernel_label_map():
for rank in VALID_FFT_RANKS:
# rfft3d/irfft3d do not have gradients yet.
if rank == 3:
continue
for dims in xrange(rank, rank + 2):
for size in (5, 6):
re = np.ones(shape=(size,) * dims, dtype=np.float32)
im = -np.ones(shape=(size,) * dims, dtype=np.float32)
self._checkGradReal(self._tfFFTForRank(rank), re)
self._checkGradComplex(
self._tfIFFTForRank(rank), re, im, result_is_complex=False)
@test_util.run_deprecated_v1
def testGrad_Random(self):
with spectral_ops_test_util.fft_kernel_label_map():
for rank in VALID_FFT_RANKS:
# rfft3d/irfft3d do not have gradients yet.
if rank == 3:
continue
for dims in xrange(rank, rank + 2):
for size in (5, 6):
re = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1
im = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1
self._checkGradReal(self._tfFFTForRank(rank), re)
self._checkGradComplex(
self._tfIFFTForRank(rank), re, im, result_is_complex=False)
class FFTShiftTest(test.TestCase):
@test_util.run_deprecated_v1
def testDefinition(self):
with self.session():
x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
self.assertAllEqual(fft_ops.fftshift(x).eval(), y)
self.assertAllEqual(fft_ops.ifftshift(y).eval(), x)
x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
self.assertAllEqual(fft_ops.fftshift(x).eval(), y)
self.assertAllEqual(fft_ops.ifftshift(y).eval(), x)
@test_util.run_deprecated_v1
def testAxesKeyword(self):
with self.session():
freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)).eval(), shifted)
self.assertAllEqual(
fft_ops.fftshift(freqs, axes=0).eval(),
fft_ops.fftshift(freqs, axes=(0,)).eval())
self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)).eval(), freqs)
self.assertAllEqual(
fft_ops.ifftshift(shifted, axes=0).eval(),
fft_ops.ifftshift(shifted, axes=(0,)).eval())
self.assertAllEqual(fft_ops.fftshift(freqs).eval(), shifted)
self.assertAllEqual(fft_ops.ifftshift(shifted).eval(), freqs)
@test_util.run_deprecated_v1
def testNumpyCompatibility(self):
with self.session():
x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
self.assertAllEqual(fft_ops.fftshift(x).eval(), np.fft.fftshift(x))
self.assertAllEqual(fft_ops.ifftshift(y).eval(), np.fft.ifftshift(y))
x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
self.assertAllEqual(fft_ops.fftshift(x).eval(), np.fft.fftshift(x))
self.assertAllEqual(fft_ops.ifftshift(y).eval(), np.fft.ifftshift(y))
freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
self.assertAllEqual(
fft_ops.fftshift(freqs, axes=(0, 1)).eval(),
np.fft.fftshift(freqs, axes=(0, 1)))
self.assertAllEqual(
fft_ops.ifftshift(shifted, axes=(0, 1)).eval(),
np.fft.ifftshift(shifted, axes=(0, 1)))
@test_util.run_deprecated_v1
def testNoneAxes(self):
x = array_ops.placeholder(shape=[None, None, None], dtype='float32')
axes_to_test = [None, 1, [1, 2]]
for axes in axes_to_test:
y_fftshift = fft_ops.fftshift(x, axes=axes)
y_ifftshift = fft_ops.ifftshift(x, axes=axes)
with self.session() as sess:
sess.run([y_fftshift, y_ifftshift], feed_dict={x: np.zeros((16, 256, 256))})
if __name__ == "__main__":
test.main()