Add MWMS test to ctl_correctness_test. Fixed an issue with reducing int values on GPUs.

PiperOrigin-RevId: 322673156
Change-Id: I60a85c0689d86ce74e233d3d8f8103b6817b7d28
diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index 708d5eb..23792f6 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -527,8 +527,8 @@
   # TODO(b/131423105): we should be able to short-cut the all-reduce in some
   # cases.
   if getattr(strategy.extended, "_support_per_replica_values", True):
-    # Slight hack: `reduce` expects a `PerReplica`, so we pass it one, even
-    # though it doesn't actually have a value per replica.
+    # `reduce` expects a `PerReplica`, so we pass it one, even
+    # though it doesn't actually have a value per replica
     worker_has_values = values.PerReplica(worker_has_values)
     global_has_value = strategy.reduce(
         reduce_util.ReduceOp.SUM, worker_has_values, axis=None)
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index 83c2556..bdd4cbc 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -198,7 +198,7 @@
     name = "ctl_correctness_test",
     srcs = ["ctl_correctness_test.py"],
     main = "ctl_correctness_test.py",
-    shard_count = 5,
+    shard_count = 10,
     tags = [
         "multi_and_single_gpu",
     ],
diff --git a/tensorflow/python/keras/distribute/ctl_correctness_test.py b/tensorflow/python/keras/distribute/ctl_correctness_test.py
index eade27e..a55f80e 100644
--- a/tensorflow/python/keras/distribute/ctl_correctness_test.py
+++ b/tensorflow/python/keras/distribute/ctl_correctness_test.py
@@ -33,6 +33,7 @@
 from tensorflow.python.eager import test
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import test_util
 from tensorflow.python.keras.distribute import optimizer_combinations
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
@@ -230,6 +231,14 @@
           mode=['eager'],
           iteration_type=['iterator', 'dataset'],
           inside_func=[False, True],
+          sync_batchnorm=[True, False]) +
+      combinations.combine(
+          distribution=strategy_combinations.multiworker_strategies,
+          optimizer_fn=
+          optimizer_combinations.gradient_descent_optimizer_keras_v2_fn,
+          mode=['eager'],
+          iteration_type=['iterator', 'dataset'],
+          inside_func=[False, True],
           sync_batchnorm=[True, False]
       ))
   def test_dnn_correctness_minus_tpus(self, distribution, optimizer_fn,
@@ -238,6 +247,14 @@
     # TODO(anjs): Identify why this particular V1 optimizer needs a higher tol.
     if 'FtrlV1' in optimizer_fn._name and 'TPU' in type(distribution).__name__:
       self.skipTest('Reduced tolerance of the order of 1e-1 required.')
+    if ('CollectiveAllReduce' in type(distribution).__name__ and
+        test_util.is_xla_enabled()):
+      self.skipTest('XLA tests fail with MWMS.')
+    # Unable to use required_gpus to check if this is a multiGPU combination
+    # since required_gpus and NamedDistribution cannot be used together.
+    if ('CollectiveAllReduce' in type(distribution).__name__
+        and not inside_func and iteration_type == 'dataset'):
+      self.skipTest('MWMS tests fail with multiple GPUs.')
     self.dnn_correctness(distribution, optimizer_fn, iteration_type,
                          inside_func, sync_batchnorm)
 
@@ -263,4 +280,4 @@
 
 
 if __name__ == '__main__':
-  test.main()
+  combinations.main()