blob: 20e1a886d4edd3f97c07475a1e24596e8f3715ea [file] [log] [blame]
# Description:
# Contains Keras integration tests that verify with other TF high level APIs.
load("//tensorflow:tensorflow.bzl", "cuda_py_test", "tf_py_test")
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
package(
default_visibility = [
"//tensorflow/tools/pip_package:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
tf_py_test(
name = "forwardprop_test",
srcs = ["forwardprop_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:extra_py_tests_deps",
],
)
tf_py_test(
name = "function_test",
srcs = ["function_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:extra_py_tests_deps",
],
)
tf_py_test(
name = "gradients_test",
srcs = ["gradients_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:extra_py_tests_deps",
],
)
cuda_py_test(
name = "saved_model_test",
srcs = ["saved_model_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:extra_py_tests_deps",
],
)
tf_py_test(
name = "legacy_rnn_test", # Remove this target in when TF 1 is deprecated.
srcs = ["legacy_rnn_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:extra_py_tests_deps",
],
)
tf_py_test(
name = "module_test",
srcs = ["module_test.py"],
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:extra_py_tests_deps",
],
)
tf_py_test(
name = "vectorized_map_test",
srcs = ["vectorized_map_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:extra_py_tests_deps",
],
)
cuda_py_test(
name = "gradient_checkpoint_test",
srcs = ["gradient_checkpoint_test.py"],
python_version = "PY3",
tags = ["no_rocm"],
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:extra_py_tests_deps",
],
)
tpu_py_test(
name = "tpu_strategy_test",
srcs = ["tpu_strategy_test.py"],
disable_experimental = True,
disable_mlir_bridge = False,
python_version = "PY3",
tags = ["no_oss"],
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:extra_py_tests_deps",
],
)