Add distribution strategy tests for Keras preprocessing layers.

PiperOrigin-RevId: 309027642
Change-Id: Id38df0a26c9e134034f9b8ceadc7f9b91c219158
diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD
index bca49b9..2a9971f 100644
--- a/tensorflow/python/keras/layers/preprocessing/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/BUILD
@@ -266,17 +266,19 @@
     ],
 )
 
-tpu_py_test(
-    name = "categorical_encoding_tpu_test",
-    srcs = ["categorical_encoding_tpu_test.py"],
-    disable_experimental = True,
+distribute_py_test(
+    name = "categorical_encoding_distribution_test",
+    srcs = ["categorical_encoding_distribution_test.py"],
+    main = "categorical_encoding_distribution_test.py",
     python_version = "PY3",
-    tags = ["no_oss"],
+    tags = [
+        "multi_and_single_gpu",
+    ],
     deps = [
         ":categorical_encoding",
-        "//tensorflow/python/distribute:tpu_strategy",
+        "//tensorflow/python/distribute:combinations",
+        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/keras",
-        "//tensorflow/python/keras/distribute:tpu_strategy_test_utils",
     ],
 )
 
@@ -293,17 +295,17 @@
     ],
 )
 
-tpu_py_test(
-    name = "discretization_tpu_test",
-    srcs = ["discretization_tpu_test.py"],
-    disable_experimental = True,
+distribute_py_test(
+    name = "discretization_distribution_test",
+    srcs = ["discretization_distribution_test.py"],
+    main = "discretization_distribution_test.py",
     python_version = "PY3",
-    tags = ["no_oss"],
+    tags = ["multi_and_single_gpu"],
     deps = [
         ":discretization",
-        "//tensorflow/python/distribute:tpu_strategy",
+        "//tensorflow/python/distribute:combinations",
+        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/keras",
-        "//tensorflow/python/keras/distribute:tpu_strategy_test_utils",
     ],
 )
 
@@ -322,16 +324,16 @@
 )
 
 tpu_py_test(
-    name = "hashing_tpu_test",
-    srcs = ["hashing_tpu_test.py"],
-    disable_experimental = True,
+    name = "hashing_distribution_test",
+    srcs = ["hashing_distribution_test.py"],
+    main = "hashing_distribution_test.py",
     python_version = "PY3",
-    tags = ["no_oss"],
+    tags = ["multi_and_single_gpu"],
     deps = [
         ":hashing",
-        "//tensorflow/python/distribute:tpu_strategy",
+        "//tensorflow/python/distribute:combinations",
+        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/keras",
-        "//tensorflow/python/keras/distribute:tpu_strategy_test_utils",
     ],
 )
 
@@ -352,16 +354,16 @@
 )
 
 tpu_py_test(
-    name = "index_lookup_tpu_test",
-    srcs = ["index_lookup_tpu_test.py"],
-    disable_experimental = True,
+    name = "index_lookup_distribution_test",
+    srcs = ["index_lookup_distribution_test.py"],
+    main = "index_lookup_distribution_test.py",
     python_version = "PY3",
     tags = ["no_oss"],
     deps = [
         ":index_lookup",
-        "//tensorflow/python/distribute:tpu_strategy",
+        "//tensorflow/python/distribute:combinations",
+        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/keras",
-        "//tensorflow/python/keras/distribute:tpu_strategy_test_utils",
     ],
 )
 
diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_encoding_tpu_test.py b/tensorflow/python/keras/layers/preprocessing/categorical_encoding_distribution_test.py
similarity index 85%
rename from tensorflow/python/keras/layers/preprocessing/categorical_encoding_tpu_test.py
rename to tensorflow/python/keras/layers/preprocessing/categorical_encoding_distribution_test.py
index c3bba2f..c521453 100644
--- a/tensorflow/python/keras/layers/preprocessing/categorical_encoding_tpu_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/categorical_encoding_distribution_test.py
@@ -21,21 +21,24 @@
 import numpy as np
 
 from tensorflow.python import keras
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.framework import dtypes
 from tensorflow.python.keras import keras_parameterized
-from tensorflow.python.keras.distribute import tpu_strategy_test_utils
 from tensorflow.python.keras.layers.preprocessing import categorical_encoding
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
 from tensorflow.python.platform import test
 
 
-@keras_parameterized.run_all_keras_modes(
-    always_skip_v1=True, always_skip_eager=True)
+@combinations.generate(
+    combinations.combine(
+        distribution=strategy_combinations.all_strategies,
+        mode=["eager", "graph"]))
 class CategoricalEncodingDistributionTest(
     keras_parameterized.TestCase,
     preprocessing_test_utils.PreprocessingLayerTest):
 
-  def test_tpu_distribution(self):
+  def test_distribution(self, distribution):
     input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
 
     # pyformat: disable
@@ -44,9 +47,7 @@
     # pyformat: enable
     max_tokens = 6
 
-    strategy = tpu_strategy_test_utils.get_tpu_strategy()
-
-    with strategy.scope():
+    with distribution.scope():
       input_data = keras.Input(shape=(4,), dtype=dtypes.int32)
       layer = categorical_encoding.CategoricalEncoding(
           max_tokens=max_tokens, output_mode=categorical_encoding.BINARY)
diff --git a/tensorflow/python/keras/layers/preprocessing/discretization_tpu_test.py b/tensorflow/python/keras/layers/preprocessing/discretization_distribution_test.py
similarity index 85%
rename from tensorflow/python/keras/layers/preprocessing/discretization_tpu_test.py
rename to tensorflow/python/keras/layers/preprocessing/discretization_distribution_test.py
index 005f8b0..7da40b8 100644
--- a/tensorflow/python/keras/layers/preprocessing/discretization_tpu_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/discretization_distribution_test.py
@@ -21,27 +21,29 @@
 import numpy as np
 
 from tensorflow.python import keras
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.keras import keras_parameterized
-from tensorflow.python.keras.distribute import tpu_strategy_test_utils
 from tensorflow.python.keras.layers.preprocessing import discretization
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
 from tensorflow.python.platform import test
 
 
-@keras_parameterized.run_all_keras_modes(
-    always_skip_v1=True, always_skip_eager=True)
+@combinations.generate(
+    combinations.combine(
+        distribution=strategy_combinations.all_strategies,
+        mode=["eager", "graph"]))
 class DiscretizationDistributionTest(
     keras_parameterized.TestCase,
     preprocessing_test_utils.PreprocessingLayerTest):
 
-  def test_tpu_distribution(self):
+  def test_distribution(self, distribution):
     input_array = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])
 
     expected_output = [[0, 2, 3, 1], [1, 3, 2, 1]]
     expected_output_shape = [None, None]
 
-    strategy = tpu_strategy_test_utils.get_tpu_strategy()
-    with strategy.scope():
+    with distribution.scope():
       input_data = keras.Input(shape=(None,))
       layer = discretization.Discretization(
           bins=[0., 1., 2.], output_mode=discretization.INTEGER)
diff --git a/tensorflow/python/keras/layers/preprocessing/hashing_tpu_test.py b/tensorflow/python/keras/layers/preprocessing/hashing_distribution_test.py
similarity index 85%
rename from tensorflow/python/keras/layers/preprocessing/hashing_tpu_test.py
rename to tensorflow/python/keras/layers/preprocessing/hashing_distribution_test.py
index e2e6d98..0cfd1ab 100644
--- a/tensorflow/python/keras/layers/preprocessing/hashing_tpu_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/hashing_distribution_test.py
@@ -22,30 +22,32 @@
 
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.framework import config
 from tensorflow.python.framework import dtypes
 from tensorflow.python.keras import keras_parameterized
-from tensorflow.python.keras.distribute import tpu_strategy_test_utils
 from tensorflow.python.keras.layers.preprocessing import hashing
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
 from tensorflow.python.platform import test
 
 
-@keras_parameterized.run_all_keras_modes(
-    always_skip_v1=True, always_skip_eager=True)
+@combinations.generate(
+    combinations.combine(
+        distribution=strategy_combinations.all_strategies,
+        mode=["eager", "graph"]))
 class HashingDistributionTest(keras_parameterized.TestCase,
                               preprocessing_test_utils.PreprocessingLayerTest):
 
-  def test_tpu_distribution(self):
+  def test_distribution(self, distribution):
     input_data = np.asarray([["omar"], ["stringer"], ["marlo"], ["wire"]])
     input_dataset = dataset_ops.Dataset.from_tensor_slices(input_data).batch(
         2, drop_remainder=True)
     expected_output = [[0], [0], [1], [0]]
 
     config.set_soft_device_placement(True)
-    strategy = tpu_strategy_test_utils.get_tpu_strategy()
 
-    with strategy.scope():
+    with distribution.scope():
       input_data = keras.Input(shape=(None,), dtype=dtypes.string)
       layer = hashing.Hashing(num_bins=2)
       int_data = layer(input_data)
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_tpu_test.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py
similarity index 78%
rename from tensorflow/python/keras/layers/preprocessing/index_lookup_tpu_test.py
rename to tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py
index b371eec..3360dad 100644
--- a/tensorflow/python/keras/layers/preprocessing/index_lookup_tpu_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py
@@ -22,22 +22,34 @@
 
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from tensorflow.python.eager import context
 from tensorflow.python.framework import config
 from tensorflow.python.framework import dtypes
 from tensorflow.python.keras import keras_parameterized
-from tensorflow.python.keras.distribute import tpu_strategy_test_utils
 from tensorflow.python.keras.layers.preprocessing import index_lookup
+from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
 from tensorflow.python.platform import test
 
 
-@keras_parameterized.run_all_keras_modes(
-    always_skip_v1=True, always_skip_eager=True)
+def get_layer_class():
+  if context.executing_eagerly():
+    return index_lookup.IndexLookup
+  else:
+    return index_lookup_v1.IndexLookup
+
+
+@combinations.generate(
+    combinations.combine(
+        distribution=strategy_combinations.all_strategies,
+        mode=["eager", "graph"]))
 class IndexLookupDistributionTest(
     keras_parameterized.TestCase,
     preprocessing_test_utils.PreprocessingLayerTest):
 
-  def test_tpu_distribution(self):
+  def test_tpu_distribution(self, distribution):
     vocab_data = [[
         "earth", "earth", "earth", "earth", "wind", "wind", "wind", "and",
         "and", "fire"
@@ -50,11 +62,10 @@
     expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]
 
     config.set_soft_device_placement(True)
-    strategy = tpu_strategy_test_utils.get_tpu_strategy()
 
-    with strategy.scope():
+    with distribution.scope():
       input_data = keras.Input(shape=(None,), dtype=dtypes.string)
-      layer = index_lookup.IndexLookup()
+      layer = get_layer_class()()
       layer.adapt(vocab_dataset)
       int_data = layer(input_data)
       model = keras.Model(inputs=input_data, outputs=int_data)