Reenable change to share resource variables across TFEContexts.
PiperOrigin-RevId: 406242257
Change-Id: Ie82405025a90d5e2aeaef8c65d7e1bb1bc121497
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index d138bde..e793410 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -102,6 +102,7 @@
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_interface",
"//tensorflow/core:gpu_runtime",
+ "@com_google_absl//absl/strings:str_format",
] + internal_tfrt_deps(),
alwayslink = 1,
)
@@ -681,6 +682,7 @@
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:tensor_handle",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
# copybara:uncomment_begin
# "//tensorflow/core/tfrt/eager:c_api_tfrt",
# "@tf_runtime//backends/cpu:tf_ops_alwayslink",
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index c697540..a36156d 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -33,6 +33,7 @@
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/logging.h"
@@ -1960,4 +1961,217 @@
TFE_DeleteContext(ctx);
}
+tensorflow::ServerDef ReplaceTaskInServerDef(
+ const tensorflow::ServerDef& server_def, int task_index) {
+ tensorflow::ServerDef server_def_copy = server_def;
+ tensorflow::ClusterDef* cluster_def = server_def_copy.mutable_cluster();
+ tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
+ const int port = tensorflow::testing::PickUnusedPortOrDie();
+ job_def->mutable_tasks()->at(task_index) =
+ tensorflow::strings::StrCat("localhost:", port);
+ return server_def_copy;
+}
+
+TFE_TensorHandle* CreateVarHandle(TFE_Context* ctx,
+ const tensorflow::string& device_name,
+ const tensorflow::string& variable_name) {
+ TF_Status* status = TF_NewStatus();
+ // Create the variable handle.
+ TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+ TFE_OpSetAttrShape(op, "shape", {}, 0, status);
+ TFE_OpSetAttrString(op, "container", "localhost", 0);
+ TFE_OpSetAttrString(op, "shared_name", variable_name.data(),
+ variable_name.size());
+ if (!device_name.empty()) {
+ TFE_OpSetDevice(op, device_name.c_str(), status);
+ }
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_TensorHandle* var_handle = nullptr;
+ int num_retvals = 1;
+ TFE_Execute(op, &var_handle, &num_retvals, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_DeleteOp(op);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ CHECK_EQ(1, num_retvals);
+ TF_DeleteStatus(status);
+ return var_handle;
+}
+
+TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
+ const tensorflow::string& device_name,
+ const tensorflow::string& variable_name) {
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* var_handle =
+ CreateVarHandle(ctx, device_name, variable_name);
+
+ // Assign 'value' to it.
+ TFE_Op* op = TFE_NewOp(ctx, "AssignVariableOp", status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+ TFE_OpAddInput(op, var_handle, status);
+ if (!device_name.empty()) {
+ TFE_OpSetDevice(op, device_name.c_str(), status);
+ }
+
+ // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
+ std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
+ TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
+ memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
+
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ value_handle(TFE_NewTensorHandle(t.get(), status),
+ TFE_DeleteTensorHandle);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+
+ TFE_OpAddInput(op, value_handle.get(), status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+
+ int num_retvals = 0;
+ TFE_Execute(op, nullptr, &num_retvals, status);
+ TFE_DeleteOp(op);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ CHECK_EQ(0, num_retvals);
+ TF_DeleteStatus(status);
+ return var_handle;
+}
+
+TFE_Context* CreateContext(const string& serialized_server_def,
+ bool isolate_session_state) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ opts->session_options.options.config.set_isolate_session_state(
+ isolate_session_state);
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(false));
+ TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
+ TFE_ContextSetServerDef(ctx, 0, serialized_server_def.data(),
+ serialized_server_def.size(), status);
+ EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+ TF_DeleteStatus(status);
+ return ctx;
+}
+
+std::vector<std::string> ListDeviceNames(TFE_Context* ctx) {
+ TF_Status* status = TF_NewStatus();
+ std::vector<std::string> device_names;
+ TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
+ EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
+ const int num_devices = TF_DeviceListCount(devices);
+ for (int i = 0; i < num_devices; ++i) {
+ device_names.emplace_back(TF_DeviceListName(devices, i, status));
+ EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
+ }
+ TF_DeleteDeviceList(devices);
+ TF_DeleteStatus(status);
+ return device_names;
+}
+
+TEST(CAPI, ShareVariableAcrossContextsWorks) {
+ // TODO(shreepadma): Add a test case with isolate_session_state set to true.
+ tensorflow::ServerDef server_def_0 = GetServerDef(3);
+ server_def_0.mutable_default_session_config()->set_isolate_session_state(
+ false);
+ tensorflow::ServerDef server_def_1 =
+ ReplaceTaskInServerDef(server_def_0, /*task_index=*/0);
+
+ // These server defs have task index set to 0.
+ string serialized_server_def_0 = server_def_0.SerializeAsString();
+ string serialized_server_def_1 = server_def_1.SerializeAsString();
+
+ // Create two worker tasks.
+ server_def_0.set_task_index(1);
+ std::unique_ptr<tensorflow::GrpcServer> worker_server1;
+ ASSERT_TRUE(tensorflow::GrpcServer::Create(
+ server_def_0, tensorflow::Env::Default(), &worker_server1)
+ .ok());
+ ASSERT_TRUE(worker_server1->Start().ok());
+ server_def_0.set_task_index(2);
+ std::unique_ptr<tensorflow::GrpcServer> worker_server2;
+ ASSERT_TRUE(tensorflow::GrpcServer::Create(
+ server_def_0, tensorflow::Env::Default(), &worker_server2)
+ .ok());
+ ASSERT_TRUE(worker_server2->Start().ok());
+
+ TFE_Context* ctx_0 = CreateContext(serialized_server_def_0,
+ /*isolate_session_state=*/false);
+ TFE_Context* ctx_1 = CreateContext(serialized_server_def_1,
+ /*isolate_session_state=*/false);
+
+ // Remote device on `worker1`.
+ const char remote_device[] = "/job:localhost/replica:0/task:1/device:CPU:0";
+ // `ctx_0`, `ctx_1`, `ctx_2` contains `remote_device`.
+ {
+ const std::vector<std::string>& device_names = ListDeviceNames(ctx_0);
+ ASSERT_TRUE(std::find(device_names.begin(), device_names.end(),
+ remote_device) != device_names.end());
+ }
+
+ {
+ const std::vector<std::string>& device_names = ListDeviceNames(ctx_1);
+ ASSERT_TRUE(std::find(device_names.begin(), device_names.end(),
+ remote_device) != device_names.end());
+ }
+
+ // Create a variable using `ctx_0`.
+ // Read the variable using `ctx_1`. This read should succeed.
+ // 1. Create a variable on `remote_device`, using `ctx_0`.
+ TFE_TensorHandle* handle_0 =
+ CreateVariable(ctx_0, 1.2, remote_device, /*variable_name=*/"var2");
+
+ // 2. Wait for `var2` to be created and initialized on the worker.
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextAsyncWait(ctx_0, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+
+ // 3. Read `var_2` using `ctx_1`. This read should succeed since `ctx_1` was
+ // created with `isolate_session_state` set to false.
+ {
+ // Create a handle to `var2`, using `ctx_1`.
+ TFE_TensorHandle* var_handle =
+ CreateVarHandle(ctx_1, remote_device, /*variable_name=*/"var2");
+
+ TFE_TensorHandle* handle_1 = nullptr;
+ int num_retvals = 1;
+ TF_Status* status = TF_NewStatus();
+ TFE_Op* op = TFE_NewOp(ctx_1, "ReadVariableOp", status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+ TFE_OpAddInput(op, var_handle, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_Execute(op, &handle_1, &num_retvals, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteOp(op);
+
+ ASSERT_EQ(1, num_retvals);
+ EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(handle_1));
+ EXPECT_EQ(0, TFE_TensorHandleNumDims(handle_1, status));
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // Read the value of tensor handle `handle_1`.
+ float value = 0.0f;
+ TF_Tensor* t = TFE_TensorHandleResolve(handle_1, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
+ memcpy(&value, TF_TensorData(t), sizeof(float));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(1.2f, value);
+ TFE_DeleteTensorHandle(handle_1);
+ TF_DeleteStatus(status);
+ TFE_DeleteTensorHandle(var_handle);
+ }
+
+ TFE_DeleteTensorHandle(handle_0);
+
+ TFE_DeleteContext(ctx_0);
+ TFE_DeleteContext(ctx_1);
+
+ worker_server1.release();
+ worker_server2.release();
+}
+
} // namespace
diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
index 083454f..6f38dcb 100644
--- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
+++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
@@ -598,7 +598,7 @@
std::shared_ptr<WorkerSession> worker_session;
LOG_AND_RETURN_IF_ERROR(server->worker_env()->session_mgr->CreateSession(
session_name, server_def, base_request.cluster_device_attributes(),
- true));
+ context->session_options().config.isolate_session_state()));
LOG_AND_RETURN_IF_ERROR(
server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
@@ -708,7 +708,8 @@
auto session_name = strings::StrCat("eager_", context_->GetContextId());
std::shared_ptr<WorkerSession> worker_session;
LOG_AND_RETURN_IF_ERROR(server->worker_env()->session_mgr->CreateSession(
- session_name, server_def, true));
+ session_name, server_def,
+ context_->session_options().config.isolate_session_state()));
LOG_AND_RETURN_IF_ERROR(
server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 0e076d8..5c1c379 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -243,7 +243,7 @@
}
TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
session_name, request->server_def(), request->cluster_device_attributes(),
- true));
+ request->server_def().default_session_config().isolate_session_state()));
int64_t context_id = request->context_id();
std::function<void()> session_destroyer = [this, context_id, session_name]() {
env_->rendezvous_mgr->Cleanup(context_id);