Fix finished task handling for restarted workers.

This reflects the recent change where workers now receive their previous tasks on restart, as opposed to being assigned new tasks.

PiperOrigin-RevId: 325456693
Change-Id: I2c99998c1310983ecc70e57f9c7c0a362d42c9d6
diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc
index 0e955e1..d17acff 100644
--- a/tensorflow/core/data/service/worker_impl.cc
+++ b/tensorflow/core/data/service/worker_impl.cc
@@ -76,7 +76,11 @@
                                   [this, dispatcher = dispatcher.release()]() {
                                     BackgroundThread(dispatcher);
                                   });
+  LOG(INFO) << "Worker registered with dispatcher running at "
+            << config_.dispatcher_address();
   background_thread_.reset(thread);
+  mutex_lock l(mu_);
+  registered_ = true;
   return Status::OK();
 }
 
@@ -118,6 +122,13 @@
   std::vector<tensorflow::Tensor> outputs;
   {
     mutex_lock l(mu_);
+    if (!registered_) {
+      // We need to reject requests until the worker has registered with the
+      // dispatcher, so that we don't return NOT_FOUND for tasks that the worker
+      // had before preemption.
+      return errors::Unavailable(
+          "Worker has not yet registered with dispatcher.");
+    }
     auto it = tasks_.find(request->task_id());
     if (it == tasks_.end()) {
       return errors::NotFound("DataServiceWorkerImpl::GetElement failed. ",
diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h
index 8353d11..36edbe5 100644
--- a/tensorflow/core/data/service/worker_impl.h
+++ b/tensorflow/core/data/service/worker_impl.h
@@ -84,6 +84,8 @@
   // Completed tasks which haven't yet been communicated to the dispatcher.
   absl::flat_hash_set<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
   bool cancelled_ TF_GUARDED_BY(mu_) = false;
+  // Whether the worker has registered with the dispatcher yet.
+  bool registered_ TF_GUARDED_BY(mu_) = false;
   // Condition variable for notifying the background thread.
   condition_variable background_cv_ TF_GUARDED_BY(mu_);
   std::unique_ptr<Thread> background_thread_;
diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
index ca73799..8a160aa 100644
--- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
@@ -401,7 +401,6 @@
           mutex_lock l(mu_);
           num_running_worker_threads_--;
           outstanding_requests_--;
-          VLOG(3) << "Exiting worker thread";
         };
         worker_threads_.push_back(ctx->StartThread(
             "tf-data-service-task_thread", [this, done = std::move(done)]() {
@@ -437,10 +436,10 @@
             }
             worker_thread_cv_.wait(l);
           }
+          outstanding_requests_++;
           if (cancelled_) {
             return;
           }
-          outstanding_requests_++;
           // Search for a task to update.
           int num_tasks = tasks_.size();
           for (int i = 0; i < num_tasks; ++i) {
@@ -461,6 +460,9 @@
         Status s = GetElement(task_to_process.get(), deadline_micros);
         if (!s.ok()) {
           mutex_lock l(mu_);
+          VLOG(1) << "Failed to get element for task "
+                  << task_to_process->task_id << ": " << s;
+          task_to_process->in_use = false;
           status_ = s;
           get_next_cv_.notify_all();
           return;
@@ -486,14 +488,6 @@
         if (s.ok()) {
           break;
         }
-        if (errors::IsNotFound(s)) {
-          // This indicates that the worker was restarted. The restarted worker
-          // will get a new task, and the old task is lost.
-          mutex_lock l(mu_);
-          finished_tasks_++;
-          task->end_of_sequence = true;
-          return Status::OK();
-        }
         // Retry all errors that could indicate preemption.
         if (!errors::IsUnavailable(s) && !errors::IsCancelled(s) &&
             !errors::IsAborted(s)) {
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 639c07b..210b6f5 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -94,7 +94,6 @@
     name = "data_service_ops_test",
     size = "medium",
     srcs = ["data_service_ops_test.py"],
-    tags = ["notap"],  # "b/163085430"
     deps = [
         "//tensorflow:tensorflow_py",
         "//tensorflow/python:client_testlib",