blob: d55465099b5e831286d307cf2610da26f61085fd [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
#include "google/protobuf/any.pb.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/test_utils.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/transport_options.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
// The only interesting method on CollectiveRemoteAccessDistributed
// that's not on CollectiveRemoteAccessLocal is RecvFromPeer which
// issues a RecvBufAsync call against a WorkerInterface. That's all
// that's tested here. Note that RecvFromPeer can do a
// DeviceResolverInterface::GetDeviceLocalityAsync call in preparation
// for the RecvBufAsync.
namespace tensorflow {
namespace {
static std::unique_ptr<Device> NewDevice(const string& type,
const string& name) {
class FakeDevice : public Device {
public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
};
DeviceAttributes attr;
attr.set_name(name);
attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(3); // a non-default value
attr.set_incarnation(random::New64());
return absl::make_unique<FakeDevice>(attr);
}
static int64 kStepId = 123;
class FakeWorker : public TestWorkerInterface {
public:
FakeWorker(const string& name, DeviceMgr* dev_mgr,
DeviceResolverDistributed* dres)
: name_(name),
device_mgr_(dev_mgr),
device_resolver_(dres),
buf_rendezvous_(kStepId, dev_mgr) {}
// Direct access to a BufRendezvous that holds whatever the remote
// worker is supposed to have.
BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
void GetStatusAsync(const GetStatusRequest* request,
GetStatusResponse* response,
StatusCallback done) override {
std::vector<DeviceAttributes> dev_attr;
device_mgr_->ListDeviceAttributes(&dev_attr);
for (const auto& da : dev_attr) {
*response->add_device_attributes() = da;
}
done(Status::OK());
}
void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
RecvBufResponse* response, StatusCallback done) override {
opts->SetCancelCallback([this]() {
// Within this test the call is satisfied by a process-local
// BufRendezvous table. In real application the BufRendezvous
// would be on the other side of a network hop, so call
// BufRendezvous::StartAbort() from a separate thread to be
// more consistent with that situation and avoid mutex deadlock.
SchedClosure([this]() {
Env::Default()->SleepForMicroseconds(100);
buf_rendezvous_.StartAbort(errors::Internal("Cancelled"));
});
});
VLOG(2) << "ConsumeBuf key=" << request->buf_rendezvous_key()
<< " src_device=" << request->src_device()
<< " src_incarnation=" << request->src_incarnation();
buf_rendezvous_.ConsumeBuf(
request->buf_rendezvous_key(), request->src_device(),
request->src_incarnation(),
[opts, response, done](const Status& s, BufRendezvous::Hook* h) {
if (s.ok()) {
opts->ClearCancelCallback();
// Since this is not really RDMA into pre-allocated memory send the
// bytes in the response.
RecvBufRespExtra extra;
int64 num_bytes = h->prod_value->TotalBytes();
extra.add_tensor_content(string(
reinterpret_cast<const char*>(DMAHelper::base(h->prod_value)),
num_bytes));
response->mutable_transport_options()->PackFrom(extra);
}
done(s);
if (h) BufRendezvous::DoneWithHook(h);
});
}
private:
string name_;
DeviceMgr* device_mgr_;
DeviceResolverDistributed* device_resolver_;
BufRendezvous buf_rendezvous_;
};
class FakeCache : public TestWorkerCache {
public:
// Override the Locality methods to actually pass through to the
// worker.
bool GetDeviceLocalityNonBlocking(const string& device,
DeviceLocality* locality) override {
return false;
}
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
StatusCallback done) override {
string task_name;
string dev_part;
if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) {
done(errors::Internal("failed to parse device name"));
return;
}
auto it = workers_.find(task_name);
if (it == workers_.end()) {
done(errors::Internal("failed to find worker ", task_name));
return;
}
WorkerInterface* wi = it->second;
GetStatusRequest req;
GetStatusResponse resp;
Status status = wi->GetStatus(&req, &resp);
if (!status.ok()) {
done(status);
return;
}
for (const auto& it : resp.device_attributes()) {
if (it.name() == device) {
*locality = it.locality();
done(Status::OK());
return;
}
}
done(errors::Internal("device not found: ", device));
}
};
class CollRMADistTest : public ::testing::Test {
protected:
CollRMADistTest()
: work_queue_(
std::make_shared<UnboundedWorkQueue>(Env::Default(), "test")) {}
~CollRMADistTest() override {
for (DeviceMgr* dm : device_mgrs_) {
delete dm;
}
for (auto it : dev_resolvers_) {
delete it.second;
}
for (FakeWorker* w : workers_) {
delete w;
}
}
void SetUp() override {
const int num_workers = 2;
const int num_devices = 1;
string device_type = "CPU";
string dev0_worker_name;
for (int w = 0; w < num_workers; ++w) {
string name = strings::StrCat("/job:worker/replica:0/task:", w);
if (w == 0) {
dev0_worker_name = name;
}
DefineWorker(name, device_type, num_devices);
}
// All tests simulate requests from worker 0 to worker 1.
rma_.reset(new CollectiveRemoteAccessDistributed(
device_mgrs_[0], dev_resolvers_[dev0_worker_name], work_queue_, &wc_,
kStepId));
const int kNumElts = 8;
expected_value_ = Tensor(DT_FLOAT, {kNumElts});
to_tensor_ = Tensor(DT_FLOAT, {kNumElts});
auto exp_alias = expected_value_.flat<float>();
auto to_alias = to_tensor_.flat<float>();
for (int i = 0; i < kNumElts; ++i) {
exp_alias(i) = i;
to_alias(i) = -1;
}
}
void DefineWorker(const string& worker_name, const string& device_type,
int num_devices) {
std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice(
device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
}
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name];
dv->clear();
for (auto d : dev_mgr->ListDevices()) {
dv->push_back(d->name());
}
DeviceResolverDistributed* dev_res =
new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
dev_resolvers_[worker_name] = dev_res;
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
workers_.push_back(fw);
wc_.AddWorker(worker_name, fw);
}
void RestartWorker(const string& worker_name, const string& device_type,
int num_devices) {
auto it = dev_resolvers_.find(worker_name);
if (it != dev_resolvers_.end()) {
delete it->second;
dev_resolvers_.erase(it);
}
DefineWorker(worker_name, device_type, num_devices);
}
void ValidateResultTensor() {
ASSERT_EQ(expected_value_.NumElements(), to_tensor_.NumElements());
for (int i = 0; i < to_tensor_.NumElements(); ++i) {
EXPECT_FLOAT_EQ(expected_value_.flat<float>()(i),
to_tensor_.flat<float>()(i));
}
}
FakeCache wc_;
CancellationManager cm_;
std::vector<DeviceMgr*> device_mgrs_;
std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
std::unordered_map<string, std::vector<string>> dev_by_task_;
std::shared_ptr<UnboundedWorkQueue> work_queue_;
std::vector<FakeWorker*> workers_;
std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
mutex mu_;
int num_done_ GUARDED_BY(mu_);
condition_variable done_;
Tensor expected_value_;
Tensor to_tensor_;
CallOptions opts_;
DeviceLocality device_locality_;
AllocatorAttributes alloc_attr_;
};
TEST_F(CollRMADistTest, ProdFirstOK) {
Notification consumer_note;
Notification producer_note;
Status consumer_status;
Status producer_status;
FakeWorker* wi = workers_[1];
const string kBufKey = "fake_buf_key";
wi->buf_rendezvous()->ProvideBuf(
kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
AllocatorAttributes(),
[&producer_note, &producer_status](const Status& s) {
producer_status.Update(s);
producer_note.Notify();
});
Device* dst_device = nullptr;
string dev_name = "CPU:0";
TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
DeviceContext* to_device_ctx = nullptr;
rma_->RecvFromPeer(
"/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
"/job:worker/replica:0/task:1", // peer_task
false, // peer_is_local
kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
device_locality_, 0 /*dev_to_dev_stream_index*/,
[&consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
});
consumer_note.WaitForNotification();
TF_EXPECT_OK(consumer_status);
producer_note.WaitForNotification();
TF_EXPECT_OK(producer_status);
ValidateResultTensor();
}
TEST_F(CollRMADistTest, ConsFirstOK) {
Notification consumer_note;
Notification producer_note;
Status consumer_status;
Status producer_status;
FakeWorker* wi = workers_[1];
const string kBufKey = "fake_buf_key";
Device* dst_device = nullptr;
string dev_name = "CPU:0";
TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
DeviceContext* to_device_ctx = nullptr;
rma_->RecvFromPeer(
"/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
"/job:worker/replica:0/task:1", // peer_task
false, // peer_is_local
kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
device_locality_, 0 /*dev_to_dev_stream_index*/,
[&consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
});
wi->buf_rendezvous()->ProvideBuf(
kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
AllocatorAttributes(),
[&producer_note, &producer_status](const Status& s) {
producer_status.Update(s);
producer_note.Notify();
});
consumer_note.WaitForNotification();
TF_EXPECT_OK(consumer_status);
producer_note.WaitForNotification();
TF_EXPECT_OK(producer_status);
ValidateResultTensor();
}
TEST_F(CollRMADistTest, ConsFirstAbort) {
Notification consumer_note;
Status consumer_status;
const string kBufKey = "fake_buf_key";
Device* dst_device = nullptr;
string dev_name = "CPU:0";
TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
DeviceContext* to_device_ctx = nullptr;
rma_->RecvFromPeer(
"/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
"/job:worker/replica:0/task:1", // peer_task
false, // peer_is_local
kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
device_locality_, 0 /*dev_to_dev_stream_index*/,
[&consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
});
rma_->StartAbort(errors::Internal("Deliberate Failure"));
consumer_note.WaitForNotification();
EXPECT_EQ(consumer_status.error_message(), "Cancelled");
}
TEST_F(CollRMADistTest, WorkerRestart) {
Notification consumer_note;
Notification producer_note;
Status consumer_status;
Status producer_status;
FakeWorker* wi = workers_[1];
const string buf_key = "fake_buf_key";
Device* dst_device = nullptr;
string dev_name = "CPU:0";
TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
DeviceContext* to_device_ctx = nullptr;
rma_->RecvFromPeer(
"/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
"/job:worker/replica:0/task:1", // peer_task
false, // peer_is_local
buf_key, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
device_locality_, 0 /*dev_to_dev_stream_index*/,
[&consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
});
wi->buf_rendezvous()->ProvideBuf(
buf_key, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
AllocatorAttributes(),
[&producer_note, &producer_status](const Status& s) {
producer_status.Update(s);
producer_note.Notify();
});
consumer_note.WaitForNotification();
TF_EXPECT_OK(consumer_status);
producer_note.WaitForNotification();
TF_EXPECT_OK(producer_status);
ValidateResultTensor();
// Restart task 1 and check that recv from task 1 to task 0 fails.
RestartWorker("/job:worker/replica:0/task:1", "CPU", 1);
Notification post_restart_note;
rma_->RecvFromPeer(
"/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
"/job:worker/replica:0/task:1", // peer_task
false, // peer_is_local
buf_key, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
device_locality_, 0 /*dev_to_dev_stream_index*/,
[&consumer_status, &post_restart_note](const Status& s) {
consumer_status = s;
post_restart_note.Notify();
});
post_restart_note.WaitForNotification();
EXPECT_TRUE(errors::IsFailedPrecondition(consumer_status));
}
} // namespace
} // namespace tensorflow