| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "absl/strings/str_cat.h" |
| #include "tensorflow/c/eager/c_api.h" |
| #include "tensorflow/c/eager/c_api_experimental.h" |
| #include "tensorflow/c/eager/c_api_internal.h" |
| #include "tensorflow/c/eager/c_api_remote_test_util.h" |
| #include "tensorflow/c/eager/c_api_test_util.h" |
| #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
| #include "tensorflow/core/common_runtime/eager/eager_operation.h" |
| #include "tensorflow/core/common_runtime/function_optimization_registry.h" |
| #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/graph/graph.h" |
| #include "tensorflow/core/platform/casts.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/protobuf.h" |
| #include "tensorflow/core/platform/test.h" |
| #include "tensorflow/core/protobuf/cluster.pb.h" |
| #include "tensorflow/core/protobuf/config.pb.h" |
| #include "tensorflow/core/protobuf/tensorflow_server.pb.h" |
| |
| namespace { |
| |
| using ::tensorflow::string; |
| |
| void TestRemoteExecute(bool async) { |
| tensorflow::ServerDef server_def = GetServerDef(2); |
| |
| // This server def has the task index set to 0. |
| string serialized = server_def.SerializeAsString(); |
| |
| server_def.set_task_index(1); |
| |
| std::unique_ptr<tensorflow::GrpcServer> worker_server; |
| ASSERT_TRUE(tensorflow::GrpcServer::Create( |
| server_def, tensorflow::Env::Default(), &worker_server) |
| .ok()); |
| ASSERT_TRUE(worker_server->Start().ok()); |
| |
| TF_Status* status = TF_NewStatus(); |
| TFE_ContextOptions* opts = TFE_NewContextOptions(); |
| TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); |
| TFE_ContextOptionsSetDevicePlacementPolicy(opts, |
| TFE_DEVICE_PLACEMENT_EXPLICIT); |
| TFE_Context* ctx = TFE_NewContext(opts, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteContextOptions(opts); |
| |
| TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); |
| TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx); |
| const char remote_device_name[] = |
| "/job:localhost/replica:0/task:1/device:CPU:0"; |
| auto* h0_task1 = |
| TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| auto* h1_task1 = |
| TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1); |
| TFE_OpSetDevice(matmul, remote_device_name, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TFE_TensorHandle* retvals[1]; |
| int num_retvals = 1; |
| TFE_Execute(matmul, &retvals[0], &num_retvals, status); |
| EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| |
| TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| float product[4] = {0}; |
| EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); |
| memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); |
| TF_DeleteTensor(t); |
| EXPECT_EQ(7, product[0]); |
| EXPECT_EQ(10, product[1]); |
| EXPECT_EQ(15, product[2]); |
| EXPECT_EQ(22, product[3]); |
| |
| TFE_DeleteTensorHandle(h0_task0); |
| TFE_DeleteTensorHandle(h1_task0); |
| TFE_DeleteTensorHandle(h0_task1); |
| TFE_DeleteTensorHandle(h1_task1); |
| TFE_DeleteTensorHandle(retvals[0]); |
| |
| TFE_DeleteOp(matmul); |
| |
| TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); |
| TFE_ExecutorWaitForAllPendingNodes(executor, status); |
| ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); |
| TFE_DeleteExecutor(executor); |
| TFE_DeleteContext(ctx); |
| |
| TF_DeleteStatus(status); |
| |
| // TODO(b/136478427): Figure out how to correctly shut the server down. |
| worker_server.release(); |
| } |
| |
| TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } |
| TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } |
| |
| void TestRemoteExecuteSilentCopiesOp(bool async, bool remote, |
| bool remote_func_outputs = false) { |
| return TestRemoteExecuteSilentCopies(async, remote, /*func=*/false, |
| /*heavy_load_on_streaming_rpc=*/false, |
| remote_func_outputs); |
| } |
| |
| TEST(CAPI, RemoteExecuteSilentCopies) { |
| TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/true); |
| } |
| TEST(CAPI, RemoteExecuteSilentCopiesAsync) { |
| TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/true); |
| } |
| TEST(CAPI, RemoteExecuteSilentCopiesLocal) { |
| TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/false); |
| } |
| TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) { |
| TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/false); |
| } |
| |
| } // namespace |