[tf.data] Reduce flakiness of tests.

Recent consolidation of tf.data tests increased runtime of some of them. This CL introduces sharding for tests that became flaky as a result of the consolidation.

PiperOrigin-RevId: 363491524
Change-Id: I03e6b50808193c29d6410791ce6bb4ac0ef217cc
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index 9fc66b3..6299bb6 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -1,5 +1,6 @@
-load("//tensorflow:tensorflow.bzl", "tf_py_test")  # buildifier: disable=same-origin-load
+# Definitions are loaded separately so that copybara can pattern match (and modify) each definition.
 load("//tensorflow:tensorflow.bzl", "cuda_py_test")  # buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_py_test")  # buildifier: disable=same-origin-load
 
 package(
     default_visibility = ["//tensorflow:internal"],
@@ -673,6 +674,7 @@
     name = "rebatch_dataset_test",
     size = "small",
     srcs = ["rebatch_dataset_test.py"],
+    shard_count = 4,
     deps = [
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:image_ops",
@@ -786,21 +788,6 @@
 )
 
 tf_py_test(
-    name = "sql_dataset_test",
-    size = "medium",
-    srcs = ["sql_dataset_test.py"],
-    tags = ["no_pip"],
-    deps = [
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:dtypes",
-        "//tensorflow/python:errors",
-        "//tensorflow/python/data/experimental/ops:readers",
-        "//tensorflow/python/data/kernel_tests:checkpoint_test_base",
-        "@org_sqlite//:python",
-    ],
-)
-
-tf_py_test(
     name = "snapshot_test",
     size = "medium",
     timeout = "long",
@@ -826,9 +813,25 @@
 )
 
 tf_py_test(
+    name = "sql_dataset_test",
+    size = "medium",
+    srcs = ["sql_dataset_test.py"],
+    tags = ["no_pip"],
+    deps = [
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:errors",
+        "//tensorflow/python/data/experimental/ops:readers",
+        "//tensorflow/python/data/kernel_tests:checkpoint_test_base",
+        "@org_sqlite//:python",
+    ],
+)
+
+tf_py_test(
     name = "stats_dataset_ops_test",
     size = "small",
     srcs = ["stats_dataset_ops_test.py"],
+    shard_count = 4,
     tags = [
         "no_oss",  # TODO(b/155795733): Note that this functionality is deprecated.
         "no_pip",
@@ -868,6 +871,7 @@
     name = "take_while_test",
     size = "small",
     srcs = ["take_while_test.py"],
+    shard_count = 4,
     deps = [
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
index f548665d..dca38ac 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
@@ -1,5 +1,6 @@
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-load("//tensorflow:tensorflow.bzl", "tf_py_test")
+# Definitions are loaded separately so that copybara can pattern match (and modify) each definition.
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")  # buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_py_test")  # buildifier: disable=same-origin-load
 
 package(
     default_visibility = ["//tensorflow:internal"],
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index da1a093..fc9a871f 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -1,7 +1,6 @@
-# Tests of TensorFlow kernels written using the Python API.
-
-load("//tensorflow:tensorflow.bzl", "tf_py_test")  # buildifier: disable=same-origin-load
+# Definitions are loaded separately so that copybara can pattern match (and modify) each definition.
 load("//tensorflow:tensorflow.bzl", "cuda_py_test")  # buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_py_test")  # buildifier: disable=same-origin-load
 
 package(
     default_visibility = ["//tensorflow:internal"],
@@ -244,7 +243,7 @@
     name = "from_generator_test",
     size = "medium",
     srcs = ["from_generator_test.py"],
-    shard_count = 10,
+    shard_count = 8,
     deps = [
         ":test_base",
         "//tensorflow/python:client_testlib",
@@ -332,6 +331,7 @@
     name = "interleave_test",
     size = "medium",
     srcs = ["interleave_test.py"],
+    shard_count = 4,
     deps = [
         ":checkpoint_test_base",
         ":test_base",
@@ -730,6 +730,7 @@
     name = "skip_test",
     size = "small",
     srcs = ["skip_test.py"],
+    shard_count = 4,
     deps = [
         ":checkpoint_test_base",
         ":test_base",
@@ -746,6 +747,7 @@
     name = "take_test",
     size = "small",
     srcs = ["take_test.py"],
+    shard_count = 4,
     deps = [
         ":checkpoint_test_base",
         ":test_base",
@@ -781,7 +783,7 @@
     name = "tf_record_dataset_test",
     size = "small",
     srcs = ["tf_record_dataset_test.py"],
-    shard_count = 4,
+    shard_count = 8,
     deps = [
         ":checkpoint_test_base",
         ":test_base",