Add timeout to CollectiveReduceV2
Collective{Reduce,Gather,GatherV2} all have this argument.
PiperOrigin-RevId: 334432651
Change-Id: I578415a30098cc8db8d3f04871427085e50dc710
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 18976bd..a3db45d 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -456,6 +456,9 @@
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_->instance.impl_details.communication_hint));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("timeout_seconds",
+ &col_params_->instance.impl_details.timeout_seconds));
// Prepare OpKernels for reduction and final operations.
// The merge_op takes two inputs
NodeDef sub_node;
@@ -510,7 +513,8 @@
col_params->instance.data_type = col_params_->instance.data_type;
col_params->instance.impl_details.communication_hint =
col_params_->instance.impl_details.communication_hint;
- col_params->instance.impl_details.timeout_seconds = 0;
+ col_params->instance.impl_details.timeout_seconds =
+ col_params_->instance.impl_details.timeout_seconds;
col_params->instance.impl_details.subdiv_offsets =
col_params_->instance.impl_details.subdiv_offsets;
col_params->merge_op = std::move(col_params_->merge_op);
diff --git a/tensorflow/core/ops/collective_ops.cc b/tensorflow/core/ops/collective_ops.cc
index fc9010a..ecaab00 100644
--- a/tensorflow/core/ops/collective_ops.cc
+++ b/tensorflow/core/ops/collective_ops.cc
@@ -114,6 +114,7 @@
.Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}")
.Attr("final_op: {'Id', 'Div'}")
.Attr("communication_hint: string = 'auto'")
+ .Attr("timeout_seconds: float = 0")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape);
diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py
index 583f114..0e3e161 100644
--- a/tensorflow/python/kernel_tests/collective_ops_test.py
+++ b/tensorflow/python/kernel_tests/collective_ops_test.py
@@ -345,6 +345,139 @@
def_function.function(collective_fn)()
+@combinations.generate(
+ combinations.combine(
+ collective_op=[
+ combinations.NamedObject('all_reduce', _collective_ops.all_reduce),
+ combinations.NamedObject('all_reduce_v2',
+ _collective_ops.all_reduce_v2),
+ combinations.NamedObject('all_gather', _collective_ops.all_gather),
+ combinations.NamedObject('all_gather_v2',
+ _collective_ops.all_gather_v2),
+ ],
+ mode='eager',
+ communication=['ring']))
+class TimeoutTest(test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ _setup_context()
+ super().setUp()
+
+ def testTimeout(self, collective_op, communication):
+ timeout = 4.5
+
+ @def_function.function
+ def run(group_size, reported_group_size=None):
+ group_key = 20
+ instance_key = 30
+ tensor = [1, 2, 3, 4]
+ results = []
+ if reported_group_size is None:
+ reported_group_size = group_size
+ for i in range(group_size):
+ with ops.device('/CPU:{}'.format(i)):
+ input_data = constant_op.constant(tensor)
+ result = collective_op(
+ input_data,
+ group_size=reported_group_size,
+ group_key=group_key,
+ instance_key=instance_key,
+ communication_hint=communication,
+ timeout=timeout)
+ results.append(result)
+ return results
+
+ run(2, 2)
+
+ start_time = time.time()
+ with self.assertRaisesRegex(errors.DeadlineExceededError,
+ 'Collective has timed out during execution'):
+ run(1, 2)
+ elapsed = time.time() - start_time
+ self.assertAllGreaterEqual(elapsed, timeout)
+
+ def testParamResolutionAfterTimeoutV2(self, collective_op, communication):
+ timeout = 1.5
+
+ group_key = 20
+ instance_key = 30
+ input_data = constant_op.constant([1, 2, 3, 4])
+
+ # This timeout comes from param solution.
+ with self.assertRaisesRegex(
+ errors.DeadlineExceededError,
+ 'Collective has timed out waiting for other workers'):
+ with ops.device('CPU:0'):
+ collective_op(
+ input_data,
+ group_size=2,
+ group_key=group_key,
+ instance_key=instance_key,
+ communication_hint=communication,
+ timeout=timeout)
+
+ # We launch the second device after the first device times out. This is to
+ # simulate the situation when other workers are slow and the timeout is
+ # short. Since the CPU:0 times out in the param resolution phase, CPU:1
+ # should times out as well, but in the execute phase.
+ with self.assertRaisesRegex(errors.DeadlineExceededError,
+ 'Collective has timed out during execution'):
+ with ops.device('CPU:1'):
+ collective_op(
+ input_data,
+ group_size=2,
+ group_key=group_key,
+ instance_key=instance_key,
+ communication_hint=communication,
+ timeout=timeout)
+
+ def testExecutionAfterTimeoutV2(self, collective_op, communication):
+ timeout = 1.5
+ group_key = 20
+ instance_key = 30
+ input_data = constant_op.constant([1, 2, 3, 4])
+
+ @def_function.function
+ def run():
+ for device in ['CPU:0', 'CPU:1']:
+ with ops.device(device):
+ collective_op(
+ input_data,
+ group_size=2,
+ group_key=group_key,
+ instance_key=instance_key,
+ communication_hint=communication,
+ timeout=timeout)
+
+ # Run a normal all-reduce to complete param resolution.
+ run()
+
+ with self.assertRaisesRegex(errors.DeadlineExceededError,
+ 'Collective has timed out during execution'):
+ with ops.device('CPU:0'):
+ collective_op(
+ input_data,
+ group_size=2,
+ group_key=group_key,
+ instance_key=instance_key,
+ communication_hint=communication,
+ timeout=timeout)
+
+ # We launch the second device after the first device times out. This is to
+ # simulate the situation when other workers are slow and the timeout is
+ # short. It should error immediately.
+ with self.assertRaisesRegex(errors.DeadlineExceededError,
+ 'Collective has timed out during execution'):
+ with ops.device('CPU:1'):
+ # No timeout.
+ collective_op(
+ input_data,
+ group_size=2,
+ group_key=group_key,
+ instance_key=instance_key,
+ communication_hint=communication)
+
+
def _setup_context():
context._reset_context()
cpus = config.list_physical_devices('CPU')
diff --git a/tensorflow/python/ops/collective_ops.py b/tensorflow/python/ops/collective_ops.py
index 5786915..6afe923 100644
--- a/tensorflow/python/ops/collective_ops.py
+++ b/tensorflow/python/ops/collective_ops.py
@@ -77,7 +77,8 @@
instance_key,
merge_op='Add',
final_op='Id',
- communication_hint='auto'):
+ communication_hint='auto',
+ timeout=0):
"""Reduces tensors collectively, across devices.
Args:
@@ -94,6 +95,9 @@
communication_hint: preferred collective communication. The implementation
may fall back to another mechanism. Options include `auto`, `ring`, and
`nccl`.
+ timeout: a float. If set to a non zero, set a completion timeout to detect
+ staleness. If the timer goes off, a DeadlineExceededError is raised. The
+ timeout value in seconds. This feature is experimental.
Returns:
An Op implementing the distributed reduction.
@@ -105,7 +109,8 @@
instance_key=instance_key,
merge_op=merge_op,
final_op=final_op,
- communication_hint=communication_hint.lower())
+ communication_hint=communication_hint.lower(),
+ timeout_seconds=timeout)
def all_gather(t,
diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py
index d3d110f..d3988d0 100644
--- a/tensorflow/python/ops/collective_ops_test.py
+++ b/tensorflow/python/ops/collective_ops_test.py
@@ -166,151 +166,6 @@
elapsed = time.time() - start_time
self.assertAllGreaterEqual(elapsed, timeout)
- @test_util.run_v2_only
- def testCollectiveTimeoutV2(self):
- timeout = 4.5
- cpus = config.list_physical_devices('CPU')
- self.assertEqual(len(cpus), 1)
- config.set_logical_device_configuration(cpus[0], [
- context.LogicalDeviceConfiguration(),
- context.LogicalDeviceConfiguration()
- ])
- context.ensure_initialized()
-
- @def_function.function
- def run_all_reduce(group_size, reported_group_size=None):
- group_key = 20
- instance_key = 30
- tensor = [1, 2, 3, 4]
- results = []
- if reported_group_size is None:
- reported_group_size = group_size
- for i in range(group_size):
- with ops.device('/CPU:{}'.format(i)):
- input_data = constant_op.constant(tensor)
- collective_op = collective_ops.all_reduce(
- input_data,
- group_size=reported_group_size,
- group_key=group_key,
- instance_key=instance_key,
- merge_op='Add',
- final_op='Id',
- timeout=timeout)
- results.append(collective_op)
- return results
-
- run_all_reduce(2, 2)
-
- start_time = time.time()
- with self.assertRaisesRegex(errors.DeadlineExceededError,
- 'Collective has timed out during execution'):
- run_all_reduce(1, 2)
- elapsed = time.time() - start_time
- self.assertAllGreaterEqual(elapsed, timeout)
-
- @test_util.run_v2_only
- def testParamResolutionAfterTimeoutV2(self):
- timeout = 1.5
- cpus = config.list_physical_devices('CPU')
- self.assertEqual(len(cpus), 1)
- config.set_logical_device_configuration(cpus[0], [
- context.LogicalDeviceConfiguration(),
- context.LogicalDeviceConfiguration()
- ])
- context.ensure_initialized()
-
- group_key = 20
- instance_key = 30
- input_data = constant_op.constant([1, 2, 3, 4])
-
- # This timeout comes from param solution.
- with self.assertRaisesRegex(
- errors.DeadlineExceededError,
- 'Collective has timed out waiting for other workers'):
- with ops.device('CPU:0'):
- collective_ops.all_reduce(
- input_data,
- group_size=2,
- group_key=group_key,
- instance_key=instance_key,
- merge_op='Add',
- final_op='Id',
- timeout=timeout)
-
- # We launch the second device after the first device times out. This is to
- # simulate the situation when other workers are slow and the timeout is
- # short. Since the CPU:0 times out in the param resolution phase, CPU:1
- # should times out as well, but in the execute phase.
- with self.assertRaisesRegex(errors.DeadlineExceededError,
- 'Collective has timed out during execution'):
- with ops.device('CPU:1'):
- collective_ops.all_reduce(
- input_data,
- group_size=2,
- group_key=group_key,
- instance_key=instance_key,
- merge_op='Add',
- final_op='Id',
- timeout=timeout)
-
- @test_util.run_v2_only
- def testExecutionAfterTimeoutV2(self):
- timeout = 1.5
- cpus = config.list_physical_devices('CPU')
- self.assertEqual(len(cpus), 1)
- config.set_logical_device_configuration(cpus[0], [
- context.LogicalDeviceConfiguration(),
- context.LogicalDeviceConfiguration()
- ])
- context.ensure_initialized()
-
- group_key = 20
- instance_key = 30
- input_data = constant_op.constant([1, 2, 3, 4])
-
- @def_function.function
- def run_all_reduce():
- for device in ['CPU:0', 'CPU:1']:
- with ops.device(device):
- collective_ops.all_reduce(
- input_data,
- group_size=2,
- group_key=group_key,
- instance_key=instance_key,
- merge_op='Add',
- final_op='Id',
- timeout=timeout)
-
- # Run a normal all-reduce to complete param resolution.
- run_all_reduce()
-
- with self.assertRaisesRegex(errors.DeadlineExceededError,
- 'Collective has timed out during execution'):
- with ops.device('CPU:0'):
- collective_ops.all_reduce(
- input_data,
- group_size=2,
- group_key=group_key,
- instance_key=instance_key,
- merge_op='Add',
- final_op='Id',
- timeout=timeout)
-
- # We launch the second device after the first device times out. This is to
- # simulate the situation when other workers are slow and the timeout is
- # short. It should error immediately.
- with self.assertRaisesRegex(errors.DeadlineExceededError,
- 'Collective has timed out during execution'):
- with ops.device('CPU:1'):
- # No timeout.
- collective_ops.all_reduce(
- input_data,
- group_size=2,
- group_key=group_key,
- merge_op='Add',
- final_op='Id',
- instance_key=instance_key)
-
def testNcclHintFallbackToRingReduce(self):
"""Tests that setting `communication_hint=nccl` works on non-GPU builds."""
if kernels.get_registered_kernels_for_op('NcclAllReduce'):
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 2a2c310..938aa69 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -774,7 +774,7 @@
}
member_method {
name: "CollectiveReduceV2"
- argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
+ argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CombinedNonMaxSuppression"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 2a2c310..938aa69 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -774,7 +774,7 @@
}
member_method {
name: "CollectiveReduceV2"
- argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
+ argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CombinedNonMaxSuppression"