PSv2: Make multi-GPU tests compatible with guitar tests where GPU count is > 2.
PiperOrigin-RevId: 363710173
Change-Id: Ida8e3118f5cd097fe07bed375b2fc4e8bc7804ea
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py b/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py
index deeec98..5f6ce2e 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py
@@ -914,7 +914,7 @@
# Invoking `run` without `coordinator.schedule` again should error.
self.strategy.run(replica_fn)
- all_results = [(2, 0), (2, 0)]
+ all_results = [(2, 0)] * self.strategy.num_replicas_in_sync
expected_result = []
for i in range(self.strategy.num_replicas_in_sync):
expected_result.append(all_results[i])
@@ -1067,7 +1067,7 @@
def testAsyncScheduleWithDistributedDataset(self):
def input_fn():
- dataset = dataset_ops.DatasetV2.from_tensor_slices([2.] * 24).batch(
+ dataset = dataset_ops.DatasetV2.from_tensor_slices([2.]).repeat().batch(
self.strategy.num_replicas_in_sync)
return self.strategy.experimental_distribute_dataset(dataset)
@@ -1122,7 +1122,8 @@
self._map_fn_tracing_count += 1
return x + 10
- dataset = dataset_ops.DatasetV2.range(0, 10).batch(2).map(map_fn)
+ dataset = dataset_ops.DatasetV2.range(0, 10).batch(
+ self.strategy.num_replicas_in_sync).map(map_fn)
return self.strategy.experimental_distribute_dataset(dataset)
@def_function.function
@@ -1134,7 +1135,7 @@
worker_fn, args=(iter(distributed_dataset),))
expected_result = array_ops.split(
- math_ops.range(10., 12.),
+ math_ops.range(10., 10. + self.strategy.num_replicas_in_sync),
num_or_size_splits=self.strategy.num_replicas_in_sync,
axis=0)