Add CheckPeerHealth to PeerAccessInterface
This method will be used to detect peer failures. It sends a GetStatus to the peer and verifies that the device incarnation matches the local record. Verifying incarnation is necessary to detect the case when a worker fails and restarts quickly.
This change also adds a GetDeviceAttributesCached method to DeviceResolverInterface in order to get the local record of device incarnation. We cannot use GetDeviceAttributesAsync because it sends a RPC with fail_fast=false if the device is not in cache. Such RPC may wait forever if the peer is down and never comes back.
PiperOrigin-RevId: 327711672
Change-Id: I870b7079b32f00803a1f509dfa009141c67fdf49
diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc
index 4cd9f82..ec875d0 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local.cc
@@ -108,6 +108,13 @@
from_alloc_attr, done);
}
+void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task,
+ const StatusCallback& done) {
+ // Assume local devices are always healthy.
+ done(errors::Internal(
+ "CheckPeerHealth is not supposed to be called for local collectives"));
+}
+
/*static*/
void CollectiveRemoteAccessLocal::MemCpyAsync(
DeviceContext* src_dev_ctx, DeviceContext* dst_dev_ctx, Device* src_dev,
diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h
index 8a0bbd5..12aca90 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.h
+++ b/tensorflow/core/common_runtime/collective_rma_local.h
@@ -53,6 +53,9 @@
const DeviceLocality& client_locality,
const StatusCallback& done) override;
+ void CheckPeerHealth(const string& peer_task,
+ const StatusCallback& done) override;
+
BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; }
// Copy utility that always copies bytes from src to dst even if
diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc
index d721fc3..2c60614 100644
--- a/tensorflow/core/common_runtime/collective_rma_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc
@@ -151,5 +151,16 @@
EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
}
+TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) {
+ Status status;
+ Notification done;
+ rma_->CheckPeerHealth(kTaskName, [&status, &done](const Status& s) {
+ status = s;
+ done.Notify();
+ });
+ done.WaitForNotification();
+ EXPECT_TRUE(errors::IsInternal(status));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_resolver_local.cc b/tensorflow/core/common_runtime/device_resolver_local.cc
index 12e1e28..9a898e7 100644
--- a/tensorflow/core/common_runtime/device_resolver_local.cc
+++ b/tensorflow/core/common_runtime/device_resolver_local.cc
@@ -46,4 +46,10 @@
done(s);
}
+Status DeviceResolverLocal::GetTaskCached(
+ const string& task, std::vector<DeviceAttributes>* attributes) {
+ return errors::Internal(
+ "GetTaskCached is not supposed to be called in local collectives");
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_resolver_local.h b/tensorflow/core/common_runtime/device_resolver_local.h
index 53a3c87..12b7dce 100644
--- a/tensorflow/core/common_runtime/device_resolver_local.h
+++ b/tensorflow/core/common_runtime/device_resolver_local.h
@@ -39,6 +39,9 @@
DeviceAttributes* attributes,
const StatusCallback& done) override;
+ Status GetTaskCached(const string& task,
+ std::vector<DeviceAttributes>* attributes) override;
+
void ClearTask(const string& task) override {}
void ClearCache() override {}
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 505e0c3..94570c1 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -538,6 +538,7 @@
"//tensorflow/core:lib_internal", # protobuf::Any
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
index dbc9417..4215b16 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -185,6 +185,62 @@
dev_attributes_callback);
}
+void CollectiveRemoteAccessDistributed::CheckPeerHealth(
+ const string& peer_task, const StatusCallback& done) {
+ if (peer_task == task_name_) {
+ // Fast path if the peer is the worker itself.
+ done(Status::OK());
+ return;
+ }
+ // We send a GetStatus RPC with fail_fast=false to check the health of a peer
+ // task. If the RPC succeeds, we verify if the peer_device incarnation matches
+ // the local record if we have it. Note that DeviceResolverInterface always
+ // caches the device attributes.
+ WorkerInterface* wi = worker_cache_->GetOrCreateWorker(peer_task);
+ if (wi == nullptr) {
+ done(errors::InvalidArgument(peer_task,
+ " not found. It's probably in valid. The "
+ "valid form is /job:xxx/replica:0/task:N"));
+ return;
+ }
+ auto req = new GetStatusRequest();
+ auto resp = new GetStatusResponse();
+ // We're not using Cancellable call because GetStatusAsync doesn't support
+ // cancellation yet.
+ wi->GetStatusAsync(
+ req, resp, /*fail_fast*/ true,
+ [this, req, resp, wi, peer_task, done](Status s) {
+ std::vector<DeviceAttributes> cached_attrs;
+ if (s.ok()) {
+ s = dev_resolver_->GetTaskCached(peer_task, &cached_attrs);
+ }
+ if (s.ok()) {
+ absl::flat_hash_set<uint64> remote_incarnations;
+ for (const DeviceAttributes& da : resp->device_attributes()) {
+ remote_incarnations.insert(da.incarnation());
+ }
+ for (const DeviceAttributes& attr : cached_attrs) {
+ if (!remote_incarnations.contains(attr.incarnation())) {
+ s = errors::FailedPrecondition(
+ attr.name(), " with incarnation ", attr.incarnation(),
+ " is not available. This usually means ", peer_task,
+ " has restarted");
+ break;
+ }
+ }
+ } else if (errors::IsNotFound(s)) {
+ // Skip validating device incarnation if we don't know what the
+ // incarnation should be. The device attribute is cached after the
+ // first collective.
+ s = Status::OK();
+ }
+ delete req;
+ delete resp;
+ worker_cache_->ReleaseWorker(peer_task, wi);
+ done(s);
+ });
+}
+
void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
CollectiveRemoteAccessLocal::StartAbort(s);
cancel_mgr_.StartCancel();
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
index d6546e3..ed4d448 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.h
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
@@ -28,10 +28,11 @@
CollectiveRemoteAccessDistributed(
const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
std::shared_ptr<UnboundedWorkQueue> work_queue,
- WorkerCacheInterface* worker_cache, int64 step_id)
+ WorkerCacheInterface* worker_cache, int64 step_id, string task_name)
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
worker_cache_(worker_cache),
- work_queue_(std::move(work_queue)) {}
+ work_queue_(std::move(work_queue)),
+ task_name_(std::move(task_name)) {}
~CollectiveRemoteAccessDistributed() override {}
@@ -43,6 +44,9 @@
int dev_to_dev_stream_index,
const StatusCallback& done) override;
+ void CheckPeerHealth(const string& peer_task,
+ const StatusCallback& done) override;
+
void StartAbort(const Status& s) override;
protected:
@@ -51,6 +55,7 @@
// `CollectiveExecutorMgr`.
std::shared_ptr<UnboundedWorkQueue> work_queue_;
CancellationManager cancel_mgr_;
+ string task_name_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
index 2975442..b6975e4 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -63,11 +63,12 @@
class FakeWorker : public TestWorkerInterface {
public:
FakeWorker(const string& name, DeviceMgr* dev_mgr,
- DeviceResolverDistributed* dres)
+ DeviceResolverDistributed* dres, bool is_failed)
: name_(name),
device_mgr_(dev_mgr),
device_resolver_(dres),
- buf_rendezvous_(kStepId, dev_mgr) {}
+ buf_rendezvous_(kStepId, dev_mgr),
+ is_failed_(is_failed) {}
// Direct access to a BufRendezvous that holds whatever the remote
// worker is supposed to have.
@@ -76,6 +77,10 @@
void GetStatusAsync(const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) override {
+ if (is_failed_) {
+ done(errors::Unavailable("peer down"));
+ return;
+ }
std::vector<DeviceAttributes> dev_attr;
device_mgr_->ListDeviceAttributes(&dev_attr);
for (const auto& da : dev_attr) {
@@ -86,6 +91,10 @@
void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
RecvBufResponse* response, StatusCallback done) override {
+ if (is_failed_) {
+ done(errors::Unavailable("peer down"));
+ return;
+ }
opts->SetCancelCallback([this]() {
// Within this test the call is satisfied by a process-local
// BufRendezvous table. In real application the BufRendezvous
@@ -125,6 +134,7 @@
DeviceMgr* device_mgr_;
DeviceResolverDistributed* device_resolver_;
BufRendezvous buf_rendezvous_;
+ bool is_failed_;
};
class FakeCache : public TestWorkerCache {
@@ -201,7 +211,7 @@
// All tests simulate requests from worker 0 to worker 1.
rma_.reset(new CollectiveRemoteAccessDistributed(
device_mgrs_[0], dev_resolvers_[dev0_worker_name], work_queue_, &wc_,
- kStepId));
+ kStepId, "/job:worker/replica:0/task:0"));
const int kNumElts = 8;
expected_value_ = Tensor(DT_FLOAT, {kNumElts});
@@ -215,7 +225,7 @@
}
void DefineWorker(const string& worker_name, const string& device_type,
- int num_devices) {
+ int num_devices, bool is_failed = false) {
std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice(
@@ -232,19 +242,19 @@
DeviceResolverDistributed* dev_res =
new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
dev_resolvers_[worker_name] = dev_res;
- FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
+ FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res, is_failed);
workers_.push_back(fw);
wc_.AddWorker(worker_name, fw);
}
void RestartWorker(const string& worker_name, const string& device_type,
- int num_devices) {
+ int num_devices, bool is_failed = false) {
auto it = dev_resolvers_.find(worker_name);
if (it != dev_resolvers_.end()) {
delete it->second;
dev_resolvers_.erase(it);
}
- DefineWorker(worker_name, device_type, num_devices);
+ DefineWorker(worker_name, device_type, num_devices, is_failed);
}
void ValidateResultTensor() {
@@ -401,7 +411,7 @@
ValidateResultTensor();
// Restart task 1 and check that recv from task 1 to task 0 fails.
- RestartWorker("/job:worker/replica:0/task:1", "CPU", 1);
+ RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
Notification post_restart_note;
rma_->RecvFromPeer(
"/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
@@ -417,5 +427,139 @@
EXPECT_TRUE(errors::IsFailedPrecondition(consumer_status));
}
+TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) {
+ DeviceAttributes attr;
+ Status get_attr_status;
+ Notification get_attr_done;
+ // Call GetDeviceAttributesAsync to cache the device attributes of a remote
+ // worker.
+ dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
+ "/job:worker/replica:0/task:1/device:CPU:0",
+ "/job:worker/replica:0/task:1", &attr,
+ [&get_attr_status, &get_attr_done](const Status& s) {
+ get_attr_status = s;
+ get_attr_done.Notify();
+ });
+ get_attr_done.WaitForNotification();
+ TF_ASSERT_OK(get_attr_status);
+
+ Status check_health_status;
+ Notification check_health_done;
+ rma_->CheckPeerHealth(
+ "/job:worker/replica:0/task:1",
+ [&check_health_status, &check_health_done](const Status s) {
+ check_health_status = s;
+ check_health_done.Notify();
+ });
+ check_health_done.WaitForNotification();
+ TF_EXPECT_OK(check_health_status);
+}
+
+TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
+ Status check_health_status;
+ Notification check_health_done;
+ rma_->CheckPeerHealth(
+ "/job:worker/replica:0/task:1",
+ [&check_health_status, &check_health_done](const Status s) {
+ check_health_status = s;
+ check_health_done.Notify();
+ });
+ check_health_done.WaitForNotification();
+ EXPECT_TRUE(check_health_status.ok());
+}
+
+TEST_F(CollRMADistTest, CheckHealthRestarted) {
+ DeviceAttributes attr;
+ Status get_attr_status;
+ Notification get_attr_done;
+ // Call GetDeviceAttributesAsync to cache the device attributes of a remote
+ // worker.
+ dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
+ "/job:worker/replica:0/task:1/device:CPU:0",
+ "/job:worker/replica:0/task:1", &attr,
+ [&get_attr_status, &get_attr_done](const Status& s) {
+ get_attr_status = s;
+ get_attr_done.Notify();
+ });
+ get_attr_done.WaitForNotification();
+ TF_ASSERT_OK(get_attr_status);
+
+ RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
+
+ Status check_health_status;
+ Notification check_health_done;
+ rma_->CheckPeerHealth(
+ "/job:worker/replica:0/task:1",
+ [&check_health_status, &check_health_done](const Status s) {
+ check_health_status = s;
+ check_health_done.Notify();
+ });
+ check_health_done.WaitForNotification();
+ EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
+}
+
+TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
+ DeviceAttributes attr;
+ Status get_attr_status;
+ Notification get_attr_done;
+ // Call GetDeviceAttributesAsync to cache the device attributes of a remote
+ // worker.
+ dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
+ "/job:worker/replica:0/task:1/device:CPU:0",
+ "/job:worker/replica:0/task:1", &attr,
+ [&get_attr_status, &get_attr_done](const Status& s) {
+ get_attr_status = s;
+ get_attr_done.Notify();
+ });
+ get_attr_done.WaitForNotification();
+ TF_ASSERT_OK(get_attr_status);
+
+ RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1,
+ /*is_failed*/ true);
+
+ Status check_health_status;
+ Notification check_health_done;
+ rma_->CheckPeerHealth(
+ "/job:worker/replica:0/task:1",
+ [&check_health_status, &check_health_done](const Status s) {
+ check_health_status = s;
+ check_health_done.Notify();
+ });
+ check_health_done.WaitForNotification();
+ EXPECT_TRUE(errors::IsUnavailable(check_health_status));
+}
+
+TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
+ RestartWorker("/job:worker/replica:0/task:1", "GPU", /*num_devices*/ 1);
+
+ DeviceAttributes attr;
+ Status get_attr_status;
+ Notification get_attr_done;
+ // Call GetDeviceAttributesAsync to cache the device attributes of a remote
+ // worker.
+ dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
+ "/job:worker/replica:0/task:1/device:GPU:0",
+ "/job:worker/replica:0/task:1", &attr,
+ [&get_attr_status, &get_attr_done](const Status& s) {
+ get_attr_status = s;
+ get_attr_done.Notify();
+ });
+ get_attr_done.WaitForNotification();
+ TF_ASSERT_OK(get_attr_status);
+
+ RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
+
+ Status check_health_status;
+ Notification check_health_done;
+ rma_->CheckPeerHealth(
+ "/job:worker/replica:0/task:1",
+ [&check_health_status, &check_health_done](const Status s) {
+ check_health_status = s;
+ check_health_done.Notify();
+ });
+ check_health_done.WaitForNotification();
+ EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc
index 927925c..ab0b3a6 100644
--- a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc
@@ -113,6 +113,22 @@
});
}
+Status DeviceResolverDistributed::GetTaskCached(
+ const string& task, std::vector<DeviceAttributes>* attributes) {
+ mutex_lock l(mu_);
+ attributes->clear();
+ for (const auto& it : attr_table_) {
+ const string& device_name = it.first;
+ if (DeviceNameUtils::IsSameAddressSpace(task, device_name)) {
+ attributes->push_back(it.second);
+ }
+ }
+ if (attributes->empty()) {
+ return errors::NotFound(task, " not found in the cache");
+ }
+ return Status::OK();
+}
+
void DeviceResolverDistributed::ClearTask(const string& task) {
mutex_lock l(mu_);
// First find all the keys belonging to the task.
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.h b/tensorflow/core/distributed_runtime/device_resolver_distributed.h
index 93d51a5..d400fb5 100644
--- a/tensorflow/core/distributed_runtime/device_resolver_distributed.h
+++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.h
@@ -43,6 +43,9 @@
DeviceAttributes* attributes,
const StatusCallback& done) override;
+ Status GetTaskCached(const string& task,
+ std::vector<DeviceAttributes>* attributes) override;
+
void ClearTask(const string& task) override;
void ClearCache() override;
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
index 4fbc4bb..62a67b5 100644
--- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
@@ -47,8 +47,9 @@
CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
CollectiveRemoteAccessDistributed* rma =
- new CollectiveRemoteAccessDistributed(
- dev_mgr_, dev_resolver_.get(), work_queue_, worker_cache_, step_id);
+ new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
+ work_queue_, worker_cache_, step_id,
+ task_name_);
return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
&gpu_ring_order_, work_queue_);
}
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index 05eefed..d0c5323 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -161,17 +161,20 @@
std::vector<DeviceAttributes>* attributes,
const StatusCallback& done) = 0;
- // Populate *attributes with the DeviceAttributes of the specified
- // device.
+ // Populates *attributes with the DeviceAttributes of the specified device.
virtual void GetDeviceAttributesAsync(const string& device,
const string& task,
DeviceAttributes* attributes,
const StatusCallback& done) = 0;
- // Clear the cache of device data belonging to the specified task.
+ // Returns the cached device attributes of a task.
+ virtual Status GetTaskCached(const string& task,
+ std::vector<DeviceAttributes>* attributes) = 0;
+
+ // Clears the cache of device data belonging to the specified task.
virtual void ClearTask(const string& task) = 0;
- // Clear the cache of all device data.
+ // Clears the cache of all device data.
virtual void ClearCache() = 0;
};
@@ -279,6 +282,12 @@
const DeviceLocality& client_locality,
const StatusCallback& done) = 0;
+ // Checks the health of a collective peer. It probes the peer to see if it is
+ // alive. Note that if a peer has restarted, it's considered a different one,
+ // so CheckPeerHealth fails.
+ virtual void CheckPeerHealth(const string& peer_task,
+ const StatusCallback& done) = 0;
+
virtual BufRendezvous* buf_rendezvous() = 0;
virtual void StartAbort(const Status& s) = 0;