| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #include "tensorflow/core/common_runtime/rendezvous_util.h" |
| |
| #include "tensorflow/core/lib/core/notification.h" |
| #include "tensorflow/core/lib/core/status_test_util.h" |
| #include "tensorflow/core/platform/test.h" |
| |
| namespace tensorflow { |
| namespace { |
| |
| class RendezvousUtilTest : public ::testing::Test { |
| public: |
| RendezvousUtilTest() { rendez_ = NewLocalRendezvous(); } |
| |
| ~RendezvousUtilTest() override { rendez_->Unref(); } |
| |
| Rendezvous* rendez_; |
| }; |
| |
| // string -> Tensor<string> |
| Tensor V(const string& content) { |
| Tensor tensor(DT_STRING, TensorShape({})); |
| tensor.scalar<string>()() = content; |
| return tensor; |
| } |
| |
| // Tensor<string> -> string |
| string V(const Tensor& tensor) { |
| CHECK_EQ(tensor.dtype(), DT_STRING); |
| CHECK(TensorShapeUtils::IsScalar(tensor.shape())); |
| return tensor.scalar<string>()(); |
| } |
| |
| string MakeStringKey(const string& name) { |
| return Rendezvous::CreateKey( |
| "/job:localhost/replica:0/task:0/device:CPU:0", 0, |
| "/job:localhost/replica:0/task:0/device:GPU:0", name, FrameAndIter(0, 0)); |
| } |
| |
| TEST_F(RendezvousUtilTest, SendBeforeRecv) { |
| // Fire off sends before receive the tensors. |
| TF_ASSERT_OK(SendTensorsToRendezvous( |
| rendez_, nullptr, {}, {MakeStringKey("hello1"), MakeStringKey("hello2")}, |
| {V("hello1"), V("hello2")})); |
| |
| Notification n; |
| std::vector<Tensor> received_keys; |
| RecvOutputsFromRendezvousAsync( |
| rendez_, nullptr, {}, {MakeStringKey("hello1"), MakeStringKey("hello2")}, |
| &received_keys, [&n](const Status& status) { n.Notify(); }); |
| n.WaitForNotification(); |
| |
| EXPECT_EQ(2, received_keys.size()); |
| EXPECT_EQ("hello1", V(received_keys[0])); |
| EXPECT_EQ("hello2", V(received_keys[1])); |
| } |
| |
| TEST_F(RendezvousUtilTest, RecvBeforeSend) { |
| // Fire off recvs, wait for a notification in the callback. |
| Notification n; |
| std::vector<Tensor> received_keys; |
| RecvOutputsFromRendezvousAsync( |
| rendez_, nullptr, {}, {MakeStringKey("hello1"), MakeStringKey("hello2")}, |
| &received_keys, [&n](const Status& status) { n.Notify(); }); |
| |
| TF_ASSERT_OK(SendTensorsToRendezvous( |
| rendez_, nullptr, {}, {MakeStringKey("hello1"), MakeStringKey("hello2")}, |
| {V("hello1"), V("hello2")})); |
| |
| n.WaitForNotification(); |
| |
| EXPECT_EQ(2, received_keys.size()); |
| EXPECT_EQ("hello1", V(received_keys[0])); |
| EXPECT_EQ("hello2", V(received_keys[1])); |
| } |
| |
| } // namespace |
| } // namespace tensorflow |