[tf.data] Reduce flakiness of tests.
Recent consolidation of tf.data tests increased runtime of some of them. This CL introduces sharding for tests that became flaky as a result of the consolidation.
PiperOrigin-RevId: 363491524
Change-Id: I03e6b50808193c29d6410791ce6bb4ac0ef217cc
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index 9fc66b3..6299bb6 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -1,5 +1,6 @@
-load("//tensorflow:tensorflow.bzl", "tf_py_test") # buildifier: disable=same-origin-load
+# Definitions are loaded separately so that copybara can pattern match (and modify) each definition.
load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_py_test") # buildifier: disable=same-origin-load
package(
default_visibility = ["//tensorflow:internal"],
@@ -673,6 +674,7 @@
name = "rebatch_dataset_test",
size = "small",
srcs = ["rebatch_dataset_test.py"],
+ shard_count = 4,
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:image_ops",
@@ -786,21 +788,6 @@
)
tf_py_test(
- name = "sql_dataset_test",
- size = "medium",
- srcs = ["sql_dataset_test.py"],
- tags = ["no_pip"],
- deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/experimental/ops:readers",
- "//tensorflow/python/data/kernel_tests:checkpoint_test_base",
- "@org_sqlite//:python",
- ],
-)
-
-tf_py_test(
name = "snapshot_test",
size = "medium",
timeout = "long",
@@ -826,9 +813,25 @@
)
tf_py_test(
+ name = "sql_dataset_test",
+ size = "medium",
+ srcs = ["sql_dataset_test.py"],
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:checkpoint_test_base",
+ "@org_sqlite//:python",
+ ],
+)
+
+tf_py_test(
name = "stats_dataset_ops_test",
size = "small",
srcs = ["stats_dataset_ops_test.py"],
+ shard_count = 4,
tags = [
"no_oss", # TODO(b/155795733): Note that this functionality is deprecated.
"no_pip",
@@ -868,6 +871,7 @@
name = "take_while_test",
size = "small",
srcs = ["take_while_test.py"],
+ shard_count = 4,
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
index f548665d..dca38ac 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
@@ -1,5 +1,6 @@
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-load("//tensorflow:tensorflow.bzl", "tf_py_test")
+# Definitions are loaded separately so that copybara can pattern match (and modify) each definition.
+load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_py_test") # buildifier: disable=same-origin-load
package(
default_visibility = ["//tensorflow:internal"],
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index da1a093..fc9a871f 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -1,7 +1,6 @@
-# Tests of TensorFlow kernels written using the Python API.
-
-load("//tensorflow:tensorflow.bzl", "tf_py_test") # buildifier: disable=same-origin-load
+# Definitions are loaded separately so that copybara can pattern match (and modify) each definition.
load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_py_test") # buildifier: disable=same-origin-load
package(
default_visibility = ["//tensorflow:internal"],
@@ -244,7 +243,7 @@
name = "from_generator_test",
size = "medium",
srcs = ["from_generator_test.py"],
- shard_count = 10,
+ shard_count = 8,
deps = [
":test_base",
"//tensorflow/python:client_testlib",
@@ -332,6 +331,7 @@
name = "interleave_test",
size = "medium",
srcs = ["interleave_test.py"],
+ shard_count = 4,
deps = [
":checkpoint_test_base",
":test_base",
@@ -730,6 +730,7 @@
name = "skip_test",
size = "small",
srcs = ["skip_test.py"],
+ shard_count = 4,
deps = [
":checkpoint_test_base",
":test_base",
@@ -746,6 +747,7 @@
name = "take_test",
size = "small",
srcs = ["take_test.py"],
+ shard_count = 4,
deps = [
":checkpoint_test_base",
":test_base",
@@ -781,7 +783,7 @@
name = "tf_record_dataset_test",
size = "small",
srcs = ["tf_record_dataset_test.py"],
- shard_count = 4,
+ shard_count = 8,
deps = [
":checkpoint_test_base",
":test_base",