Add MultiWorkerMirroredStratgy to custom training loop tests.

PiperOrigin-RevId: 328002300
Change-Id: I5713bc15bb0d7a8647b1097fe81570ace30cb1c5
diff --git a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
index a327f87..b6b9239 100644
--- a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
+++ b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
@@ -52,15 +52,17 @@
     return x
 
 
+@combinations.generate(
+    combinations.combine(
+        distribution=(strategy_combinations.all_strategies +
+                      strategy_combinations.multiworker_strategies),
+        mode=["eager"]
+        )
+    )
 class KerasModelsTest(test.TestCase, parameterized.TestCase):
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
-  def test_single_keras_layer_experimental_run(self, distribution):
-    dataset = self._get_dataset()
+  def test_single_keras_layer_run(self, distribution):
+    dataset = _get_dataset()
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
@@ -72,7 +74,7 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         return grads
 
@@ -83,72 +85,33 @@
 
     train_step(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
-  def test_keras_model_creation_experimental_run(self, distribution):
-    dataset = self._get_dataset()
+  def test_keras_model_optimizer_run(self, distribution):
+    dataset = _get_dataset()
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
-      model = self._get_model()
-
-    @def_function.function
-    def train_step(iterator):
-      def step_fn(inputs):
-        images, targets = inputs
-        with backprop.GradientTape() as tape:
-          outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
-        grads = tape.gradient(loss, model.variables)
-        return grads
-
-      outputs = distribution.run(
-          step_fn, args=(next(iterator),))
-      return nest.map_structure(distribution.experimental_local_results,
-                                outputs)
-
-    train_step(input_iterator)
-
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
-  def test_keras_model_optimizer_experimental_run(self, distribution):
-    dataset = self._get_dataset()
-    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
-
-    with distribution.scope():
-      model = self._get_model()
+      model = _get_model()
       optimizer = keras.optimizer_v2.rmsprop.RMSprop()
 
     @def_function.function
-    def train_step(iterator):
+    def train_step(replicated_inputs):
       def step_fn(inputs):
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         optimizer.apply_gradients(zip(grads, model.variables))
         return loss
 
-      outputs = distribution.run(
-          step_fn, args=(next(iterator),))
+      outputs = distribution.run(step_fn, args=(replicated_inputs,))
       return nest.map_structure(distribution.experimental_local_results,
                                 outputs)
 
-    train_step(input_iterator)
+    for x in input_iterator:
+      train_step(x)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
-  def test_keras_subclass_model_optimizer_experimental_run(self, distribution):
+  def test_keras_subclass_model_optimizer_run(self, distribution):
     def get_subclass_model():
 
       class KerasSubclassModel(keras.Model):
@@ -161,7 +124,7 @@
           return self.l(x)
 
       return KerasSubclassModel()
-    dataset = self._get_dataset()
+    dataset = _get_dataset()
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
@@ -174,29 +137,23 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         optimizer.apply_gradients(zip(grads, model.variables))
         return loss
 
-      outputs = distribution.run(
-          step_fn, args=(next(iterator),))
+      outputs = distribution.run(step_fn, args=(next(iterator),))
       return nest.map_structure(distribution.experimental_local_results,
                                 outputs)
 
     train_step(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
-  def test_keras_model_optimizer_experimental_run_loop(self, distribution):
-    dataset = self._get_dataset()
+  def test_keras_model_optimizer_run_loop(self, distribution):
+    dataset = _get_dataset()
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
-      model = self._get_model()
+      model = _get_model()
       optimizer = keras.optimizer_v2.rmsprop.RMSprop()
 
     @def_function.function
@@ -205,27 +162,22 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         optimizer.apply_gradients(zip(grads, model.variables))
         return loss
 
-      for _ in range(5):
+      for _ in math_ops.range(4):
         distribution.run(step_fn, args=(next(iterator),))
 
     train_step(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
   def test_batch_norm_with_dynamic_batch(self, distribution):
     inputs = np.zeros((10, 3, 3, 3), dtype=np.float32)
     targets = np.zeros((10, 4), dtype=np.float32)
     dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
     dataset = dataset.repeat()
-    dataset = dataset.batch(10, drop_remainder=False)
+    dataset = dataset.batch(10)
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
@@ -242,7 +194,7 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images, training=True)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         optimizer.apply_gradients(zip(grads, model.variables))
         return loss
@@ -305,9 +257,6 @@
 
     train_step(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies, mode=["eager"]))
   def test_nested_tf_functions(self, distribution):
     # The test builds two computations with keras layers, one with nested
     # tf.function, and the other without nested tf.function. We run these
@@ -317,7 +266,7 @@
     inputs = np.random.random((10, 3)).astype(np.float32)
     targets = np.ones((10, 4), dtype=np.float32)
     dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat()
-    dataset = dataset.batch(10, drop_remainder=True)
+    dataset = dataset.batch(10)
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     def get_model():
@@ -340,7 +289,7 @@
 
     def compute_loss(images, targets):
       outputs = model(images)
-      return math_ops.reduce_sum(outputs - targets)
+      return keras.losses.mean_squared_error(targets, outputs)
 
     @def_function.function
     def train_step_without_nested_tf_function(inputs):
@@ -357,7 +306,7 @@
     @def_function.function
     def compute_loss2(images, targets):
       outputs = model2(images)
-      return math_ops.reduce_sum(outputs - targets)
+      return keras.losses.mean_squared_error(targets, outputs)
 
     @def_function.function
     def train_step_with_nested_tf_function(inputs):
@@ -380,14 +329,11 @@
     for model_v, model2_v in zip(model.variables, model2.variables):
       self.assertAllClose(model_v.numpy(), model2_v.numpy())
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies, mode=["eager"]))
   def test_nested_tf_functions_with_control_flow(self, distribution):
     inputs = np.random.random((10, 3)).astype(np.float32)
     targets = np.ones((10, 4), dtype=np.float32)
     dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat()
-    dataset = dataset.batch(10, drop_remainder=True)
+    dataset = dataset.batch(10)
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     def get_model():
@@ -407,7 +353,7 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         optimizer.apply_gradients(zip(grads, model.variables))
 
@@ -420,13 +366,8 @@
 
     train_steps(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
-  def test_customized_tf_module_experimental_run(self, distribution):
-    dataset = self._get_dataset()
+  def test_customized_tf_module_run(self, distribution):
+    dataset = _get_dataset()
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
@@ -439,7 +380,7 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         return grads
 
@@ -450,14 +391,11 @@
 
     train_step(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies, mode=["eager"]))
   def test_reduce_loss(self, distribution):
     inputs = np.zeros((10, 4), dtype=np.float32)
     targets = np.zeros((10, 1), dtype=np.float32)
     dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
-    dataset = dataset.batch(10, drop_remainder=False)
+    dataset = dataset.batch(10)
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
@@ -479,11 +417,14 @@
     loss = train_step(input_iterator)
     loss = distribution.reduce(reduce_util.ReduceOp.MEAN, loss, axis=0)
 
+
+class KerasModelsXLATest(test.TestCase, parameterized.TestCase):
+
   @combinations.generate(
       combinations.combine(
           distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
   def test_tf_function_experimental_compile(self, distribution):
-    dataset = self._get_dataset()
+    dataset = _get_dataset()
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     class CustomDense(keras.layers.Layer):
@@ -511,7 +452,7 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         return grads
 
@@ -522,20 +463,21 @@
 
     train_step(input_iterator)
 
-  def _get_dataset(self):
-    inputs = np.zeros((10, 3), dtype=np.float32)
-    targets = np.zeros((10, 4), dtype=np.float32)
-    dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
-    dataset = dataset.repeat(100)
-    dataset = dataset.batch(10, drop_remainder=True)
-    return dataset
 
-  def _get_model(self):
-    x = keras.layers.Input(shape=(3,), name="input")
-    y = keras.layers.Dense(4, name="dense")(x)
-    model = keras.Model(x, y)
-    return model
+def _get_dataset():
+  inputs = np.zeros((31, 3), dtype=np.float32)
+  targets = np.zeros((31, 4), dtype=np.float32)
+  dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+  dataset = dataset.batch(10)
+  return dataset
+
+
+def _get_model():
+  x = keras.layers.Input(shape=(3,), name="input")
+  y = keras.layers.Dense(4, name="dense")(x)
+  model = keras.Model(x, y)
+  return model
 
 
 if __name__ == "__main__":
-  test.main()
+  combinations.main()