Add timeout to CollectiveReduceV2

Collective{Reduce,Gather,GatherV2} all have this argument.

PiperOrigin-RevId: 334432651
Change-Id: I578415a30098cc8db8d3f04871427085e50dc710
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 18976bd..a3db45d 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -456,6 +456,9 @@
     OP_REQUIRES_OK(
         c, c->GetAttr("communication_hint",
                       &col_params_->instance.impl_details.communication_hint));
+    OP_REQUIRES_OK(
+        c, c->GetAttr("timeout_seconds",
+                      &col_params_->instance.impl_details.timeout_seconds));
     // Prepare OpKernels for reduction and final operations.
     // The merge_op takes two inputs
     NodeDef sub_node;
@@ -510,7 +513,8 @@
     col_params->instance.data_type = col_params_->instance.data_type;
     col_params->instance.impl_details.communication_hint =
         col_params_->instance.impl_details.communication_hint;
-    col_params->instance.impl_details.timeout_seconds = 0;
+    col_params->instance.impl_details.timeout_seconds =
+        col_params_->instance.impl_details.timeout_seconds;
     col_params->instance.impl_details.subdiv_offsets =
         col_params_->instance.impl_details.subdiv_offsets;
     col_params->merge_op = std::move(col_params_->merge_op);
diff --git a/tensorflow/core/ops/collective_ops.cc b/tensorflow/core/ops/collective_ops.cc
index fc9010a..ecaab00 100644
--- a/tensorflow/core/ops/collective_ops.cc
+++ b/tensorflow/core/ops/collective_ops.cc
@@ -114,6 +114,7 @@
     .Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}")
     .Attr("final_op: {'Id', 'Div'}")
     .Attr("communication_hint: string = 'auto'")
+    .Attr("timeout_seconds: float = 0")
     .SetIsStateful()
     .SetShapeFn(shape_inference::UnchangedShape);
 
diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py
index 583f114..0e3e161 100644
--- a/tensorflow/python/kernel_tests/collective_ops_test.py
+++ b/tensorflow/python/kernel_tests/collective_ops_test.py
@@ -345,6 +345,139 @@
     def_function.function(collective_fn)()
 
 
+@combinations.generate(
+    combinations.combine(
+        collective_op=[
+            combinations.NamedObject('all_reduce', _collective_ops.all_reduce),
+            combinations.NamedObject('all_reduce_v2',
+                                     _collective_ops.all_reduce_v2),
+            combinations.NamedObject('all_gather', _collective_ops.all_gather),
+            combinations.NamedObject('all_gather_v2',
+                                     _collective_ops.all_gather_v2),
+        ],
+        mode='eager',
+        communication=['ring']))
+class TimeoutTest(test.TestCase, parameterized.TestCase):
+
+  def setUp(self):
+    _setup_context()
+    super().setUp()
+
+  def testTimeout(self, collective_op, communication):
+    timeout = 4.5
+
+    @def_function.function
+    def run(group_size, reported_group_size=None):
+      group_key = 20
+      instance_key = 30
+      tensor = [1, 2, 3, 4]
+      results = []
+      if reported_group_size is None:
+        reported_group_size = group_size
+      for i in range(group_size):
+        with ops.device('/CPU:{}'.format(i)):
+          input_data = constant_op.constant(tensor)
+          result = collective_op(
+              input_data,
+              group_size=reported_group_size,
+              group_key=group_key,
+              instance_key=instance_key,
+              communication_hint=communication,
+              timeout=timeout)
+          results.append(result)
+      return results
+
+    run(2, 2)
+
+    start_time = time.time()
+    with self.assertRaisesRegex(errors.DeadlineExceededError,
+                                'Collective has timed out during execution'):
+      run(1, 2)
+    elapsed = time.time() - start_time
+    self.assertAllGreaterEqual(elapsed, timeout)
+
+  def testParamResolutionAfterTimeoutV2(self, collective_op, communication):
+    timeout = 1.5
+
+    group_key = 20
+    instance_key = 30
+    input_data = constant_op.constant([1, 2, 3, 4])
+
+    # This timeout comes from param solution.
+    with self.assertRaisesRegex(
+        errors.DeadlineExceededError,
+        'Collective has timed out waiting for other workers'):
+      with ops.device('CPU:0'):
+        collective_op(
+            input_data,
+            group_size=2,
+            group_key=group_key,
+            instance_key=instance_key,
+            communication_hint=communication,
+            timeout=timeout)
+
+    # We launch the second device after the first device times out. This is to
+    # simulate the situation when other workers are slow and the timeout is
+    # short. Since the CPU:0 times out in the param resolution phase, CPU:1
+    # should times out as well, but in the execute phase.
+    with self.assertRaisesRegex(errors.DeadlineExceededError,
+                                'Collective has timed out during execution'):
+      with ops.device('CPU:1'):
+        collective_op(
+            input_data,
+            group_size=2,
+            group_key=group_key,
+            instance_key=instance_key,
+            communication_hint=communication,
+            timeout=timeout)
+
+  def testExecutionAfterTimeoutV2(self, collective_op, communication):
+    timeout = 1.5
+    group_key = 20
+    instance_key = 30
+    input_data = constant_op.constant([1, 2, 3, 4])
+
+    @def_function.function
+    def run():
+      for device in ['CPU:0', 'CPU:1']:
+        with ops.device(device):
+          collective_op(
+              input_data,
+              group_size=2,
+              group_key=group_key,
+              instance_key=instance_key,
+              communication_hint=communication,
+              timeout=timeout)
+
+    # Run a normal all-reduce to complete param resolution.
+    run()
+
+    with self.assertRaisesRegex(errors.DeadlineExceededError,
+                                'Collective has timed out during execution'):
+      with ops.device('CPU:0'):
+        collective_op(
+            input_data,
+            group_size=2,
+            group_key=group_key,
+            instance_key=instance_key,
+            communication_hint=communication,
+            timeout=timeout)
+
+    # We launch the second device after the first device times out. This is to
+    # simulate the situation when other workers are slow and the timeout is
+    # short. It should error immediately.
+    with self.assertRaisesRegex(errors.DeadlineExceededError,
+                                'Collective has timed out during execution'):
+      with ops.device('CPU:1'):
+        # No timeout.
+        collective_op(
+            input_data,
+            group_size=2,
+            group_key=group_key,
+            instance_key=instance_key,
+            communication_hint=communication)
+
+
 def _setup_context():
   context._reset_context()
   cpus = config.list_physical_devices('CPU')
diff --git a/tensorflow/python/ops/collective_ops.py b/tensorflow/python/ops/collective_ops.py
index 5786915..6afe923 100644
--- a/tensorflow/python/ops/collective_ops.py
+++ b/tensorflow/python/ops/collective_ops.py
@@ -77,7 +77,8 @@
                   instance_key,
                   merge_op='Add',
                   final_op='Id',
-                  communication_hint='auto'):
+                  communication_hint='auto',
+                  timeout=0):
   """Reduces tensors collectively, across devices.
 
   Args:
@@ -94,6 +95,9 @@
     communication_hint: preferred collective communication.  The implementation
       may fall back to another mechanism.  Options include `auto`, `ring`, and
       `nccl`.
+    timeout: a float. If set to a non zero, set a completion timeout to detect
+      staleness.  If the timer goes off, a DeadlineExceededError is raised.  The
+      timeout value in seconds. This feature is experimental.
 
   Returns:
     An Op implementing the distributed reduction.
@@ -105,7 +109,8 @@
       instance_key=instance_key,
       merge_op=merge_op,
       final_op=final_op,
-      communication_hint=communication_hint.lower())
+      communication_hint=communication_hint.lower(),
+      timeout_seconds=timeout)
 
 
 def all_gather(t,
diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py
index d3d110f..d3988d0 100644
--- a/tensorflow/python/ops/collective_ops_test.py
+++ b/tensorflow/python/ops/collective_ops_test.py
@@ -166,151 +166,6 @@
     elapsed = time.time() - start_time
     self.assertAllGreaterEqual(elapsed, timeout)
 
-  @test_util.run_v2_only
-  def testCollectiveTimeoutV2(self):
-    timeout = 4.5
-    cpus = config.list_physical_devices('CPU')
-    self.assertEqual(len(cpus), 1)
-    config.set_logical_device_configuration(cpus[0], [
-        context.LogicalDeviceConfiguration(),
-        context.LogicalDeviceConfiguration()
-    ])
-    context.ensure_initialized()
-
-    @def_function.function
-    def run_all_reduce(group_size, reported_group_size=None):
-      group_key = 20
-      instance_key = 30
-      tensor = [1, 2, 3, 4]
-      results = []
-      if reported_group_size is None:
-        reported_group_size = group_size
-      for i in range(group_size):
-        with ops.device('/CPU:{}'.format(i)):
-          input_data = constant_op.constant(tensor)
-          collective_op = collective_ops.all_reduce(
-              input_data,
-              group_size=reported_group_size,
-              group_key=group_key,
-              instance_key=instance_key,
-              merge_op='Add',
-              final_op='Id',
-              timeout=timeout)
-          results.append(collective_op)
-      return results
-
-    run_all_reduce(2, 2)
-
-    start_time = time.time()
-    with self.assertRaisesRegex(errors.DeadlineExceededError,
-                                'Collective has timed out during execution'):
-      run_all_reduce(1, 2)
-    elapsed = time.time() - start_time
-    self.assertAllGreaterEqual(elapsed, timeout)
-
-  @test_util.run_v2_only
-  def testParamResolutionAfterTimeoutV2(self):
-    timeout = 1.5
-    cpus = config.list_physical_devices('CPU')
-    self.assertEqual(len(cpus), 1)
-    config.set_logical_device_configuration(cpus[0], [
-        context.LogicalDeviceConfiguration(),
-        context.LogicalDeviceConfiguration()
-    ])
-    context.ensure_initialized()
-
-    group_key = 20
-    instance_key = 30
-    input_data = constant_op.constant([1, 2, 3, 4])
-
-    # This timeout comes from param solution.
-    with self.assertRaisesRegex(
-        errors.DeadlineExceededError,
-        'Collective has timed out waiting for other workers'):
-      with ops.device('CPU:0'):
-        collective_ops.all_reduce(
-            input_data,
-            group_size=2,
-            group_key=group_key,
-            instance_key=instance_key,
-            merge_op='Add',
-            final_op='Id',
-            timeout=timeout)
-
-    # We launch the second device after the first device times out. This is to
-    # simulate the situation when other workers are slow and the timeout is
-    # short. Since the CPU:0 times out in the param resolution phase, CPU:1
-    # should times out as well, but in the execute phase.
-    with self.assertRaisesRegex(errors.DeadlineExceededError,
-                                'Collective has timed out during execution'):
-      with ops.device('CPU:1'):
-        collective_ops.all_reduce(
-            input_data,
-            group_size=2,
-            group_key=group_key,
-            instance_key=instance_key,
-            merge_op='Add',
-            final_op='Id',
-            timeout=timeout)
-
-  @test_util.run_v2_only
-  def testExecutionAfterTimeoutV2(self):
-    timeout = 1.5
-    cpus = config.list_physical_devices('CPU')
-    self.assertEqual(len(cpus), 1)
-    config.set_logical_device_configuration(cpus[0], [
-        context.LogicalDeviceConfiguration(),
-        context.LogicalDeviceConfiguration()
-    ])
-    context.ensure_initialized()
-
-    group_key = 20
-    instance_key = 30
-    input_data = constant_op.constant([1, 2, 3, 4])
-
-    @def_function.function
-    def run_all_reduce():
-      for device in ['CPU:0', 'CPU:1']:
-        with ops.device(device):
-          collective_ops.all_reduce(
-              input_data,
-              group_size=2,
-              group_key=group_key,
-              instance_key=instance_key,
-              merge_op='Add',
-              final_op='Id',
-              timeout=timeout)
-
-    # Run a normal all-reduce to complete param resolution.
-    run_all_reduce()
-
-    with self.assertRaisesRegex(errors.DeadlineExceededError,
-                                'Collective has timed out during execution'):
-      with ops.device('CPU:0'):
-        collective_ops.all_reduce(
-            input_data,
-            group_size=2,
-            group_key=group_key,
-            instance_key=instance_key,
-            merge_op='Add',
-            final_op='Id',
-            timeout=timeout)
-
-    # We launch the second device after the first device times out. This is to
-    # simulate the situation when other workers are slow and the timeout is
-    # short. It should error immediately.
-    with self.assertRaisesRegex(errors.DeadlineExceededError,
-                                'Collective has timed out during execution'):
-      with ops.device('CPU:1'):
-        # No timeout.
-        collective_ops.all_reduce(
-            input_data,
-            group_size=2,
-            group_key=group_key,
-            merge_op='Add',
-            final_op='Id',
-            instance_key=instance_key)
-
   def testNcclHintFallbackToRingReduce(self):
     """Tests that setting `communication_hint=nccl` works on non-GPU builds."""
     if kernels.get_registered_kernels_for_op('NcclAllReduce'):
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 2a2c310..938aa69 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -774,7 +774,7 @@
   }
   member_method {
     name: "CollectiveReduceV2"
-    argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
+    argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
   }
   member_method {
     name: "CombinedNonMaxSuppression"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 2a2c310..938aa69 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -774,7 +774,7 @@
   }
   member_method {
     name: "CollectiveReduceV2"
-    argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
+    argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
   }
   member_method {
     name: "CombinedNonMaxSuppression"