Reenable change to share resource variables across TFEContexts.

PiperOrigin-RevId: 406242257
Change-Id: Ie82405025a90d5e2aeaef8c65d7e1bb1bc121497
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index d138bde..e793410 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -102,6 +102,7 @@
         "//tensorflow/core/distributed_runtime:worker_env",
         "//tensorflow/core/distributed_runtime:worker_interface",
         "//tensorflow/core:gpu_runtime",
+        "@com_google_absl//absl/strings:str_format",
     ] + internal_tfrt_deps(),
     alwayslink = 1,
 )
@@ -681,6 +682,7 @@
         "//tensorflow/core:test_main",
         "//tensorflow/core/common_runtime/eager:eager_operation",
         "//tensorflow/core/common_runtime/eager:tensor_handle",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
         # copybara:uncomment_begin
         # "//tensorflow/core/tfrt/eager:c_api_tfrt",
         # "@tf_runtime//backends/cpu:tf_ops_alwayslink",
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index c697540..a36156d 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -33,6 +33,7 @@
 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/platform/casts.h"
 #include "tensorflow/core/platform/logging.h"
@@ -1960,4 +1961,217 @@
   TFE_DeleteContext(ctx);
 }
 
+tensorflow::ServerDef ReplaceTaskInServerDef(
+    const tensorflow::ServerDef& server_def, int task_index) {
+  tensorflow::ServerDef server_def_copy = server_def;
+  tensorflow::ClusterDef* cluster_def = server_def_copy.mutable_cluster();
+  tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
+  const int port = tensorflow::testing::PickUnusedPortOrDie();
+  job_def->mutable_tasks()->at(task_index) =
+      tensorflow::strings::StrCat("localhost:", port);
+  return server_def_copy;
+}
+
+TFE_TensorHandle* CreateVarHandle(TFE_Context* ctx,
+                                  const tensorflow::string& device_name,
+                                  const tensorflow::string& variable_name) {
+  TF_Status* status = TF_NewStatus();
+  // Create the variable handle.
+  TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
+  if (TF_GetCode(status) != TF_OK) return nullptr;
+  TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+  TFE_OpSetAttrShape(op, "shape", {}, 0, status);
+  TFE_OpSetAttrString(op, "container", "localhost", 0);
+  TFE_OpSetAttrString(op, "shared_name", variable_name.data(),
+                      variable_name.size());
+  if (!device_name.empty()) {
+    TFE_OpSetDevice(op, device_name.c_str(), status);
+  }
+  if (TF_GetCode(status) != TF_OK) return nullptr;
+  TFE_TensorHandle* var_handle = nullptr;
+  int num_retvals = 1;
+  TFE_Execute(op, &var_handle, &num_retvals, status);
+  if (TF_GetCode(status) != TF_OK) return nullptr;
+  TFE_DeleteOp(op);
+  if (TF_GetCode(status) != TF_OK) return nullptr;
+  CHECK_EQ(1, num_retvals);
+  TF_DeleteStatus(status);
+  return var_handle;
+}
+
+TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
+                                 const tensorflow::string& device_name,
+                                 const tensorflow::string& variable_name) {
+  TF_Status* status = TF_NewStatus();
+  TFE_TensorHandle* var_handle =
+      CreateVarHandle(ctx, device_name, variable_name);
+
+  // Assign 'value' to it.
+  TFE_Op* op = TFE_NewOp(ctx, "AssignVariableOp", status);
+  if (TF_GetCode(status) != TF_OK) return nullptr;
+  TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+  TFE_OpAddInput(op, var_handle, status);
+  if (!device_name.empty()) {
+    TFE_OpSetDevice(op, device_name.c_str(), status);
+  }
+
+  // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
+  std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
+      TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
+  memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
+
+  std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+      value_handle(TFE_NewTensorHandle(t.get(), status),
+                   TFE_DeleteTensorHandle);
+  if (TF_GetCode(status) != TF_OK) return nullptr;
+
+  TFE_OpAddInput(op, value_handle.get(), status);
+  if (TF_GetCode(status) != TF_OK) return nullptr;
+
+  int num_retvals = 0;
+  TFE_Execute(op, nullptr, &num_retvals, status);
+  TFE_DeleteOp(op);
+  if (TF_GetCode(status) != TF_OK) return nullptr;
+  CHECK_EQ(0, num_retvals);
+  TF_DeleteStatus(status);
+  return var_handle;
+}
+
+TFE_Context* CreateContext(const string& serialized_server_def,
+                           bool isolate_session_state) {
+  TF_Status* status = TF_NewStatus();
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  opts->session_options.options.config.set_isolate_session_state(
+      isolate_session_state);
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(false));
+  TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
+  TFE_ContextSetServerDef(ctx, 0, serialized_server_def.data(),
+                          serialized_server_def.size(), status);
+  EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+  TF_DeleteStatus(status);
+  return ctx;
+}
+
+std::vector<std::string> ListDeviceNames(TFE_Context* ctx) {
+  TF_Status* status = TF_NewStatus();
+  std::vector<std::string> device_names;
+  TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
+  EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
+  const int num_devices = TF_DeviceListCount(devices);
+  for (int i = 0; i < num_devices; ++i) {
+    device_names.emplace_back(TF_DeviceListName(devices, i, status));
+    EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
+  }
+  TF_DeleteDeviceList(devices);
+  TF_DeleteStatus(status);
+  return device_names;
+}
+
+TEST(CAPI, ShareVariableAcrossContextsWorks) {
+  // TODO(shreepadma): Add a test case with isolate_session_state set to true.
+  tensorflow::ServerDef server_def_0 = GetServerDef(3);
+  server_def_0.mutable_default_session_config()->set_isolate_session_state(
+      false);
+  tensorflow::ServerDef server_def_1 =
+      ReplaceTaskInServerDef(server_def_0, /*task_index=*/0);
+
+  // These server defs have task index set to 0.
+  string serialized_server_def_0 = server_def_0.SerializeAsString();
+  string serialized_server_def_1 = server_def_1.SerializeAsString();
+
+  // Create two worker tasks.
+  server_def_0.set_task_index(1);
+  std::unique_ptr<tensorflow::GrpcServer> worker_server1;
+  ASSERT_TRUE(tensorflow::GrpcServer::Create(
+                  server_def_0, tensorflow::Env::Default(), &worker_server1)
+                  .ok());
+  ASSERT_TRUE(worker_server1->Start().ok());
+  server_def_0.set_task_index(2);
+  std::unique_ptr<tensorflow::GrpcServer> worker_server2;
+  ASSERT_TRUE(tensorflow::GrpcServer::Create(
+                  server_def_0, tensorflow::Env::Default(), &worker_server2)
+                  .ok());
+  ASSERT_TRUE(worker_server2->Start().ok());
+
+  TFE_Context* ctx_0 = CreateContext(serialized_server_def_0,
+                                     /*isolate_session_state=*/false);
+  TFE_Context* ctx_1 = CreateContext(serialized_server_def_1,
+                                     /*isolate_session_state=*/false);
+
+  // Remote device on `worker1`.
+  const char remote_device[] = "/job:localhost/replica:0/task:1/device:CPU:0";
+  // `ctx_0`, `ctx_1`, `ctx_2` contains `remote_device`.
+  {
+    const std::vector<std::string>& device_names = ListDeviceNames(ctx_0);
+    ASSERT_TRUE(std::find(device_names.begin(), device_names.end(),
+                          remote_device) != device_names.end());
+  }
+
+  {
+    const std::vector<std::string>& device_names = ListDeviceNames(ctx_1);
+    ASSERT_TRUE(std::find(device_names.begin(), device_names.end(),
+                          remote_device) != device_names.end());
+  }
+
+  // Create a variable using `ctx_0`.
+  // Read the variable using `ctx_1`. This read should succeed.
+  // 1. Create a variable on `remote_device`, using `ctx_0`.
+  TFE_TensorHandle* handle_0 =
+      CreateVariable(ctx_0, 1.2, remote_device, /*variable_name=*/"var2");
+
+  // 2. Wait for `var2` to be created and initialized on the worker.
+  TF_Status* status = TF_NewStatus();
+  TFE_ContextAsyncWait(ctx_0, status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TF_DeleteStatus(status);
+
+  // 3. Read `var_2` using `ctx_1`. This read should succeed since `ctx_1` was
+  // created with `isolate_session_state` set to false.
+  {
+    // Create a handle to `var2`, using `ctx_1`.
+    TFE_TensorHandle* var_handle =
+        CreateVarHandle(ctx_1, remote_device, /*variable_name=*/"var2");
+
+    TFE_TensorHandle* handle_1 = nullptr;
+    int num_retvals = 1;
+    TF_Status* status = TF_NewStatus();
+    TFE_Op* op = TFE_NewOp(ctx_1, "ReadVariableOp", status);
+    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+    TFE_OpAddInput(op, var_handle, status);
+    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    TFE_Execute(op, &handle_1, &num_retvals, status);
+    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    TFE_DeleteOp(op);
+
+    ASSERT_EQ(1, num_retvals);
+    EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(handle_1));
+    EXPECT_EQ(0, TFE_TensorHandleNumDims(handle_1, status));
+    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+    // Read the value of tensor handle `handle_1`.
+    float value = 0.0f;
+    TF_Tensor* t = TFE_TensorHandleResolve(handle_1, status);
+    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
+    memcpy(&value, TF_TensorData(t), sizeof(float));
+    TF_DeleteTensor(t);
+    EXPECT_EQ(1.2f, value);
+    TFE_DeleteTensorHandle(handle_1);
+    TF_DeleteStatus(status);
+    TFE_DeleteTensorHandle(var_handle);
+  }
+
+  TFE_DeleteTensorHandle(handle_0);
+
+  TFE_DeleteContext(ctx_0);
+  TFE_DeleteContext(ctx_1);
+
+  worker_server1.release();
+  worker_server2.release();
+}
+
 }  // namespace
diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
index 083454f..6f38dcb 100644
--- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
+++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
@@ -598,7 +598,7 @@
     std::shared_ptr<WorkerSession> worker_session;
     LOG_AND_RETURN_IF_ERROR(server->worker_env()->session_mgr->CreateSession(
         session_name, server_def, base_request.cluster_device_attributes(),
-        true));
+        context->session_options().config.isolate_session_state()));
     LOG_AND_RETURN_IF_ERROR(
         server->worker_env()->session_mgr->WorkerSessionForSession(
             session_name, &worker_session));
@@ -708,7 +708,8 @@
       auto session_name = strings::StrCat("eager_", context_->GetContextId());
       std::shared_ptr<WorkerSession> worker_session;
       LOG_AND_RETURN_IF_ERROR(server->worker_env()->session_mgr->CreateSession(
-          session_name, server_def, true));
+          session_name, server_def,
+          context_->session_options().config.isolate_session_state()));
       LOG_AND_RETURN_IF_ERROR(
           server->worker_env()->session_mgr->WorkerSessionForSession(
               session_name, &worker_session));
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 0e076d8..5c1c379 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -243,7 +243,7 @@
   }
   TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
       session_name, request->server_def(), request->cluster_device_attributes(),
-      true));
+      request->server_def().default_session_config().isolate_session_state()));
   int64_t context_id = request->context_id();
   std::function<void()> session_destroyer = [this, context_id, session_name]() {
     env_->rendezvous_mgr->Cleanup(context_id);