Update keras to use tf.__internal__ for distribute tests.
PiperOrigin-RevId: 341332973
Change-Id: I61cce014764668a040bdcedb4e5e956b13983ddb
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index 8726956..95644a2 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -226,6 +226,7 @@
],
deps = [
":optimizer_combinations",
+ ":strategy_combinations",
"//tensorflow/python:platform_test",
"//tensorflow/python:util",
"//tensorflow/python/compat:v2_compat",
@@ -246,6 +247,7 @@
"multi_and_single_gpu",
],
deps = [
+ ":strategy_combinations",
"//tensorflow/python:errors",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
@@ -271,6 +273,7 @@
"notsan", # TODO(b/170869466)
],
deps = [
+ ":strategy_combinations",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
@@ -295,6 +298,7 @@
"multi_and_single_gpu",
],
deps = [
+ ":strategy_combinations",
"//tensorflow/python:framework_ops",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:combinations",
@@ -314,6 +318,7 @@
],
deps = [
":optimizer_combinations",
+ ":strategy_combinations",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/distribute:combinations",
@@ -400,6 +405,7 @@
"keras_stateful_lstm_model_correctness_test.py",
],
deps = [
+ ":strategy_combinations",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/distribute:collective_all_reduce_strategy",
@@ -510,6 +516,7 @@
"multi_and_single_gpu",
],
deps = [
+ ":strategy_combinations",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/eager:test",
@@ -949,3 +956,11 @@
"//tensorflow/python/keras/engine:base_layer",
],
)
+
+py_library(
+ name = "strategy_combinations",
+ srcs = ["strategy_combinations.py"],
+ deps = [
+ "//tensorflow/python/distribute:strategy_combinations",
+ ],
+)
diff --git a/tensorflow/python/keras/distribute/ctl_correctness_test.py b/tensorflow/python/keras/distribute/ctl_correctness_test.py
index 9b62431..629bc2a 100644
--- a/tensorflow/python/keras/distribute/ctl_correctness_test.py
+++ b/tensorflow/python/keras/distribute/ctl_correctness_test.py
@@ -28,7 +28,6 @@
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import reduce_util
-from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
@@ -36,6 +35,7 @@
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.framework import test_util
from tensorflow.python.keras.distribute import optimizer_combinations
+from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
diff --git a/tensorflow/python/keras/distribute/custom_training_loop_metrics_test.py b/tensorflow/python/keras/distribute/custom_training_loop_metrics_test.py
index 0ad6969..08a1b7a 100644
--- a/tensorflow/python/keras/distribute/custom_training_loop_metrics_test.py
+++ b/tensorflow/python/keras/distribute/custom_training_loop_metrics_test.py
@@ -24,11 +24,11 @@
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
-from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import metrics
+from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.platform import test
diff --git a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
index 4fad158..f29bc0e 100644
--- a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
+++ b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
@@ -28,10 +28,10 @@
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import reduce_util
-from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import test_combinations as combinations
+from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.module import module
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py b/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py
index b61534f..802d2c4 100644
--- a/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py
+++ b/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py
@@ -25,6 +25,7 @@
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations
+from tensorflow.python.keras.distribute import strategy_combinations as keras_strategy_combinations
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -35,7 +36,7 @@
@ds_combinations.generate(
combinations.times(
combinations.combine(
- distribution=strategy_combinations.multidevice_strategies,
+ distribution=keras_strategy_combinations.multidevice_strategies,
mode=["eager"],
),
combinations.combine(
diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py
index eae881d..7cc018e 100644
--- a/tensorflow/python/keras/distribute/distribute_strategy_test.py
+++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py
@@ -51,6 +51,11 @@
from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.distribute import distributed_training_utils_v1
from tensorflow.python.keras.distribute import optimizer_combinations
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
+from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
+from tensorflow.python.keras.distribute.strategy_combinations import strategies_minus_default_minus_tpu
+from tensorflow.python.keras.distribute.strategy_combinations import strategies_minus_tpu
+from tensorflow.python.keras.distribute.strategy_combinations import tpu_strategies
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.mixed_precision import policy
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
@@ -230,37 +235,6 @@
return model
-strategies_minus_default_minus_tpu = [
- strategy_combinations.one_device_strategy,
- strategy_combinations.one_device_strategy_gpu,
- strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
- strategy_combinations.mirrored_strategy_with_two_gpus,
- strategy_combinations.central_storage_strategy_with_gpu_and_cpu
-]
-
-strategies_minus_tpu = [
- strategy_combinations.default_strategy,
- strategy_combinations.one_device_strategy,
- strategy_combinations.one_device_strategy_gpu,
- strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
- strategy_combinations.mirrored_strategy_with_two_gpus,
- strategy_combinations.central_storage_strategy_with_gpu_and_cpu
-]
-
-multi_worker_mirrored_strategies = [
- strategy_combinations.multi_worker_mirrored_2x1_cpu,
- strategy_combinations.multi_worker_mirrored_2x1_gpu,
- strategy_combinations.multi_worker_mirrored_2x2_gpu,
-]
-
-tpu_strategies = [
- strategy_combinations.tpu_strategy,
-]
-
-all_strategies = (
- strategies_minus_tpu + tpu_strategies + multi_worker_mirrored_strategies)
-
-
def strategy_minus_tpu_combinations():
return combinations.combine(
distribution=strategies_minus_tpu, mode=['graph', 'eager'])
@@ -1704,8 +1678,7 @@
return math_ops.reduce_mean(y_pred)
@ds_combinations.generate(
- combinations.times(
- strategy_combinations.all_strategy_combinations_minus_default()))
+ combinations.times(all_strategy_combinations_minus_default()))
def test_regularizer_loss(self, distribution):
batch_size = 2
if not distributed_training_utils.global_batch_size_supported(distribution):
@@ -2648,9 +2621,7 @@
"""Tests that model creation captures the strategy."""
@ds_combinations.generate(
- combinations.combine(
- distribution=strategy_combinations.all_strategies,
- mode=['eager']))
+ combinations.combine(distribution=all_strategies, mode=['eager']))
def test_fit_and_evaluate(self, distribution):
dataset = dataset_ops.DatasetV2.from_tensor_slices(
(array_ops.ones(shape=(64,)), array_ops.ones(shape=(64,))))
diff --git a/tensorflow/python/keras/distribute/keras_correctness_test_base.py b/tensorflow/python/keras/distribute/keras_correctness_test_base.py
index f40f45c..8e806f4 100644
--- a/tensorflow/python/keras/distribute/keras_correctness_test_base.py
+++ b/tensorflow/python/keras/distribute/keras_correctness_test_base.py
@@ -32,6 +32,9 @@
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.distribute import distributed_training_utils
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
+from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
+from tensorflow.python.keras.distribute.strategy_combinations import strategies_minus_tpu
from tensorflow.python.keras.mixed_precision import policy
from tensorflow.python.keras.preprocessing import sequence
from tensorflow.python.platform import test
@@ -44,22 +47,6 @@
# Note: Please make sure the tests in this file are also covered in
# keras_backward_compat_test for features that are supported with both APIs.
-all_strategies = [
- strategy_combinations.default_strategy,
- strategy_combinations.one_device_strategy,
- strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
- strategy_combinations.mirrored_strategy_with_two_gpus,
- strategy_combinations.tpu_strategy, # steps_per_run=2
-]
-
-
-# TODO(b/159831559): add to all_strategies once all tests pass.
-multi_worker_mirrored = [
- strategy_combinations.multi_worker_mirrored_2x1_cpu,
- strategy_combinations.multi_worker_mirrored_2x1_gpu,
- strategy_combinations.multi_worker_mirrored_2x2_gpu,
-]
-
def eager_mode_test_configuration():
return combinations.combine(
@@ -85,8 +72,7 @@
def strategy_minus_tpu_and_input_config_combinations_eager():
return (combinations.times(
- combinations.combine(
- distribution=strategy_combinations.strategies_minus_tpu),
+ combinations.combine(distribution=strategies_minus_tpu),
eager_mode_test_configuration()))
@@ -130,13 +116,13 @@
def multi_worker_mirrored_eager():
return combinations.times(
- combinations.combine(distribution=multi_worker_mirrored),
+ combinations.combine(distribution=multi_worker_mirrored_strategies),
eager_mode_test_configuration())
def multi_worker_mirrored_eager_and_graph():
return combinations.times(
- combinations.combine(distribution=multi_worker_mirrored),
+ combinations.combine(distribution=multi_worker_mirrored_strategies),
eager_mode_test_configuration() + graph_mode_test_configuration())
diff --git a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
index e6581a8..bcb4720 100644
--- a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
@@ -28,15 +28,16 @@
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import keras_correctness_test_base
+from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
from tensorflow.python.training import gradient_descent
def all_strategy_combinations_with_eager_and_graph_modes():
return (combinations.combine(
- distribution=keras_correctness_test_base.all_strategies,
+ distribution=strategy_combinations.all_strategies,
mode=['graph', 'eager']) + combinations.combine(
- distribution=keras_correctness_test_base.multi_worker_mirrored,
+ distribution=strategy_combinations.multi_worker_mirrored_strategies,
mode='eager'))
diff --git a/tensorflow/python/keras/distribute/keras_models_test.py b/tensorflow/python/keras/distribute/keras_models_test.py
index 6c82545..925648e 100644
--- a/tensorflow/python/keras/distribute/keras_models_test.py
+++ b/tensorflow/python/keras/distribute/keras_models_test.py
@@ -23,8 +23,8 @@
from tensorflow.python import keras
from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import test_combinations as combinations
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.platform import test
@@ -32,7 +32,7 @@
@ds_combinations.generate(
combinations.combine(
- distribution=strategy_combinations.all_strategies, mode=["eager"]))
+ distribution=all_strategies, mode=["eager"]))
def test_lstm_model_with_dynamic_batch(self, distribution):
input_data = np.random.random([1, 32, 64, 64, 3])
input_shape = tuple(input_data.shape[1:])
diff --git a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py
index 43092fc..1488d62 100644
--- a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py
@@ -20,6 +20,7 @@
import numpy as np
from tensorflow.python import keras
from tensorflow.python import tf2
+from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import tpu_strategy
@@ -117,6 +118,11 @@
def test_lstm_model_correctness_mixed_precision(self, distribution, use_numpy,
use_validation_data):
if isinstance(distribution,
+ (central_storage_strategy.CentralStorageStrategy,
+ central_storage_strategy.CentralStorageStrategyV1)):
+ self.skipTest('CentralStorageStrategy is not supported by '
+ 'mixed precision.')
+ if isinstance(distribution,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
policy_name = 'mixed_bfloat16'
else:
diff --git a/tensorflow/python/keras/distribute/strategy_combinations.py b/tensorflow/python/keras/distribute/strategy_combinations.py
new file mode 100644
index 0000000..3ea3185
--- /dev/null
+++ b/tensorflow/python/keras/distribute/strategy_combinations.py
@@ -0,0 +1,59 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Strategy combinations for combinations.combine()."""
+
+from tensorflow.python.distribute import strategy_combinations
+
+
+multidevice_strategies = [
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.mirrored_strategy_with_two_gpus,
+ strategy_combinations.tpu_strategy,
+]
+
+multiworker_strategies = [
+ strategy_combinations.multi_worker_mirrored_2x1_cpu,
+ strategy_combinations.multi_worker_mirrored_2x1_gpu,
+ strategy_combinations.multi_worker_mirrored_2x2_gpu
+]
+
+strategies_minus_default_minus_tpu = [
+ strategy_combinations.one_device_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.mirrored_strategy_with_two_gpus,
+ strategy_combinations.central_storage_strategy_with_gpu_and_cpu
+]
+
+strategies_minus_tpu = [
+ strategy_combinations.default_strategy,
+ strategy_combinations.one_device_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.mirrored_strategy_with_two_gpus,
+ strategy_combinations.central_storage_strategy_with_gpu_and_cpu
+]
+
+multi_worker_mirrored_strategies = [
+ strategy_combinations.multi_worker_mirrored_2x1_cpu,
+ strategy_combinations.multi_worker_mirrored_2x1_gpu,
+ strategy_combinations.multi_worker_mirrored_2x2_gpu,
+]
+
+tpu_strategies = [
+ strategy_combinations.tpu_strategy,
+]
+
+all_strategies = strategies_minus_tpu + tpu_strategies
diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD
index d131001..177bfa4 100644
--- a/tensorflow/python/keras/layers/preprocessing/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/BUILD
@@ -404,9 +404,9 @@
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
- "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/distribute:tpu_strategy",
"//tensorflow/python/keras",
+ "//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@@ -430,9 +430,9 @@
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
- "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/distribute:tpu_strategy",
"//tensorflow/python/keras",
+ "//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@@ -453,8 +453,8 @@
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
- "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras",
+ "//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@@ -489,8 +489,8 @@
"//tensorflow/python:config",
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/distribute:combinations",
- "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras",
+ "//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@@ -526,8 +526,8 @@
deps = [
":hashing",
"//tensorflow/python/distribute:combinations",
- "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras",
+ "//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@@ -557,8 +557,8 @@
deps = [
":index_lookup",
"//tensorflow/python/distribute:combinations",
- "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras",
+ "//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@@ -637,9 +637,9 @@
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
- "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/eager:context",
"//tensorflow/python/keras",
+ "//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@@ -694,9 +694,9 @@
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
- "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/eager:context",
"//tensorflow/python/keras",
+ "//tensorflow/python/keras/distribute:strategy_combinations",
],
)
diff --git a/tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py
index 867d1c6..3a00e96 100644
--- a/tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py
@@ -23,12 +23,12 @@
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import category_crossing
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test
@@ -49,7 +49,7 @@
@ds_combinations.generate(
combinations.combine(
# Investigate why crossing is not supported with TPU.
- distribution=strategy_combinations.all_strategies,
+ distribution=all_strategies,
mode=['eager', 'graph']))
class CategoryCrossingDistributionTest(
keras_parameterized.TestCase,
diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py
index 7d6cff9..c95be0a 100644
--- a/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py
@@ -23,12 +23,12 @@
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.keras.layers.preprocessing import category_encoding
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test
diff --git a/tensorflow/python/keras/layers/preprocessing/discretization_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/discretization_distribution_test.py
index 208aca9..12e08de 100644
--- a/tensorflow/python/keras/layers/preprocessing/discretization_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/discretization_distribution_test.py
@@ -22,10 +22,10 @@
from tensorflow.python import keras
from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import config
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.keras.layers.preprocessing import discretization
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test
diff --git a/tensorflow/python/keras/layers/preprocessing/hashing_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/hashing_distribution_test.py
index 2698259..78d9400 100644
--- a/tensorflow/python/keras/layers/preprocessing/hashing_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/hashing_distribution_test.py
@@ -23,11 +23,11 @@
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import hashing
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test
@@ -35,7 +35,7 @@
@ds_combinations.generate(
combinations.combine(
- distribution=strategy_combinations.all_strategies,
+ distribution=all_strategies,
mode=["eager", "graph"]))
class HashingDistributionTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_distribution_test.py
index 2932ca3..2d1f3a6 100644
--- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_distribution_test.py
@@ -23,10 +23,10 @@
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test
@@ -34,7 +34,7 @@
@ds_combinations.generate(
combinations.combine(
- distribution=strategy_combinations.all_strategies,
+ distribution=all_strategies,
mode=["eager", "graph"]))
class ImagePreprocessingDistributionTest(
keras_parameterized.TestCase,
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py
index b421990..c4ef047 100644
--- a/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py
@@ -23,12 +23,12 @@
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import index_lookup
from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
@@ -44,7 +44,7 @@
@ds_combinations.generate(
combinations.combine(
- distribution=strategy_combinations.all_strategies,
+ distribution=all_strategies,
mode=["eager"])) # Eager-only, no graph: b/158793009
class IndexLookupDistributionTest(
keras_parameterized.TestCase,
diff --git a/tensorflow/python/keras/layers/preprocessing/normalization_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/normalization_distribution_test.py
index 4bf15da..478529c 100644
--- a/tensorflow/python/keras/layers/preprocessing/normalization_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/normalization_distribution_test.py
@@ -23,10 +23,10 @@
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import normalization
from tensorflow.python.keras.layers.preprocessing import normalization_v1
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
@@ -108,7 +108,7 @@
@ds_combinations.generate(
combinations.times(
combinations.combine(
- distribution=strategy_combinations.all_strategies,
+ distribution=all_strategies,
mode=["eager", "graph"]), _get_layer_computation_test_cases()))
class NormalizationTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_distribution_test.py
index 222bcd6..9dc8bcc 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_distribution_test.py
@@ -23,12 +23,12 @@
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
-from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.keras.layers.preprocessing import text_vectorization
from tensorflow.python.keras.layers.preprocessing import text_vectorization_v1
@@ -44,7 +44,7 @@
@ds_combinations.generate(
combinations.combine(
- distribution=strategy_combinations.all_strategies,
+ distribution=all_strategies,
mode=["eager", "graph"]))
class TextVectorizationDistributionTest(
keras_parameterized.TestCase,