| # Description: |
| # Contains the Keras OptimizerV2 API (internal TensorFlow version). |
| |
| load("//tensorflow:tensorflow.bzl", "cuda_py_test") |
| |
| package( |
| # TODO(scottzhu): Remove non-keras deps from TF. |
| default_visibility = [ |
| "//tensorflow/python:__pkg__", |
| "//tensorflow/python/distribute:__pkg__", |
| "//tensorflow/python/keras:__subpackages__", |
| "//tensorflow/python/training/tracking:__pkg__", |
| ], |
| licenses = ["notice"], # Apache 2.0 |
| ) |
| |
| filegroup( |
| name = "all_py_srcs", |
| srcs = glob(["*.py"]), |
| visibility = ["//tensorflow/python/keras/google/private_tf_api_test:__pkg__"], |
| ) |
| |
| py_library( |
| name = "optimizer_v2", |
| srcs = [ |
| "adadelta.py", |
| "adagrad.py", |
| "adam.py", |
| "adamax.py", |
| "ftrl.py", |
| "gradient_descent.py", |
| "nadam.py", |
| "optimizer_v2.py", |
| "rmsprop.py", |
| "utils.py", |
| ], |
| srcs_version = "PY2AND3", |
| deps = [ |
| ":learning_rate_schedule", |
| "//tensorflow/python:control_flow_ops", |
| "//tensorflow/python:framework", |
| "//tensorflow/python:math_ops", |
| "//tensorflow/python:resource_variable_ops", |
| "//tensorflow/python:state_ops", |
| "//tensorflow/python:variable_scope", |
| "//tensorflow/python:variables", |
| "//tensorflow/python/distribute:central_storage_strategy", |
| "//tensorflow/python/distribute:distribute_lib", |
| "//tensorflow/python/distribute:parameter_server_strategy", |
| "//tensorflow/python/distribute:parameter_server_strategy_v2", |
| "//tensorflow/python/distribute:reduce_util", |
| "//tensorflow/python/distribute:values", |
| "//tensorflow/python/keras:backend", |
| "//tensorflow/python/keras:backend_config", |
| "//tensorflow/python/keras:initializers", |
| "//tensorflow/python/keras/engine:base_layer_utils", |
| "//tensorflow/python/keras/utils:layer_utils", |
| "//tensorflow/python/keras/utils:tf_utils", |
| ], |
| ) |
| |
| py_library( |
| name = "learning_rate_schedule", |
| srcs = [ |
| "learning_rate_schedule.py", |
| ], |
| srcs_version = "PY2AND3", |
| deps = [ |
| "//tensorflow/python:control_flow_ops", |
| "//tensorflow/python:framework", |
| "//tensorflow/python:math_ops", |
| "//tensorflow/python:random_ops", |
| "//tensorflow/python/keras/utils:generic_utils", |
| ], |
| ) |
| |
| py_library( |
| name = "legacy_learning_rate_decay", |
| srcs = ["legacy_learning_rate_decay.py"], |
| srcs_version = "PY2AND3", |
| deps = [ |
| ":learning_rate_schedule", |
| "//tensorflow/python:dtypes", |
| "//tensorflow/python:framework_ops", |
| "//tensorflow/python:math_ops", |
| "//tensorflow/python/eager:context", |
| "//tensorflow/python/util:tf_export", |
| ], |
| ) |
| |
| cuda_py_test( |
| name = "adagrad_test", |
| size = "medium", |
| srcs = ["adagrad_test.py"], |
| shard_count = 4, |
| deps = [ |
| ":optimizer_v2", |
| "//tensorflow/python:client_testlib", |
| "//tensorflow/python:embedding_ops", |
| "//tensorflow/python:framework", |
| "//tensorflow/python:math_ops", |
| "//tensorflow/python:platform_test", |
| "//tensorflow/python:resource_variable_ops", |
| "//tensorflow/python:resources", |
| "//tensorflow/python:variables", |
| "//tensorflow/python/eager:context", |
| "//tensorflow/python/keras:combinations", |
| ], |
| ) |
| |
| cuda_py_test( |
| name = "adam_test", |
| size = "medium", |
| srcs = ["adam_test.py"], |
| shard_count = 4, |
| tags = [ |
| "no_rocm", |
| "no_windows", # TODO(b/171384138) |
| ], |
| deps = [ |
| ":optimizer_v2", |
| "//tensorflow/python:client_testlib", |
| "//tensorflow/python:embedding_ops", |
| "//tensorflow/python:framework", |
| "//tensorflow/python:math_ops", |
| "//tensorflow/python:platform_test", |
| "//tensorflow/python:resource_variable_ops", |
| "//tensorflow/python:resources", |
| "//tensorflow/python:variables", |
| "//tensorflow/python/eager:context", |
| "//tensorflow/python/keras:combinations", |
| ], |
| ) |
| |
| cuda_py_test( |
| name = "adamax_test", |
| size = "medium", |
| srcs = ["adamax_test.py"], |
| shard_count = 4, |
| # TODO(b/168527439): invalid resource variable reference on GPU for TFRT. |
| deps = [ |
| ":optimizer_v2", |
| "//tensorflow/python:client_testlib", |
| "//tensorflow/python:embedding_ops", |
| "//tensorflow/python:framework", |
| "//tensorflow/python:math_ops", |
| "//tensorflow/python:platform_test", |
| "//tensorflow/python:resource_variable_ops", |
| "//tensorflow/python:resources", |
| "//tensorflow/python:variables", |
| "//tensorflow/python/eager:context", |
| "//tensorflow/python/keras:combinations", |
| ], |
| ) |
| |
| cuda_py_test( |
| name = "adadelta_test", |
| size = "medium", |
| srcs = ["adadelta_test.py"], |
| shard_count = 4, |
| tags = ["no_rocm"], |
| # TODO(b/168527439): invalid resource variable reference on GPU for TFRT. |
| deps = [ |
| ":optimizer_v2", |
| "//tensorflow/python:client_testlib", |
| "//tensorflow/python:embedding_ops", |
| "//tensorflow/python:framework", |
| "//tensorflow/python:math_ops", |
| "//tensorflow/python:platform_test", |
| "//tensorflow/python:resource_variable_ops", |
| "//tensorflow/python:resources", |
| "//tensorflow/python:variables", |
| "//tensorflow/python/eager:context", |
| "//tensorflow/python/keras:combinations", |
| ], |
| ) |
| |
| cuda_py_test( |
| name = "ftrl_test", |
| size = "medium", |
| srcs = ["ftrl_test.py"], |
| shard_count = 4, |
| deps = [ |
| ":optimizer_v2", |
| "//tensorflow/python:client_testlib", |
| "//tensorflow/python:embedding_ops", |
| "//tensorflow/python:framework", |
| "//tensorflow/python:math_ops", |
| "//tensorflow/python:platform_test", |
| "//tensorflow/python:resource_variable_ops", |
| "//tensorflow/python:resources", |
| "//tensorflow/python:variables", |
| "//tensorflow/python/eager:context", |
| ], |
| ) |
| |
| cuda_py_test( |
| name = "gradient_descent_test", |
| size = "medium", |
| srcs = ["gradient_descent_test.py"], |
| shard_count = 4, |
| deps = [ |
| ":optimizer_v2", |
| "//tensorflow/python:client_testlib", |
| "//tensorflow/python:embedding_ops", |
| "//tensorflow/python:framework", |
| "//tensorflow/python:math_ops", |
| "//tensorflow/python:platform_test", |
| "//tensorflow/python:resource_variable_ops", |
| "//tensorflow/python:resources", |
| "//tensorflow/python:variables", |
| "//tensorflow/python/eager:context", |
| "//tensorflow/python/keras:combinations", |
| ], |
| ) |
| |
| cuda_py_test( |
| name = "nadam_test", |
| size = "medium", |
| srcs = ["nadam_test.py"], |
| shard_count = 4, |
| deps = [ |
| ":optimizer_v2", |
| "//tensorflow/python:client_testlib", |
| "//tensorflow/python:embedding_ops", |
| "//tensorflow/python:framework", |
| "//tensorflow/python:math_ops", |
| "//tensorflow/python:platform_test", |
| "//tensorflow/python:resource_variable_ops", |
| "//tensorflow/python:resources", |
| "//tensorflow/python:variables", |
| "//tensorflow/python/eager:context", |
| ], |
| ) |
| |
| cuda_py_test( |
| name = "optimizer_v2_test", |
| size = "medium", |
| srcs = ["optimizer_v2_test.py"], |
| shard_count = 8, |
| tags = [ |
| "no_rocm", |
| "no_windows", |
| ], |
| deps = [ |
| ":optimizer_v2", |
| "//tensorflow/python:array_ops", |
| "//tensorflow/python:client_testlib", |
| "//tensorflow/python:clip_ops", |
| "//tensorflow/python:framework", |
| "//tensorflow/python:framework_test_lib", |
| "//tensorflow/python:gradients", |
| "//tensorflow/python:resource_variable_ops", |
| "//tensorflow/python:state_ops", |
| "//tensorflow/python:variables", |
| "//tensorflow/python/eager:def_function", |
| "//tensorflow/python/keras", |
| "//tensorflow/python/keras:combinations", |
| "@absl_py//absl/testing:parameterized", |
| ], |
| ) |
| |
| cuda_py_test( |
| name = "learning_rate_schedule_test", |
| size = "medium", |
| srcs = ["learning_rate_schedule_test.py"], |
| shard_count = 4, |
| deps = [ |
| ":optimizer_v2", |
| "//tensorflow/python:client_testlib", |
| "//tensorflow/python/keras", |
| "//tensorflow/python/keras:combinations", |
| "//third_party/py/numpy", |
| "@absl_py//absl/testing:parameterized", |
| ], |
| ) |
| |
| cuda_py_test( |
| name = "legacy_learning_rate_decay_test", |
| size = "medium", |
| srcs = ["legacy_learning_rate_decay_test.py"], |
| deps = [ |
| ":legacy_learning_rate_decay", |
| "//tensorflow/python:framework_test_lib", |
| "//tensorflow/python:platform_test", |
| "//tensorflow/python:resource_variable_ops", |
| "//tensorflow/python:training_lib", |
| "//tensorflow/python:variables", |
| "//tensorflow/python/eager:context", |
| "//tensorflow/python/keras", |
| "//tensorflow/python/keras:combinations", |
| ], |
| ) |
| |
| cuda_py_test( |
| name = "rmsprop_test", |
| size = "medium", |
| srcs = ["rmsprop_test.py"], |
| shard_count = 2, |
| tags = ["no_rocm"], |
| xla_tags = [ |
| "no_cuda_asan", # times out |
| ], |
| # TODO(b/168527439): invalid resource variable reference on GPU for TFRT. |
| deps = [ |
| ":optimizer_v2", |
| "//tensorflow/python:array_ops", |
| "//tensorflow/python:client_testlib", |
| "//tensorflow/python:clip_ops", |
| "//tensorflow/python:framework", |
| "//tensorflow/python:framework_test_lib", |
| "//tensorflow/python:gradients", |
| "//tensorflow/python:resource_variable_ops", |
| "//tensorflow/python:state_ops", |
| "//tensorflow/python:variables", |
| "//tensorflow/python/eager:def_function", |
| "//tensorflow/python/keras:combinations", |
| "@absl_py//absl/testing:parameterized", |
| ], |
| ) |