| /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" |
| |
| #include "tensorflow/core/common_runtime/process_util.h" |
| #include "tensorflow/core/framework/control_flow.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/notification.h" |
| #include "tensorflow/core/lib/core/status_test_util.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/platform/test.h" |
| |
| namespace tensorflow { |
| |
| // string -> Tensor<string> |
| Tensor V(const string& content) { |
| Tensor tensor(DT_STRING, TensorShape({})); |
| tensor.scalar<string>()() = content; |
| return tensor; |
| } |
| |
| // Tensor<string> -> string |
| string V(const Tensor& tensor) { |
| CHECK_EQ(tensor.dtype(), DT_STRING); |
| CHECK(TensorShapeUtils::IsScalar(tensor.shape())); |
| return tensor.scalar<string>()(); |
| } |
| |
| Rendezvous::ParsedKey MakeKey(const string& s) { |
| Rendezvous::ParsedKey key; |
| CHECK(Rendezvous::ParseKey(s, &key).ok()); |
| return key; |
| } |
| |
| namespace { |
| // Fake cache implementation for WorkerEnv. |
| class DummyWorkerCache : public WorkerCacheInterface { |
| void ListWorkers(std::vector<string>* workers) const override {} |
| void ListWorkersInJob(const string& job_name, |
| std::vector<string>* workers) const override {} |
| WorkerInterface* GetOrCreateWorker(const string& target) override { |
| return nullptr; |
| } |
| Status GetEagerClientCache( |
| std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override { |
| return errors::Unimplemented("Unimplemented."); |
| } |
| bool GetDeviceLocalityNonBlocking(const string& device, |
| DeviceLocality* locality) override { |
| return false; |
| } |
| void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, |
| StatusCallback done) override {} |
| }; |
| } // namespace |
| |
| class RpcRendezvousMgrTest : public ::testing::Test { |
| protected: |
| RpcRendezvousMgrTest() |
| : cache_(new DummyWorkerCache), |
| worker_session_("rpc_session", "/job:mnist/replica:1/task:2", |
| std::unique_ptr<WorkerCacheInterface>(cache_), |
| std::unique_ptr<DeviceMgr>(), |
| std::unique_ptr<GraphMgr>(), nullptr), |
| rmgr_(&env) { |
| env.env = Env::Default(); |
| } |
| |
| DummyWorkerCache* cache_; // Managed by worker_session. |
| WorkerEnv env; |
| |
| WorkerSession worker_session_; |
| RpcRendezvousMgr rmgr_; |
| }; |
| |
| TEST_F(RpcRendezvousMgrTest, LocalSendRecv) { |
| const int64 step_id = 123; |
| const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( |
| "/job:mnist/replica:1/task:2/cpu:0", 7890, |
| "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); |
| { |
| RemoteRendezvous* rendez = rmgr_.Find(step_id); |
| TF_ASSERT_OK(rendez->Initialize(&worker_session_)); |
| core::ScopedUnref unref(rendez); |
| Rendezvous::Args args; |
| TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); |
| } |
| { |
| Tensor val(DT_FLOAT); |
| bool val_dead = false; |
| TF_ASSERT_OK(rmgr_.RecvLocal(step_id, key, &val, &val_dead)); |
| EXPECT_EQ(V(val), "peach"); |
| } |
| rmgr_.Cleanup(step_id); |
| } |
| |
| TEST_F(RpcRendezvousMgrTest, LocalAbort) { |
| const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( |
| "/job:mnist/replica:1/task:2/cpu:0", 7890, |
| "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); |
| { // Explicit Abort(). |
| const int64 step_id = 123; |
| RemoteRendezvous* rendez = rmgr_.Find(step_id); |
| core::ScopedUnref unref(rendez); |
| SchedClosure([this, rendez]() { |
| env.env->SleepForMicroseconds(100 * 1000); |
| rendez->StartAbort(errors::Aborted("")); |
| }); |
| Tensor val(DT_STRING); |
| bool val_dead = false; |
| Rendezvous::Args args; |
| TF_ASSERT_OK(rendez->Initialize(&worker_session_)); |
| EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); |
| } |
| { // Cleanup causes Abort(). |
| const int64 step_id = 321; |
| RemoteRendezvous* rendez = rmgr_.Find(step_id); |
| core::ScopedUnref unref(rendez); |
| SchedClosure([this, step_id]() { |
| env.env->SleepForMicroseconds(100 * 1000); |
| rmgr_.Cleanup(step_id); |
| }); |
| Tensor val(DT_STRING); |
| bool val_dead = false; |
| Rendezvous::Args args; |
| TF_ASSERT_OK(rendez->Initialize(&worker_session_)); |
| EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); |
| } |
| } |
| |
| TEST_F(RpcRendezvousMgrTest, CleanupAll) { |
| const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( |
| "/job:mnist/replica:1/task:2/cpu:0", 7890, |
| "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); |
| { |
| const int64 step_id = 123; |
| RemoteRendezvous* rendez = rmgr_.Find(step_id); |
| TF_ASSERT_OK(rendez->Initialize(&worker_session_)); |
| core::ScopedUnref unref(rendez); |
| Rendezvous::Args args; |
| TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); |
| rmgr_.CleanupAll(); |
| Tensor val(DT_STRING); |
| bool val_dead = false; |
| EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); |
| } |
| } |
| |
| class DummyDeviceContext : public DeviceContext { |
| public: |
| explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {} |
| ~DummyDeviceContext() override {} |
| int stream_id() const { return stream_id_; } |
| |
| private: |
| const int stream_id_; |
| }; |
| |
| TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) { |
| DummyDeviceContext* dc = new DummyDeviceContext(123); |
| |
| const int64 step_id = 123; |
| const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( |
| "/job:mnist/replica:1/task:2/cpu:0", 7890, |
| "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); |
| { |
| RemoteRendezvous* rendez = rmgr_.Find(step_id); |
| core::ScopedUnref unref(rendez); |
| Rendezvous::Args args; |
| args.device_context = dc; |
| TF_ASSERT_OK(rendez->Initialize(&worker_session_)); |
| TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); |
| } |
| { |
| Notification n; |
| rmgr_.RecvLocalAsync( |
| step_id, key, |
| [&n](const Status& s, const Rendezvous::Args send_args, |
| const Rendezvous::Args recv_args, const Tensor& val, |
| bool is_dead) { |
| auto send_dev_context = |
| static_cast<DummyDeviceContext*>(send_args.device_context); |
| CHECK_EQ(123, send_dev_context->stream_id()); |
| CHECK_EQ(V(val), "peach"); |
| n.Notify(); |
| }); |
| n.WaitForNotification(); |
| } |
| rmgr_.Cleanup(step_id); |
| dc->Unref(); |
| } |
| |
| // NOTE: Remote Send/Recv is better tested in worker_test.cc |
| |
| } // namespace tensorflow |