Move `lstm_model_with_dynamic_batch` test to a separate file.
PiperOrigin-RevId: 327858513
Change-Id: I907a6b808a540fc36a19d721c7361efd2f00ff99
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index adc9523..f00fbe6 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -469,6 +469,21 @@
)
distribute_py_test(
+ name = "keras_models_test",
+ srcs = ["keras_models_test.py"],
+ main = "keras_models_test.py",
+ tags = [
+ "multi_and_single_gpu",
+ ],
+ deps = [
+ "//tensorflow/python/distribute:combinations",
+ "//tensorflow/python/distribute:strategy_combinations",
+ "//tensorflow/python/eager:test",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+distribute_py_test(
name = "keras_rnn_model_correctness_test",
size = "medium",
srcs = ["keras_rnn_model_correctness_test.py"],
diff --git a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
index 3a32410..a327f87 100644
--- a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
+++ b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
@@ -251,33 +251,6 @@
train_step(input_iterator)
- @combinations.generate(
- combinations.combine(
- distribution=strategy_combinations.all_strategies,
- mode=["eager"]))
- def test_model_predict_with_dynamic_batch(self, distribution):
- input_data = np.random.random([1, 32, 64, 64, 3])
- input_shape = tuple(input_data.shape[1:])
-
- def build_model():
- model = keras.models.Sequential()
- model.add(
- keras.layers.ConvLSTM2D(
- 4,
- kernel_size=(4, 4),
- activation="sigmoid",
- padding="same",
- input_shape=input_shape))
- model.add(keras.layers.GlobalMaxPooling2D())
- model.add(keras.layers.Dense(2, activation="sigmoid"))
- return model
-
- with distribution.scope():
- model = build_model()
- model.compile(loss="binary_crossentropy", optimizer="adam")
- result = model.predict(input_data)
- self.assertEqual(result.shape, (1, 2))
-
# TODO(b/165912857): Re-enable.
@combinations.generate(
combinations.combine(
diff --git a/tensorflow/python/keras/distribute/keras_models_test.py b/tensorflow/python/keras/distribute/keras_models_test.py
new file mode 100644
index 0000000..da58c04
--- /dev/null
+++ b/tensorflow/python/keras/distribute/keras_models_test.py
@@ -0,0 +1,60 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras high level APIs, e.g. fit, evaluate and predict."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from tensorflow.python.eager import test
+
+
+class KerasModelsTest(test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=strategy_combinations.all_strategies, mode=["eager"]))
+ def test_lstm_model_with_dynamic_batch(self, distribution):
+ input_data = np.random.random([1, 32, 64, 64, 3])
+ input_shape = tuple(input_data.shape[1:])
+
+ def build_model():
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.ConvLSTM2D(
+ 4,
+ kernel_size=(4, 4),
+ activation="sigmoid",
+ padding="same",
+ input_shape=input_shape))
+ model.add(keras.layers.GlobalMaxPooling2D())
+ model.add(keras.layers.Dense(2, activation="sigmoid"))
+ return model
+
+ with distribution.scope():
+ model = build_model()
+ model.compile(loss="binary_crossentropy", optimizer="adam")
+ result = model.predict(input_data)
+ self.assertEqual(result.shape, (1, 2))
+
+
+if __name__ == "__main__":
+ test.main()