Exchange device attributes at group resolution

Previously CollectiveParamResolver queries device attributes when initializing instance params. That has issues when the collective leader fails and restarts quickly between group resolution and instance resolution. In such case, all other workers get the incarnation of the restarted leader, thus they're unable to detect that the leader has failed; the leader will deadlock on the group resolution.

This change doesn't fully fixed the issue because it only exchanges device attributes at group resolution, but doesn't populate the device attributes to DeviceResolver. That will be done in a following change.

This change also changes the behavior when a non-leader fails and restarts. Previously it gets the cached group resolution from the leader, now it will get an error because its incarnation doesn't match with the one in the cached group parameters. This should have no actual effect since that worker will always restart again after the leader has restarted.

This change changes both the client and server without being backward compatible. It assumes that client and server are running the same version of Tensorflow. This should be true since the only way to use CollectiveParamResolverDistributed is through MultiWorkerMirroredStrategy (MWMS). For MWMS, all workers should run the same version of the program.

PiperOrigin-RevId: 329735919
Change-Id: I5c29a3ec8462c7737bcbbbf823a95693b0d27dc3
diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD
index 4d1c5c7..73c1458 100644
--- a/tensorflow/core/common_runtime/BUILD
+++ b/tensorflow/core/common_runtime/BUILD
@@ -379,12 +379,11 @@
     hdrs = ["collective_param_resolver_local.h"],
     copts = tf_copts(),
     deps = [
+        ":device",
         ":device_mgr",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
     ],
 )
 
diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc
index 6e5e5c8..a662928 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.cc
+++ b/tensorflow/core/common_runtime/base_collective_executor.cc
@@ -298,8 +298,8 @@
 }
 
 void BaseCollectiveExecutor::CompleteParamsAsync(
-    const DeviceAttributes& device, CollectiveParams* cp,
-    CancellationManager* cancel_mgr, StatusCallback done) {
+    const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
+    StatusCallback done) {
   cp->instance.gpu_ring_order = *gpu_ring_order_;
   const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
   auto done_with_timeout = done;
diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h
index 4081b88..c9cea39 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.h
+++ b/tensorflow/core/common_runtime/base_collective_executor.h
@@ -113,7 +113,7 @@
   void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params,
                     const string& exec_key, StatusCallback done) override;
 
-  void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
+  void CompleteParamsAsync(const string& device, CollectiveParams* cp,
                            CancellationManager* cancel_mgr,
                            StatusCallback done) override;
 
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 87fe453..ba21abc 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -20,7 +20,6 @@
 #include <unordered_map>
 #include <utility>
 
-#include "absl/container/flat_hash_set.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/framework/device_attributes.pb.h"
@@ -31,7 +30,6 @@
 #include "tensorflow/core/lib/strings/numbers.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/protobuf/config.pb.h"
@@ -76,21 +74,12 @@
       return "undef";
   }
 }
-
-string TaskNameFromDeviceName(const string& device_name) {
-  DeviceNameUtils::ParsedName parsed_device;
-  CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device));
-  string task_name;
-  CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name));
-  return task_name;
-}
 }  // namespace
 
 void CollectiveParamResolverLocal::CompleteGroupLocal(
-    const DeviceAttributes& device, CollectiveParams* cp,
-    const GroupRecCallback& done) {
-  VLOG(1) << "CompleteGroupLocal device=" << device.name() << " cp: " << cp
-          << ": " << cp->ToString();
+    const string& device, CollectiveParams* cp, const GroupRecCallback& done) {
+  VLOG(1) << "CompleteGroupLocal device=" << device << " cp: " << cp << ": "
+          << cp->ToString();
   std::vector<StatusCallback> to_be_called;
   GroupRec* gr = nullptr;
   Status status;
@@ -150,13 +139,13 @@
     // status.
     VLOG(2) << "gr device_type=" << gr->group.device_type
             << " cp device_type=" << cp->group.device_type
-            << " current device=" << device.name();
+            << " current device=" << device;
     if (gr->status.ok()) {
       // Check for consistency with existing GroupRec.
       if (cp->group.device_type != gr->group.device_type) {
         gr->status = errors::Internal(
-            "Collective Op ", cp->name, " is assigned to device ",
-            device.name(), " with type ", cp->group.device_type.type_string(),
+            "Collective Op ", cp->name, " is assigned to device ", device,
+            " with type ", cp->group.device_type.type_string(),
             " and group_key ", cp->group.group_key, " but that group has type ",
             gr->group.device_type.type_string());
       } else if (cp->group.group_size != gr->group.group_size) {
@@ -168,47 +157,38 @@
     }
     if (gr->status.ok()) {
       // Insert device if not already present.
-      auto it = gr->devices.find(device.name());
-      if (it == gr->devices.end()) {
-        if (gr->devices.size() == gr->group.group_size) {
+      auto it = gr->device_set.find(device);
+      if (it == gr->device_set.end()) {
+        if (gr->device_set.size() == gr->group.group_size) {
           // The group is already full.
           gr->status = errors::Internal(
-              "Collective Op ", cp->name, " is assigned to device ",
-              device.name(), " and group_key ", cp->group.group_key,
+              "Collective Op ", cp->name, " is assigned to device ", device,
+              " and group_key ", cp->group.group_key,
               " but that group doesn't contain that device.");
         } else {
           // This is a new device that has not yet joined the group.
-          gr->devices[device.name()] = device;
-          if (gr->devices.size() == gr->group.group_size) {
-            // The group is full after adding this device, calculate the number
-            // of tasks.
-            absl::flat_hash_set<string> tasks;
-            for (const auto& item : gr->devices) {
-              tasks.insert(TaskNameFromDeviceName(item.first));
-            }
-            gr->group.num_tasks = static_cast<int32>(tasks.size());
-          }
+          gr->device_set.insert(device);
+          gr->device_list.push_back(device);
+          DeviceNameUtils::ParsedName parsed_device;
+          DeviceNameUtils::ParseFullName(device, &parsed_device);
+          string task_name = strings::StrCat("/job:", parsed_device.job,
+                                             "/replica:", parsed_device.replica,
+                                             "/task:", parsed_device.task);
+          gr->task_set.insert(task_name);
+          gr->task_list.push_back(task_name);
+          gr->group.num_tasks = static_cast<int32>(gr->task_set.size());
           if (VLOG_IS_ON(1)) {
             string dev_buf;
-            for (const auto& d : gr->devices) {
-              strings::StrAppend(&dev_buf, ",", d.first);
+            for (const auto& d : gr->device_set) {
+              strings::StrAppend(&dev_buf, ",", d);
             }
             VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
                     << " group_size=" << gr->group.group_size << " (current"
                     << " devices)=(" << dev_buf << ") (number of"
                     << " devices pending)="
-                    << (gr->group.group_size - gr->devices.size());
+                    << (gr->group.group_size - gr->device_set.size());
           }
         }
-      } else {
-        // If the device already exists, check if the incarnation matches.
-        if (it->second.incarnation() != device.incarnation()) {
-          gr->status = errors::FailedPrecondition(
-              "Device ", device.name(),
-              " current incarnation doesn't match with one in the group. This "
-              "usually means this worker has restarted but the collective "
-              "leader hasn't, or this worker connects to a wrong cluster.");
-        }
       }
     }
 
@@ -216,13 +196,13 @@
       cp->group.runtime_details = gr->group.runtime_details;
       // If the group is not yet complete, queue to wait for it.
       VLOG(2) << "group_size " << gr->group.group_size << " set size "
-              << gr->devices.size() << " gr " << gr;
+              << gr->device_set.size() << " gr " << gr;
 
-      if (gr->devices.size() < gr->group.group_size) {
+      if (gr->device_set.size() < gr->group.group_size) {
         gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
         return;
       }
-      CHECK_EQ(gr->devices.size(), gr->group.group_size);
+      CHECK_EQ(gr->device_set.size(), gr->group.group_size);
     }
     // At this point, we either have a full group, or an error status.  Ensure
     // that all callbacks are invoked with the appropriate status.
@@ -501,15 +481,10 @@
   {
     mutex_lock gl(gr->mu);
     ir->shared.group = gr->group;
-    ir->shared.instance.device_names.clear();
-    ir->shared.instance.task_names.clear();
-    ir->shared.instance.device_names.reserve(gr->devices.size());
-    ir->shared.instance.task_names.reserve(gr->devices.size());
-    for (const auto& item : gr->devices) {
-      ir->shared.instance.device_names.push_back(item.first);
-      ir->shared.instance.task_names.push_back(
-          TaskNameFromDeviceName(item.first));
-    }
+    ir->shared.instance.device_names.assign(gr->device_list.begin(),
+                                            gr->device_list.end());
+    ir->shared.instance.task_names.assign(gr->task_list.begin(),
+                                          gr->task_list.end());
     VLOG(2) << "Initialized names for instance: "
             << ir->shared.instance.ToString();
   }
@@ -707,15 +682,15 @@
 }
 
 void CollectiveParamResolverLocal::CompleteParamsAsync(
-    const DeviceAttributes& device, CollectiveParams* cp,
-    CancellationManager* cancel_mgr, const StatusCallback& done) {
-  VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": "
+    const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
+    const StatusCallback& done) {
+  VLOG(1) << "CompleteParams local " << device << " for " << cp << ": "
           << cp->ToString();
   CompleteGroupLocal(
       device, cp,
       [this, device, cp, done](const Status& s, const GroupRec* gr) {
         if (s.ok()) {
-          CompleteInstanceLocal(device.name(), gr, cp, cp->is_source, done);
+          CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
         } else {
           done(s);
         }
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 9d9f713..40f0f00 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -21,9 +21,7 @@
 #include <string>
 #include <vector>
 
-#include "absl/container/flat_hash_map.h"
 #include "tensorflow/core/framework/collective.h"
-#include "tensorflow/core/framework/device_attributes.pb.h"
 #include "tensorflow/core/lib/gtl/flatmap.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 
@@ -47,7 +45,7 @@
 
   ~CollectiveParamResolverLocal() override {}
 
-  void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
+  void CompleteParamsAsync(const string& device, CollectiveParams* cp,
                            CancellationManager* cancel_mgr,
                            const StatusCallback& done) override;
 
@@ -72,7 +70,10 @@
     CollGroupParams group;
     mutable mutex mu;
     Status status TF_GUARDED_BY(mu);
-    absl::flat_hash_map<string, DeviceAttributes> devices TF_GUARDED_BY(mu);
+    std::set<string> device_set TF_GUARDED_BY(mu);
+    std::vector<string> device_list TF_GUARDED_BY(mu);
+    std::set<string> task_set TF_GUARDED_BY(mu);
+    std::vector<string> task_list TF_GUARDED_BY(mu);
     std::vector<StatusCallback> waiting TF_GUARDED_BY(mu);
   };
 
@@ -84,7 +85,7 @@
   // callback.
   typedef std::function<void(const Status& s, const GroupRec* gr)>
       GroupRecCallback;
-  void CompleteGroupLocal(const DeviceAttributes& device, CollectiveParams* cp,
+  void CompleteGroupLocal(const string& device, CollectiveParams* cp,
                           const GroupRecCallback& done)
       TF_LOCKS_EXCLUDED(group_mu_);
 
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
index b117632..f23f03d 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
@@ -23,7 +23,6 @@
 #include "tensorflow/core/common_runtime/device_resolver_local.h"
 #include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/framework/collective.h"
-#include "tensorflow/core/framework/device_attributes.pb.h"
 #include "tensorflow/core/lib/core/notification.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
@@ -87,12 +86,6 @@
     }
   }
 
-  DeviceAttributes GetDeviceAttributes(const string& device_name) {
-    Device* device = nullptr;
-    TF_CHECK_OK(device_mgr_->LookupDevice(device_name, &device));
-    return device->attributes();
-  }
-
   string task_name_;
   std::unique_ptr<DeviceMgr> device_mgr_;
   std::unique_ptr<DeviceResolverLocal> drl_;
@@ -194,13 +187,12 @@
     cp->instance.impl_details.subdiv_offsets.push_back(0);
     cp->is_source = false;
     Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
-      prl_->CompleteParamsAsync(
-          GetDeviceAttributes(cp->instance.device_names[0]), cp,
-          nullptr /*CancellationManager*/,
-          [&statuses, &note, i](const Status& s) {
-            statuses[i] = s;
-            note[i].Notify();
-          });
+      prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
+                                nullptr /*CancellationManager*/,
+                                [&statuses, &note, i](const Status& s) {
+                                  statuses[i] = s;
+                                  note[i].Notify();
+                                });
     });
   }
   for (int i = 0; i < NUM_DEVS; ++i) {
@@ -248,13 +240,12 @@
     CollectiveParams* cp = &cps[i];
     InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp);
     Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
-      prl_->CompleteParamsAsync(
-          GetDeviceAttributes(cp->instance.device_names[0]), cp,
-          nullptr /*CancellationManager*/,
-          [&statuses, &note, i](const Status& s) {
-            statuses[i] = s;
-            note[i].Notify();
-          });
+      prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
+                                nullptr /*CancellationManager*/,
+                                [&statuses, &note, i](const Status& s) {
+                                  statuses[i] = s;
+                                  note[i].Notify();
+                                });
     });
   }
   for (int i = 0; i < NUM_DEVS; ++i) {
@@ -287,13 +278,12 @@
     CollectiveParams* cp = &cps[i];
     InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp);
     Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
-      prl_->CompleteParamsAsync(
-          GetDeviceAttributes(cp->instance.device_names[0]), cp,
-          nullptr /*CancellationManager*/,
-          [&statuses, &note, i](const Status& s) {
-            statuses[i] = s;
-            note[i].Notify();
-          });
+      prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
+                                nullptr /*CancellationManager*/,
+                                [&statuses, &note, i](const Status& s) {
+                                  statuses[i] = s;
+                                  note[i].Notify();
+                                });
     });
   }
   for (int i = 0; i < NUM_DEVS; ++i) {
@@ -336,8 +326,8 @@
           strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
       cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100,
                                    /*is_source*/ i == 0);
-      prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
-                                &cancel_mgr, [&done](const Status& s) {
+      prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr,
+                                [&done](const Status& s) {
                                   EXPECT_EQ(s.code(), error::ABORTED);
                                   EXPECT_EQ(s.error_message(), "__aborted__");
                                   done.DecrementCount();
@@ -365,8 +355,8 @@
             strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
         cp[i] = MakeCollectiveParams(group_key, instance_key,
                                      /*is_source*/ i == 0);
-        prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
-                                  &cancel_mgr, [&done](const Status& s) {
+        prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr,
+                                  [&done](const Status& s) {
                                     EXPECT_EQ(s.code(), error::OK);
                                     done.DecrementCount();
                                   });
@@ -383,13 +373,12 @@
               strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
           cp[i] = MakeCollectiveParams(group_key, instance_key + 1,
                                        /*is_source*/ i == 0);
-          prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
-                                    &cancel_mgr, [&done](const Status& s) {
-                                      EXPECT_EQ(s.code(), error::ABORTED);
-                                      EXPECT_EQ(s.error_message(),
-                                                "__aborted__");
-                                      done.DecrementCount();
-                                    });
+          prl_->CompleteParamsAsync(
+              device, &cp[i], &cancel_mgr, [&done](const Status& s) {
+                EXPECT_EQ(s.code(), error::ABORTED);
+                EXPECT_EQ(s.error_message(), "__aborted__");
+                done.DecrementCount();
+              });
           start.DecrementCount();
         });
   }
@@ -413,8 +402,8 @@
             strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
         cp[i] = MakeCollectiveParams(group_key, instance_key,
                                      /*is_source*/ i == 0);
-        prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
-                                  &cancel_mgr, [&done](const Status& s) {
+        prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr,
+                                  [&done](const Status& s) {
                                     EXPECT_EQ(s.code(), error::OK);
                                     done.DecrementCount();
                                   });
@@ -429,7 +418,7 @@
     Notification done;
     auto cp = MakeCollectiveParams(group_key, instance_key,
                                    /*is_source*/ true);
-    prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, &cancel_mgr,
+    prl_->CompleteParamsAsync(device, &cp, &cancel_mgr,
                               [&done](const Status& s) {
                                 EXPECT_EQ(s.code(), error::ABORTED);
                                 EXPECT_EQ(s.error_message(), "__aborted__");
@@ -468,8 +457,7 @@
               auto cp =
                   MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key,
                                        /*is_source*/ i == 0);
-              prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp,
-                                        &cancel_mgr,
+              prl_->CompleteParamsAsync(device, &cp, &cancel_mgr,
                                         [&status, &n](const Status& s) {
                                           status = s;
                                           n.Notify();
diff --git a/tensorflow/core/common_runtime/test_collective_executor_mgr.h b/tensorflow/core/common_runtime/test_collective_executor_mgr.h
index ff4d966..c2e6d2a 100644
--- a/tensorflow/core/common_runtime/test_collective_executor_mgr.h
+++ b/tensorflow/core/common_runtime/test_collective_executor_mgr.h
@@ -16,7 +16,6 @@
 #define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_
 
 #include "tensorflow/core/framework/collective.h"
-#include "tensorflow/core/framework/device_attributes.pb.h"
 #include "tensorflow/core/lib/gtl/flatmap.h"
 
 namespace tensorflow {
@@ -36,7 +35,7 @@
 };
 
 class TestParamResolver : public ParamResolverInterface {
-  void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
+  void CompleteParamsAsync(const string& device, CollectiveParams* cp,
                            CancellationManager* cancel_mgr,
                            const StatusCallback& done) override {
     done(errors::Internal("Unimplemented"));
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 6e87d28..94570c1 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -571,9 +571,7 @@
         ":device_resolver_distributed",
         ":worker_cache",
         "//tensorflow/core:core_cpu_internal",
-        "//tensorflow/core:framework",
         "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core/platform:errors",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -608,7 +606,6 @@
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
         "//tensorflow/core/kernels:collective_ops",
-        "@com_google_absl//absl/container:flat_hash_map",
     ],
 )
 
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
index 44d9081..650c52c 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
@@ -18,18 +18,14 @@
 #include "tensorflow/core/distributed_runtime/cancellable_call.h"
 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
 #include "tensorflow/core/distributed_runtime/worker_cache.h"
-#include "tensorflow/core/framework/device_attributes.pb.h"
-#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/protobuf/config.pb.h"
-#include "tensorflow/core/util/device_name_utils.h"
 
 namespace tensorflow {
 namespace {
 
 class CompleteGroupCall : public CancellableCall {
  public:
-  CompleteGroupCall(const CollGroupParams& group,
-                    const DeviceAttributes& device,
+  CompleteGroupCall(const CollGroupParams& group, const string& device_name,
                     const CollectiveType& collective_type,
                     CancellationManager* cancel_mgr,
                     const string& remote_worker, WorkerCacheInterface* wc)
@@ -37,7 +33,7 @@
     req_.set_group_key(group.group_key);
     req_.set_group_size(group.group_size);
     req_.set_device_type(group.device_type.type_string());
-    *req_.mutable_device_attributes() = device;
+    req_.add_device_name(device_name);
     req_.set_collective_type(collective_type);
   }
   ~CompleteGroupCall() override {}
@@ -102,16 +98,16 @@
 }
 
 void CollectiveParamResolverDistributed::CompleteParamsAsync(
-    const DeviceAttributes& device, CollectiveParams* cp,
-    CancellationManager* cancel_mgr, const StatusCallback& done) {
-  VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp
-          << ": " << cp->ToString();
+    const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
+    const StatusCallback& done) {
+  VLOG(1) << "CompleteParams distributed " << device << " for " << cp << ": "
+          << cp->ToString();
   CompleteGroupDistributed(device, cp, cancel_mgr,
                            [this, device, cp, cancel_mgr, done](
                                const Status& s, const GroupRec* gr) {
                              if (s.ok()) {
-                               CompleteInstanceDistributed(
-                                   device.name(), gr, cp, cancel_mgr, done);
+                               CompleteInstanceDistributed(device, gr, cp,
+                                                           cancel_mgr, done);
                              } else {
                                done(s);
                              }
@@ -121,28 +117,28 @@
 void CollectiveParamResolverDistributed::CompleteGroupAsync(
     const CompleteGroupRequest* request, CompleteGroupResponse* response,
     CancellationManager* cancel_mgr, const StatusCallback& done) {
-  if (!request->has_device_attributes()) {
-    done(errors::Internal(
-        "CompleteGroupRequest device_attributes is not set. Make sure you're "
-        "running the same version of Tensorflow on all workers."));
-    return;
-  }
   CollectiveParams cp;
   cp.group.group_key = request->group_key();
   cp.group.group_size = request->group_size();
   cp.group.device_type = DeviceType(request->device_type());
+  for (const string& dn : request->device_name()) {
+    cp.instance.device_names.push_back(dn);
+  }
   cp.instance.type = CollectiveType(request->collective_type());
   CompleteGroupDistributed(
-      request->device_attributes(), &cp, cancel_mgr,
+      cp.instance.device_names[0], &cp, cancel_mgr,
       [response, done](const Status& s, const GroupRec* gr) {
         if (s.ok()) {
           mutex_lock l(gr->mu);
           response->set_group_key(gr->group.group_key);
           response->set_group_size(gr->group.group_size);
           response->set_device_type(gr->group.device_type.type_string());
-          response->set_num_tasks(gr->group.num_tasks);
-          for (const auto& item : gr->devices) {
-            *response->add_device_attributes() = item.second;
+          response->set_num_tasks(gr->task_set.size());
+          for (const string& dn : gr->device_list) {
+            response->add_device_name(dn);
+          }
+          for (const string& tn : gr->task_list) {
+            response->add_task_name(tn);
           }
           response->set_communicator_key(
               gr->group.runtime_details.communicator_key);
@@ -156,22 +152,6 @@
 void CollectiveParamResolverDistributed::CompleteInstanceAsync(
     const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
     CancellationManager* cancel_mgr, const StatusCallback& done) {
-  GroupRec* gr = GetCachedGroup(request->group_key());
-  if (gr == nullptr) {
-    done(errors::FailedPrecondition(
-        "group ", request->group_key(),
-        " not found. This normally means the server has restarted"));
-    return;
-  }
-  {
-    mutex_lock l(gr->mu);
-    if (!gr->status.ok() || gr->devices.size() != gr->group.group_size) {
-      done(errors::FailedPrecondition(
-          "group ", request->group_key(),
-          " failed to resolve. This normally means the server has restarted"));
-      return;
-    }
-  }
   CollectiveParams* cp = new CollectiveParams;
   cp->name = request->name();
   cp->group.group_key = request->group_key();
@@ -184,44 +164,56 @@
   for (int32 offset : request->subdiv_offset()) {
     cp->instance.impl_details.subdiv_offsets.push_back(offset);
   }
-  StatusCallback done_and_cleanup = [cp, done](const Status& s) {
+  string* device = new string(request->device());
+  VLOG(1) << "New cp " << cp << " for device " << *device << " : "
+          << cp->ToString();
+  StatusCallback done_and_cleanup = [cp, device, done](const Status& s) {
     done(s);
     delete cp;
+    delete device;
   };
-  CompleteInstanceDistributed(
-      request->device(), gr, cp, cancel_mgr,
-      [this, gr, cp, response, done_and_cleanup](const Status& ci_status) {
-        if (ci_status.ok()) {
-          // Now source_rank should be known, so
-          // retrieve it.
-          FindInstanceRec(
-              gr, cp,
-              [cp, response, done_and_cleanup](const Status& fi_status,
-                                               InstanceRec* ir) {
-                if (fi_status.ok()) {
-                  mutex_lock l(ir->out_mu);
-                  ir->WaitForOutMu(l);
-                  response->set_instance_key(cp->instance.instance_key);
-                  response->set_source_rank(ir->source_rank);
-                  done_and_cleanup(fi_status);
+  // Start by completing the group.
+  CompleteGroupDistributed(
+      *device, cp, cancel_mgr,
+      [this, cp, device, response, cancel_mgr, done_and_cleanup](
+          const Status& cg_status, const GroupRec* gr) {
+        if (cg_status.ok()) {
+          // Then complete the instance.
+          CompleteInstanceDistributed(
+              *device, gr, cp, cancel_mgr,
+              [this, gr, cp, response,
+               done_and_cleanup](const Status& ci_status) {
+                if (ci_status.ok()) {
+                  // Now source_rank should be known, so
+                  // retrieve it.
+                  FindInstanceRec(
+                      gr, cp,
+                      [cp, response, done_and_cleanup](const Status& fi_status,
+                                                       InstanceRec* ir) {
+                        if (fi_status.ok()) {
+                          mutex_lock l(ir->out_mu);
+                          ir->WaitForOutMu(l);
+                          response->set_instance_key(cp->instance.instance_key);
+                          response->set_source_rank(ir->source_rank);
+                          done_and_cleanup(fi_status);
+                        } else {
+                          done_and_cleanup(fi_status);
+                        }
+                      });
                 } else {
-                  done_and_cleanup(fi_status);
+                  done_and_cleanup(ci_status);
                 }
               });
         } else {
-          done_and_cleanup(ci_status);
+          done_and_cleanup(cg_status);
         }
       });
 }
 
-CollectiveParamResolverDistributed::GroupRec*
-CollectiveParamResolverDistributed::GetCachedGroup(int32 group_key) {
+bool CollectiveParamResolverDistributed::GroupIsCached(int32 group_key) {
   mutex_lock l(group_mu_);
-  auto it = group_table_.find(group_key);
-  if (it == group_table_.end()) {
-    return nullptr;
-  }
-  return it->second.get();
+  const auto& it = group_table_.find(group_key);
+  return it != group_table_.end();
 }
 
 Status CollectiveParamResolverDistributed::UpdateGroupCache(
@@ -234,19 +226,26 @@
     gr->group.group_key = resp.group_key();
     gr->group.group_size = resp.group_size();
     gr->group.num_tasks = resp.num_tasks();
-    if (resp.device_attributes().empty()) {
-      return errors::Internal(
-          "CompleteGroupResponse device_attributes is empty. Make sure you're "
-          "running the same version of Tensorflow on all workers.");
-    }
-    if (resp.device_attributes_size() != gr->group.group_size) {
+    if (resp.device_name_size() != gr->group.group_size) {
       return errors::Internal(
           "CompleteGroupResponse group_size doesn't match device_name list");
     }
-    for (const DeviceAttributes& device : resp.device_attributes()) {
-      gr->devices[device.name()] = device;
+    for (const string& dn : resp.device_name()) {
+      gr->device_set.insert(dn);
+      gr->device_list.push_back(dn);
     }
+    if (resp.task_name_size() != gr->group.group_size) {
+      return errors::Internal(
+          "CompleteGroupResponse group_size doesn't match task_name list");
+    }
+    for (const string& tn : resp.task_name()) {
+      gr->task_list.push_back(tn);
+      gr->task_set.insert(tn);
+    }
+    CHECK_EQ(gr->task_set.size(), gr->group.num_tasks);
     gr->group.runtime_details.communicator_key = resp.communicator_key();
+    VLOG(2) << "Group communicator_key="
+            << absl::CEscape(gr->group.runtime_details.communicator_key);
   }
   {
     // Group membership should never change. Once a record is in group_table_
@@ -274,15 +273,14 @@
 }
 
 void CollectiveParamResolverDistributed::CompleteGroupDistributed(
-    const DeviceAttributes& device, CollectiveParams* cp,
-    CancellationManager* cancel_mgr, const GroupRecCallback& done) {
+    const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
+    const GroupRecCallback& done) {
   VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
-          << " dev: " << device.name()
-          << " is_leader=" << (group_leader_.empty());
+          << " dev: " << device << " is_leader=" << (group_leader_.empty());
   if (group_leader_.empty()) {
     // This is the group leader, so resolution is local.
     return CompleteGroupLocal(device, cp, done);
-  } else if (GetCachedGroup(cp->group.group_key) == nullptr) {
+  } else if (!GroupIsCached(cp->group.group_key)) {
     // Need to update Group cache from the leader.
     CompleteGroupCall* call =
         new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr,
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h
index fc692a1..6848874 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h
@@ -16,7 +16,6 @@
 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_
 
 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
-#include "tensorflow/core/framework/device_attributes.pb.h"
 
 namespace tensorflow {
 class ConfigProto;
@@ -32,7 +31,7 @@
                                      WorkerCacheInterface* worker_cache,
                                      const string& task_name);
 
-  void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
+  void CompleteParamsAsync(const string& device, CollectiveParams* cp,
                            CancellationManager* cancel_mgr,
                            const StatusCallback& done) override;
 
@@ -47,9 +46,9 @@
                              const StatusCallback& done) override;
 
  protected:
-  // Returns the cached group iff there's an entry for this group_key in the
-  // local group_table_; returns nullptr otherwise.
-  GroupRec* GetCachedGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_);
+  // Returns true iff there's an entry for this group_key in the
+  // local group_table_.
+  bool GroupIsCached(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_);
 
   // Updates group_table_ with contents of resp.
   Status UpdateGroupCache(const CompleteGroupResponse& resp)
@@ -60,8 +59,7 @@
   //
   // Semantics are like those of CompleteGroupLocal but will make a
   // remote call to the group leader if necessary.
-  void CompleteGroupDistributed(const DeviceAttributes& device,
-                                CollectiveParams* cp,
+  void CompleteGroupDistributed(const string& device, CollectiveParams* cp,
                                 CancellationManager* cancel_mgr,
                                 const GroupRecCallback& done);
 
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
index a963c02..130a48e 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
@@ -15,7 +15,6 @@
 
 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
 
-#include "absl/container/flat_hash_map.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
 #include "tensorflow/core/distributed_runtime/test_utils.h"
@@ -24,7 +23,6 @@
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/random.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/util/device_name_utils.h"
 
@@ -43,7 +41,6 @@
   attr.set_name(name);
   attr.set_device_type(type);
   attr.mutable_locality()->set_numa_node(3);  // a non-default value
-  attr.set_incarnation(random::New64());
   return absl::make_unique<FakeDevice>(attr);
 }
 
@@ -128,110 +125,127 @@
 
 class DeviceResDistTest : public ::testing::Test {
  protected:
-  void DefineWorkers(int num_workers, int num_devices,
-                     const string& device_type, bool nccl) {
-    for (int w = 0; w < num_workers; ++w) {
-      string name = strings::StrCat("/job:worker/replica:0/task:", w);
-      DefineWorker(name, device_type, num_devices, nccl);
+  DeviceResDistTest() {}
+
+  ~DeviceResDistTest() override {
+    for (DeviceMgr* dm : device_mgrs_) {
+      delete dm;
+    }
+    for (auto it : dev_resolvers_) {
+      delete it.second;
+    }
+    for (auto it : cp_resolvers_) {
+      delete it.second;
+    }
+    for (FakeWorker* w : workers_) {
+      delete w;
     }
   }
 
-  void DefineWorker(const string& worker_name, const string& device_type,
-                    int num_devices, bool nccl) {
+  void DefineWorkers(int num_workers, int num_devices,
+                     const string& device_type, bool nccl) {
     ConfigProto config;
-    config.mutable_experimental()->set_collective_group_leader(
-        "/job:worker/replica:0/task:0");
-    config.mutable_experimental()->set_collective_nccl(nccl);
+    for (int w = 0; w < num_workers; ++w) {
+      string name = strings::StrCat("/job:worker/replica:0/task:", w);
+      if (w == 0) {
+        config.mutable_experimental()->set_collective_group_leader(name);
+        if (nccl) {
+          config.mutable_experimental()->set_collective_nccl(true);
+        }
+      }
+      DefineWorker(config, name, device_type, num_devices);
+    }
+  }
 
+  void DefineWorker(const ConfigProto& config, const string& worker_name,
+                    const string& device_type, int num_devices) {
     std::vector<std::unique_ptr<Device>> devices;
     for (int i = 0; i < num_devices; ++i) {
       devices.push_back(NewDevice(
           device_type,
           strings::StrCat(worker_name, "/device:", device_type, ":", i)));
     }
-    device_mgrs_[worker_name] =
-        absl::make_unique<StaticDeviceMgr>(std::move(devices));
+    DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
+    device_mgrs_.push_back(dev_mgr);
     std::vector<string>* dv = &dev_by_task_[worker_name];
-    dv->clear();
-    for (auto* d : device_mgrs_[worker_name]->ListDevices()) {
+    for (auto* d : dev_mgr->ListDevices()) {
       dv->push_back(d->name());
     }
-    dev_resolvers_[worker_name] = absl::make_unique<DeviceResolverDistributed>(
-        device_mgrs_[worker_name].get(), &wc_, worker_name);
-    cp_resolvers_[worker_name] =
-        absl::make_unique<CollectiveParamResolverDistributed>(
-            config, device_mgrs_[worker_name].get(),
-            dev_resolvers_[worker_name].get(), &wc_, worker_name);
-    workers_[worker_name] = absl::make_unique<FakeWorker>(
-        worker_name, device_mgrs_[worker_name].get(),
-        cp_resolvers_[worker_name].get());
-    wc_.AddWorker(worker_name, workers_[worker_name].get());
+    DeviceResolverDistributed* dev_res =
+        new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
+    dev_resolvers_[worker_name] = dev_res;
+    CollectiveParamResolverDistributed* cp_res =
+        new CollectiveParamResolverDistributed(config, dev_mgr, dev_res, &wc_,
+                                               worker_name);
+    cp_resolvers_[worker_name] = cp_res;
+    FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, cp_res);
+    workers_.push_back(fw);
+    wc_.AddWorker(worker_name, fw);
   }
 
-  void DefineCollectiveParams(int num_workers, int num_devices,
-                              const string& device_type) {
-    for (int wi = 0; wi < num_workers; ++wi) {
-      string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
-      for (int di = 0; di < num_devices; ++di) {
-        string device_name =
-            strings::StrCat(task_name, "/device:", device_type, ":", di);
-        cp_[device_name] =
-            CreateCollectiveParams(num_workers, num_devices, device_type);
-      }
-    }
-  }
-
-  CollectiveParams CreateCollectiveParams(int num_workers, int num_devices,
-                                          const string& device_type) {
+  void DefineCollectiveParams(int num_workers, int num_devices) {
     const int kGroupKey = 5;
     const int kInstanceKey = 3;
-    CollectiveParams cp;
-    cp.group.group_key = kGroupKey;
-    cp.group.group_size = num_workers * num_devices;
-    cp.group.device_type = DeviceType(device_type);
-    cp.group.num_tasks = num_workers;
-    cp.instance.instance_key = kInstanceKey;
-    cp.instance.type = REDUCTION_COLLECTIVE;
-    cp.instance.data_type = DT_FLOAT;
-    cp.instance.shape = TensorShape({64});
-    cp.instance.impl_details.subdiv_offsets.push_back(0);
-    return cp;
-  }
-
-  void IssueRequests(int num_workers, int num_devices) {
-    {
-      mutex_lock l(mu_);
-      num_done_ = 0;
-    }
-    int group_size = num_workers * num_devices;
     for (int wi = 0; wi < num_workers; ++wi) {
       string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
       for (int di = 0; di < num_devices; ++di) {
         string device_name = strings::StrCat(task_name, "/device:CPU:", di);
-        IssueRequest(task_name, device_name, group_size);
+        cp_.push_back(CollectiveParams());
+        CollectiveParams& cp = cp_.back();
+        cp.group.group_key = kGroupKey;
+        cp.group.group_size = num_workers * num_devices;
+        cp.group.device_type = DEVICE_CPU;
+        cp.group.num_tasks = num_workers;
+        cp.instance.instance_key = kInstanceKey;
+        cp.instance.type = REDUCTION_COLLECTIVE;
+        cp.instance.data_type = DT_FLOAT;
+        cp.instance.shape = TensorShape({64});
+        cp.instance.impl_details.subdiv_offsets.push_back(0);
       }
     }
   }
 
-  void IssueRequest(const string& task_name, const string& device_name,
-                    int group_size) {
-    Device* device = nullptr;
-    TF_CHECK_OK(device_mgrs_[task_name]->LookupDevice(device_name, &device));
-    CollectiveParams* cp = &cp_[device_name];
-    CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name].get();
+  void IssueRequests(int num_workers, int num_devices) {
+    const int device_count = num_workers * num_devices;
+    {
+      mutex_lock l(mu_);
+      num_done_ = 0;
+    }
+    cp_.resize(device_count);
+    status_.resize(device_count);
+    int idx = 0;
+    for (int wi = 0; wi < num_workers; ++wi) {
+      for (int di = 0; di < num_devices; ++di) {
+        IssueRequest(num_workers, num_devices, idx);
+        ++idx;
+      }
+    }
+  }
+
+  void IssueRequest(int num_workers, int num_devices, int idx) {
+    int device_count = num_workers * num_devices;
+    int wi = idx / num_devices;
+    int di = idx % num_devices;
+    string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
+    string device_name = strings::StrCat(task_name, "/device:CPU:", di);
+    while (idx >= cp_.size()) {
+      status_.resize(idx + 1);
+      cp_.resize(idx + 1);
+    }
+    CollectiveParams* cp = &cp_[idx];
+    CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name];
     CHECK(cp_res);
-    cp_res->CompleteParamsAsync(
-        device->attributes(), cp, &cm_,
-        [this, device_name, group_size](const Status& s) {
-          status_[device_name] = s;
-          {
-            mutex_lock l(mu_);
-            ++num_done_;
-            if (num_done_ == group_size) {
-              done_.notify_all();
-            }
-          }
-        });
+    cp_res->CompleteParamsAsync(device_name, cp, &cm_,
+                                [this, idx, device_count](const Status& s) {
+                                  status_[idx] = s;
+                                  {
+                                    mutex_lock l(mu_);
+                                    ++num_done_;
+                                    if (num_done_ == device_count) {
+                                      done_.notify_all();
+                                    }
+                                  }
+                                });
   }
 
   void ValidateCollectiveParams(int num_workers, int num_devices) {
@@ -245,59 +259,39 @@
     // Verify that all cp_ values get the same set of task and device
     // names, with unique default_rank in the expected order.
     const int dev_count = num_workers * num_devices;
-    string dev0 = "/job:worker/replica:0/task:0/device:CPU:0";
     for (int wi = 0; wi < num_workers; ++wi) {
       string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
       for (int di = 0; di < num_devices; ++di) {
         string device_name = strings::StrCat(task_name, "/device:CPU:", di);
         int idx = wi * num_devices + di;
-        TF_ASSERT_OK(status_[device_name]);
-        EXPECT_EQ(cp_[device_name].default_rank, idx);
-        EXPECT_EQ(cp_[device_name].instance.device_names.size(), dev_count);
-        EXPECT_EQ(cp_[device_name].instance.device_names[idx], device_name);
-        EXPECT_EQ(cp_[device_name].instance.task_names[idx], task_name);
+        TF_ASSERT_OK(status_[idx]);
+        EXPECT_EQ(cp_[idx].default_rank, idx);
+        EXPECT_EQ(cp_[idx].instance.device_names.size(), dev_count);
+        EXPECT_EQ(cp_[idx].instance.device_names[idx], device_name);
+        EXPECT_EQ(cp_[idx].instance.task_names[idx], task_name);
         if (idx > 0) {
-          EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key,
-                    cp_[device_name].group.runtime_details.communicator_key);
+          EXPECT_EQ(cp_[0].group.runtime_details.communicator_key,
+                    cp_[idx].group.runtime_details.communicator_key);
           for (int i = 0; i < dev_count; ++i) {
-            EXPECT_EQ(cp_[dev0].instance.device_names[i],
-                      cp_[device_name].instance.device_names[i]);
-            EXPECT_EQ(cp_[dev0].instance.task_names[i],
-                      cp_[device_name].instance.task_names[i]);
+            EXPECT_EQ(cp_[0].instance.device_names[i],
+                      cp_[idx].instance.device_names[i]);
+            EXPECT_EQ(cp_[0].instance.task_names[i],
+                      cp_[idx].instance.task_names[i]);
           }
         }
       }
     }
   }
 
-  void RestartWorker(int worker_idx, int num_workers, int num_devices,
-                     const string& device_type, bool nccl) {
-    string worker_name =
-        strings::StrCat("/job:worker/replica:0/task:", worker_idx);
-    DefineWorker(worker_name, device_type, num_devices, nccl);
-    for (int i = 0; i < num_devices; ++i) {
-      string device_name =
-          strings::StrCat(worker_name, "/device:", device_type, ":", i);
-      cp_[device_name] =
-          CreateCollectiveParams(num_workers, num_devices, device_type);
-      status_.erase(device_name);
-    }
-  }
-
   FakeCache wc_;
   CancellationManager cm_;
-  // Below are keyed by task names.
-  absl::flat_hash_map<string, std::unique_ptr<DeviceMgr>> device_mgrs_;
-  absl::flat_hash_map<string, std::unique_ptr<DeviceResolverDistributed>>
-      dev_resolvers_;
-  absl::flat_hash_map<string,
-                      std::unique_ptr<CollectiveParamResolverDistributed>>
-      cp_resolvers_;
-  absl::flat_hash_map<string, std::vector<string>> dev_by_task_;
-  absl::flat_hash_map<string, std::unique_ptr<FakeWorker>> workers_;
-  // Below are keyed by device names;
-  absl::flat_hash_map<string, CollectiveParams> cp_;
-  absl::flat_hash_map<string, Status> status_;
+  std::vector<DeviceMgr*> device_mgrs_;
+  std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
+  std::unordered_map<string, CollectiveParamResolverDistributed*> cp_resolvers_;
+  std::unordered_map<string, std::vector<string>> dev_by_task_;
+  std::vector<FakeWorker*> workers_;
+  std::vector<CollectiveParams> cp_;
+  std::vector<Status> status_;
   mutex mu_;
   int num_done_ TF_GUARDED_BY(mu_);
   condition_variable done_;
@@ -306,8 +300,8 @@
 TEST_F(DeviceResDistTest, Workers1Devices1) {
   const int num_workers = 1;
   const int num_devices = 1;
-  DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false);
-  DefineCollectiveParams(num_workers, num_devices, "CPU");
+  DefineWorkers(num_workers, num_devices, "CPU", false);
+  DefineCollectiveParams(num_workers, num_devices);
   IssueRequests(num_workers, num_devices);
   ValidateCollectiveParams(num_workers, num_devices);
 }
@@ -315,25 +309,12 @@
 TEST_F(DeviceResDistTest, Workers2Devices2) {
   const int num_workers = 2;
   const int num_devices = 2;
-  DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false);
-  DefineCollectiveParams(num_workers, num_devices, "CPU");
+  DefineWorkers(num_workers, num_devices, "CPU", false);
+  DefineCollectiveParams(num_workers, num_devices);
   IssueRequests(num_workers, num_devices);
   ValidateCollectiveParams(num_workers, num_devices);
 }
 
-TEST_F(DeviceResDistTest, DifferentIncarnation) {
-  const int num_workers = 2;
-  const int num_devices = 1;
-  DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false);
-  DefineCollectiveParams(num_workers, num_devices, "CPU");
-  IssueRequests(num_workers, num_devices);
-  RestartWorker(1, num_workers, num_devices, "CPU", /*nccl*/ false);
-  const string task_name = "/job:worker/replica:0/task:1";
-  const string device_name = absl::StrCat(task_name, "/device:CPU:0");
-  IssueRequest(task_name, device_name, num_workers * num_devices);
-  EXPECT_TRUE(errors::IsFailedPrecondition(status_[device_name]));
-}
-
 #if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
 namespace {
 // A mock NcclReducer for testing group runtime details initialization with CPU
@@ -366,7 +347,7 @@
   const int num_workers = 4;
   const int num_devices = 3;
   DefineWorkers(num_workers, num_devices, "CPU", true);
-  DefineCollectiveParams(num_workers, num_devices, "CPU");
+  DefineCollectiveParams(num_workers, num_devices);
   IssueRequests(num_workers, num_devices);
   ValidateCollectiveParams(num_workers, num_devices);
 }
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index 0f7a7ff..72e0b3d 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -180,8 +180,7 @@
   // Called by each collective op at first execution in order to fill out
   // the CollectiveParams structure with data gathered from the full
   // (maybe distributed) collection of peer nodes.
-  virtual void CompleteParamsAsync(const DeviceAttributes& device,
-                                   CollectiveParams* cp,
+  virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
                                    CancellationManager* cancel_mgr,
                                    const StatusCallback& done) = 0;
 
@@ -302,8 +301,7 @@
         "a CollectiveExecutor has not been provided."));
   }
 
-  virtual void CompleteParamsAsync(const DeviceAttributes& device,
-                                   CollectiveParams* cp,
+  virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
                                    CancellationManager* cancel_mgr,
                                    StatusCallback done) {
     done(errors::Internal(
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 522c096..0230852 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -73,7 +73,7 @@
                 << " group " << col_params_.group.group_key << " instance "
                 << col_params_.instance.instance_key;
         col_exec->CompleteParamsAsync(
-            c->device()->attributes(), &col_params_, c->cancellation_manager(),
+            c->device()->name(), &col_params_, c->cancellation_manager(),
             [this, c, done](const Status& s) {
               if (s.ok()) {
                 col_params_.instance.impl_details.dependencies = dependencies_;
@@ -538,8 +538,7 @@
               << " group " << col_params->group.group_key << " instance "
               << col_params->instance.instance_key;
       col_exec->CompleteParamsAsync(
-          c->device()->attributes(), col_params.get(),
-          c->cancellation_manager(),
+          c->device()->name(), col_params.get(), c->cancellation_manager(),
           [c, done = std::move(done), col_params, col_exec](const Status& s) {
             if (s.ok()) {
               auto actual_done = [c, group_key = col_params->group.group_key,
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index 0b4b502..739ba8e 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -545,10 +545,8 @@
   int32 group_key = 1;
   int32 group_size = 2;
   string device_type = 3;
+  repeated string device_name = 4;
   int32 collective_type = 5;
-  DeviceAttributes device_attributes = 6;
-
-  reserved 4;
 }
 
 // Gives the complete membership of the group identified by group_key.
@@ -557,10 +555,9 @@
   int32 group_size = 2;
   string device_type = 3;
   int32 num_tasks = 4;  // number of distinct tasks hosting the devices
+  repeated string device_name = 5;
+  repeated string task_name = 6;  // task name prefixes of device_names
   bytes communicator_key = 7;
-  repeated DeviceAttributes device_attributes = 8;
-
-  reserved 5, 6;
 }
 
 // Supplies data about one collective op belonging to the instance identified