Fix finished task handling for restarted workers.
This reflects the recent change where workers now receive their previous tasks on restart, as opposed to being assigned new tasks.
PiperOrigin-RevId: 325456693
Change-Id: I2c99998c1310983ecc70e57f9c7c0a362d42c9d6
diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc
index 0e955e1..d17acff 100644
--- a/tensorflow/core/data/service/worker_impl.cc
+++ b/tensorflow/core/data/service/worker_impl.cc
@@ -76,7 +76,11 @@
[this, dispatcher = dispatcher.release()]() {
BackgroundThread(dispatcher);
});
+ LOG(INFO) << "Worker registered with dispatcher running at "
+ << config_.dispatcher_address();
background_thread_.reset(thread);
+ mutex_lock l(mu_);
+ registered_ = true;
return Status::OK();
}
@@ -118,6 +122,13 @@
std::vector<tensorflow::Tensor> outputs;
{
mutex_lock l(mu_);
+ if (!registered_) {
+ // We need to reject requests until the worker has registered with the
+ // dispatcher, so that we don't return NOT_FOUND for tasks that the worker
+ // had before preemption.
+ return errors::Unavailable(
+ "Worker has not yet registered with dispatcher.");
+ }
auto it = tasks_.find(request->task_id());
if (it == tasks_.end()) {
return errors::NotFound("DataServiceWorkerImpl::GetElement failed. ",
diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h
index 8353d11..36edbe5 100644
--- a/tensorflow/core/data/service/worker_impl.h
+++ b/tensorflow/core/data/service/worker_impl.h
@@ -84,6 +84,8 @@
// Completed tasks which haven't yet been communicated to the dispatcher.
absl::flat_hash_set<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
bool cancelled_ TF_GUARDED_BY(mu_) = false;
+ // Whether the worker has registered with the dispatcher yet.
+ bool registered_ TF_GUARDED_BY(mu_) = false;
// Condition variable for notifying the background thread.
condition_variable background_cv_ TF_GUARDED_BY(mu_);
std::unique_ptr<Thread> background_thread_;
diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
index ca73799..8a160aa 100644
--- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
@@ -401,7 +401,6 @@
mutex_lock l(mu_);
num_running_worker_threads_--;
outstanding_requests_--;
- VLOG(3) << "Exiting worker thread";
};
worker_threads_.push_back(ctx->StartThread(
"tf-data-service-task_thread", [this, done = std::move(done)]() {
@@ -437,10 +436,10 @@
}
worker_thread_cv_.wait(l);
}
+ outstanding_requests_++;
if (cancelled_) {
return;
}
- outstanding_requests_++;
// Search for a task to update.
int num_tasks = tasks_.size();
for (int i = 0; i < num_tasks; ++i) {
@@ -461,6 +460,9 @@
Status s = GetElement(task_to_process.get(), deadline_micros);
if (!s.ok()) {
mutex_lock l(mu_);
+ VLOG(1) << "Failed to get element for task "
+ << task_to_process->task_id << ": " << s;
+ task_to_process->in_use = false;
status_ = s;
get_next_cv_.notify_all();
return;
@@ -486,14 +488,6 @@
if (s.ok()) {
break;
}
- if (errors::IsNotFound(s)) {
- // This indicates that the worker was restarted. The restarted worker
- // will get a new task, and the old task is lost.
- mutex_lock l(mu_);
- finished_tasks_++;
- task->end_of_sequence = true;
- return Status::OK();
- }
// Retry all errors that could indicate preemption.
if (!errors::IsUnavailable(s) && !errors::IsCancelled(s) &&
!errors::IsAborted(s)) {
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 639c07b..210b6f5 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -94,7 +94,6 @@
name = "data_service_ops_test",
size = "medium",
srcs = ["data_service_ops_test.py"],
- tags = ["notap"], # "b/163085430"
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:client_testlib",