[XLA:TPU] Enable the TopK tests for F64 on TPU

This has worked for a while but the tests were disabled.

PiperOrigin-RevId: 381142245
Change-Id: I3eddfff399ce3ac38e21d5bbd44b418896f50a82
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 3fadad6..3b1441a 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1598,7 +1598,7 @@
     srcs = ["sort_ops_test.py"],
     enable_mlir_bridge = True,
     python_version = "PY3",
-    shard_count = 2,
+    shard_count = 10,
     # Times out in fastbuild mode.
     tags = [
         "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index b4b3b4f..761706d 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -245,8 +245,6 @@
         np.int32, np.uint32
     ])
     for dtype in supported_types.intersection(self.numeric_types):
-      if dtype == np.float64 and self.device == "TPU":
-        continue
       # Use small input size for bfloat16. Otherwise, we'll get duplicate values
       # after conversion to bfloat16, so the possible resulting index array is
       # no longer unique.
@@ -278,9 +276,6 @@
   )
   def testTopK2D(self, dtype):
     if dtype in self.numeric_types:
-      # TPU implementation is not supported for double precision
-      if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
-        return
       # Use small input size for bfloat16. Otherwise, we'll get duplicate values
       # after conversion to bfloat16, so the possible resulting index array is
       # no longer unique.
@@ -310,9 +305,6 @@
     supported_types = set(
         [dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
     for dtype in supported_types.intersection(self.numeric_types):
-      # TPU implementation is not supported for double precision
-      if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
-        continue
       with self.session() as sess:
         p = array_ops.placeholder(dtype)
         with self.test_scope():
@@ -328,9 +320,6 @@
     supported_types = set(
         [dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
     for dtype in supported_types.intersection(self.numeric_types):
-      # TPU implementation is not supported for double precision
-      if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
-        continue
       with self.session() as sess:
         p = array_ops.placeholder(dtype)
         with self.test_scope():