| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "absl/debugging/leak_check.h" |
| #include "tensorflow/c/eager/c_api.h" |
| #include "tensorflow/c/eager/c_api_experimental.h" |
| #include "tensorflow/c/eager/c_api_internal.h" |
| #include "tensorflow/c/eager/c_api_test_util.h" |
| #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
| #include "tensorflow/core/common_runtime/eager/eager_operation.h" |
| #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" |
| #include "tensorflow/core/platform/casts.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/platform/protobuf.h" |
| #include "tensorflow/core/platform/test.h" |
| #include "tensorflow/core/protobuf/cluster.pb.h" |
| #include "tensorflow/core/protobuf/tensorflow_server.pb.h" |
| |
| namespace { |
| |
| using ::tensorflow::string; |
| |
| tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { |
| tensorflow::ServerDef server_def; |
| server_def.set_protocol("grpc"); |
| server_def.set_job_name(job_name); |
| server_def.set_task_index(0); |
| tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); |
| tensorflow::JobDef* job_def = cluster_def->add_job(); |
| job_def->set_name(job_name); |
| for (int i = 0; i < num_tasks; i++) { |
| int port = tensorflow::testing::PickUnusedPortOrDie(); |
| job_def->mutable_tasks()->insert( |
| {i, tensorflow::strings::StrCat("localhost:", port)}); |
| } |
| return server_def; |
| } |
| |
| tensorflow::ServerDef GetServerDef(int num_tasks) { |
| return GetServerDef("localhost", num_tasks); |
| } |
| |
| void TestRemoteExecute(bool async) { |
| tensorflow::ServerDef server_def = GetServerDef(2); |
| |
| // This server def has the task index set to 0. |
| string serialized = server_def.SerializeAsString(); |
| |
| server_def.set_task_index(1); |
| |
| std::unique_ptr<tensorflow::GrpcServer> worker_server; |
| ASSERT_TRUE(tensorflow::GrpcServer::Create( |
| server_def, tensorflow::Env::Default(), &worker_server) |
| .ok()); |
| ASSERT_TRUE(worker_server->Start().ok()); |
| |
| TF_Status* status = TF_NewStatus(); |
| TFE_ContextOptions* opts = TFE_NewContextOptions(); |
| TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); |
| TFE_ContextOptionsSetDevicePlacementPolicy(opts, |
| TFE_DEVICE_PLACEMENT_EXPLICIT); |
| TFE_Context* ctx = TFE_NewContext(opts, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteContextOptions(opts); |
| |
| TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); |
| TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx); |
| const char remote_device_name[] = |
| "/job:localhost/replica:0/task:1/device:CPU:0"; |
| auto* h0_task1 = |
| TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| auto* h1_task1 = |
| TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1); |
| TFE_OpSetDevice(matmul, remote_device_name, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TFE_TensorHandle* retvals[1]; |
| int num_retvals = 1; |
| TFE_Execute(matmul, &retvals[0], &num_retvals, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| float product[4] = {0}; |
| EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); |
| memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); |
| TF_DeleteTensor(t); |
| EXPECT_EQ(7, product[0]); |
| EXPECT_EQ(10, product[1]); |
| EXPECT_EQ(15, product[2]); |
| EXPECT_EQ(22, product[3]); |
| |
| TFE_DeleteTensorHandle(h0_task0); |
| TFE_DeleteTensorHandle(h1_task0); |
| TFE_DeleteTensorHandle(h0_task1); |
| TFE_DeleteTensorHandle(h1_task1); |
| TFE_DeleteTensorHandle(retvals[0]); |
| |
| TFE_DeleteOp(matmul); |
| |
| TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); |
| TFE_ExecutorWaitForAllPendingNodes(executor, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteExecutor(executor); |
| TFE_DeleteContext(ctx); |
| |
| TF_DeleteStatus(status); |
| |
| // TODO(b/136478427): Figure out how to correctly shut the server down. |
| worker_server.release(); |
| } |
| |
| TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } |
| TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } |
| |
| string MatMulFunction() { |
| tensorflow::FunctionDef def; |
| CHECK(tensorflow::protobuf::TextFormat::ParseFromString( |
| " signature {" |
| " name: 'MatMulFunction'" |
| " input_arg {" |
| " name: 'a'" |
| " type: DT_FLOAT" |
| " }" |
| " input_arg {" |
| " name: 'b'" |
| " type: DT_FLOAT" |
| " }" |
| " output_arg {" |
| " name: 'm'" |
| " type: DT_FLOAT" |
| " }" |
| " }" |
| " node_def {" |
| " name: 'matmul'" |
| " op: 'MatMul'" |
| " input: 'a'" |
| " input: 'b'" |
| " attr {" |
| " key: 'T'" |
| " value {" |
| " type: DT_FLOAT" |
| " }" |
| " }" |
| " }" |
| " ret {" |
| " key: 'm'" |
| " value: 'matmul:product'" |
| " }", |
| &def)); |
| return def.SerializeAsString(); |
| } |
| |
| void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) { |
| tensorflow::ServerDef server_def = GetServerDef(3); |
| |
| // This server def has the task index set to 0. |
| string serialized = server_def.SerializeAsString(); |
| |
| server_def.set_task_index(1); |
| std::unique_ptr<tensorflow::GrpcServer> worker_server1; |
| ASSERT_TRUE(tensorflow::GrpcServer::Create( |
| server_def, tensorflow::Env::Default(), &worker_server1) |
| .ok()); |
| ASSERT_TRUE(worker_server1->Start().ok()); |
| |
| server_def.set_task_index(2); |
| std::unique_ptr<tensorflow::GrpcServer> worker_server2; |
| ASSERT_TRUE(tensorflow::GrpcServer::Create( |
| server_def, tensorflow::Env::Default(), &worker_server2) |
| .ok()); |
| ASSERT_TRUE(worker_server2->Start().ok()); |
| |
| TF_Status* status = TF_NewStatus(); |
| TFE_ContextOptions* opts = TFE_NewContextOptions(); |
| TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); |
| TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); |
| TFE_Context* ctx = TFE_NewContext(opts, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteContextOptions(opts); |
| |
| TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); |
| TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx); |
| const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; |
| const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; |
| |
| auto* h1_task2 = |
| TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TFE_Op* matmul = nullptr; |
| if (func) { |
| string function_def = MatMulFunction(); |
| TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), |
| status); |
| CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| matmul = TFE_NewOp(ctx, "MatMulFunction", status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_OpAddInput(matmul, h0_task0, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_OpAddInput(matmul, h1_task2, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| } else { |
| // Handles are on task0 (local), and task2, but op is on task1. |
| matmul = MatMulOp(ctx, h0_task0, h1_task2); |
| } |
| if (remote) { |
| TFE_OpSetDevice(matmul, task1_name, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| } else if (!async) { |
| // Set the local device to CPU to easily validate mirroring |
| string cpu_device_name; |
| ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU")); |
| TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| auto remote_arg = |
| tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); |
| // The input handles should never change since they have been mirrored. |
| ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr)); |
| } |
| |
| TFE_TensorHandle* retvals[1]; |
| int num_retvals = 1; |
| TFE_Execute(matmul, &retvals[0], &num_retvals, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| // TODO(gjn): Add support for waiting on async local mirrors |
| if (!remote && !async) { |
| auto remote_arg = |
| tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); |
| // The input handles should never change since they have been mirrored. |
| ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr)); |
| } |
| |
| auto* retval_task0 = TFE_TensorHandleCopyToDevice( |
| retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteTensorHandle(retval_task0); |
| float product[4] = {0}; |
| EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); |
| memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); |
| TF_DeleteTensor(t); |
| EXPECT_EQ(7, product[0]); |
| EXPECT_EQ(10, product[1]); |
| EXPECT_EQ(15, product[2]); |
| EXPECT_EQ(22, product[3]); |
| |
| TFE_DeleteTensorHandle(h0_task0); |
| TFE_DeleteTensorHandle(h1_task0); |
| TFE_DeleteTensorHandle(h1_task2); |
| TFE_DeleteTensorHandle(retvals[0]); |
| |
| TFE_DeleteOp(matmul); |
| |
| TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); |
| TFE_ExecutorWaitForAllPendingNodes(executor, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteExecutor(executor); |
| if (func) { |
| TFE_ContextRemoveFunction(ctx, "MatMulFunction", status); |
| } |
| TFE_DeleteContext(ctx); |
| |
| TF_DeleteStatus(status); |
| |
| // TODO(b/136478427): Figure out how to correctly shut the server down. |
| worker_server1.release(); |
| worker_server2.release(); |
| } |
| |
| TEST(CAPI, RemoteExecuteSilentCopies) { |
| TestRemoteExecuteSilentCopies(false, true, false); |
| } |
| TEST(CAPI, RemoteExecuteSilentCopiesAsync) { |
| TestRemoteExecuteSilentCopies(true, true, false); |
| } |
| TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) { |
| TestRemoteExecuteSilentCopies(true, true, true); |
| } |
| TEST(CAPI, RemoteExecuteSilentCopiesLocal) { |
| TestRemoteExecuteSilentCopies(false, false, false); |
| } |
| TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) { |
| TestRemoteExecuteSilentCopies(true, false, false); |
| } |
| TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) { |
| TestRemoteExecuteSilentCopies(true, false, true); |
| } |
| |
| void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) { |
| tensorflow::ServerDef server_def = GetServerDef(2); |
| |
| // This server def has the task index set to 0. |
| string serialized = server_def.SerializeAsString(); |
| |
| server_def.set_task_index(1); |
| |
| std::unique_ptr<tensorflow::GrpcServer> worker_server; |
| ASSERT_TRUE(tensorflow::GrpcServer::Create( |
| server_def, tensorflow::Env::Default(), &worker_server) |
| .ok()); |
| ASSERT_TRUE(worker_server->Start().ok()); |
| |
| TF_Status* status = TF_NewStatus(); |
| TFE_ContextOptions* opts = TFE_NewContextOptions(); |
| TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); |
| TFE_ContextOptionsSetDevicePlacementPolicy(opts, |
| TFE_DEVICE_PLACEMENT_EXPLICIT); |
| TFE_Context* ctx = TFE_NewContext(opts, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteContextOptions(opts); |
| |
| TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| // Use large matrices so that RPCs don't return before we get a chance |
| // to call TFE_DeleteContext. |
| TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx); |
| TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx); |
| const char remote_device_name[] = |
| "/job:localhost/replica:0/task:1/device:CPU:0"; |
| auto* h0_task1 = |
| TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| auto* h1_task1 = |
| TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1); |
| TFE_OpSetDevice(matmul, remote_device_name, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TFE_TensorHandle* retvals[1]; |
| int num_retvals = 1; |
| TFE_Execute(matmul, &retvals[0], &num_retvals, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TF_DeleteStatus(status); |
| |
| TFE_DeleteTensorHandle(h0_task0); |
| TFE_DeleteTensorHandle(h1_task0); |
| TFE_DeleteTensorHandle(h0_task1); |
| TFE_DeleteTensorHandle(h1_task1); |
| TFE_DeleteTensorHandle(retvals[0]); |
| |
| TFE_DeleteOp(matmul); |
| |
| TFE_DeleteContext(ctx); |
| |
| // TODO(b/136478427): Figure out how to correctly shut the server down. |
| worker_server.release(); |
| } |
| |
| TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) { |
| TestRemoteExecuteDeleteContextWithOutstandingRPC(false); |
| } |
| |
| TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) { |
| TestRemoteExecuteDeleteContextWithOutstandingRPC(true); |
| } |
| |
| void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, |
| const std::vector<float>& expected_values) { |
| std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( |
| TF_NewStatus(), TF_DeleteStatus); |
| TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get()); |
| ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); |
| std::unique_ptr<float[]> actual_values(new float[expected_values.size()]); |
| EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t)); |
| memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t)); |
| TF_DeleteTensor(t); |
| |
| for (int i = 0; i < expected_values.size(); i++) { |
| EXPECT_EQ(expected_values[i], actual_values[i]) |
| << "Mismatch in expected values at (zero-based) index " << i; |
| } |
| } |
| |
| void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, |
| const char* remote_device_name, |
| const char* local_device_name) { |
| TF_Status* status = TF_NewStatus(); |
| TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); |
| |
| TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0); |
| TFE_OpSetDevice(matmul, remote_device_name, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TFE_TensorHandle* retvals[1]; |
| int num_retvals = 1; |
| TFE_Execute(matmul, &retvals[0], &num_retvals, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| auto* retval_task0 = |
| TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22}); |
| |
| TFE_DeleteTensorHandle(retval_task0); |
| TFE_DeleteTensorHandle(h0_task0); |
| TFE_DeleteTensorHandle(retvals[0]); |
| |
| TFE_DeleteOp(matmul); |
| |
| TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); |
| TFE_ExecutorWaitForAllPendingNodes(executor, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteExecutor(executor); |
| TF_DeleteStatus(status); |
| } |
| |
| void TestRemoteExecuteChangeServerDef(bool async) { |
| tensorflow::ServerDef server_def = GetServerDef(2); |
| |
| // This server def has the task index set to 0. |
| string serialized = server_def.SerializeAsString(); |
| |
| server_def.set_task_index(1); |
| |
| std::unique_ptr<tensorflow::GrpcServer> worker_server; |
| ASSERT_TRUE(tensorflow::GrpcServer::Create( |
| server_def, tensorflow::Env::Default(), &worker_server) |
| .ok()); |
| ASSERT_TRUE(worker_server->Start().ok()); |
| |
| TF_Status* status = TF_NewStatus(); |
| TFE_ContextOptions* opts = TFE_NewContextOptions(); |
| TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); |
| TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); |
| TFE_Context* ctx = TFE_NewContext(opts, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteContextOptions(opts); |
| |
| TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| const char remote_device_name[] = |
| "/job:localhost/replica:0/task:1/device:CPU:0"; |
| const char local_device_name[] = |
| "/job:localhost/replica:0/task:0/device:CPU:0"; |
| CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); |
| |
| TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); |
| TFE_ExecutorWaitForAllPendingNodes(executor, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| // TODO(b/136478427): Figure out how to correctly shut the server down. |
| worker_server.release(); |
| |
| // Update the server def with a new set of names (worker instead of |
| // localhost). |
| tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2); |
| serialized = updated_server_def.SerializeAsString(); |
| |
| updated_server_def.set_task_index(1); |
| tensorflow::Status s = tensorflow::GrpcServer::Create( |
| updated_server_def, tensorflow::Env::Default(), &worker_server); |
| ASSERT_TRUE(s.ok()) << s.error_message(); |
| ASSERT_TRUE(worker_server->Start().ok()); |
| |
| TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| // Create a new tensor_handle. |
| TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx); |
| |
| // Check that copying it to the old remote device (named localhost) fails. |
| TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status); |
| EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| // Copying and executing on the new remote device works. |
| const char new_remote_device_name[] = |
| "/job:worker/replica:0/task:1/device:CPU:0"; |
| const char new_local_device_name[] = |
| "/job:worker/replica:0/task:0/device:CPU:0"; |
| |
| auto* h0_task1_new = TFE_TensorHandleCopyToDevice( |
| h0_task0_new, ctx, new_remote_device_name, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TFE_DeleteTensorHandle(h0_task0_new); |
| TFE_DeleteTensorHandle(h0_task1_new); |
| |
| CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name, |
| new_local_device_name); |
| |
| TFE_ExecutorWaitForAllPendingNodes(executor, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteExecutor(executor); |
| |
| TF_DeleteStatus(status); |
| |
| TFE_DeleteContext(ctx); |
| |
| // TODO(b/136478427): Figure out how to correctly shut the server down. |
| worker_server.release(); |
| } |
| |
| TEST(CAPI, RemoteExecuteChangeServerDef) { |
| TestRemoteExecuteChangeServerDef(false); |
| } |
| TEST(CAPI, RemoteExecuteChangeServerDefAsync) { |
| TestRemoteExecuteChangeServerDef(true); |
| } |
| |
| void TestRemoteExecuteUpdateServerDef(bool async) { |
| // TODO(b/136478427): Skip heap checker for leaked gRPC server instances. |
| absl::LeakCheckDisabler disabler; |
| |
| tensorflow::ServerDef server_def = GetServerDef(2); |
| // This server def has the task index set to 0. |
| string serialized = server_def.SerializeAsString(); |
| |
| server_def.set_task_index(1); |
| std::unique_ptr<tensorflow::GrpcServer> worker_server; |
| ASSERT_TRUE(tensorflow::GrpcServer::Create( |
| server_def, tensorflow::Env::Default(), &worker_server) |
| .ok()); |
| ASSERT_TRUE(worker_server->Start().ok()); |
| |
| TF_Status* status = TF_NewStatus(); |
| TFE_ContextOptions* opts = TFE_NewContextOptions(); |
| TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); |
| TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); |
| TFE_Context* ctx = TFE_NewContext(opts, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteContextOptions(opts); |
| |
| TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| const char local_device_name[] = |
| "/job:localhost/replica:0/task:0/device:CPU:0"; |
| const char remote_device_name[] = |
| "/job:localhost/replica:0/task:1/device:CPU:0"; |
| CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); |
| |
| TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(), |
| status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); |
| |
| TFE_DeleteContext(ctx); |
| TF_DeleteStatus(status); |
| |
| // TODO(b/136478427): Figure out how to correctly shut the server down. |
| worker_server.release(); |
| } |
| |
| TEST(CAPI, RemoteExecuteUpdateServerDef) { |
| TestRemoteExecuteUpdateServerDef(false); |
| } |
| |
| TEST(CAPI, RemoteExecuteUpdateServerDefAsync) { |
| TestRemoteExecuteUpdateServerDef(true); |
| } |
| |
| void TestRemoteExecuteUpdateServerDefWithFailures(bool async) { |
| // TODO(b/136478427): Skip heap checker for leaked gRPC server instances. |
| absl::LeakCheckDisabler disabler; |
| // Fail fast on GetStatus requests so we can get errors instead of timeout |
| // when updating cluster with non-exsitent worker |
| setenv("GRPC_FAIL_FAST", "TRUE", 1); |
| |
| tensorflow::ServerDef server_def = GetServerDef(2); |
| // This server def has the task index set to 0. |
| string serialized = server_def.SerializeAsString(); |
| |
| server_def.set_task_index(1); |
| std::unique_ptr<tensorflow::GrpcServer> worker_server; |
| ASSERT_TRUE(tensorflow::GrpcServer::Create( |
| server_def, tensorflow::Env::Default(), &worker_server) |
| .ok()); |
| ASSERT_TRUE(worker_server->Start().ok()); |
| |
| TF_Status* status = TF_NewStatus(); |
| TFE_ContextOptions* opts = TFE_NewContextOptions(); |
| TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); |
| TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); |
| TFE_Context* ctx = TFE_NewContext(opts, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteContextOptions(opts); |
| |
| TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| const char local_device_name[] = |
| "/job:localhost/replica:0/task:0/device:CPU:0"; |
| const char remote_device_name[] = |
| "/job:localhost/replica:0/task:1/device:CPU:0"; |
| CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); |
| |
| // Adding a non-existent remote worker to cluster def. This should cause the |
| // UpdateServerDef call to fail. |
| tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); |
| tensorflow::JobDef* job_def = cluster_def->mutable_job(0); |
| int port = tensorflow::testing::PickUnusedPortOrDie(); |
| job_def->mutable_tasks()->insert( |
| {2, tensorflow::strings::StrCat("localhost:", port)}); |
| string serialized_update = server_def.SerializeAsString(); |
| TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), |
| serialized_update.size(), status); |
| EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| // Even after the prevoiusly failed cluster update, another update and op |
| // execution should work fine as long as the provided server_def is valid. |
| TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(), |
| status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); |
| |
| TFE_DeleteContext(ctx); |
| TF_DeleteStatus(status); |
| |
| // TODO(b/136478427): Figure out how to correctly shut the server down. |
| worker_server.release(); |
| unsetenv("GRPC_FAIL_FAST"); |
| } |
| |
| TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) { |
| TestRemoteExecuteUpdateServerDefWithFailures(false); |
| } |
| |
| TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) { |
| TestRemoteExecuteUpdateServerDefWithFailures(true); |
| } |
| |
| } // namespace |