[XLA:TPU] Enable the TopK tests for F64 on TPU
This has worked for a while but the tests were disabled.
PiperOrigin-RevId: 381142245
Change-Id: I3eddfff399ce3ac38e21d5bbd44b418896f50a82
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 3fadad6..3b1441a 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1598,7 +1598,7 @@
srcs = ["sort_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
- shard_count = 2,
+ shard_count = 10,
# Times out in fastbuild mode.
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index b4b3b4f..761706d 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -245,8 +245,6 @@
np.int32, np.uint32
])
for dtype in supported_types.intersection(self.numeric_types):
- if dtype == np.float64 and self.device == "TPU":
- continue
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
@@ -278,9 +276,6 @@
)
def testTopK2D(self, dtype):
if dtype in self.numeric_types:
- # TPU implementation is not supported for double precision
- if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
- return
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
@@ -310,9 +305,6 @@
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types):
- # TPU implementation is not supported for double precision
- if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
- continue
with self.session() as sess:
p = array_ops.placeholder(dtype)
with self.test_scope():
@@ -328,9 +320,6 @@
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types):
- # TPU implementation is not supported for double precision
- if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
- continue
with self.session() as sess:
p = array_ops.placeholder(dtype)
with self.test_scope():