Remove distribute test libraries from python:no_contrib
They shouldn't be dependencies of non test targets.
PiperOrigin-RevId: 308874149
Change-Id: I50785a60c5ba58fb99e46ca34ed089b991989b55
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 622185e..6941276 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -144,6 +144,7 @@
":confusion_matrix",
":control_flow_ops",
":cudnn_rnn_ops_gen",
+ ":distributed_framework_test_lib",
":errors",
":framework",
":framework_combinations",
@@ -202,11 +203,8 @@
"//tensorflow/python/data",
"//tensorflow/python/debug:debug_py",
"//tensorflow/python/distribute",
- "//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:distribute_config",
"//tensorflow/python/distribute:estimator_training",
- "//tensorflow/python/distribute:multi_worker_test_base",
- "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/dlpack",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:monitoring",
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index aa303b9..04adc01 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -707,6 +707,10 @@
name = "combinations",
srcs = ["combinations.py"],
srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow:internal",
+ "//tensorflow_models:__subpackages__",
+ ],
deps = [
":multi_process_runner",
":multi_worker_test_base",
@@ -738,6 +742,10 @@
name = "strategy_combinations",
srcs = ["strategy_combinations.py"],
srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow:internal",
+ "//tensorflow_models:__subpackages__",
+ ],
deps = [
":central_storage_strategy",
":collective_all_reduce_strategy",
@@ -1578,7 +1586,6 @@
srcs = ["multi_process_runner.py"],
deps = [
":multi_process_lib",
- ":multi_worker_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:tf2",
"//tensorflow/python/compat:v2_compat",
@@ -1599,6 +1606,7 @@
shard_count = 12,
deps = [
":multi_process_runner",
+ ":multi_worker_test_base",
"//tensorflow/python/eager:test",
],
)
@@ -1609,6 +1617,7 @@
python_version = "PY3",
deps = [
":multi_process_runner",
+ ":multi_worker_test_base",
"//tensorflow/python/eager:test",
],
)