Add MWMS test to ctl_correctness_test. Fixed an issue with reducing int values on GPUs.
PiperOrigin-RevId: 322673156
Change-Id: I60a85c0689d86ce74e233d3d8f8103b6817b7d28
diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index 708d5eb..23792f6 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -527,8 +527,8 @@
# TODO(b/131423105): we should be able to short-cut the all-reduce in some
# cases.
if getattr(strategy.extended, "_support_per_replica_values", True):
- # Slight hack: `reduce` expects a `PerReplica`, so we pass it one, even
- # though it doesn't actually have a value per replica.
+ # `reduce` expects a `PerReplica`, so we pass it one, even
+ # though it doesn't actually have a value per replica
worker_has_values = values.PerReplica(worker_has_values)
global_has_value = strategy.reduce(
reduce_util.ReduceOp.SUM, worker_has_values, axis=None)
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index 83c2556..bdd4cbc 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -198,7 +198,7 @@
name = "ctl_correctness_test",
srcs = ["ctl_correctness_test.py"],
main = "ctl_correctness_test.py",
- shard_count = 5,
+ shard_count = 10,
tags = [
"multi_and_single_gpu",
],
diff --git a/tensorflow/python/keras/distribute/ctl_correctness_test.py b/tensorflow/python/keras/distribute/ctl_correctness_test.py
index eade27e..a55f80e 100644
--- a/tensorflow/python/keras/distribute/ctl_correctness_test.py
+++ b/tensorflow/python/keras/distribute/ctl_correctness_test.py
@@ -33,6 +33,7 @@
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import test_util
from tensorflow.python.keras.distribute import optimizer_combinations
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@@ -230,6 +231,14 @@
mode=['eager'],
iteration_type=['iterator', 'dataset'],
inside_func=[False, True],
+ sync_batchnorm=[True, False]) +
+ combinations.combine(
+ distribution=strategy_combinations.multiworker_strategies,
+ optimizer_fn=
+ optimizer_combinations.gradient_descent_optimizer_keras_v2_fn,
+ mode=['eager'],
+ iteration_type=['iterator', 'dataset'],
+ inside_func=[False, True],
sync_batchnorm=[True, False]
))
def test_dnn_correctness_minus_tpus(self, distribution, optimizer_fn,
@@ -238,6 +247,14 @@
# TODO(anjs): Identify why this particular V1 optimizer needs a higher tol.
if 'FtrlV1' in optimizer_fn._name and 'TPU' in type(distribution).__name__:
self.skipTest('Reduced tolerance of the order of 1e-1 required.')
+ if ('CollectiveAllReduce' in type(distribution).__name__ and
+ test_util.is_xla_enabled()):
+ self.skipTest('XLA tests fail with MWMS.')
+ # Unable to use required_gpus to check if this is a multiGPU combination
+ # since required_gpus and NamedDistribution cannot be used together.
+ if ('CollectiveAllReduce' in type(distribution).__name__
+ and not inside_func and iteration_type == 'dataset'):
+ self.skipTest('MWMS tests fail with multiple GPUs.')
self.dnn_correctness(distribution, optimizer_fn, iteration_type,
inside_func, sync_batchnorm)
@@ -263,4 +280,4 @@
if __name__ == '__main__':
- test.main()
+ combinations.main()