Move CollTaskParams into CollGroupParams

CollTaskParams only store whether a device is local or not. CollTaskParams::is_local corresponds to the devices, instead of the tasks. It makes more sense to fold CollTaskParams simply into CollGroupParams.

CollTaskParams used to populated as part of CompleteInstance. After the move it will be populated as par of CompleteGroup, which also makes more sense.

PiperOrigin-RevId: 398386333
Change-Id: Ic2f4f56bb8c9fecbc9e5edf0516c4f2648930659
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index adff95d..476bf45 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -272,7 +272,7 @@
   device_assignment = xla::DeviceAssignment(params->group.group_size, 1);
   for (int device_idx = 0; device_idx < params->group.group_size;
        device_idx++) {
-    const DeviceAttributes& device = params->group.devices[device_idx];
+    const DeviceAttributes& device = params->group.members[device_idx].device;
     if (device.xla_global_id() == -1) {
       if (params->group.device_type == DEVICE_TPU) {
         return errors::InvalidArgument(
@@ -309,7 +309,7 @@
     for (int device_idx = 0; device_idx < params->group.group_size;
          device_idx++) {
       const DeviceAttributes& device_attributes =
-          params->group.devices[device_idx];
+          params->group.members[device_idx].device;
       Device* resolved_device = nullptr;
       Status lookup_status =
           device_mgr->LookupDevice(device_attributes.name(), &resolved_device);
diff --git a/tensorflow/core/common_runtime/all_to_all.cc b/tensorflow/core/common_runtime/all_to_all.cc
index abaa016..f7e6eb6 100644
--- a/tensorflow/core/common_runtime/all_to_all.cc
+++ b/tensorflow/core/common_runtime/all_to_all.cc
@@ -120,8 +120,8 @@
   string send_buf_key =
       strings::StrCat(col_ctx_->exec_key, src_rank, target_rank);
   col_ctx_->col_exec->remote_access()->PostToPeer(
-      col_params_->group.devices[target_rank].name(),
-      col_params_->group.task_names[target_rank], send_buf_key,
+      col_params_->group.members[target_rank].device.name(),
+      col_params_->group.members[target_rank].task, send_buf_key,
       col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
       col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality,
       col_ctx_->op_ctx->cancellation_manager(), done);
@@ -132,10 +132,10 @@
   string recv_buf_key =
       strings::StrCat(col_ctx_->exec_key, src_rank, target_rank);
   col_ctx_->col_exec->remote_access()->RecvFromPeer(
-      col_params_->group.devices[src_rank].name(),
-      col_params_->group.task_names[src_rank],
-      col_params_->task.is_local[src_rank], recv_buf_key, col_ctx_->device,
-      col_ctx_->op_ctx->op_device_context(),
+      col_params_->group.members[src_rank].device.name(),
+      col_params_->group.members[src_rank].task,
+      col_params_->group.members[src_rank].is_local, recv_buf_key,
+      col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
       col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality,
       0, col_ctx_->op_ctx->cancellation_manager(), done);
 }
diff --git a/tensorflow/core/common_runtime/all_to_all_test.cc b/tensorflow/core/common_runtime/all_to_all_test.cc
index 11ddc2e..7ee8434 100644
--- a/tensorflow/core/common_runtime/all_to_all_test.cc
+++ b/tensorflow/core/common_runtime/all_to_all_test.cc
@@ -48,7 +48,7 @@
                                                tensors[i].shape());
       Device* device = nullptr;
       TF_CHECK_OK(test_env_->device_mgr->LookupDevice(
-          col_params->group.devices[i].name(), &device));
+          col_params->group.members[i].device.name(), &device));
       TF_CHECK_OK(RunCollective(test_env_.get(), col_params.get(), device,
                                 &tensors[i], &tensors[i]));
       counter.DecrementCount();
@@ -82,7 +82,7 @@
                                                tensors[i].shape());
       Device* device = nullptr;
       TF_CHECK_OK(test_env_->device_mgr->LookupDevice(
-          col_params->group.devices[i].name(), &device));
+          col_params->group.members[i].device.name(), &device));
       Status status = RunCollective(test_env_.get(), col_params.get(), device,
                                     &tensors[i], &tensors[i]);
       if (!status.ok()) {
@@ -114,7 +114,7 @@
                                                tensors[i].shape());
       Device* device = nullptr;
       TF_CHECK_OK(test_env_->device_mgr->LookupDevice(
-          col_params->group.devices[i].name(), &device));
+          col_params->group.members[i].device.name(), &device));
       Status status = RunCollective(test_env_.get(), col_params.get(), device,
                                     &tensors[i], &tensors[i]);
       counter.DecrementCount();
diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc
index ea0e30f..a14a415 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.cc
+++ b/tensorflow/core/common_runtime/base_collective_executor.cc
@@ -473,7 +473,7 @@
   mutex_lock l(launch_mu_);
   if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
     const string& task_name =
-        col_params.group.task_names[col_params.default_rank];
+        col_params.group.members[col_params.default_rank].task;
     const int32_t num_devices =
         col_params.group.num_devices_per_task.at(task_name);
     launched_[col_params.instance.instance_key] = num_devices;
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 0714d37..b4336b9 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -174,7 +174,7 @@
       // Insert device if not already present.
       auto it = gr->incarnations_by_device_name.find(device.name());
       if (it == gr->incarnations_by_device_name.end()) {
-        if (gr->group.devices.size() == gr->group.group_size) {
+        if (gr->group.members.size() == gr->group.group_size) {
           // The group is already full.
           gr->status =
               errors::Internal("Device ", device.name(),
@@ -183,18 +183,20 @@
         } else {
           // This is a new device that has not yet joined the group.
           gr->incarnations_by_device_name[device.name()] = device.incarnation();
-          gr->group.devices.push_back(device);
+          CollGroupMember member;
+          member.device = device;
+          gr->group.members.push_back(std::move(member));
           new_device = true;
           if (VLOG_IS_ON(1)) {
             string dev_buf;
-            for (const auto& d : gr->group.devices) {
-              strings::StrAppend(&dev_buf, ",", d.name());
+            for (const auto& m : gr->group.members) {
+              strings::StrAppend(&dev_buf, ",", m.device.name());
             }
             VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
                     << " group_size=" << gr->group.group_size << " (current"
                     << " devices)=(" << dev_buf << ") (number of"
                     << " devices pending)="
-                    << (gr->group.group_size - gr->group.devices.size());
+                    << (gr->group.group_size - gr->group.members.size());
           }
         }
       } else {
@@ -212,14 +214,14 @@
     if (gr->status.ok()) {
       // If the group is not yet complete, queue to wait for it.
       VLOG(2) << "group_size " << gr->group.group_size << " set size "
-              << gr->group.devices.size() << " gr " << gr;
+              << gr->group.members.size() << " gr " << gr;
 
-      if (gr->group.devices.size() < gr->group.group_size) {
+      if (gr->group.members.size() < gr->group.group_size) {
         gr->pending_done.push_back(std::move(done));
         gr->pending_params.push_back(group_params);
         return;
       }
-      CHECK_EQ(gr->group.devices.size(), gr->group.group_size);
+      CHECK_EQ(gr->group.members.size(), gr->group.group_size);
       // We get a full group. Fill in remaining fields in gr->group.
       if (new_device) {
         FinishGroup(gr);
@@ -257,16 +259,16 @@
 // Create a populated GlobalDeviceMap from CollInstanceParams and localities.
 GlobalDeviceMap BuildDevRecs(const CollGroupParams& gp) {
   GlobalDeviceMap gdm;
-  CHECK_EQ(gp.devices.size(), gp.task_names.size());
-  for (int i = 0; i < gp.devices.size(); ++i) {
-    TaskDeviceMap& tdm = gdm[gp.task_names[i]];
-    DevRec* dr = &tdm[gp.devices[i].name()];
-    dr->task = gp.task_names[i];
-    dr->device = gp.devices[i].name();
+  CHECK_EQ(gp.members.size(), gp.members.size());
+  for (int i = 0; i < gp.members.size(); ++i) {
+    TaskDeviceMap& tdm = gdm[gp.members[i].task];
+    DevRec* dr = &tdm[gp.members[i].device.name()];
+    dr->task = gp.members[i].task;
+    dr->device = gp.members[i].device.name();
     dr->original_rank = i;
     dr->local_rank = 0;   // Will be populated later by OrderTaskDeviceMap.
     dr->global_rank = 0;  // Will be populated later by EstablishGlobalRank.
-    dr->locality = &gp.devices[i].locality();
+    dr->locality = &gp.members[i].device.locality();
   }
   return gdm;
 }
@@ -394,16 +396,14 @@
     TaskDeviceMap& tdm = iter.second;
     OrderTaskDeviceMap(gpu_ring_order, &tdm);
   }
-  // Connect the global rank order by the order in which tasks first appear.
-  std::set<string> ordered_tasks;
+  // Connect the global rank order by the lexicographical order of the tasks.
+  std::set<string> tasks;
+  for (const CollGroupMember& member : gp.members) {
+    tasks.insert(member.task);
+  }
   int next_rank = 0;
-  for (int i = 0; i < gp.task_names.size(); ++i) {
-    const string& task_name = gp.task_names[i];
-    if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
-      continue;
-    }
-    ordered_tasks.insert(task_name);
-    TaskDeviceMap* tdm = &gdm[task_name];
+  for (const string& task : tasks) {
+    TaskDeviceMap* tdm = &gdm[task];
     for (auto& it : *tdm) {
       it.second.global_rank = it.second.local_rank + next_rank;
     }
@@ -417,19 +417,9 @@
 // be sorted.
 void SetDevPerTask(CollGroupParams* gp) {
   gp->num_devices_per_task.clear();
-  const string* last_task_name = &gp->task_names[0];
-  int count = 0;
-  for (const string& task_name : gp->task_names) {
-    if (task_name == *last_task_name) {
-      ++count;
-    } else {
-      gp->num_devices_per_task[*last_task_name] = count;
-      count = 1;
-      last_task_name = &task_name;
-    }
+  for (const CollGroupMember& member : gp->members) {
+    gp->num_devices_per_task[member.task]++;
   }
-  gp->num_devices_per_task[*last_task_name] = count;
-
   gp->same_num_devices_per_task = false;
   int dev_per_task = -1;
   for (const auto& task_dev : gp->num_devices_per_task) {
@@ -445,18 +435,13 @@
 }  // namespace
 
 void CollectiveParamResolverLocal::FinishGroup(GroupRec* gr) {
-  // Sort devices lexicographically first.
-  std::sort(gr->group.devices.begin(), gr->group.devices.end(),
-            [](const DeviceAttributes& lhs, const DeviceAttributes& rhs) {
-              return lhs.name() < rhs.name();
-            });
-  // Build task_names, which is needed by CompleteDefaultRanking.
-  gr->group.task_names.reserve(gr->group.devices.size());
-  for (const DeviceAttributes& device : gr->group.devices) {
-    gr->group.task_names.push_back(TaskNameFromDeviceName(device.name()));
+  // Populate group member task and is_local.
+  for (CollGroupMember& member : gr->group.members) {
+    member.task = TaskNameFromDeviceName(member.device.name());
+    member.is_local = member.task == task_name_;
   }
-  // Establish the final order of gp->devices and gp->task_names by
-  // considering localities of all devices.
+  // Establish the order of the members by considering localities of all
+  // devices.
   CompleteDefaultRanking(&gr->group);
   SetDevPerTask(&gr->group);
   gr->group.num_tasks =
@@ -476,7 +461,7 @@
   }
   {
     mutex_lock l(gr->mu);
-    if (gr->group.devices.size() == gr->group.group_size) {
+    if (gr->group.members.size() == gr->group.group_size) {
       // The group is already complete. There's no need to cancel.
       return;
     }
@@ -489,19 +474,11 @@
   }
 }
 
-void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
-                                                       CollectiveParams* cp) {
-  cp->task.is_local.resize(cp->group.group_size, false);
-  for (int i = 0; i < cp->group.group_size; ++i) {
-    cp->task.is_local[i] = (cp->group.task_names[i] == task_name);
-  }
-}
-
 void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
                                                   CollectiveParams* cp) {
-  CHECK_EQ(cp->group.group_size, cp->group.devices.size()) << cp->ToString();
+  CHECK_EQ(cp->group.group_size, cp->group.members.size()) << cp->ToString();
   for (int i = 0; i < cp->group.group_size; ++i) {
-    if (cp->group.devices[i].name() == device) {
+    if (cp->group.members[i].device.name() == device) {
       cp->default_rank = i;
       break;
     }
@@ -512,12 +489,6 @@
     const CollectiveParams* cp, InstanceRec* ir) {
   ir->shared->instance = cp->instance;
   ir->shared->default_rank = -1;
-
-  // Set is_local and task_names in *shared prior to invoking
-  // GetDeviceAttributesAsync.  In a distributed context this function can be
-  // called by a derived class, some of the devices may be non-local and
-  // GetDeviceAttributesAsync will use those fields to launch RPCs.
-  CompleteTaskIsLocal(task_name_, ir->shared);
 }
 
 // NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
@@ -525,30 +496,33 @@
 // TensorFlow runtime.  This set of devices may be a superset of the devices
 // participating in this instance of collectives.
 void CollectiveParamResolverLocal::CompleteDefaultRanking(CollGroupParams* gp) {
+  // Sort gp->member to avoid indeterminism.
+  std::sort(gp->members.begin(), gp->members.end(),
+            [](const CollGroupMember& lhs, const CollGroupMember& rhs) {
+              return lhs.device.name() < rhs.device.name();
+            });
   // Establish an instance-specific default rank order for devices
   // based on localities.  This rank order should be a good ring
   // order, if possible.
   GlobalDeviceMap gdm = EstablishGlobalRank(*gp, gpu_ring_order_);
   // Reflect the new global ranking on shared
-  std::vector<DeviceAttributes> new_devices(gp->group_size);
-  std::vector<string> new_task_names(gp->group_size);
+  std::vector<CollGroupMember> new_members(gp->group_size);
   for (const auto& git : gdm) {
     const TaskDeviceMap& tdm = git.second;
     for (const auto& tit : tdm) {
       const DevRec& dr = tit.second;
-      new_devices[dr.global_rank] = gp->devices[dr.original_rank];
-      new_task_names[dr.global_rank] = gp->task_names[dr.original_rank];
+      new_members[dr.global_rank] = std::move(gp->members[dr.original_rank]);
     }
   }
 
   if (VLOG_IS_ON(2)) {
     string buf;
-    for (const auto& d : new_devices) strings::StrAppend(&buf, "\n", d.name());
+    for (const auto& m : new_members)
+      strings::StrAppend(&buf, "\n", m.device.name());
     VLOG(2) << "Optimized device order for group " << gp->group_key << ": "
             << buf;
   }
-  gp->devices = std::move(new_devices);
-  gp->task_names = std::move(new_task_names);
+  gp->members = std::move(new_members);
 }
 
 CollectiveParamResolverLocal::InstanceRec*
@@ -724,7 +698,6 @@
   // Populate the fields common across task.
   AssignCollectiveType(cp);
   SetDefaultRank(device, cp);
-  CompleteTaskIsLocal(task_name_, cp);
 
   CollectiveImplementationInterface* col_impl;
   status = CollectiveRegistry::LookupParamResolverInstance(
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index e9891ed..64ad8bb 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -172,10 +172,6 @@
   Status GetLocalDeviceLocalities(const CollectiveParams& cp,
                                   std::vector<DeviceLocality>* localities);
 
-  // Sets CollTaskParams.is_local and CollectiveParams.default_rank.
-  // Precondition: cp->device_names is fully populated and in final order.
-  void CompleteTaskIsLocal(const string& task_name, CollectiveParams* cp);
-
   // Sets cp->instance_default_rank according to location of device in
   // current ordering of cp->instance.device_names.
   void SetDefaultRank(const string& device, CollectiveParams* cp);
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
index dac2046..c61f976 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
@@ -71,8 +71,8 @@
     ResetParamResolver(config);
     prl_->CompleteDefaultRanking(&group);
     std::vector<string> actual_device_order;
-    for (const DeviceAttributes& device : group.devices) {
-      actual_device_order.push_back(device.name());
+    for (const CollGroupMember& member : group.members) {
+      actual_device_order.push_back(member.device.name());
     }
     EXPECT_EQ(actual_device_order, expected_device_order);
   }
@@ -97,9 +97,9 @@
   group.group_size = kNumGpus;
   std::unordered_set<int> clique1 = {0, 1, 6, 7};
   for (int gpu_idx = 0; gpu_idx < kNumGpus; ++gpu_idx) {
-    group.task_names.push_back("/job:localhost/replica:0/task:0");
-    DeviceAttributes device;
-    device.set_name(strings::StrCat(
+    CollGroupMember member;
+    member.task = "/job:localhost/replica:0/task:0";
+    member.device.set_name(strings::StrCat(
         "/job:localhost/replica:0/task:0/device:GPU:", gpu_idx));
     // Build localities so that 0,1,6,7 and 2,3,4,5 form 2 strongly connected
     // components.  Across components, connect 3 and 7.
@@ -109,19 +109,19 @@
       bool link_in_clique1 = clique1.find(link_idx) != clique1.end();
       if ((gpu_in_clique1 && link_in_clique1) ||
           (!gpu_in_clique1 && !link_in_clique1)) {
-        LocalLinks* links = device.mutable_locality()->mutable_links();
+        LocalLinks* links = member.device.mutable_locality()->mutable_links();
         InterconnectLink* ilink = links->add_link();
         ilink->set_device_id(link_idx);
         ilink->set_strength(2);
       } else if ((gpu_idx == 3 && link_idx == 7) ||
                  (gpu_idx == 7 && link_idx == 3)) {
-        LocalLinks* links = device.mutable_locality()->mutable_links();
+        LocalLinks* links = member.device.mutable_locality()->mutable_links();
         InterconnectLink* ilink = links->add_link();
         ilink->set_device_id(link_idx);
         ilink->set_strength(1);
       }
     }
-    group.devices.push_back(device);
+    group.members.push_back(member);
   }
   RunCompleteDefaultRanking(group, {1, 3, 5, 7, 6, 4, 2, 0},
                             {
@@ -193,12 +193,12 @@
   }
   for (int i = 0; i < NUM_DEVS; ++i) {
     TF_ASSERT_OK(statuses[i]);
-    ASSERT_EQ(cps[i]->group.devices.size(), 3);
+    ASSERT_EQ(cps[i]->group.members.size(), 3);
     for (int j = 0; j < NUM_DEVS; ++j) {
       EXPECT_EQ(
           strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
-          cps[i]->group.devices[j].name());
-      EXPECT_TRUE(cps[i]->task.is_local[j]);
+          cps[i]->group.members[j].device.name());
+      EXPECT_TRUE(cps[i]->group.members[j].is_local);
     }
     EXPECT_EQ(cps[i]->instance.impl_details.subdiv_source_rank.size(), 0);
     EXPECT_FALSE(cps[i]->is_source);
@@ -248,12 +248,12 @@
   }
   for (int i = 0; i < NUM_DEVS; ++i) {
     TF_ASSERT_OK(statuses[i]);
-    ASSERT_EQ(cps[i]->group.devices.size(), 3);
+    ASSERT_EQ(cps[i]->group.members.size(), 3);
     for (int j = 0; j < NUM_DEVS; ++j) {
       EXPECT_EQ(
           strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
-          cps[i]->group.devices[j].name());
-      EXPECT_TRUE(cps[i]->task.is_local[j]);
+          cps[i]->group.members[j].device.name());
+      EXPECT_TRUE(cps[i]->group.members[j].is_local);
     }
     EXPECT_EQ(cps[i]->is_source, (i == 1));
     EXPECT_EQ(cps[i]->default_rank, i);
diff --git a/tensorflow/core/common_runtime/collective_test_util.cc b/tensorflow/core/common_runtime/collective_test_util.cc
index 5c989fb..77ce74c 100644
--- a/tensorflow/core/common_runtime/collective_test_util.cc
+++ b/tensorflow/core/common_runtime/collective_test_util.cc
@@ -216,14 +216,14 @@
     col_params->group.num_devices_per_task[task_name] =
         test_env.num_devices_per_worker;
     for (int di = 0; di < test_env.num_devices_per_worker; ++di) {
-      DeviceAttributes device;
-      device.set_name(strings::StrCat(
+      CollGroupMember member;
+      member.device.set_name(strings::StrCat(
           task_name, "/device:", test_env.device_type.type_string(), ":", di));
-      col_params->group.devices.push_back(device);
-      col_params->group.task_names.push_back(task_name);
+      member.task = task_name;
       // Normally each device would set is_local to its own perspective but
       // this test runs in a single process so is_local is always true.
-      col_params->task.is_local.push_back(true);
+      member.is_local = true;
+      col_params->group.members.push_back(member);
     }
   }
 
diff --git a/tensorflow/core/common_runtime/collective_util.cc b/tensorflow/core/common_runtime/collective_util.cc
index 9b49e12..0c86d1f 100644
--- a/tensorflow/core/common_runtime/collective_util.cc
+++ b/tensorflow/core/common_runtime/collective_util.cc
@@ -60,8 +60,9 @@
     for (int di = 0; di < subdiv_perms[sdi].size(); ++di) {
       int idx = subdiv_perms[sdi][di];
       if (idx >= 0) {
-        CHECK_GT(col_params.group.devices.size(), idx);
-        strings::StrAppend(&buf, col_params.group.devices[idx].name(), "\n");
+        CHECK_GT(col_params.group.members.size(), idx);
+        strings::StrAppend(&buf, col_params.group.members[idx].device.name(),
+                           "\n");
       }
     }
     strings::StrAppend(&buf, " subdiv_offsets: ");
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
index 21b7c99..326c0d6 100644
--- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
@@ -81,20 +81,18 @@
   CHECK_EQ(col_params->instance.impl_details.collective_name,
            "HierarchicalTreeBroadcast");
   const string& device_name =
-      col_params->group.devices[col_params->default_rank].name();
+      col_params->group.members[col_params->default_rank].device.name();
   // Start by counting the devices in each task.
   // Precondition: device_names must be sorted so that all devices in
   // the same task are adjacent.
-  VLOG(2) << "Sorted task names: "
-          << absl::StrJoin(col_params->group.task_names, ", ");
   std::vector<int> dev_per_task;
-  const string* prior_task_name = &col_params->group.task_names[0];
+  const string* prior_task_name = &col_params->group.members[0].task;
   int dev_count = 1;
   for (int di = 1; di < col_params->group.group_size; ++di) {
-    if (col_params->group.task_names[di] != *prior_task_name) {
+    if (col_params->group.members[di].task != *prior_task_name) {
       dev_per_task.push_back(dev_count);
       dev_count = 1;
-      prior_task_name = &col_params->group.task_names[di];
+      prior_task_name = &col_params->group.members[di].task;
     } else {
       ++dev_count;
     }
@@ -137,13 +135,13 @@
         // Source device belongs to this task.
         perm.push_back(col_params->source_rank);
         participate =
-            col_params->group.devices[col_params->source_rank].name() ==
+            col_params->group.members[col_params->source_rank].device.name() ==
             device_name;
       } else {
         // Source does not belong to this task, choose dev 0.
         perm.push_back(device_count);
-        participate =
-            col_params->group.devices[device_count].name() == device_name;
+        participate = col_params->group.members[device_count].device.name() ==
+                      device_name;
       }
       if (participate) col_params->subdiv_rank.push_back(ti);
       device_count += dev_per_task[ti];
@@ -167,7 +165,7 @@
     int subdiv_source = 0;
     for (int di = 0; di < dev_per_task[ti]; di++) {
       perm.push_back(abs_di);
-      if (col_params->group.devices[abs_di].name() == device_name) {
+      if (col_params->group.members[abs_di].device.name() == device_name) {
         participate = true;
         col_params->subdiv_rank.push_back(di);
       }
@@ -420,11 +418,12 @@
       col_params_->instance.impl_details.subdiv_permutations[subdiv][dst_rank];
   VLOG(3) << "DispatchSend " << send_buf_key << " from_device "
           << col_ctx_->device_name << " to_device "
-          << col_params_->group.devices[dst_idx].name() << " subdiv=" << subdiv
-          << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx;
+          << col_params_->group.members[dst_idx].device.name()
+          << " subdiv=" << subdiv << " dst_rank=" << dst_rank
+          << " dst_idx=" << dst_idx;
   col_ctx_->col_exec->remote_access()->PostToPeer(
-      col_params_->group.devices[dst_idx].name(),
-      col_params_->group.task_names[dst_idx], send_buf_key, col_ctx_->device,
+      col_params_->group.members[dst_idx].device.name(),
+      col_params_->group.members[dst_idx].task, send_buf_key, col_ctx_->device,
       col_ctx_->op_ctx->op_device_context(),
       col_ctx_->op_ctx->output_alloc_attr(0), src_tensor,
       col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(),
@@ -439,14 +438,14 @@
   int src_idx =
       col_params_->instance.impl_details.subdiv_permutations[subdiv][src_rank];
   VLOG(3) << "DispatchRecv " << recv_buf_key << " from_device "
-          << col_params_->group.devices[src_idx].name() << " to_device "
+          << col_params_->group.members[src_idx].device.name() << " to_device "
           << col_ctx_->device_name << " subdiv=" << subdiv
           << " src_rank=" << src_rank << " src_idx=" << src_idx;
   col_ctx_->col_exec->remote_access()->RecvFromPeer(
-      col_params_->group.devices[src_idx].name(),
-      col_params_->group.task_names[src_idx],
-      col_params_->task.is_local[src_idx], recv_buf_key, col_ctx_->device,
-      col_ctx_->op_ctx->op_device_context(),
+      col_params_->group.members[src_idx].device.name(),
+      col_params_->group.members[src_idx].task,
+      col_params_->group.members[src_idx].is_local, recv_buf_key,
+      col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
       col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
       col_ctx_->device_locality, 0 /*stream_index*/,
       col_ctx_->op_ctx->cancellation_manager(), done);
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
index 9726117..be93bae 100644
--- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
@@ -222,7 +222,7 @@
       // In the test we always broadcast from rank 0.
       col_params_->is_source = (rank == 0);
       col_params_->source_rank = 0;
-      string dev_name = col_params_->group.devices[rank].name();
+      string dev_name = col_params_->group.members[rank].device.name();
       TF_CHECK_OK(test_env_->device_mgr->LookupDevice(dev_name, &device_))
           << "Couldn't find device " << dev_name
           << " existing devices: " << test_env_->device_mgr->DebugString();
@@ -359,10 +359,10 @@
   for (int ti = 0; ti < cp->group.num_tasks; ti++) {
     string task_name = strings::StrCat("/job:worker/replica:0/task:", ti);
     for (int di = 0; di < dev_per_task[ti]; di++) {
-      DeviceAttributes device;
-      device.set_name(strings::StrCat(task_name, "/device:GPU:", di));
-      cp->group.task_names.push_back(task_name);
-      cp->group.devices.push_back(device);
+      CollGroupMember member;
+      member.device.set_name(strings::StrCat(task_name, "/device:GPU:", di));
+      member.task = task_name;
+      cp->group.members.push_back(member);
       cp->group.group_size++;
     }
   }
diff --git a/tensorflow/core/common_runtime/permuter.cc b/tensorflow/core/common_runtime/permuter.cc
index c1dcd20..b83157d 100644
--- a/tensorflow/core/common_runtime/permuter.cc
+++ b/tensorflow/core/common_runtime/permuter.cc
@@ -87,7 +87,7 @@
           << " target_rank=" << target_rank << " src_rank=" << src_rank;
   col_ctx_->col_exec->remote_access()->PostToPeer(
       col_params_->instance.devices[target_rank],
-      col_params_->group.task_names[target_rank], send_buf_key,
+      col_params_->group.members[target_rank].task, send_buf_key,
       col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
       col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality,
       col_ctx_->op_ctx->cancellation_manager(), done);
@@ -103,9 +103,9 @@
           << " target_rank=" << target_rank << " src_rank=" << src_rank;
   col_ctx_->col_exec->remote_access()->RecvFromPeer(
       col_params_->instance.devices[src_rank],
-      col_params_->group.task_names[src_rank],
-      col_params_->task.is_local[src_rank], recv_buf_key, col_ctx_->device,
-      col_ctx_->op_ctx->op_device_context(),
+      col_params_->group.members[src_rank].task,
+      col_params_->group.members[src_rank].is_local, recv_buf_key,
+      col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
       col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality,
       0, col_ctx_->op_ctx->cancellation_manager(), done);
 }
diff --git a/tensorflow/core/common_runtime/permuter_test.cc b/tensorflow/core/common_runtime/permuter_test.cc
index 499a112..8f1b123 100644
--- a/tensorflow/core/common_runtime/permuter_test.cc
+++ b/tensorflow/core/common_runtime/permuter_test.cc
@@ -147,10 +147,10 @@
       col_params_ = CreateCollectiveParams(*test_env_, rank, "Permute",
                                            PERMUTE_COLLECTIVE, dtype, shape);
       col_params_->instance.permutation = std::move(permutation);
-      for (const DeviceAttributes& device : col_params_->group.devices) {
-        col_params_->instance.devices.push_back(device.name());
+      for (const CollGroupMember& member : col_params_->group.members) {
+        col_params_->instance.devices.push_back(member.device.name());
       }
-      string dev_name = col_params_->group.devices[rank].name();
+      string dev_name = col_params_->group.members[rank].device.name();
       TF_CHECK_OK(test_env_->device_mgr->LookupDevice(dev_name, &device_))
           << "Couldn't find device " << dev_name
           << " existing devices: " << test_env_->device_mgr->DebugString();
diff --git a/tensorflow/core/common_runtime/ring_alg.cc b/tensorflow/core/common_runtime/ring_alg.cc
index 800ff03..1bfd3b7 100644
--- a/tensorflow/core/common_runtime/ring_alg.cc
+++ b/tensorflow/core/common_runtime/ring_alg.cc
@@ -179,7 +179,7 @@
 
 Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
   const string& device_name =
-      col_params->group.devices[col_params->default_rank].name();
+      col_params->group.members[col_params->default_rank].device.name();
   // Each subdiv permutation is a ring formed by rotating each
   // single-task subsequence of devices by an offset.  This makes most
   // sense when each task has the same number of devices but we can't
@@ -189,16 +189,14 @@
   // Start by counting the devices in each task.
   // Precondition: device_names must be sorted so that all devices in
   // the same task are adjacent.
-  VLOG(2) << "Sorted task names: "
-          << absl::StrJoin(col_params->group.task_names, ", ");
   std::vector<int> dev_per_task;
-  const string* prior_task_name = &col_params->group.task_names[0];
+  const string* prior_task_name = &col_params->group.members[0].task;
   int dev_count = 1;
   for (int di = 1; di < col_params->group.group_size; ++di) {
-    if (col_params->group.task_names[di] != *prior_task_name) {
+    if (col_params->group.members[di].task != *prior_task_name) {
       dev_per_task.push_back(dev_count);
       dev_count = 1;
-      prior_task_name = &col_params->group.task_names[di];
+      prior_task_name = &col_params->group.members[di].task;
     } else {
       ++dev_count;
     }
@@ -242,7 +240,8 @@
         int permuted_di = prior_dev_count + offset_di;
         int rank = static_cast<int>(perm.size());
         perm.push_back(permuted_di);
-        if (col_params->group.devices[permuted_di].name() == device_name) {
+        if (col_params->group.members[permuted_di].device.name() ==
+            device_name) {
           DCHECK_EQ(permuted_di, col_params->default_rank);
           col_params->subdiv_rank[sdi] = rank;
         }
@@ -346,8 +345,8 @@
                          .subdiv_permutations[subdiv_idx][recv_from_rank];
   int send_dev_idx = col_params_->instance.impl_details
                          .subdiv_permutations[subdiv_idx][send_to_rank];
-  rf->recv_is_remote = !col_params_->task.is_local[rf->recv_dev_idx];
-  rf->send_is_remote = !col_params_->task.is_local[send_dev_idx];
+  rf->recv_is_remote = !col_params_->group.members[rf->recv_dev_idx].is_local;
+  rf->send_is_remote = !col_params_->group.members[send_dev_idx].is_local;
   if (ca_->ChunkBytes(rf->sc_idx) > 0) {
     // In pass 0 we skip Recv when rank = chunk_idx
     rf->do_recv = (rf->chunk_idx != rf->rank);
@@ -405,8 +404,8 @@
   int send_to_dev_idx = col_params_->instance.impl_details
                             .subdiv_permutations[rf->subdiv_idx][send_to_rank];
   col_ctx_->col_exec->remote_access()->PostToPeer(
-      col_params_->group.devices[send_to_dev_idx].name(),
-      col_params_->group.task_names[send_to_dev_idx], send_buf_key,
+      col_params_->group.members[send_to_dev_idx].device.name(),
+      col_params_->group.members[send_to_dev_idx].task, send_buf_key,
       col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
       col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk,
       col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(),
@@ -425,9 +424,9 @@
                            ? &rf->tmp_chunk
                            : &rf->chunk;
   col_ctx_->col_exec->remote_access()->RecvFromPeer(
-      col_params_->group.devices[rf->recv_dev_idx].name(),
-      col_params_->group.task_names[rf->recv_dev_idx],
-      col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key,
+      col_params_->group.members[rf->recv_dev_idx].device.name(),
+      col_params_->group.members[rf->recv_dev_idx].task,
+      col_params_->group.members[rf->recv_dev_idx].is_local, recv_buf_key,
       col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
       col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
       col_ctx_->device_locality, rf->subdiv_idx,
diff --git a/tensorflow/core/common_runtime/ring_gatherer.cc b/tensorflow/core/common_runtime/ring_gatherer.cc
index 0877333..27a6853 100644
--- a/tensorflow/core/common_runtime/ring_gatherer.cc
+++ b/tensorflow/core/common_runtime/ring_gatherer.cc
@@ -71,9 +71,9 @@
 
   if (VLOG_IS_ON(1)) {
     string buf;
-    for (int r = 0; r < col_params_->group.devices.size(); ++r) {
+    for (int r = 0; r < col_params_->group.members.size(); ++r) {
       strings::StrAppend(&buf, "dev ", r, " : ",
-                         col_params_->group.devices[r].name(), "\n");
+                         col_params_->group.members[r].device.name(), "\n");
     }
     for (int sd = 0;
          sd < col_params_->instance.impl_details.subdiv_permutations.size();
diff --git a/tensorflow/core/common_runtime/ring_gatherer_test.cc b/tensorflow/core/common_runtime/ring_gatherer_test.cc
index abad866..063f5f9 100644
--- a/tensorflow/core/common_runtime/ring_gatherer_test.cc
+++ b/tensorflow/core/common_runtime/ring_gatherer_test.cc
@@ -131,7 +131,7 @@
             GenerateEvenSubdivOffsets(test_env->num_devices_per_worker,
                                       num_subdivs);
       }
-      string dev_name = col_params_->group.devices[rank].name();
+      string dev_name = col_params_->group.members[rank].device.name();
       TF_CHECK_OK(test_env_->device_mgr->LookupDevice(dev_name, &device_))
           << "Couldn't find device " << dev_name
           << " existing devices: " << test_env_->device_mgr->DebugString();
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index 6d2b3d3..23b3de7 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -67,9 +67,9 @@
 
   if (VLOG_IS_ON(1)) {
     string buf;
-    for (int r = 0; r < col_params_->group.devices.size(); ++r) {
+    for (int r = 0; r < col_params_->group.members.size(); ++r) {
       strings::StrAppend(&buf, "dev ", r, " : ",
-                         col_params_->group.devices[r].name(), "\n");
+                         col_params_->group.members[r].device.name(), "\n");
     }
     for (int sd = 0;
          sd < col_params_->instance.impl_details.subdiv_permutations.size();
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index eafce7b..c16e5c2 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -166,7 +166,7 @@
             GenerateEvenSubdivOffsets(test_env->num_devices_per_worker,
                                       num_subdivs);
       }
-      string dev_name = col_params_->group.devices[rank].name();
+      string dev_name = col_params_->group.members[rank].device.name();
       TF_CHECK_OK(test_env_->device_mgr->LookupDevice(dev_name, &device_))
           << "Couldn't find device " << dev_name
           << " existing devices: " << test_env_->device_mgr->DebugString();
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
index 563c282..fbbcd5e 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
@@ -114,7 +114,12 @@
         device, &cp->group, cancel_mgr,
         [this, device, cp, cancel_mgr, done](Status s) {
           if (s.ok()) {
-            s = dev_resolver_->UpdateDeviceAttributes(cp->group.devices);
+            std::vector<DeviceAttributes> devices;
+            devices.reserve(cp->group.group_size);
+            for (const CollGroupMember& m : cp->group.members) {
+              devices.push_back(m.device);
+            }
+            s = dev_resolver_->UpdateDeviceAttributes(devices);
           }
           if (s.ok()) {
             CompleteInstanceDistributed(device.name(), cp, cancel_mgr, done);
@@ -156,7 +161,7 @@
     if (!gr->status.ok()) {
       done(gr->status);
       return;
-    } else if (gr->group.devices.size() != gr->group.group_size) {
+    } else if (gr->group.members.size() != gr->group.group_size) {
       done(errors::FailedPrecondition(
           "group ", request->group_key(),
           " failed to resolve. This normally means the server has restarted"));
@@ -226,9 +231,11 @@
       return errors::Internal(
           "CompleteGroupResponse group_size doesn't match device_name list");
     }
-    gr->group.devices.reserve(resp.device_attributes().size());
+    gr->group.members.reserve(resp.device_attributes().size());
     for (const DeviceAttributes& device : resp.device_attributes()) {
-      gr->group.devices.push_back(device);
+      CollGroupMember member;
+      member.device = device;
+      gr->group.members.push_back(std::move(member));
       gr->incarnations_by_device_name[device.name()] = device.incarnation();
     }
     gr->group.runtime_details.communicator_key = resp.communicator_key();
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
index c52a4e2..29d2a1a 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
@@ -252,18 +252,19 @@
         int idx = wi * num_devices + di;
         TF_ASSERT_OK(status_[device_name]);
         EXPECT_EQ(cp_[device_name]->default_rank, idx);
-        EXPECT_EQ(cp_[device_name]->group.devices.size(), dev_count);
-        EXPECT_EQ(cp_[device_name]->group.devices[idx].name(), device_name);
-        EXPECT_EQ(cp_[device_name]->group.task_names[idx], task_name);
+        EXPECT_EQ(cp_[device_name]->group.members.size(), dev_count);
+        EXPECT_EQ(cp_[device_name]->group.members[idx].device.name(),
+                  device_name);
+        EXPECT_EQ(cp_[device_name]->group.members[idx].task, task_name);
         ValidateDeviceResolver(*cp_[device_name], task_name);
         if (idx > 0) {
           EXPECT_EQ(cp_[dev0]->group.runtime_details.communicator_key,
                     cp_[device_name]->group.runtime_details.communicator_key);
           for (int i = 0; i < dev_count; ++i) {
-            EXPECT_EQ(cp_[dev0]->group.devices[i].name(),
-                      cp_[device_name]->group.devices[i].name());
-            EXPECT_EQ(cp_[dev0]->group.task_names[i],
-                      cp_[device_name]->group.task_names[i]);
+            EXPECT_EQ(cp_[dev0]->group.members[i].device.name(),
+                      cp_[device_name]->group.members[i].device.name());
+            EXPECT_EQ(cp_[dev0]->group.members[i].task,
+                      cp_[device_name]->group.members[i].task);
           }
         }
       }
@@ -271,10 +272,10 @@
   }
 
   void ValidateDeviceResolver(const CollectiveParams& cp, const string& task) {
-    for (const DeviceAttributes& device : cp.group.devices) {
+    for (const CollGroupMember& member : cp.group.members) {
       DeviceAttributes attributes;
-      TF_ASSERT_OK(dev_resolvers_[task]->GetDeviceAttributes(device.name(),
-                                                             &attributes));
+      TF_ASSERT_OK(dev_resolvers_[task]->GetDeviceAttributes(
+          member.device.name(), &attributes));
     }
   }
 
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 862392b..c612756 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -414,8 +414,8 @@
             response->set_group_size(group_params->group_size);
             response->set_device_type(group_params->device_type.type_string());
             response->set_num_tasks(group_params->num_tasks);
-            for (const DeviceAttributes& device : group_params->devices) {
-              *response->add_device_attributes() = device;
+            for (const CollGroupMember& member : group_params->members) {
+              *response->add_device_attributes() = member.device;
             }
             response->set_communicator_key(
                 group_params->runtime_details.communicator_key);
diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc
index fede957..c632e45 100644
--- a/tensorflow/core/framework/collective.cc
+++ b/tensorflow/core/framework/collective.cc
@@ -58,12 +58,8 @@
       "CollGroupParams {group_key=", group_key, " group_size=", group_size,
       " device_type=", device_type.type_string(), " num_tasks=", num_tasks,
       " runtime_details=", runtime_details.ToString(), " devices {");
-  for (const auto& d : devices) {
-    strings::StrAppend(&v, d.name(), ",");
-  }
-  strings::StrAppend(&v, "} task_names={");
-  for (const auto& n : task_names) {
-    strings::StrAppend(&v, n, ", ");
+  for (const auto& m : members) {
+    strings::StrAppend(&v, m.device.name(), ",");
   }
   strings::StrAppend(&v, "} num_devices_per_task={");
   for (const auto& dpt : num_devices_per_task) {
@@ -138,19 +134,9 @@
   return v;
 }
 
-string CollTaskParams::ToString() const {
-  string v = strings::StrCat("CollTaskParams {is_local={");
-  for (const auto& b : is_local) {
-    strings::StrAppend(&v, static_cast<int>(b), ",");
-  }
-  strings::StrAppend(&v, "}}");
-  return v;
-}
-
 string CollectiveParams::ToString() const {
   string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
   strings::StrAppend(&v, " ", instance.ToString());
-  strings::StrAppend(&v, " ", task.ToString());
   strings::StrAppend(&v, " default_rank=", default_rank,
                      " is_source=", is_source, " source_rank=", source_rank,
                      " subdiv_rank={");
@@ -183,7 +169,8 @@
       input(input),
       output(output),
       device(nullptr),
-      device_name(col_params->group.devices[col_params->default_rank].name()) {}
+      device_name(
+          col_params->group.members[col_params->default_rank].device.name()) {}
 
 /*static*/
 int64_t CollectiveExecutor::kInvalidId = -1;
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index d2582fb..a7abf1a 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -58,6 +58,12 @@
   string ToString() const;
 };
 
+struct CollGroupMember {
+  DeviceAttributes device;
+  string task;
+  bool is_local;
+};
+
 // Data common to all members of a device group.
 // All members share the same device set but its order is
 // particular to an instance so it is stored there.
@@ -65,10 +71,8 @@
   int32 group_key;
   int32 group_size;
   DeviceType device_type;
-  // Devices in this group, in default rank order.
-  std::vector<DeviceAttributes> devices;
-  // Task name prefix of corresponding device name.
-  std::vector<string> task_names;
+  // Members in this group, in default rank order.
+  std::vector<CollGroupMember> members;
   // True if every task has the same number of devices.
   bool same_num_devices_per_task = false;
   // Task -> number of devices on that task.
@@ -130,18 +134,10 @@
   std::vector<int> permutation;
 };
 
-// Data common to all instance members in the same task.
-struct CollTaskParams {
-  // True for devices that are local to the process, i.e. no RPC needed.
-  std::vector<bool> is_local;
-  string ToString() const;
-};
-
 // Unique to a single CollectiveOp node.
 struct CollectiveParams : public core::RefCounted {
   CollGroupParams group;
   CollInstanceParams instance;
-  CollTaskParams task;
 
   string name = "";        // node name used only for log or error messages
   int default_rank = -1;   // index of this op within device_names
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 5f2ae25..c59464c 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -111,7 +111,7 @@
   // immediately.
   bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
                              const DoneCallback& done) {
-    if (col_params_->group.group_size > col_params_->group.devices.size()) {
+    if (col_params_->group.group_size > col_params_->group.members.size()) {
       // This is the first invocation: Finish initializing col_params_.
       // Schedule the `CompleteParamsAsync` call on a work queue that can handle
       // blocking work because it's not guaranteed that this call cannot block.
diff --git a/tensorflow/core/nccl/collective_communicator.cc b/tensorflow/core/nccl/collective_communicator.cc
index 0596905..ec01699 100644
--- a/tensorflow/core/nccl/collective_communicator.cc
+++ b/tensorflow/core/nccl/collective_communicator.cc
@@ -83,7 +83,7 @@
   const CollectiveParams* col_params = col_ctx->col_params;
   const int num_global_devices = col_params->group.group_size;
   const int num_local_devices = col_params->group.num_devices_per_task.at(
-      col_params->group.task_names[col_params->default_rank]);
+      col_params->group.members[col_params->default_rank].task);
   const string nccl_collective_key =
       NcclCollectiveKey(col_ctx->exec_key, col_ctx->step_id);
   auto* compute_stream = col_ctx->op_ctx->op_device_context()->stream();
@@ -120,7 +120,7 @@
       col_params->source_rank);
   VLOG(1) << "NcclCommunicator::Enqueue type " << col_params->instance.type
           << " num_tasks " << col_params->group.num_tasks << " current task "
-          << col_params->group.task_names[col_params->default_rank]
+          << col_params->group.members[col_params->default_rank].task
           << " num local devices " << num_local_devices
           << " num global devices " << num_global_devices << " device "
           << col_ctx->device_name << " instance "