Add support in collective_rma_distributed for worker interfaces that directly populate the pointer provided in the RecvBuf request with the remote tensor, rather than populating transport_options with a serialized proto containing the remote tensor. This allows for more efficient implementation that avoids extra serialization / de-serialization costs from transport format into protos and finally to tensor buffer.
PiperOrigin-RevId: 372976144
Change-Id: I2f46d9c6b7dd30ea7003de7587703dc8666e7c27
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
index 29fcd82..05032da 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -14,6 +14,8 @@
==============================================================================*/
#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
+#include <memory>
+
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@@ -24,6 +26,7 @@
#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/transport_options.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
@@ -73,6 +76,32 @@
head += tensor_content_chunk.size();
}
}
+
+Status PopulateTensorFromResponse(const RecvBufResponse& response,
+ Tensor* cpu_tensor) {
+ const bool has_transport_options = response.has_transport_options();
+
+ // If there are no transport options, then the tensor has already been
+ // copied into request.buf_ptr.
+ if (!has_transport_options) return Status::OK();
+
+ const int64 total_bytes = cpu_tensor->TotalBytes();
+ int64 num_bytes = 0;
+ RecvBufRespExtra extra;
+ response.transport_options().UnpackTo(&extra);
+ for (const auto& chunk : extra.tensor_content()) {
+ num_bytes += chunk.size();
+ }
+
+ if (num_bytes != total_bytes) {
+ return errors::Internal("Tensor Size Mismatch: RecvBufResponse returned ",
+ num_bytes,
+ " bytes, expected: ", cpu_tensor->TotalBytes());
+ }
+ PopulateTensorFromExtra(extra, cpu_tensor);
+ return Status::OK();
+}
+
} // namespace
void CollectiveRemoteAccessDistributed::RecvFromPeer(
@@ -94,83 +123,95 @@
struct State {
DeviceAttributes server_attributes;
std::unique_ptr<RecvBufCall> call;
+ std::unique_ptr<Tensor> cpu_tensor;
};
State* state = new State;
- // Logic to be executed on the RecvBufAsync callback.
- auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr,
- to_device_ctx, to_tensor, dev_to_dev_stream_index,
- done](const Status& s) {
- if (s.ok()) {
- // In this generic implementation the bytes come back in the
- // RPC response protobuf rather than via RDMA so we need to copy
- // them into the destination tensor here.
- RecvBufRespExtra extra;
- state->call->resp_.transport_options().UnpackTo(&extra);
- int64 num_bytes = 0;
- for (const auto& chunk : extra.tensor_content()) {
- num_bytes += chunk.size();
- }
- const int64 total_bytes = to_tensor->TotalBytes();
- if (num_bytes != total_bytes) {
- done(errors::Internal("RecvBufResponse returned ", num_bytes,
- " bytes where to_tensor expected ",
- to_tensor->TotalBytes()));
- delete state;
- return;
- }
- if (to_device->tensorflow_gpu_device_info()) {
- // Move the bytes into a CPU tensor then use tensor-to-tensor copy.
- // Use GPU-registered memory for the CPU tensor so the transfer
- // goes faster.
- Device* cpu_dev = nullptr;
- Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev);
- if (!status.ok()) {
- done(status);
- delete state;
- return;
- }
- AllocatorAttributes cpu_attr;
- cpu_attr.set_gpu_compatible(true);
- ScopedMemoryDebugAnnotation op_annotation(
- "CollectiveRemoteAccessDistributed::RecvFromPeer"
- "::recv_buf_callback",
- step_id_, "dynamic", to_tensor->dtype(), &to_tensor->shape());
- Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr),
- to_tensor->dtype(), to_tensor->shape());
- PopulateTensorFromExtra(extra, cpu_tensor);
- // Then copy it to the GPU.
- CopyTensor::ViaDMA("", // edge name (non-existent)
- nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev,
- to_device, cpu_attr, to_alloc_attr, cpu_tensor,
- to_tensor, dev_to_dev_stream_index,
- [this, cpu_tensor, done](const Status& s) {
- delete cpu_tensor;
- // This callback must not block, so execute
- // done in another thread.
- work_queue_->Schedule([s, done] { done(s); });
- });
- delete state;
- return;
- } else {
- // CPU device
- PopulateTensorFromExtra(extra, to_tensor);
- }
- }
-
- delete state;
- done(s);
- };
-
+ DeviceAttributes server_attributes;
Status s = dev_resolver_->GetDeviceAttributes(peer_device,
&state->server_attributes);
if (!s.ok()) {
- recv_buf_callback(s);
+ delete state;
+ done(s);
return;
}
+
+ Tensor* dst_tensor = nullptr;
+ Device* cpu_dev = nullptr;
+ if (to_device->tensorflow_gpu_device_info()) {
+ // Move the bytes into a CPU tensor then use tensor-to-tensor copy.
+ // Use GPU-registered memory for the CPU tensor so the transfer
+ // goes faster.
+
+ Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev);
+ if (!status.ok()) {
+ delete state;
+ done(s);
+ return;
+ }
+ AllocatorAttributes cpu_attr;
+ cpu_attr.set_gpu_compatible(true);
+ ScopedMemoryDebugAnnotation op_annotation(
+ "CollectiveRemoteAccessDistributed::RecvFromPeer"
+ "::recv_buf_callback",
+ step_id_, "dynamic", to_tensor->dtype(), &to_tensor->shape());
+
+ state->cpu_tensor =
+ std::make_unique<Tensor>(cpu_dev->GetAllocator(cpu_attr),
+ to_tensor->dtype(), to_tensor->shape());
+ dst_tensor = state->cpu_tensor.get();
+ } else {
+ dst_tensor = to_tensor;
+ }
+
+ // Logic to be executed on the RecvBufAsync callback.
+ auto recv_buf_callback =
+ [this, state, to_device, to_alloc_attr, to_device_ctx, to_tensor, cpu_dev,
+ dev_to_dev_stream_index, dst_tensor, done](const Status& s) {
+ if (s.ok()) {
+ // In this generic implementation the bytes come back in one of 2
+ // ways:
+ // 1. In the response protobuf transport_options field (OR)
+ // 2. It has already been copied over into RecvBufCall::req_.buf_ptr()
+ // provided in request. buf_ptr is set to dst_tensor and points to
+ // either the temporary cpu_tensor in case to_device is a GPU device
+ // OR directly to to_tensor if to_device is not a GPU device.
+ //
+ // PopulateTensorFromResponse handles both cases.
+ // (NOP in 2nd case) In case the final to_tensor is on GPU, buf_ptr
+ // points to a tmp CPU buffer and needs to be copied over to
+ // to_tensor.
+ Status status =
+ PopulateTensorFromResponse(state->call->resp_, dst_tensor);
+ if (!status.ok()) {
+ done(status);
+ delete state;
+ return;
+ }
+
+ if (to_device->tensorflow_gpu_device_info()) {
+ AllocatorAttributes cpu_attr;
+ cpu_attr.set_gpu_compatible(true);
+ CopyTensor::ViaDMA("", // edge name (non-existent)
+ nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev,
+ to_device, cpu_attr, to_alloc_attr, dst_tensor,
+ to_tensor, dev_to_dev_stream_index,
+ [this, state, done](const Status& s) {
+ delete state;
+ // This callback must not block, so execute
+ // done in another thread.
+ work_queue_->Schedule([s, done] { done(s); });
+ });
+ return;
+ }
+ }
+ delete state;
+ done(s);
+ };
+
state->call.reset(new RecvBufCall(
step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
- to_alloc_attr, to_tensor, client_locality, state->server_attributes,
+ to_alloc_attr, dst_tensor, client_locality, state->server_attributes,
cancellation_manager, worker_cache_));
CancellationToken abortion_token =
abortion_cancel_mgr_.get_cancellation_token();
@@ -194,10 +235,10 @@
done(Status::OK());
return;
}
- // We send a GetStatus RPC with fail_fast=false to check the health of a peer
- // task. If the RPC succeeds, we verify if the peer_device incarnation matches
- // the local record if we have it. Note that DeviceResolverInterface always
- // caches the device attributes.
+ // We send a GetStatus RPC to check the health of a peer task. If the RPC
+ // succeeds, we verify if the peer_device incarnation matches the local record
+ // if we have it. Note that DeviceResolverInterface always caches the device
+ // attributes.
WorkerInterface* wi = worker_cache_->GetOrCreateWorker(peer_task);
if (wi == nullptr) {
done(errors::InvalidArgument(peer_task,
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
index 74282be..35ae27f 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -21,13 +21,16 @@
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/test_utils.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/transport_options.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
@@ -42,20 +45,33 @@
namespace tensorflow {
namespace {
-static std::unique_ptr<Device> NewDevice(const string& type,
- const string& name) {
+class FakeAllocator : public Allocator {
+ public:
+ string Name() override { return "fake"; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ return port::AlignedMalloc(num_bytes, alignment);
+ }
+ void DeallocateRaw(void* ptr) override { return port::AlignedFree(ptr); }
+};
+
+static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
+ Allocator* allocator) {
class FakeDevice : public Device {
public:
- explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
+ explicit FakeDevice(const DeviceAttributes& attr, Allocator* allocator)
+ : Device(nullptr, attr), allocator_(allocator) {}
Status Sync() override { return Status::OK(); }
- Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
+ Allocator* GetAllocator(AllocatorAttributes) override { return allocator_; }
+
+ private:
+ Allocator* const allocator_;
};
DeviceAttributes attr;
attr.set_name(name);
attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(3); // a non-default value
attr.set_incarnation(random::New64());
- return absl::make_unique<FakeDevice>(attr);
+ return absl::make_unique<FakeDevice>(attr, allocator);
}
static int64 kStepId = 123;
@@ -63,12 +79,14 @@
class FakeWorker : public TestWorkerInterface {
public:
FakeWorker(const string& name, DeviceMgr* dev_mgr,
- DeviceResolverDistributed* dres, bool is_failed)
+ DeviceResolverDistributed* dres, bool is_failed,
+ bool set_tensor_in_extra)
: name_(name),
device_mgr_(dev_mgr),
device_resolver_(dres),
buf_rendezvous_(kStepId, dev_mgr),
- is_failed_(is_failed) {}
+ is_failed_(is_failed),
+ set_tensor_in_extra_(set_tensor_in_extra) {}
// Direct access to a BufRendezvous that holds whatever the remote
// worker is supposed to have.
@@ -112,17 +130,29 @@
buf_rendezvous_.ConsumeBuf(
request->buf_rendezvous_key(), request->src_device(),
request->src_incarnation(),
- [opts, response, done](const Status& s, BufRendezvous::Hook* h) {
+ [this, opts, request, response, done](const Status& status,
+ BufRendezvous::Hook* h) {
+ Status s = status;
if (s.ok()) {
opts->ClearCancelCallback();
- // Since this is not really RDMA into pre-allocated memory send the
- // bytes in the response.
- RecvBufRespExtra extra;
int64 num_bytes = h->prod_value->TotalBytes();
- extra.add_tensor_content(string(
- reinterpret_cast<const char*>(DMAHelper::base(h->prod_value)),
- num_bytes));
- response->mutable_transport_options()->PackFrom(extra);
+
+ if (set_tensor_in_extra_) {
+ // Since this is not really RDMA into pre-allocated memory send
+ // the bytes in the response.
+ RecvBufRespExtra extra;
+ extra.add_tensor_content(string(
+ reinterpret_cast<const char*>(DMAHelper::base(h->prod_value)),
+ num_bytes));
+ response->mutable_transport_options()->PackFrom(extra);
+ } else {
+ if (request->num_bytes() != num_bytes) {
+ s = errors::Internal("Tensor Size Mismatch.");
+ } else {
+ memcpy(reinterpret_cast<void*>(request->buf_ptr()),
+ DMAHelper::base(h->prod_value), num_bytes);
+ }
+ }
}
done(s);
if (h) BufRendezvous::DoneWithHook(h);
@@ -136,6 +166,7 @@
DeviceResolverDistributed* device_resolver_;
BufRendezvous buf_rendezvous_;
bool is_failed_;
+ const bool set_tensor_in_extra_;
};
class FakeCache : public TestWorkerCache {
@@ -179,7 +210,19 @@
}
};
-class CollRMADistTest : public ::testing::Test {
+enum TEST_PARAM_DEVICE_TYPE {
+ TEST_PARAM_DEVICE_TYPE_CPU = 0,
+ TEST_PARAM_DEVICE_TYPE_GPU,
+};
+
+enum TEST_PARAM_TENSOR_LOC {
+ TEST_PARAM_TENSOR_LOC_AT_BUF_PTR = 0,
+ TEST_PARAM_TENSOR_LOC_IN_EXTRA,
+};
+
+class CollRMADistTest
+ : public ::testing::TestWithParam<
+ std::tuple<TEST_PARAM_DEVICE_TYPE, TEST_PARAM_TENSOR_LOC>> {
protected:
CollRMADistTest()
: work_queue_(
@@ -217,12 +260,17 @@
const int kNumElts = 8;
expected_value_ = Tensor(DT_FLOAT, {kNumElts});
to_tensor_ = Tensor(DT_FLOAT, {kNumElts});
+ large_response_ = Tensor(DT_FLOAT, {2 * kNumElts});
auto exp_alias = expected_value_.flat<float>();
auto to_alias = to_tensor_.flat<float>();
+ auto large_response_alias = large_response_.flat<float>();
for (int i = 0; i < kNumElts; ++i) {
exp_alias(i) = i;
to_alias(i) = -1;
}
+ for (int i = 0; i < 2 * kNumElts; ++i) {
+ large_response_alias(i) = -2;
+ }
}
// Populates all device resolvers with device attributes of the cluster. This
@@ -243,7 +291,8 @@
for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice(
device_type,
- strings::StrCat(worker_name, "/device:", device_type, ":", i)));
+ strings::StrCat(worker_name, "/device:", device_type, ":", i),
+ &fake_allocator_));
}
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr);
@@ -254,7 +303,12 @@
}
DeviceResolverDistributed* dev_res = new DeviceResolverDistributed(dev_mgr);
dev_resolvers_[worker_name] = dev_res;
- FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res, is_failed);
+ FakeWorker* fw =
+ new FakeWorker(worker_name, dev_mgr, dev_res, is_failed,
+ /*set_tensor_in_extra=*/
+ std::get<TEST_PARAM_TENSOR_LOC>(GetParam()) ==
+ TEST_PARAM_TENSOR_LOC_IN_EXTRA);
+
workers_.push_back(fw);
wc_.AddWorker(worker_name, fw);
}
@@ -280,6 +334,19 @@
}
}
+ void ValidateResultTensorUnchanged() {
+ for (int i = 0; i < to_tensor_.NumElements(); ++i) {
+ EXPECT_FLOAT_EQ(-1, to_tensor_.flat<float>()(i));
+ }
+ }
+
+ void MaybeSetGPUDevice(Device* dst_device) {
+ if (std::get<TEST_PARAM_DEVICE_TYPE>(GetParam()) ==
+ TEST_PARAM_DEVICE_TYPE_GPU) {
+ dst_device->set_tensorflow_gpu_device_info(&gpu_device_info_);
+ }
+ }
+
FakeCache wc_;
CancellationManager cm_;
std::vector<DeviceMgr*> device_mgrs_;
@@ -292,13 +359,16 @@
int num_done_ TF_GUARDED_BY(mu_);
condition_variable done_;
Tensor expected_value_;
+ Tensor large_response_;
Tensor to_tensor_;
CallOptions opts_;
DeviceLocality device_locality_;
AllocatorAttributes alloc_attr_;
+ FakeAllocator fake_allocator_;
+ DeviceBase::GpuDeviceInfo gpu_device_info_;
};
-TEST_F(CollRMADistTest, ProdFirstOK) {
+TEST_P(CollRMADistTest, ProdFirstOK) {
ResolveDeviceAttributes();
Notification consumer_note;
Notification producer_note;
@@ -318,6 +388,7 @@
string dev_name = "CPU:0";
TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
DeviceContext* to_device_ctx = nullptr;
+ MaybeSetGPUDevice(dst_device);
rma_->RecvFromPeer(
"/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
"/job:worker/replica:0/task:1", // peer_task
@@ -336,7 +407,7 @@
ValidateResultTensor();
}
-TEST_F(CollRMADistTest, ConsFirstOK) {
+TEST_P(CollRMADistTest, ConsFirstOK) {
ResolveDeviceAttributes();
Notification consumer_note;
Notification producer_note;
@@ -347,6 +418,7 @@
Device* dst_device = nullptr;
string dev_name = "CPU:0";
TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ MaybeSetGPUDevice(dst_device);
DeviceContext* to_device_ctx = nullptr;
rma_->RecvFromPeer(
"/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
@@ -374,7 +446,7 @@
ValidateResultTensor();
}
-TEST_F(CollRMADistTest, ConsFirstAbort) {
+TEST_P(CollRMADistTest, ConsFirstAbort) {
ResolveDeviceAttributes();
Notification consumer_note;
Status consumer_status;
@@ -382,6 +454,7 @@
Device* dst_device = nullptr;
string dev_name = "CPU:0";
TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ MaybeSetGPUDevice(dst_device);
DeviceContext* to_device_ctx = nullptr;
rma_->RecvFromPeer(
"/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
@@ -399,7 +472,47 @@
EXPECT_EQ(consumer_status.error_message(), "Cancelled");
}
-TEST_F(CollRMADistTest, WorkerRestart) {
+TEST_P(CollRMADistTest, ResponseTooLarge) {
+ ResolveDeviceAttributes();
+ Notification consumer_note;
+ Notification producer_note;
+ Status consumer_status;
+ Status producer_status;
+ FakeWorker* wi = workers_[1];
+ const string kBufKey = "fake_buf_key";
+ wi->buf_rendezvous()->ProvideBuf(
+ kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &large_response_,
+ AllocatorAttributes(),
+ [&producer_note, &producer_status](const Status& s) {
+ producer_status.Update(s);
+ producer_note.Notify();
+ },
+ nullptr /*cancellation_manager*/);
+ Device* dst_device = nullptr;
+ string dev_name = "CPU:0";
+ TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ DeviceContext* to_device_ctx = nullptr;
+ MaybeSetGPUDevice(dst_device);
+ rma_->RecvFromPeer(
+ "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
+ "/job:worker/replica:0/task:1", // peer_task
+ false, // peer_is_local
+ kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+ device_locality_, 0 /*dev_to_dev_stream_index*/,
+ nullptr /*cancellation_manager*/,
+ [&consumer_status, &consumer_note](const Status& s) {
+ consumer_status = s;
+ consumer_note.Notify();
+ });
+ consumer_note.WaitForNotification();
+ EXPECT_THAT(consumer_status.error_message(),
+ ::testing::HasSubstr("Tensor Size Mismatch"));
+ producer_note.WaitForNotification();
+ TF_EXPECT_OK(producer_status);
+ ValidateResultTensorUnchanged();
+}
+
+TEST_P(CollRMADistTest, WorkerRestart) {
ResolveDeviceAttributes();
Notification consumer_note;
Notification producer_note;
@@ -410,6 +523,7 @@
Device* dst_device = nullptr;
string dev_name = "CPU:0";
TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ MaybeSetGPUDevice(dst_device);
DeviceContext* to_device_ctx = nullptr;
rma_->RecvFromPeer(
"/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
@@ -454,7 +568,7 @@
EXPECT_TRUE(errors::IsFailedPrecondition(consumer_status));
}
-TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) {
+TEST_P(CollRMADistTest, CheckHealthOKWithCachedAttr) {
ResolveDeviceAttributes();
Status check_health_status;
Notification check_health_done;
@@ -468,7 +582,7 @@
TF_EXPECT_OK(check_health_status);
}
-TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
+TEST_P(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
Status check_health_status;
Notification check_health_done;
rma_->CheckPeerHealth(
@@ -481,7 +595,7 @@
EXPECT_TRUE(check_health_status.ok());
}
-TEST_F(CollRMADistTest, CheckHealthRestarted) {
+TEST_P(CollRMADistTest, CheckHealthRestarted) {
ResolveDeviceAttributes();
RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
@@ -497,7 +611,7 @@
EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
}
-TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
+TEST_P(CollRMADistTest, CheckHealthFailedPeer) {
ResolveDeviceAttributes();
RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1,
/*is_failed*/ true);
@@ -514,7 +628,7 @@
EXPECT_TRUE(errors::IsUnavailable(check_health_status));
}
-TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
+TEST_P(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
ResolveDeviceAttributes();
RestartWorker("/job:worker/replica:0/task:1", "GPU", /*num_devices*/ 1);
Status check_health_status;
@@ -529,5 +643,12 @@
EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
}
+INSTANTIATE_TEST_SUITE_P(
+ TensorInBufPtrOrExtra, CollRMADistTest,
+ ::testing::Combine(::testing::Values(TEST_PARAM_TENSOR_LOC_AT_BUF_PTR,
+ TEST_PARAM_TENSOR_LOC_IN_EXTRA),
+ ::testing::Values(TEST_PARAM_DEVICE_TYPE_CPU,
+ TEST_PARAM_DEVICE_TYPE_GPU)));
+
} // namespace
} // namespace tensorflow