Avoid partially creating/updating cluster when some workers fail during update.

PiperOrigin-RevId: 309276774
Change-Id: Id9191f529f0531f4db1e8a59ea3e35f9295c2bce
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 07ac760..3a4acd8 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -340,7 +340,7 @@
 
 tf_cuda_cc_test(
     name = "c_api_remote_test",
-    size = "small",
+    size = "medium",
     srcs = [
         "c_api_remote_test.cc",
     ],
@@ -361,6 +361,7 @@
         "//tensorflow/core:test_main",
         "//tensorflow/core/common_runtime/eager:eager_operation",
         "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+        "@com_google_absl//absl/debugging:leak_check",
         "@com_google_absl//absl/strings",
     ],
 )
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 540efe9..559d047 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -500,6 +500,17 @@
       grpc_server->master_env()->worker_cache->GetEagerClientCache(
           &remote_eager_workers));
 
+  // For cluster update, use a status group to aggregate statuses from
+  //   * adding and removing remote devices
+  //   * creating remote contexts on newly added workers
+  //   * updating remote contexts on existing workers
+  //   * updating the master context
+  // Note that we should not return immediately on errors in the middle of these
+  // updates to prevent cluster from having inconsistent context views.
+  //
+  // Unused if `reset_context` is True.
+  tensorflow::StatusGroup sg;
+
   // When updating an existing context, populate the following lists with:
   // * added_workers: set(remote_workers) - set(curr_remote_workers)
   // * removed_workers: set(curr_remote_workers) - set(remote_workers)
@@ -535,7 +546,7 @@
     DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
                              &added_workers, &removed_workers,
                              &existing_workers);
-    LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
+    sg.Update(GetReplacedFromExistingWorkers(
         &existing_workers, context_id, context->GetContextViewId(), server_def,
         remote_eager_workers.get(), &replaced_workers));
     if (VLOG_IS_ON(1)) {
@@ -559,11 +570,10 @@
             existing_workers.end());
       }
     }
-    LOG_AND_RETURN_IF_ERROR(
-        RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
-    LOG_AND_RETURN_IF_ERROR(AddRemoteDevicesToMgr(
-        added_workers, grpc_server->master_env()->worker_cache,
-        remote_device_mgr));
+    sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
+    sg.Update(AddRemoteDevicesToMgr(added_workers,
+                                    grpc_server->master_env()->worker_cache,
+                                    remote_device_mgr));
   }
 
   std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
@@ -584,7 +594,6 @@
   }
 
   // Initialize remote eager workers.
-  // TODO(b/138847548) Create remote eager contexts in async mode by default.
   if (reset_context) {
     LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
         ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
@@ -596,7 +605,7 @@
     // existing workers to also have the updated context_view_id, so
     // we must set their context_view_id to the existing master's
     // context_view_id + 1.
-    LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
+    sg.Update(CreateRemoteContexts(
         ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
         server_def, remote_eager_workers.get(), context->Executor().Async(),
         context->LazyCopyFunctionRemoteInputs(), base_request));
@@ -606,10 +615,10 @@
           VLOG(1) << "Updating cluster with existing worker " << w;
         }
       }
-      LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
-          ctx, existing_workers, added_workers, removed_workers, context_id,
-          context_view_id + 1, server_def, remote_eager_workers.get(),
-          base_request));
+      sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
+                                     removed_workers, context_id,
+                                     context_view_id + 1, server_def,
+                                     remote_eager_workers.get(), base_request));
     }
   }
 
@@ -645,13 +654,13 @@
     // GrpcServer cannot be destroyed after it is started.
     LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
   } else {
-    LOG_AND_RETURN_IF_ERROR(
-        grpc_server->worker_env()->session_mgr->UpdateSession(
-            session_name, server_def, base_request.cluster_device_attributes(),
-            /*isolate_session_state=*/true));
-    LOG_AND_RETURN_IF_ERROR(
-        context->UpdateRemoteMaster(context_id, std::move(remote_eager_workers),
-                                    added_workers, removed_workers));
+    sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
+        session_name, server_def, base_request.cluster_device_attributes(),
+        /*isolate_session_state=*/true));
+    sg.Update(context->UpdateRemoteMaster(context_id,
+                                          std::move(remote_eager_workers),
+                                          added_workers, removed_workers));
+    LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
   }
 #undef LOG_AND_RETURN_IF_ERROR
 
diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc
index 7c6836a..9b83023 100644
--- a/tensorflow/c/eager/c_api_remote_test.cc
+++ b/tensorflow/c/eager/c_api_remote_test.cc
@@ -13,6 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "absl/debugging/leak_check.h"
 #include "tensorflow/c/eager/c_api.h"
 #include "tensorflow/c/eager/c_api_experimental.h"
 #include "tensorflow/c/eager/c_api_internal.h"
@@ -21,6 +22,7 @@
 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
 #include "tensorflow/core/platform/casts.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/cluster.pb.h"
@@ -527,4 +529,124 @@
   TestRemoteExecuteChangeServerDef(true);
 }
 
+void TestRemoteExecuteUpdateServerDef(bool async) {
+  // TODO(b/136478427): Skip heap checker for leaked gRPC server instances.
+  absl::LeakCheckDisabler disabler;
+
+  tensorflow::ServerDef server_def = GetServerDef(2);
+  // This server def has the task index set to 0.
+  string serialized = server_def.SerializeAsString();
+
+  server_def.set_task_index(1);
+  std::unique_ptr<tensorflow::GrpcServer> worker_server;
+  ASSERT_TRUE(tensorflow::GrpcServer::Create(
+                  server_def, tensorflow::Env::Default(), &worker_server)
+                  .ok());
+  ASSERT_TRUE(worker_server->Start().ok());
+
+  TF_Status* status = TF_NewStatus();
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+  TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+
+  TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  const char local_device_name[] =
+      "/job:localhost/replica:0/task:0/device:CPU:0";
+  const char remote_device_name[] =
+      "/job:localhost/replica:0/task:1/device:CPU:0";
+  CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
+
+  TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
+                             status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
+
+  TFE_DeleteContext(ctx);
+  TF_DeleteStatus(status);
+
+  // TODO(b/136478427): Figure out how to correctly shut the server down.
+  worker_server.release();
+}
+
+TEST(CAPI, RemoteExecuteUpdateServerDef) {
+  TestRemoteExecuteUpdateServerDef(false);
+}
+
+TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
+  TestRemoteExecuteUpdateServerDef(true);
+}
+
+void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
+  // TODO(b/136478427): Skip heap checker for leaked gRPC server instances.
+  absl::LeakCheckDisabler disabler;
+  // Fail fast on GetStatus requests so we can get errors instead of timeout
+  // when updating cluster with non-exsitent worker
+  setenv("GRPC_FAIL_FAST", "TRUE", 1);
+
+  tensorflow::ServerDef server_def = GetServerDef(2);
+  // This server def has the task index set to 0.
+  string serialized = server_def.SerializeAsString();
+
+  server_def.set_task_index(1);
+  std::unique_ptr<tensorflow::GrpcServer> worker_server;
+  ASSERT_TRUE(tensorflow::GrpcServer::Create(
+                  server_def, tensorflow::Env::Default(), &worker_server)
+                  .ok());
+  ASSERT_TRUE(worker_server->Start().ok());
+
+  TF_Status* status = TF_NewStatus();
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+  TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+
+  TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  const char local_device_name[] =
+      "/job:localhost/replica:0/task:0/device:CPU:0";
+  const char remote_device_name[] =
+      "/job:localhost/replica:0/task:1/device:CPU:0";
+  CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
+
+  // Adding a non-existent remote worker to cluster def. This should cause the
+  // UpdateServerDef call to fail.
+  tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
+  tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
+  int port = tensorflow::testing::PickUnusedPortOrDie();
+  job_def->mutable_tasks()->insert(
+      {2, tensorflow::strings::StrCat("localhost:", port)});
+  string serialized_update = server_def.SerializeAsString();
+  TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
+                             serialized_update.size(), status);
+  EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+  // Even after the prevoiusly failed cluster update, another update and op
+  // execution should work fine as long as the provided server_def is valid.
+  TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
+                             status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
+
+  TFE_DeleteContext(ctx);
+  TF_DeleteStatus(status);
+
+  // TODO(b/136478427): Figure out how to correctly shut the server down.
+  worker_server.release();
+  unsetenv("GRPC_FAIL_FAST");
+}
+
+TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) {
+  TestRemoteExecuteUpdateServerDefWithFailures(false);
+}
+
+TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
+  TestRemoteExecuteUpdateServerDefWithFailures(true);
+}
+
 }  // namespace