blob: 516d4d1961bbf3d220b0b1d5ce705a8833e8a579 [file] [log] [blame]
# Description:
# Contains Keras integration tests that verify with other TF high level APIs.
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_py_test") # buildifier: disable=same-origin-load
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test")
package(
default_visibility = [
"//tensorflow/tools/pip_package:__pkg__",
],
licenses = ["notice"],
)
tf_py_test(
name = "forwardprop_test",
srcs = ["forwardprop_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py_no_contrib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "function_test",
srcs = ["function_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py_no_contrib",
],
)
tf_py_test(
name = "gradients_test",
srcs = ["gradients_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py_no_contrib",
],
)
cuda_py_test(
name = "saved_model_test",
srcs = ["saved_model_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py_no_contrib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "legacy_rnn_test", # Remove this target in when TF 1 is deprecated.
srcs = ["legacy_rnn_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py_no_contrib",
],
)
tf_py_test(
name = "module_test",
srcs = ["module_test.py"],
deps = [
"//tensorflow:tensorflow_py_no_contrib",
],
)
tf_py_test(
name = "vectorized_map_test",
srcs = ["vectorized_map_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py_no_contrib",
],
)
cuda_py_test(
name = "gradient_checkpoint_test",
srcs = ["gradient_checkpoint_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py_no_contrib",
],
)
cuda_py_test(
name = "central_storage_strategy_test",
srcs = ["central_storage_strategy_test.py"],
python_version = "PY3",
tags = [
"multi_and_single_gpu",
"no_windows_gpu", # TODO(b/130551176)
],
deps = [
"//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras/utils:kpl_test_utils",
"@absl_py//absl/testing:parameterized",
],
)
tpu_py_test(
name = "tpu_strategy_test",
srcs = ["tpu_strategy_test.py"],
disable_experimental = True,
disable_mlir_bridge = False,
python_version = "PY3",
tags = ["no_oss"],
deps = [
"//tensorflow:tensorflow_py_no_contrib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "multi_worker_tutorial_test",
srcs = ["multi_worker_tutorial_test.py"],
python_version = "PY3",
shard_count = 6,
tags = [
"no_windows", # TODO(b/183102726)
"noasan", # TODO(b/156029134)
"nomac", # TODO(b/182567880)
"nomsan", # TODO(b/156029134)
"notsan", # TODO(b/156029134)
],
deps = [
"//tensorflow:tensorflow_py_no_contrib",
"@absl_py//absl/testing:parameterized",
],
)
distribute_py_test(
name = "parameter_server_keras_preprocessing_test",
srcs = ["parameter_server_keras_preprocessing_test.py"],
python_version = "PY3",
shard_count = 4, # TODO(b/184290570): Investigate why only 1 shard times out.
tags = [
"multi_and_single_gpu",
],
deps = [
"//tensorflow:tensorflow_py_no_contrib",
"@absl_py//absl/testing:parameterized",
],
)
distribute_py_test(
name = "distributed_training_test",
srcs = ["distributed_training_test.py"],
python_version = "PY3",
shard_count = 50,
tags = [
"multi_gpu",
"no_oss", # TODO(b/183640564): Reenable
"no_rocm",
"noasan", # TODO(b/184542721)
"nomsan", # TODO(b/184542721)
"nomultivm", # TODO(b/170502145)
"notsan", # TODO(b/184542721)
],
deps = [
"//tensorflow:tensorflow_py_no_contrib",
],
)
distribute_py_test(
name = "parameter_server_custom_training_loop_test",
srcs = ["parameter_server_custom_training_loop_test.py"],
python_version = "PY3",
tags = [
"multi_gpu",
"no_oss", # TODO(b/183640564): Reenable
"no_rocm",
"noasan", # TODO(b/184542721)
"nomsan", # TODO(b/184542721)
"nomultivm", # TODO(b/170502145)
"notsan", # TODO(b/184542721)
],
deps = [
"//tensorflow:tensorflow_py_no_contrib",
],
)