Scope instance keys under group keys for collective ops.

Before this change, we would store a global mapping from instance key ->
resolved instance runtime parameters.  This prevented reusing the same key
across device groups.

After this change, instance keys are scoped under group key.  It is legal to
execute 2 collectives with the same instance key, as long as they have different
device groups.  This enables the user to assign the same instance key to a
logical collective which is sharded across device groups.

PiperOrigin-RevId: 324324902
Change-Id: Ib994b68f96c8f6cf1cc634d5a7c4998d9f3fb96c
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index a0153a5..ba21abc 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -586,25 +586,32 @@
   InstanceRec* irec = nullptr;
   bool exit_outside_locks = false;
   {
+    bool found_instance = false;
     mutex_lock l(instance_mu_);
-    auto it = instance_table_.find(cp->instance.instance_key);
-    if (it != instance_table_.end()) {
-      irec = it->second.get();
-      {
-        mutex_lock l(irec->in_mu);
-        if (irec->is_init) {
-          exit_outside_locks = true;
-        } else {
-          irec->init_waiters.push_back([this, done](InstanceRec* irec) {
-            CallbackWithStatus(done, irec);
-          });
-          return;
+    auto group_it = instance_table_.find(gr->group.group_key);
+    if (group_it != instance_table_.end()) {
+      auto instance_it = group_it->second.find(cp->instance.instance_key);
+      if (instance_it != group_it->second.end()) {
+        irec = instance_it->second.get();
+        {
+          mutex_lock l(irec->in_mu);
+          if (irec->is_init) {
+            exit_outside_locks = true;
+          } else {
+            irec->init_waiters.push_back([this, done](InstanceRec* irec) {
+              CallbackWithStatus(done, irec);
+            });
+            return;
+          }
         }
+        found_instance = true;
       }
-    } else {
+    }
+    if (!found_instance) {
       // Create new InstanceRec.
       irec = new InstanceRec;
-      instance_table_[cp->instance.instance_key].reset(irec);
+      instance_table_[gr->group.group_key][cp->instance.instance_key].reset(
+          irec);
     }
   }
   Status status;
@@ -890,8 +897,10 @@
   std::vector<InstanceRec*> instances;
   {
     mutex_lock l(instance_mu_);
-    for (const auto& item : instance_table_) {
-      instances.push_back(item.second.get());
+    for (const auto& group_entry : instance_table_) {
+      for (const auto& item : group_entry.second) {
+        instances.push_back(item.second.get());
+      }
     }
   }
   for (InstanceRec* ir : instances) {
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 2b7528d..40f0f00 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -241,8 +241,8 @@
   gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
       TF_GUARDED_BY(group_mu_);
   mutex instance_mu_;
-  gtl::FlatMap<int32, std::unique_ptr<InstanceRec>> instance_table_
-      TF_GUARDED_BY(instance_mu_);
+  gtl::FlatMap<int32, gtl::FlatMap<int32, std::unique_ptr<InstanceRec>>>
+      instance_table_ TF_GUARDED_BY(instance_mu_);
   mutex status_mu_;
   Status status_ TF_GUARDED_BY(status_mu_);
 };
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
index bfcd5b8..650c52c 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
@@ -304,10 +304,15 @@
   }
 }
 
-bool CollectiveParamResolverDistributed::InstanceIsCached(int32 instance_key) {
+bool CollectiveParamResolverDistributed::InstanceIsCached(int32 group_key,
+                                                          int32 instance_key) {
   mutex_lock l(instance_mu_);
-  const auto& it = instance_table_.find(instance_key);
-  return it != instance_table_.end();
+  auto group_it = instance_table_.find(group_key);
+  if (group_it == instance_table_.end()) {
+    return false;
+  }
+  auto instance_it = group_it->second.find(instance_key);
+  return instance_it != group_it->second.end();
 }
 
 void CollectiveParamResolverDistributed::UpdateInstanceCache(
@@ -374,7 +379,7 @@
   if (group_leader_.empty()) {
     // This is the group leader so resolution is local.
     return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
-  } else if (InstanceIsCached(cp->instance.instance_key)) {
+  } else if (InstanceIsCached(gr->group.group_key, cp->instance.instance_key)) {
     return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
   } else {
     CompleteInstanceCall* call = new CompleteInstanceCall(
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h
index 7d30c3d..6848874 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h
@@ -65,7 +65,8 @@
 
   // Returns true iff there's an entry for this instance_key in the
   // local instance_table_.
-  bool InstanceIsCached(int32 instance_key) TF_LOCKS_EXCLUDED(instance_mu_);
+  bool InstanceIsCached(int32 group_key, int32 instance_key)
+      TF_LOCKS_EXCLUDED(instance_mu_);
 
   // Updates instance_table_ with contents of resp.
   void UpdateInstanceCache(const GroupRec* gr, CollectiveParams* cp,
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 51a5219..0230852 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -22,8 +22,10 @@
 
 namespace {
 
-static string CollectiveKey(OpKernelContext* ctx, int32 instance_key) {
-  return strings::StrCat(instance_key, ":", ctx->frame_iter().frame_id, ":",
+static string CollectiveKey(OpKernelContext* ctx, int32 group_key,
+                            int32 instance_key) {
+  return strings::StrCat(group_key, ":", instance_key, ":",
+                         ctx->frame_iter().frame_id, ":",
                          ctx->frame_iter().iter_id);
 }
 
@@ -52,7 +54,8 @@
   // A string encoding instance, frame and iter to be handed off to
   // the implementation for use in generating RecvBuf keys.
   string GetCollectiveKey(OpKernelContext* c) {
-    return CollectiveKey(c, col_params_.instance.instance_key);
+    return CollectiveKey(c, col_params_.group.group_key,
+                         col_params_.instance.instance_key);
   }
 
   // Returns false if calling invocation of ComputeAsync should return
@@ -557,7 +560,8 @@
                       << " instance " << col_params->instance.instance_key;
               col_exec->ExecuteAsync(
                   c, *col_params,
-                  CollectiveKey(c, col_params->instance.instance_key),
+                  CollectiveKey(c, col_params->group.group_key,
+                                col_params->instance.instance_key),
                   actual_done);
             } else {
               c->SetStatus(s);
diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py
index 25d9367..4225df7 100644
--- a/tensorflow/python/kernel_tests/collective_ops_test.py
+++ b/tensorflow/python/kernel_tests/collective_ops_test.py
@@ -36,6 +36,8 @@
     self.assertEqual(len(cpus), 1)
     config.set_logical_device_configuration(cpus[0], [
         context.LogicalDeviceConfiguration(),
+        context.LogicalDeviceConfiguration(),
+        context.LogicalDeviceConfiguration(),
         context.LogicalDeviceConfiguration()
     ])
     context.ensure_initialized()
@@ -78,6 +80,47 @@
     for result in run_all_reduce_2cpus():
       self.assertAllClose(result, [2.], rtol=1e-5, atol=1e-5)
 
+  @test_util.run_v2_only
+  def testInstanceKeyScopedUnderGroupKey(self):
+    self._setup_context()
+
+    @def_function.function
+    def single_all_reduce(in_value, group_size, group_key, instance_key):
+      return gen_collective_ops.collective_reduce_v2(
+          in_value, group_size, group_key, instance_key, merge_op='Add',
+          final_op='Id', communication_hint='auto')
+
+    @def_function.function
+    def run_all_reduce_4cpus_same_instance_key():
+      # Use a common instance key for both groups.
+      instance_key = constant_op.constant(0)
+      # We will create 2 groups each with 2 devices.
+      group_size = constant_op.constant(2)
+      # Group 0 comprises cpu:0 and cpu:1.
+      group0_key = constant_op.constant(0)
+      # Group 1 comprises cpu:2 and cpu:3.
+      group1_key = constant_op.constant(1)
+      collectives = []
+      with ops.device('/device:CPU:0'):
+        collectives.append(single_all_reduce(
+            constant_op.constant(1.), group_size, group0_key, instance_key))
+      with ops.device('/device:CPU:1'):
+        collectives.append(single_all_reduce(
+            constant_op.constant(2.), group_size, group0_key, instance_key))
+      with ops.device('/device:CPU:2'):
+        collectives.append(single_all_reduce(
+            constant_op.constant(3.), group_size, group1_key, instance_key))
+      with ops.device('/device:CPU:3'):
+        collectives.append(single_all_reduce(
+            constant_op.constant(4.), group_size, group1_key, instance_key))
+      return collectives
+
+    results = run_all_reduce_4cpus_same_instance_key()
+    self.assertAllClose(results[0], 3., rtol=1e-5, atol=1e-5)
+    self.assertAllClose(results[1], 3., rtol=1e-5, atol=1e-5)
+    self.assertAllClose(results[2], 7., rtol=1e-5, atol=1e-5)
+    self.assertAllClose(results[3], 7., rtol=1e-5, atol=1e-5)
+
 
 if __name__ == '__main__':
   test.main()