| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" |
| |
| #include <string.h> |
| |
| #include "absl/types/span.h" |
| #include "tensorflow/c/c_api_internal.h" |
| #include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h" |
| #include "tensorflow/core/common_runtime/eager/tensor_handle.h" |
| #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" |
| #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" |
| #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" |
| #include "tensorflow/core/distributed_runtime/session_mgr.h" |
| #include "tensorflow/core/distributed_runtime/test_utils.h" |
| #include "tensorflow/core/distributed_runtime/worker_env.h" |
| #include "tensorflow/core/framework/attr_value.pb.h" |
| #include "tensorflow/core/lib/core/status_test_util.h" |
| #include "tensorflow/core/lib/random/random.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/platform/macros.h" |
| #include "tensorflow/core/platform/protobuf.h" |
| #include "tensorflow/core/platform/test.h" |
| #include "tensorflow/core/protobuf/eager_service.pb.h" |
| #include "tensorflow/core/protobuf/tensorflow_server.pb.h" |
| |
| namespace tensorflow { |
| namespace eager { |
| namespace { |
| |
| class TestEagerServiceImpl : public EagerServiceImpl { |
| public: |
| explicit TestEagerServiceImpl(const WorkerEnv* env) : EagerServiceImpl(env) {} |
| Status GetEagerContext(const uint64 context_id, EagerContext** ctx) { |
| ServerContext* context = nullptr; |
| TF_RETURN_IF_ERROR(GetServerContext(context_id, &context)); |
| core::ScopedUnref context_unref(context); |
| *ctx = context->Context(); |
| return Status::OK(); |
| } |
| Status GetTensorHandle(const uint64 context_id, |
| const RemoteTensorHandleInternal& remote_handle, |
| tensorflow::TensorHandle** handle) { |
| ServerContext* context = nullptr; |
| TF_RETURN_IF_ERROR(GetServerContext(context_id, &context)); |
| core::ScopedUnref context_unref(context); |
| |
| return context->Context()->RemoteMgr()->GetTensorHandle(remote_handle, |
| handle); |
| } |
| }; |
| |
| class FakeEagerClient : public EagerClient { |
| public: |
| FakeEagerClient() {} |
| ~FakeEagerClient() override {} |
| |
| void SetServiceImpl(TestEagerServiceImpl* impl) { impl_ = impl; } |
| |
| #define CLIENT_METHOD(method) \ |
| void method##Async(const method##Request* request, \ |
| method##Response* response, StatusCallback done) \ |
| override { \ |
| done(impl_->method(request, response)); \ |
| } |
| |
| CLIENT_METHOD(CreateContext); |
| CLIENT_METHOD(UpdateContext); |
| CLIENT_METHOD(Enqueue); |
| CLIENT_METHOD(WaitQueueDone); |
| CLIENT_METHOD(KeepAlive); |
| CLIENT_METHOD(CloseContext); |
| #undef CLIENT_METHOD |
| |
| void StreamingEnqueueAsync(const EnqueueRequest* request, |
| EnqueueResponse* response, |
| StatusCallback done) override { |
| done(impl_->Enqueue(request, response)); |
| } |
| |
| private: |
| TestEagerServiceImpl* impl_; |
| }; |
| |
| class DummyEagerClientCache : public EagerClientCache { |
| public: |
| DummyEagerClientCache() : client_(new FakeEagerClient) {} |
| Status GetClient(const string& target, EagerClient** client) override { |
| *client = client_.get(); |
| return Status::OK(); |
| } |
| |
| private: |
| std::unique_ptr<EagerClient> client_; |
| }; |
| |
| class FakeCache : public TestWorkerCache { |
| Status GetEagerClientCache( |
| std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override { |
| eager_client_cache->reset(new DummyEagerClientCache); |
| return Status::OK(); |
| } |
| |
| void ListWorkers(std::vector<string>* workers) const override { |
| workers->push_back("/job:localhost/replica:0/task:0"); |
| } |
| }; |
| |
| class EagerServiceImplTest : public ::testing::Test { |
| public: |
| EagerServiceImplTest() |
| : rendezvous_mgr_(&worker_env_), |
| session_mgr_(new SessionMgr( |
| &worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0", |
| std::unique_ptr<WorkerCacheInterface>(new FakeCache), |
| [](const ServerDef& server_def, |
| WorkerCacheInterface** worker_cache) { |
| *worker_cache = new FakeCache; |
| return Status::OK(); |
| })) { |
| worker_env_.env = Env::Default(); |
| |
| worker_env_.rendezvous_mgr = &rendezvous_mgr_; |
| worker_env_.session_mgr = session_mgr_.get(); |
| |
| device_mgr_ = absl::make_unique<StaticDeviceMgr>( |
| DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0")); |
| worker_env_.local_devices = device_mgr_->ListDevices(); |
| worker_env_.device_mgr = device_mgr_.get(); |
| } |
| |
| protected: |
| WorkerEnv worker_env_; |
| tensorflow::RpcRendezvousMgr rendezvous_mgr_; |
| std::unique_ptr<SessionMgr> session_mgr_; |
| std::unique_ptr<DeviceMgr> device_mgr_; |
| }; |
| |
| void SetTensorProto(TensorProto* tensor_proto) { |
| int64_t dims[] = {2, 2}; |
| float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; |
| TF_Tensor* t = TF_AllocateTensor( |
| TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); |
| memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); |
| tensorflow::Tensor tensor; |
| TF_ASSERT_OK(tensorflow::TF_TensorToTensor(t, &tensor)); |
| tensor.AsProtoTensorContent(tensor_proto); |
| TF_DeleteTensor(t); |
| } |
| |
| void AddOperationToEnqueueRequest( |
| int64 id, const string& name, |
| const std::vector<std::pair<int64, int32>>& inputs, |
| const std::unordered_map<string, AttrValue>& attrs, const string& device, |
| EnqueueRequest* request) { |
| auto* operation = request->add_queue()->mutable_operation(); |
| |
| operation->set_id(id); |
| operation->set_name(name); |
| operation->set_device(device); |
| |
| for (const auto& tensor_handle_pair : inputs) { |
| auto* input = operation->add_inputs(); |
| input->set_op_id(tensor_handle_pair.first); |
| input->set_output_num(tensor_handle_pair.second); |
| input->set_op_device(device); |
| input->set_device(device); |
| } |
| |
| for (const auto& attr_entry : attrs) { |
| (*operation->mutable_attrs())[attr_entry.first] = attr_entry.second; |
| } |
| } |
| |
| tensorflow::FunctionDef MatMulFunction() { |
| tensorflow::FunctionDef def; |
| CHECK(tensorflow::protobuf::TextFormat::ParseFromString( |
| " signature {" |
| " name: 'MatMulFunction'" |
| " input_arg {" |
| " name: 'a'" |
| " type: DT_FLOAT" |
| " }" |
| " output_arg {" |
| " name: 'm'" |
| " type: DT_FLOAT" |
| " }" |
| " }" |
| " node_def {" |
| " name: 'matmul'" |
| " op: 'MatMul'" |
| " input: 'a'" |
| " input: 'a'" |
| " attr {" |
| " key: 'T'" |
| " value {" |
| " type: DT_FLOAT" |
| " }" |
| " }" |
| " }" |
| " ret {" |
| " key: 'm'" |
| " value: 'matmul:product'" |
| " }", |
| &def)); |
| return def; |
| } |
| |
| // Test creates a context and attempts to execute some ops. |
| TEST_F(EagerServiceImplTest, BasicTest) { |
| TestEagerServiceImpl eager_service_impl(&worker_env_); |
| |
| uint64 context_id = random::New64(); |
| |
| CreateContextRequest request; |
| request.mutable_server_def()->set_job_name("localhost"); |
| request.mutable_server_def()->set_task_index(0); |
| request.set_context_id(context_id); |
| CreateContextResponse response; |
| |
| TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); |
| |
| EnqueueRequest remote_enqueue_request; |
| remote_enqueue_request.set_context_id(context_id); |
| EnqueueResponse remote_enqueue_response; |
| |
| std::unordered_map<string, AttrValue> const_attrs; |
| AttrValue val; |
| val.set_type(tensorflow::DataType::DT_FLOAT); |
| const_attrs.insert({"dtype", val}); |
| val.Clear(); |
| SetTensorProto(val.mutable_tensor()); |
| const_attrs.insert({"value", val}); |
| |
| AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, |
| "/job:localhost/replica:0/task:0/device:CPU:0", |
| &remote_enqueue_request); |
| |
| std::unordered_map<string, AttrValue> attrs; |
| val.Clear(); |
| val.set_type(tensorflow::DataType::DT_FLOAT); |
| attrs.insert({"T", val}); |
| val.Clear(); |
| val.set_b(false); |
| attrs.insert({"transpose_a", val}); |
| attrs.insert({"transpose_b", val}); |
| |
| AddOperationToEnqueueRequest(2, "MatMul", {{1, 0}, {1, 0}}, attrs, |
| "/job:localhost/replica:0/task:0/device:CPU:0", |
| &remote_enqueue_request); |
| |
| TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request, |
| &remote_enqueue_response)); |
| |
| auto& matmul_result_shape = |
| remote_enqueue_response.queue_response(1).shape(0); |
| EXPECT_EQ(matmul_result_shape.dim(0).size(), 2); |
| EXPECT_EQ(matmul_result_shape.dim(1).size(), 2); |
| |
| tensorflow::TensorHandle* tensor_handle; |
| TF_ASSERT_OK(eager_service_impl.GetTensorHandle( |
| context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle)); |
| |
| // This should be OK to do since we've placed all computation on the CPU |
| // device. |
| const tensorflow::Tensor* t = nullptr; |
| TF_ASSERT_OK(tensor_handle->Tensor(&t)); |
| |
| auto actual = t->flat<float>(); |
| |
| EXPECT_EQ(4, actual.size()); |
| |
| EXPECT_EQ(7, actual(0)); |
| EXPECT_EQ(10, actual(1)); |
| EXPECT_EQ(15, actual(2)); |
| EXPECT_EQ(22, actual(3)); |
| |
| CloseContextRequest close_context_request; |
| close_context_request.set_context_id(context_id); |
| CloseContextResponse close_context_response; |
| TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request, |
| &close_context_response)); |
| } |
| |
| // Test creates a context and attempts to execute a function. |
| TEST_F(EagerServiceImplTest, BasicFunctionTest) { |
| TestEagerServiceImpl eager_service_impl(&worker_env_); |
| |
| uint64 context_id = random::New64(); |
| |
| CreateContextRequest request; |
| request.mutable_server_def()->set_job_name("localhost"); |
| request.mutable_server_def()->set_task_index(0); |
| request.set_context_id(context_id); |
| CreateContextResponse response; |
| |
| TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); |
| |
| EnqueueRequest enqueue_request; |
| enqueue_request.set_context_id(context_id); |
| RegisterFunctionOp* register_function = |
| enqueue_request.add_queue()->mutable_register_function(); |
| *register_function->mutable_function_def() = MatMulFunction(); |
| EnqueueResponse enqueue_response; |
| |
| TF_ASSERT_OK(eager_service_impl.Enqueue(&enqueue_request, &enqueue_response)); |
| |
| EnqueueRequest remote_enqueue_request; |
| remote_enqueue_request.set_context_id(context_id); |
| EnqueueResponse remote_enqueue_response; |
| |
| std::unordered_map<string, AttrValue> const_attrs; |
| AttrValue val; |
| val.set_type(tensorflow::DataType::DT_FLOAT); |
| const_attrs.insert({"dtype", val}); |
| val.Clear(); |
| |
| SetTensorProto(val.mutable_tensor()); |
| const_attrs.insert({"value", val}); |
| |
| AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, |
| "/job:localhost/replica:0/task:0/device:CPU:0", |
| &remote_enqueue_request); |
| AddOperationToEnqueueRequest( |
| 2, "MatMulFunction", {{1, 0}}, std::unordered_map<string, AttrValue>(), |
| "/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request); |
| |
| TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request, |
| &remote_enqueue_response)); |
| |
| const tensorflow::Tensor* t = nullptr; |
| tensorflow::TensorHandle* tensor_handle; |
| TF_ASSERT_OK(eager_service_impl.GetTensorHandle( |
| context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle)); |
| TF_ASSERT_OK(tensor_handle->Tensor(&t)); |
| |
| auto actual = t->flat<float>(); |
| EXPECT_EQ(4, actual.size()); |
| |
| EXPECT_EQ(7, actual(0)); |
| EXPECT_EQ(10, actual(1)); |
| EXPECT_EQ(15, actual(2)); |
| EXPECT_EQ(22, actual(3)); |
| |
| CloseContextRequest close_context_request; |
| close_context_request.set_context_id(context_id); |
| CloseContextResponse close_context_response; |
| TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request, |
| &close_context_response)); |
| } |
| |
| // Test executes a remote function through |
| // EagerProcessFunctionLibraryRuntime(EagerClusterFunctionLibraryRuntime). |
| TEST_F(EagerServiceImplTest, EagerPFLRTest) { |
| TestEagerServiceImpl eager_service_impl(&worker_env_); |
| |
| uint64 context_id = random::New64(); |
| |
| CreateContextRequest request; |
| request.mutable_server_def()->set_job_name("localhost"); |
| request.mutable_server_def()->set_task_index(0); |
| request.set_context_id(context_id); |
| CreateContextResponse response; |
| TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); |
| |
| const string local_device = "/job:localhost/replica:0/task:0/device:CPU:0"; |
| const string remote_device = "/job:localhost/replica:0/task:1/device:CPU:0"; |
| |
| // Make the fake EagerClient use the local eager_service_impl. |
| EagerContext* ctx = nullptr; |
| TF_ASSERT_OK(eager_service_impl.GetEagerContext(context_id, &ctx)); |
| |
| Device* device; |
| TF_ASSERT_OK(ctx->FindDeviceFromName(local_device.c_str(), &device)); |
| EagerClient* client; |
| TF_ASSERT_OK(ctx->GetClient(device, &client)); |
| FakeEagerClient* fake_client = static_cast<FakeEagerClient*>(client); |
| fake_client->SetServiceImpl(&eager_service_impl); |
| |
| auto eager_cluster_flr = |
| absl::make_unique<EagerClusterFunctionLibraryRuntime>(ctx, |
| device_mgr_.get()); |
| |
| FunctionLibraryDefinition func_lib_def{OpRegistry::Global(), {}}; |
| auto device_mgr = absl::make_unique<StaticDeviceMgr>( |
| DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1")); |
| auto eager_pflr = absl::make_unique<EagerProcessFunctionLibraryRuntime>( |
| device_mgr.get(), Env::Default(), /*config=*/nullptr, |
| TF_GRAPH_DEF_VERSION, &func_lib_def, OptimizerOptions(), nullptr, |
| eager_cluster_flr.get(), nullptr); |
| |
| tensorflow::FunctionDef fdef = MatMulFunction(); |
| TF_ASSERT_OK(func_lib_def.AddFunctionDef(fdef)); |
| |
| // Create an input on local_device for MatMulFunction. |
| EnqueueRequest remote_enqueue_request; |
| remote_enqueue_request.set_context_id(context_id); |
| EnqueueResponse remote_enqueue_response; |
| std::unordered_map<string, AttrValue> const_attrs; |
| AttrValue val; |
| val.set_type(tensorflow::DataType::DT_FLOAT); |
| const_attrs.insert({"dtype", val}); |
| val.Clear(); |
| SetTensorProto(val.mutable_tensor()); |
| const_attrs.insert({"value", val}); |
| AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, local_device, |
| &remote_enqueue_request); |
| TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request, |
| &remote_enqueue_response)); |
| |
| // Instantiate MatMulFunction on remote_device. |
| FunctionLibraryRuntime::InstantiateOptions options; |
| options.target = remote_device; |
| options.is_multi_device_function = true; |
| options.input_devices.push_back(local_device); |
| FunctionLibraryRuntime::Handle handle; |
| TF_ASSERT_OK(eager_pflr->Instantiate( |
| fdef.signature().name(), AttrSlice(&fdef.attr()), options, &handle)); |
| |
| // Run MatMulFunction on remote_device. |
| FunctionLibraryRuntime::Options opts; |
| const int64 step_id = opts.step_id; |
| opts.op_id = 2; |
| Notification done; |
| Status status; |
| RemoteTensorHandle input; |
| input.set_op_id(1); |
| input.set_output_num(0); |
| input.set_op_device(local_device); |
| input.set_device(local_device); |
| VariantFunctionArg arg(&input); |
| std::vector<Tensor> outputs; |
| eager_pflr->Run(opts, handle, {arg}, &outputs, |
| [&status, &done](const Status& s) { |
| status = s; |
| done.Notify(); |
| }); |
| done.WaitForNotification(); |
| TF_ASSERT_OK(status); |
| |
| const tensorflow::Tensor* t = nullptr; |
| tensorflow::TensorHandle* tensor_handle; |
| TF_ASSERT_OK(eager_service_impl.GetTensorHandle( |
| context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle)); |
| TF_ASSERT_OK(tensor_handle->Tensor(&t)); |
| auto actual = t->flat<float>(); |
| EXPECT_EQ(4, actual.size()); |
| EXPECT_EQ(7, actual(0)); |
| EXPECT_EQ(10, actual(1)); |
| EXPECT_EQ(15, actual(2)); |
| EXPECT_EQ(22, actual(3)); |
| |
| Status cleanup_status; |
| bool callback_is_called = false; |
| eager_cluster_flr->CleanUp( |
| step_id, handle, [&cleanup_status, &callback_is_called](const Status& s) { |
| callback_is_called = true; |
| cleanup_status.Update(s); |
| }); |
| EXPECT_TRUE(callback_is_called); |
| TF_ASSERT_OK(cleanup_status); |
| |
| CloseContextRequest close_context_request; |
| close_context_request.set_context_id(context_id); |
| CloseContextResponse close_context_response; |
| TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request, |
| &close_context_response)); |
| } |
| |
| // Test creates a context and attempts to send a tensor (using the RPC), and |
| // then use the tensor. |
| TEST_F(EagerServiceImplTest, SendTensorTest) { |
| TestEagerServiceImpl eager_service_impl(&worker_env_); |
| |
| uint64 context_id = random::New64(); |
| |
| CreateContextRequest request; |
| request.mutable_server_def()->set_job_name("localhost"); |
| request.mutable_server_def()->set_task_index(0); |
| request.set_context_id(context_id); |
| CreateContextResponse response; |
| |
| TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); |
| |
| EnqueueRequest remote_enqueue_request; |
| remote_enqueue_request.set_context_id(context_id); |
| EnqueueResponse remote_enqueue_response; |
| |
| auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor(); |
| send_tensor->set_op_id(1); |
| SetTensorProto(send_tensor->add_tensors()); |
| |
| std::unordered_map<string, AttrValue> attrs; |
| AttrValue val; |
| val.Clear(); |
| val.set_type(tensorflow::DataType::DT_FLOAT); |
| attrs.insert({"T", val}); |
| val.Clear(); |
| val.set_b(false); |
| attrs.insert({"transpose_a", val}); |
| attrs.insert({"transpose_b", val}); |
| |
| AddOperationToEnqueueRequest(2, "MatMul", {{1, 0}, {1, 0}}, attrs, |
| "/job:localhost/replica:0/task:0/device:CPU:0", |
| &remote_enqueue_request); |
| |
| TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request, |
| &remote_enqueue_response)); |
| |
| const tensorflow::Tensor* t = nullptr; |
| tensorflow::TensorHandle* tensor_handle; |
| TF_ASSERT_OK(eager_service_impl.GetTensorHandle( |
| context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle)); |
| TF_ASSERT_OK(tensor_handle->Tensor(&t)); |
| |
| Device* device = tensor_handle->device(); |
| EXPECT_EQ(device, nullptr); |
| |
| auto actual = t->flat<float>(); |
| EXPECT_EQ(4, actual.size()); |
| |
| EXPECT_EQ(7, actual(0)); |
| EXPECT_EQ(10, actual(1)); |
| EXPECT_EQ(15, actual(2)); |
| EXPECT_EQ(22, actual(3)); |
| |
| CloseContextRequest close_context_request; |
| close_context_request.set_context_id(context_id); |
| CloseContextResponse close_context_response; |
| TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request, |
| &close_context_response)); |
| } |
| |
| // Test requests sent to the eager service on master. |
| TEST_F(EagerServiceImplTest, RequestsToMasterTest) { |
| tensorflow::Rendezvous* rendezvous = |
| new tensorflow::IntraProcessRendezvous(device_mgr_.get()); |
| // Create a master eager context. |
| tensorflow::EagerContext* ctx = new tensorflow::EagerContext( |
| SessionOptions(), |
| tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, |
| tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, |
| device_mgr_.get(), false, rendezvous, GetDefaultCustomKernelCreator(), |
| nullptr); |
| const uint64 context_id = random::New64(); |
| |
| // Set RemoteMgr to ctx. |
| auto remote_mgr = |
| absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/true, ctx); |
| TF_ASSERT_OK(ctx->InitializeRemoteWorker(nullptr, nullptr, {}, context_id, |
| nullptr, std::move(remote_mgr))); |
| |
| TestEagerServiceImpl eager_service_impl(&worker_env_); |
| |
| EnqueueRequest remote_enqueue_request; |
| remote_enqueue_request.set_context_id(context_id); |
| EnqueueResponse remote_enqueue_response; |
| |
| auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor(); |
| send_tensor->set_op_id(1); |
| SetTensorProto(send_tensor->add_tensors()); |
| |
| // Unable to handle the request since there is no eager context. |
| Status status = eager_service_impl.Enqueue(&remote_enqueue_request, |
| &remote_enqueue_response); |
| EXPECT_EQ(error::INVALID_ARGUMENT, status.code()); |
| EXPECT_TRUE(absl::StrContains( |
| status.error_message(), |
| "Unable to find a context_id matching the specified one")); |
| |
| // The request can be handled after adding the master eager context to |
| // service. |
| TF_ASSERT_OK(eager_service_impl.CreateMasterContext(context_id, ctx)); |
| TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request, |
| &remote_enqueue_response)); |
| ctx->Unref(); |
| } |
| |
| TEST_F(EagerServiceImplTest, KeepAliveTest) { |
| TestEagerServiceImpl eager_service_impl(&worker_env_); |
| |
| uint64 context_id = random::New64(); |
| CreateContextRequest request; |
| request.mutable_server_def()->set_job_name("localhost"); |
| request.mutable_server_def()->set_task_index(0); |
| request.set_context_id(context_id); |
| request.set_keep_alive_secs(3); |
| CreateContextResponse response; |
| |
| TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); |
| |
| worker_env_.env->SleepForMicroseconds(5 * |
| tensorflow::EnvTime::kSecondsToMicros); |
| |
| KeepAliveRequest keep_alive_request; |
| KeepAliveResponse keep_alive_response; |
| |
| keep_alive_request.set_context_id(context_id); |
| |
| Status status = |
| eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response); |
| |
| EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); |
| EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id", |
| status.error_message()); |
| |
| uint64 new_context_id = random::New64(); |
| // Create a new context. |
| request.set_context_id(new_context_id); |
| TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); |
| |
| // The context should not be GC'd. |
| worker_env_.env->SleepForMicroseconds(1 * |
| tensorflow::EnvTime::kSecondsToMicros); |
| |
| keep_alive_request.set_context_id(new_context_id); |
| |
| TF_ASSERT_OK( |
| eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response)); |
| } |
| |
| } // namespace |
| } // namespace eager |
| } // namespace tensorflow |