Update keras to use tf.__internal__ for distribute tests.

PiperOrigin-RevId: 341332973
Change-Id: I61cce014764668a040bdcedb4e5e956b13983ddb
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index 8726956..95644a2 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -226,6 +226,7 @@
     ],
     deps = [
         ":optimizer_combinations",
+        ":strategy_combinations",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:util",
         "//tensorflow/python/compat:v2_compat",
@@ -246,6 +247,7 @@
         "multi_and_single_gpu",
     ],
     deps = [
+        ":strategy_combinations",
         "//tensorflow/python:errors",
         "//tensorflow/python:variables",
         "//tensorflow/python/data/ops:dataset_ops",
@@ -271,6 +273,7 @@
         "notsan",  # TODO(b/170869466)
     ],
     deps = [
+        ":strategy_combinations",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:util",
         "//tensorflow/python/data/ops:dataset_ops",
@@ -295,6 +298,7 @@
         "multi_and_single_gpu",
     ],
     deps = [
+        ":strategy_combinations",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:variables",
         "//tensorflow/python/distribute:combinations",
@@ -314,6 +318,7 @@
     ],
     deps = [
         ":optimizer_combinations",
+        ":strategy_combinations",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:training",
         "//tensorflow/python/distribute:combinations",
@@ -400,6 +405,7 @@
         "keras_stateful_lstm_model_correctness_test.py",
     ],
     deps = [
+        ":strategy_combinations",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:training",
         "//tensorflow/python/distribute:collective_all_reduce_strategy",
@@ -510,6 +516,7 @@
         "multi_and_single_gpu",
     ],
     deps = [
+        ":strategy_combinations",
         "//tensorflow/python/distribute:combinations",
         "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/eager:test",
@@ -949,3 +956,11 @@
         "//tensorflow/python/keras/engine:base_layer",
     ],
 )
+
+py_library(
+    name = "strategy_combinations",
+    srcs = ["strategy_combinations.py"],
+    deps = [
+        "//tensorflow/python/distribute:strategy_combinations",
+    ],
+)
diff --git a/tensorflow/python/keras/distribute/ctl_correctness_test.py b/tensorflow/python/keras/distribute/ctl_correctness_test.py
index 9b62431..629bc2a 100644
--- a/tensorflow/python/keras/distribute/ctl_correctness_test.py
+++ b/tensorflow/python/keras/distribute/ctl_correctness_test.py
@@ -28,7 +28,6 @@
 from tensorflow.python.distribute import combinations as ds_combinations
 from tensorflow.python.distribute import multi_process_runner
 from tensorflow.python.distribute import reduce_util
-from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import dtypes
@@ -36,6 +35,7 @@
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.framework import test_util
 from tensorflow.python.keras.distribute import optimizer_combinations
+from tensorflow.python.keras.distribute import strategy_combinations
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
 from tensorflow.python.platform import test
diff --git a/tensorflow/python/keras/distribute/custom_training_loop_metrics_test.py b/tensorflow/python/keras/distribute/custom_training_loop_metrics_test.py
index 0ad6969..08a1b7a 100644
--- a/tensorflow/python/keras/distribute/custom_training_loop_metrics_test.py
+++ b/tensorflow/python/keras/distribute/custom_training_loop_metrics_test.py
@@ -24,11 +24,11 @@
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import combinations as ds_combinations
 from tensorflow.python.distribute import multi_process_runner
-from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.keras import metrics
+from tensorflow.python.keras.distribute import strategy_combinations
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
index 4fad158..f29bc0e 100644
--- a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
+++ b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
@@ -28,10 +28,10 @@
 from tensorflow.python.distribute import combinations as ds_combinations
 from tensorflow.python.distribute import multi_process_runner
 from tensorflow.python.distribute import reduce_util
-from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import test_combinations as combinations
+from tensorflow.python.keras.distribute import strategy_combinations
 from tensorflow.python.module import module
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
diff --git a/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py b/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py
index b61534f..802d2c4 100644
--- a/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py
+++ b/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py
@@ -25,6 +25,7 @@
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_combinations as combinations
+from tensorflow.python.keras.distribute import strategy_combinations as keras_strategy_combinations
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
@@ -35,7 +36,7 @@
   @ds_combinations.generate(
       combinations.times(
           combinations.combine(
-              distribution=strategy_combinations.multidevice_strategies,
+              distribution=keras_strategy_combinations.multidevice_strategies,
               mode=["eager"],
           ),
           combinations.combine(
diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py
index eae881d..7cc018e 100644
--- a/tensorflow/python/keras/distribute/distribute_strategy_test.py
+++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py
@@ -51,6 +51,11 @@
 from tensorflow.python.keras.distribute import distributed_training_utils
 from tensorflow.python.keras.distribute import distributed_training_utils_v1
 from tensorflow.python.keras.distribute import optimizer_combinations
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
+from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
+from tensorflow.python.keras.distribute.strategy_combinations import strategies_minus_default_minus_tpu
+from tensorflow.python.keras.distribute.strategy_combinations import strategies_minus_tpu
+from tensorflow.python.keras.distribute.strategy_combinations import tpu_strategies
 from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.mixed_precision import policy
 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
@@ -230,37 +235,6 @@
   return model
 
 
-strategies_minus_default_minus_tpu = [
-    strategy_combinations.one_device_strategy,
-    strategy_combinations.one_device_strategy_gpu,
-    strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-    strategy_combinations.mirrored_strategy_with_two_gpus,
-    strategy_combinations.central_storage_strategy_with_gpu_and_cpu
-]
-
-strategies_minus_tpu = [
-    strategy_combinations.default_strategy,
-    strategy_combinations.one_device_strategy,
-    strategy_combinations.one_device_strategy_gpu,
-    strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-    strategy_combinations.mirrored_strategy_with_two_gpus,
-    strategy_combinations.central_storage_strategy_with_gpu_and_cpu
-]
-
-multi_worker_mirrored_strategies = [
-    strategy_combinations.multi_worker_mirrored_2x1_cpu,
-    strategy_combinations.multi_worker_mirrored_2x1_gpu,
-    strategy_combinations.multi_worker_mirrored_2x2_gpu,
-]
-
-tpu_strategies = [
-    strategy_combinations.tpu_strategy,
-]
-
-all_strategies = (
-    strategies_minus_tpu + tpu_strategies + multi_worker_mirrored_strategies)
-
-
 def strategy_minus_tpu_combinations():
   return combinations.combine(
       distribution=strategies_minus_tpu, mode=['graph', 'eager'])
@@ -1704,8 +1678,7 @@
     return math_ops.reduce_mean(y_pred)
 
   @ds_combinations.generate(
-      combinations.times(
-          strategy_combinations.all_strategy_combinations_minus_default()))
+      combinations.times(all_strategy_combinations_minus_default()))
   def test_regularizer_loss(self, distribution):
     batch_size = 2
     if not distributed_training_utils.global_batch_size_supported(distribution):
@@ -2648,9 +2621,7 @@
   """Tests that model creation captures the strategy."""
 
   @ds_combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=['eager']))
+      combinations.combine(distribution=all_strategies, mode=['eager']))
   def test_fit_and_evaluate(self, distribution):
     dataset = dataset_ops.DatasetV2.from_tensor_slices(
         (array_ops.ones(shape=(64,)), array_ops.ones(shape=(64,))))
diff --git a/tensorflow/python/keras/distribute/keras_correctness_test_base.py b/tensorflow/python/keras/distribute/keras_correctness_test_base.py
index f40f45c..8e806f4 100644
--- a/tensorflow/python/keras/distribute/keras_correctness_test_base.py
+++ b/tensorflow/python/keras/distribute/keras_correctness_test_base.py
@@ -32,6 +32,9 @@
 from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.keras.distribute import distributed_training_utils
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
+from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
+from tensorflow.python.keras.distribute.strategy_combinations import strategies_minus_tpu
 from tensorflow.python.keras.mixed_precision import policy
 from tensorflow.python.keras.preprocessing import sequence
 from tensorflow.python.platform import test
@@ -44,22 +47,6 @@
 # Note: Please make sure the tests in this file are also covered in
 # keras_backward_compat_test for features that are supported with both APIs.
 
-all_strategies = [
-    strategy_combinations.default_strategy,
-    strategy_combinations.one_device_strategy,
-    strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-    strategy_combinations.mirrored_strategy_with_two_gpus,
-    strategy_combinations.tpu_strategy,  # steps_per_run=2
-]
-
-
-# TODO(b/159831559): add to all_strategies once all tests pass.
-multi_worker_mirrored = [
-    strategy_combinations.multi_worker_mirrored_2x1_cpu,
-    strategy_combinations.multi_worker_mirrored_2x1_gpu,
-    strategy_combinations.multi_worker_mirrored_2x2_gpu,
-]
-
 
 def eager_mode_test_configuration():
   return combinations.combine(
@@ -85,8 +72,7 @@
 
 def strategy_minus_tpu_and_input_config_combinations_eager():
   return (combinations.times(
-      combinations.combine(
-          distribution=strategy_combinations.strategies_minus_tpu),
+      combinations.combine(distribution=strategies_minus_tpu),
       eager_mode_test_configuration()))
 
 
@@ -130,13 +116,13 @@
 
 def multi_worker_mirrored_eager():
   return combinations.times(
-      combinations.combine(distribution=multi_worker_mirrored),
+      combinations.combine(distribution=multi_worker_mirrored_strategies),
       eager_mode_test_configuration())
 
 
 def multi_worker_mirrored_eager_and_graph():
   return combinations.times(
-      combinations.combine(distribution=multi_worker_mirrored),
+      combinations.combine(distribution=multi_worker_mirrored_strategies),
       eager_mode_test_configuration() + graph_mode_test_configuration())
 
 
diff --git a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
index e6581a8..bcb4720 100644
--- a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
@@ -28,15 +28,16 @@
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import keras_correctness_test_base
+from tensorflow.python.keras.distribute import strategy_combinations
 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
 from tensorflow.python.training import gradient_descent
 
 
 def all_strategy_combinations_with_eager_and_graph_modes():
   return (combinations.combine(
-      distribution=keras_correctness_test_base.all_strategies,
+      distribution=strategy_combinations.all_strategies,
       mode=['graph', 'eager']) + combinations.combine(
-          distribution=keras_correctness_test_base.multi_worker_mirrored,
+          distribution=strategy_combinations.multi_worker_mirrored_strategies,
           mode='eager'))
 
 
diff --git a/tensorflow/python/keras/distribute/keras_models_test.py b/tensorflow/python/keras/distribute/keras_models_test.py
index 6c82545..925648e 100644
--- a/tensorflow/python/keras/distribute/keras_models_test.py
+++ b/tensorflow/python/keras/distribute/keras_models_test.py
@@ -23,8 +23,8 @@
 
 from tensorflow.python import keras
 from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.framework import test_combinations as combinations
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
 from tensorflow.python.platform import test
 
 
@@ -32,7 +32,7 @@
 
   @ds_combinations.generate(
       combinations.combine(
-          distribution=strategy_combinations.all_strategies, mode=["eager"]))
+          distribution=all_strategies, mode=["eager"]))
   def test_lstm_model_with_dynamic_batch(self, distribution):
     input_data = np.random.random([1, 32, 64, 64, 3])
     input_shape = tuple(input_data.shape[1:])
diff --git a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py
index 43092fc..1488d62 100644
--- a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py
@@ -20,6 +20,7 @@
 import numpy as np
 from tensorflow.python import keras
 from tensorflow.python import tf2
+from tensorflow.python.distribute import central_storage_strategy
 from tensorflow.python.distribute import combinations as ds_combinations
 from tensorflow.python.distribute import multi_process_runner
 from tensorflow.python.distribute import tpu_strategy
@@ -117,6 +118,11 @@
   def test_lstm_model_correctness_mixed_precision(self, distribution, use_numpy,
                                                   use_validation_data):
     if isinstance(distribution,
+                  (central_storage_strategy.CentralStorageStrategy,
+                   central_storage_strategy.CentralStorageStrategyV1)):
+      self.skipTest('CentralStorageStrategy is not supported by '
+                    'mixed precision.')
+    if isinstance(distribution,
                   (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
       policy_name = 'mixed_bfloat16'
     else:
diff --git a/tensorflow/python/keras/distribute/strategy_combinations.py b/tensorflow/python/keras/distribute/strategy_combinations.py
new file mode 100644
index 0000000..3ea3185
--- /dev/null
+++ b/tensorflow/python/keras/distribute/strategy_combinations.py
@@ -0,0 +1,59 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Strategy combinations for combinations.combine()."""
+
+from tensorflow.python.distribute import strategy_combinations
+
+
+multidevice_strategies = [
+    strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+    strategy_combinations.mirrored_strategy_with_two_gpus,
+    strategy_combinations.tpu_strategy,
+]
+
+multiworker_strategies = [
+    strategy_combinations.multi_worker_mirrored_2x1_cpu,
+    strategy_combinations.multi_worker_mirrored_2x1_gpu,
+    strategy_combinations.multi_worker_mirrored_2x2_gpu
+]
+
+strategies_minus_default_minus_tpu = [
+    strategy_combinations.one_device_strategy,
+    strategy_combinations.one_device_strategy_gpu,
+    strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+    strategy_combinations.mirrored_strategy_with_two_gpus,
+    strategy_combinations.central_storage_strategy_with_gpu_and_cpu
+]
+
+strategies_minus_tpu = [
+    strategy_combinations.default_strategy,
+    strategy_combinations.one_device_strategy,
+    strategy_combinations.one_device_strategy_gpu,
+    strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+    strategy_combinations.mirrored_strategy_with_two_gpus,
+    strategy_combinations.central_storage_strategy_with_gpu_and_cpu
+]
+
+multi_worker_mirrored_strategies = [
+    strategy_combinations.multi_worker_mirrored_2x1_cpu,
+    strategy_combinations.multi_worker_mirrored_2x1_gpu,
+    strategy_combinations.multi_worker_mirrored_2x2_gpu,
+]
+
+tpu_strategies = [
+    strategy_combinations.tpu_strategy,
+]
+
+all_strategies = strategies_minus_tpu + tpu_strategies
diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD
index d131001..177bfa4 100644
--- a/tensorflow/python/keras/layers/preprocessing/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/BUILD
@@ -404,9 +404,9 @@
         "//tensorflow/python:framework_test_combinations_lib",
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/distribute:combinations",
-        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/distribute:tpu_strategy",
         "//tensorflow/python/keras",
+        "//tensorflow/python/keras/distribute:strategy_combinations",
     ],
 )
 
@@ -430,9 +430,9 @@
         "//tensorflow/python:framework_test_combinations_lib",
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/distribute:combinations",
-        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/distribute:tpu_strategy",
         "//tensorflow/python/keras",
+        "//tensorflow/python/keras/distribute:strategy_combinations",
     ],
 )
 
@@ -453,8 +453,8 @@
         "//tensorflow/python:framework_test_combinations_lib",
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/distribute:combinations",
-        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/keras",
+        "//tensorflow/python/keras/distribute:strategy_combinations",
     ],
 )
 
@@ -489,8 +489,8 @@
         "//tensorflow/python:config",
         "//tensorflow/python:framework_test_combinations_lib",
         "//tensorflow/python/distribute:combinations",
-        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/keras",
+        "//tensorflow/python/keras/distribute:strategy_combinations",
     ],
 )
 
@@ -526,8 +526,8 @@
     deps = [
         ":hashing",
         "//tensorflow/python/distribute:combinations",
-        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/keras",
+        "//tensorflow/python/keras/distribute:strategy_combinations",
     ],
 )
 
@@ -557,8 +557,8 @@
     deps = [
         ":index_lookup",
         "//tensorflow/python/distribute:combinations",
-        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/keras",
+        "//tensorflow/python/keras/distribute:strategy_combinations",
     ],
 )
 
@@ -637,9 +637,9 @@
         "//tensorflow/python:framework_test_combinations_lib",
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/distribute:combinations",
-        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/keras",
+        "//tensorflow/python/keras/distribute:strategy_combinations",
     ],
 )
 
@@ -694,9 +694,9 @@
         "//tensorflow/python:framework_test_combinations_lib",
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/distribute:combinations",
-        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/keras",
+        "//tensorflow/python/keras/distribute:strategy_combinations",
     ],
 )
 
diff --git a/tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py
index 867d1c6..3a00e96 100644
--- a/tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py
@@ -23,12 +23,12 @@
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.distribute import tpu_strategy
 from tensorflow.python.framework import config
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
 from tensorflow.python.keras.layers.preprocessing import category_crossing
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
 from tensorflow.python.platform import test
@@ -49,7 +49,7 @@
 @ds_combinations.generate(
     combinations.combine(
         # Investigate why crossing is not supported with TPU.
-        distribution=strategy_combinations.all_strategies,
+        distribution=all_strategies,
         mode=['eager', 'graph']))
 class CategoryCrossingDistributionTest(
     keras_parameterized.TestCase,
diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py
index 7d6cff9..c95be0a 100644
--- a/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py
@@ -23,12 +23,12 @@
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.distribute import tpu_strategy
 from tensorflow.python.framework import config
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute import strategy_combinations
 from tensorflow.python.keras.layers.preprocessing import category_encoding
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
 from tensorflow.python.platform import test
diff --git a/tensorflow/python/keras/layers/preprocessing/discretization_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/discretization_distribution_test.py
index 208aca9..12e08de 100644
--- a/tensorflow/python/keras/layers/preprocessing/discretization_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/discretization_distribution_test.py
@@ -22,10 +22,10 @@
 
 from tensorflow.python import keras
 from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.framework import config
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute import strategy_combinations
 from tensorflow.python.keras.layers.preprocessing import discretization
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
 from tensorflow.python.platform import test
diff --git a/tensorflow/python/keras/layers/preprocessing/hashing_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/hashing_distribution_test.py
index 2698259..78d9400 100644
--- a/tensorflow/python/keras/layers/preprocessing/hashing_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/hashing_distribution_test.py
@@ -23,11 +23,11 @@
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.framework import config
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
 from tensorflow.python.keras.layers.preprocessing import hashing
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
 from tensorflow.python.platform import test
@@ -35,7 +35,7 @@
 
 @ds_combinations.generate(
     combinations.combine(
-        distribution=strategy_combinations.all_strategies,
+        distribution=all_strategies,
         mode=["eager", "graph"]))
 class HashingDistributionTest(keras_parameterized.TestCase,
                               preprocessing_test_utils.PreprocessingLayerTest):
diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_distribution_test.py
index 2932ca3..2d1f3a6 100644
--- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_distribution_test.py
@@ -23,10 +23,10 @@
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
 from tensorflow.python.keras.layers.preprocessing import image_preprocessing
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
 from tensorflow.python.platform import test
@@ -34,7 +34,7 @@
 
 @ds_combinations.generate(
     combinations.combine(
-        distribution=strategy_combinations.all_strategies,
+        distribution=all_strategies,
         mode=["eager", "graph"]))
 class ImagePreprocessingDistributionTest(
     keras_parameterized.TestCase,
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py
index b421990..c4ef047 100644
--- a/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py
@@ -23,12 +23,12 @@
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.eager import context
 from tensorflow.python.framework import config
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
 from tensorflow.python.keras.layers.preprocessing import index_lookup
 from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
@@ -44,7 +44,7 @@
 
 @ds_combinations.generate(
     combinations.combine(
-        distribution=strategy_combinations.all_strategies,
+        distribution=all_strategies,
         mode=["eager"]))  # Eager-only, no graph: b/158793009
 class IndexLookupDistributionTest(
     keras_parameterized.TestCase,
diff --git a/tensorflow/python/keras/layers/preprocessing/normalization_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/normalization_distribution_test.py
index 4bf15da..478529c 100644
--- a/tensorflow/python/keras/layers/preprocessing/normalization_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/normalization_distribution_test.py
@@ -23,10 +23,10 @@
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.eager import context
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
 from tensorflow.python.keras.layers.preprocessing import normalization
 from tensorflow.python.keras.layers.preprocessing import normalization_v1
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
@@ -108,7 +108,7 @@
 @ds_combinations.generate(
     combinations.times(
         combinations.combine(
-            distribution=strategy_combinations.all_strategies,
+            distribution=all_strategies,
             mode=["eager", "graph"]), _get_layer_computation_test_cases()))
 class NormalizationTest(keras_parameterized.TestCase,
                         preprocessing_test_utils.PreprocessingLayerTest):
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_distribution_test.py
index 222bcd6..9dc8bcc 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_distribution_test.py
@@ -23,12 +23,12 @@
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.eager import context
 from tensorflow.python.framework import config
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
 from tensorflow.python.keras.layers.preprocessing import text_vectorization
 from tensorflow.python.keras.layers.preprocessing import text_vectorization_v1
@@ -44,7 +44,7 @@
 
 @ds_combinations.generate(
     combinations.combine(
-        distribution=strategy_combinations.all_strategies,
+        distribution=all_strategies,
         mode=["eager", "graph"]))
 class TextVectorizationDistributionTest(
     keras_parameterized.TestCase,