[tf.data service] Improve cancellation for tf.data service requests.

1. If a DataServiceDataset iterator is cancelled, it will now call TryCancel on its outstanding RPCs.
2. As a result, we can reduce the frequency of returning from blocked round-robin requests to check whether the iterator is cancelled. This may avoid delays in GetNext() that could happen if one consumer reads from a round earlier than others, and needs to perform multiple retries with exponential backoff.
3. Because of (2), server shutdown may take up to 1 minute if a round-robin request is blocked waiting for other consumers. To prevent slow unit tests, certain tests store their servers globally so that they are destroyed immediately at process exit without waiting for their outstanding RPCs to finish.

Running data_service_ops_test.py locally, this CL reduces the time from 27 seconds to 20 seconds

PiperOrigin-RevId: 351825888
Change-Id: Iba20a456bdabf251d03b94f090fe760616d3da4d
diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD
index eec9745..0228289 100644
--- a/tensorflow/core/data/service/BUILD
+++ b/tensorflow/core/data/service/BUILD
@@ -90,6 +90,8 @@
         ":grpc_util",
         ":worker_cc_grpc_proto",
         "//tensorflow/core:framework",
+        "//tensorflow/core/platform:errors",
+        "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/types:optional",
         tf_grpc_cc_dependency(),
     ],
diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc
index 78435cb..eb1f335 100644
--- a/tensorflow/core/data/service/data_service.cc
+++ b/tensorflow/core/data/service/data_service.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/core/data/service/grpc_util.h"
 #include "tensorflow/core/data/service/worker.grpc.pb.h"
 #include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/platform/errors.h"
 
 namespace tensorflow {
 namespace data {
@@ -249,6 +250,12 @@
                                            CompressedElement& element,
                                            bool& end_of_sequence) {
   TF_RETURN_IF_ERROR(EnsureInitialized());
+  {
+    mutex_lock l(mu_);
+    if (cancelled_) {
+      return errors::Cancelled("Client was cancelled.");
+    }
+  }
   GetElementRequest req;
   req.set_task_id(task_id);
   if (consumer_index.has_value()) {
@@ -259,7 +266,15 @@
   }
   GetElementResponse resp;
   grpc::ClientContext ctx;
+  {
+    mutex_lock l(mu_);
+    active_contexts_.insert(&ctx);
+  }
   grpc::Status s = stub_->GetElement(&ctx, req, &resp);
+  {
+    mutex_lock l(mu_);
+    active_contexts_.erase(&ctx);
+  }
   if (!s.ok()) {
     return grpc_util::WrapError("Failed to get element", s);
   }
@@ -285,6 +300,14 @@
   return Status::OK();
 }
 
+void DataServiceWorkerClient::TryCancel() {
+  mutex_lock l(mu_);
+  cancelled_ = true;
+  for (const auto& ctx : active_contexts_) {
+    ctx->TryCancel();
+  }
+}
+
 Status CreateDataServiceDispatcherClient(
     const std::string& address, const std::string& protocol,
     std::unique_ptr<DataServiceDispatcherClient>& out) {
diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h
index 3bd135e..1ef926c 100644
--- a/tensorflow/core/data/service/data_service.h
+++ b/tensorflow/core/data/service/data_service.h
@@ -16,6 +16,8 @@
 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
 #define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
 
+#include "grpcpp/impl/codegen/client_context.h"
+#include "absl/container/flat_hash_set.h"
 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
 #include "tensorflow/core/data/service/worker.grpc.pb.h"
 #include "tensorflow/core/framework/dataset.h"
@@ -148,6 +150,10 @@
                     absl::optional<int64> round_index,
                     CompressedElement& element, bool& end_of_sequence);
 
+  // Makes a best effort to cancel all outstanding calls in progress for the
+  // client, and causes further calls to return Cancelled status.
+  void TryCancel();
+
  protected:
   Status EnsureInitialized() override;
 
@@ -156,6 +162,12 @@
   // Initialization is guarded by `mu_`, but using the stub does not require
   // holding `mu_`
   std::unique_ptr<WorkerService::Stub> stub_;
+  // Set of all currently active clients contexts. Used to support
+  // cancellation.
+  absl::flat_hash_set<::grpc::ClientContext*> active_contexts_ GUARDED_BY(mu_);
+  // Indicates that the client has been cancelled, so no further requests should
+  // be accepted.
+  bool cancelled_ GUARDED_BY(mu_) = false;
 };
 
 // Creates and initializes a new tf.data service dispatcher client.
diff --git a/tensorflow/core/data/service/task_runner.cc b/tensorflow/core/data/service/task_runner.cc
index 334f3c7..68a530e 100644
--- a/tensorflow/core/data/service/task_runner.cc
+++ b/tensorflow/core/data/service/task_runner.cc
@@ -26,9 +26,9 @@
 namespace data {
 namespace {
 // How long to wait for other round-robin consumers before returning with an
-// Unavailable error. The unavailable error gives the client an opportunity to
-// either give up or retry to continue waiting.
-const int64 kDefaultTimeoutUs = 2 * 1000 * 1000;  // 2 seconds.
+// Unavailable error. This prevents the server from hanging on shutdown when
+// some round-robin consumers exit earlier than others.
+const int64 kTimeoutUs = 60 * 1000 * 1000;  // 1 minute.
 }  // namespace
 
 StandaloneTaskIterator::StandaloneTaskIterator(
@@ -58,8 +58,8 @@
           cardinality,
           ". Consider adding a `.repeat()` transformation to the dataset.");
     }
-    out = absl::make_unique<RoundRobinTaskRunner>(
-        std::move(iterator), task_def.num_consumers(), kDefaultTimeoutUs);
+    out = absl::make_unique<RoundRobinTaskRunner>(std::move(iterator),
+                                                  task_def.num_consumers());
   } else {
     out =
         absl::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
@@ -78,10 +78,8 @@
 }
 
 RoundRobinTaskRunner::RoundRobinTaskRunner(
-    std::unique_ptr<TaskIterator> iterator, int64 num_consumers,
-    int64 timeout_us)
+    std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
     : num_consumers_(num_consumers),
-      timeout_us_(timeout_us),
       iterator_(std::move(iterator)),
       buffer_(num_consumers_) {
   VLOG(1) << "Creating task runner for distributing data round-robin to "
@@ -128,7 +126,7 @@
   }
   while (current_round_ < request.round_index) {
     std::cv_status s =
-        new_round_cv_.wait_for(l, std::chrono::microseconds(timeout_us_));
+        new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
     if (s == std::cv_status::timeout) {
       // Clients will retry Unavailable.
       return errors::Unavailable(
diff --git a/tensorflow/core/data/service/task_runner.h b/tensorflow/core/data/service/task_runner.h
index 2a79123..7d1fa3a 100644
--- a/tensorflow/core/data/service/task_runner.h
+++ b/tensorflow/core/data/service/task_runner.h
@@ -112,7 +112,7 @@
 class RoundRobinTaskRunner : public TaskRunner {
  public:
   RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
-                       int64 num_consumers, int64 timeout_us);
+                       int64 num_consumers);
   Status GetNext(const Request& request, std::vector<Tensor>& element,
                  bool& end_of_task) override;
 
@@ -121,7 +121,6 @@
   Status FillBuffer();
 
   const int64 num_consumers_;
-  const int64 timeout_us_;
   std::unique_ptr<TaskIterator> iterator_;
   mutex mu_;
   // Condition variable notified whenever we start a new round of round-robin.
diff --git a/tensorflow/core/data/service/task_runner_test.cc b/tensorflow/core/data/service/task_runner_test.cc
index 4e454d9..16c19fb 100644
--- a/tensorflow/core/data/service/task_runner_test.cc
+++ b/tensorflow/core/data/service/task_runner_test.cc
@@ -22,7 +22,6 @@
 namespace tensorflow {
 namespace data {
 namespace {
-const int64 kNoTimeoutUs = 60ull * 60 * 1000 * 1000;  // 60 minutes.
 
 class TestTaskIterator : public TaskIterator {
  public:
@@ -97,7 +96,7 @@
     elements.push_back(element);
   }
   RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
-                              num_consumers, kNoTimeoutUs);
+                              num_consumers);
   std::vector<std::vector<int64>> per_consumer_results;
   std::vector<std::unique_ptr<Thread>> consumers;
   mutex mu;
@@ -150,7 +149,7 @@
     elements.push_back(element);
   }
   RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
-                              num_consumers, kNoTimeoutUs);
+                              num_consumers);
   std::vector<std::vector<int64>> per_consumer_results;
   std::vector<std::unique_ptr<Thread>> consumers;
   mutex mu;
diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
index 4032e32..4574ffa 100644
--- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
@@ -234,6 +234,9 @@
 
     void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
       mutex_lock l(mu_);
+      for (const auto& task : tasks_) {
+        task->worker->TryCancel();
+      }
       cancelled_ = true;
       worker_thread_cv_.notify_all();
       manager_thread_cv_.notify_all();
diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py
index 8a07b0e..ea7abb4 100644
--- a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py
@@ -47,6 +47,10 @@
 
 TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR
 NO_WORK_DIR = data_service_test_base.NO_WORK_DIR
+# Some clusters may take a long time to shut down due to blocked outstanding
+# RPCs. We store the clusters here so that they are destroyed at end of process
+# instead of slowing down unit tests.
+GLOBAL_CLUSTERS = set()
 
 
 class DataServiceOpsTest(data_service_test_base.TestBase,
@@ -289,6 +293,8 @@
           combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5])))
   def testRoundRobin(self, num_workers, num_consumers):
     cluster = self.create_cluster(num_workers=num_workers)
+    # Round robin reads can cause slow cluster shutdown.
+    GLOBAL_CLUSTERS.add(cluster)
     num_elements = 100
     ds = dataset_ops.Dataset.range(num_elements)
     ds = ds.repeat()
@@ -325,6 +331,8 @@
     # Tests a common use case for round robin reads. At each step, all
     # consumers should get batches with the same bucket size.
     cluster = self.create_cluster(num_workers=4)
+    # Round robin reads can cause slow cluster shutdown.
+    GLOBAL_CLUSTERS.add(cluster)
     num_elements = 100
     ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
     ds = ds.shuffle(num_elements)