Avoid reusing channel_id during spatial_partitioner

PiperOrigin-RevId: 288056247
Change-Id: I309f9044fb3b25f7f3f31dc61d767a15128253e4
diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc
index defd6ab..1b6494b 100644
--- a/tensorflow/compiler/xla/service/hlo_query.cc
+++ b/tensorflow/compiler/xla/service/hlo_query.cc
@@ -133,5 +133,17 @@
   return false;
 }
 
+int64 NextChannelId(const HloModule& module) {
+  int64 next_channel_id = 1;
+  for (const HloComputation* comp : module.computations()) {
+    for (const HloInstruction* hlo : comp->instructions()) {
+      if (DynCast<HloChannelInstruction>(hlo)) {
+        next_channel_id = std::max(next_channel_id, *hlo->channel_id() + 1);
+      }
+    }
+  }
+  return next_channel_id;
+}
+
 }  // namespace hlo_query
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h
index 0ea36ae..b7fbc46 100644
--- a/tensorflow/compiler/xla/service/hlo_query.h
+++ b/tensorflow/compiler/xla/service/hlo_query.h
@@ -77,6 +77,10 @@
 // layout.
 bool ContainsLayoutConstrainedAllReduce(const HloModule& module);
 
+// Returns the next available channel id that can be used in the given module
+// (for HloChannelInstructions).
+int64 NextChannelId(const HloModule& module);
+
 }  // namespace hlo_query
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 0b7d2ec..b2beb9d 100755
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1317,37 +1317,24 @@
   return Status::OK();
 }
 
-// Checks various invariants of send and recv instructions.
-Status VerifySendsAndRecvs(const HloModule& module) {
-  absl::flat_hash_map<int64, const HloInstruction*> host_channels;
-  // Host send/recv instructions must have their own unique channel.
-  auto check_unique_host_channel = [&](const HloInstruction* instruction) {
-    const HloSendRecvInstruction* sendrecv =
-        DynCast<const HloSendRecvInstruction>(instruction);
-    if (sendrecv->is_host_transfer()) {
-      auto it_inserted =
-          host_channels.insert({*sendrecv->channel_id(), sendrecv});
-      if (!it_inserted.second) {
-        return FailedPrecondition(
-            "Channel %d is used for multiple host send/recv instructions: "
-            "%s "
-            "and "
-            "%s",
-            *sendrecv->channel_id(), sendrecv->ToString(),
-            it_inserted.first->second->ToString());
-      }
-    }
-
-    return Status::OK();
-  };
+// Checks various invariants of channel instructions (send/recv and
+// collectives).
+Status VerifyChannels(const HloModule& module) {
+  absl::flat_hash_map<int64, std::vector<const HloInstruction*>>
+      channel_instructions;
 
   // Send/Recv instruction must have a single user: the corresponding
   // SendDone/RecvDone. with matching channel.
   for (const HloComputation* computation : module.computations()) {
     for (const HloInstruction* instruction : computation->instructions()) {
+      auto channel_instr = DynCast<HloChannelInstruction>(instruction);
+      if (!channel_instr || !channel_instr->channel_id()) {
+        continue;
+      }
+      channel_instructions[*channel_instr->channel_id()].push_back(instruction);
+
       switch (instruction->opcode()) {
         case HloOpcode::kSend: {
-          TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
           TF_RET_CHECK(instruction->users().size() == 1);
           const HloInstruction* send_done = instruction->users().front();
           TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
@@ -1356,7 +1343,6 @@
           break;
         }
         case HloOpcode::kRecv: {
-          TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
           TF_RET_CHECK(instruction->users().size() == 1);
           const HloInstruction* recv_done = instruction->users().front();
           TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
@@ -1377,6 +1363,39 @@
       }
     }
   }
+
+  // Iterate over each channel to check invariants.
+  for (auto& pair : channel_instructions) {
+    auto& instructions = pair.second;
+    const HloInstruction* first = instructions[0];
+    auto sendrecv = DynCast<HloSendRecvInstruction>(first);
+    if (sendrecv) {
+      absl::flat_hash_set<HloOpcode> opcodes;
+      for (const HloInstruction* instr : instructions) {
+        opcodes.insert(instr->opcode());
+        auto cast = DynCast<HloSendRecvInstruction>(instr);
+        TF_RET_CHECK(cast != nullptr)
+            << "channel " << pair.first
+            << " is used for different types of channel instructions";
+      }
+      if (sendrecv->is_host_transfer()) {
+        TF_RET_CHECK(instructions.size() == 2)
+            << "channel " << pair.first
+            << " is used for multiple host send/recv instructions";
+      } else {
+        TF_RET_CHECK(instructions.size() == opcodes.size())
+            << "channel " << pair.first
+            << " is used for multiple send/recv instructions";
+      }
+    } else {
+      for (const HloInstruction* instr : instructions) {
+        TF_RET_CHECK(first->opcode() == instr->opcode())
+            << "channel " << pair.first
+            << " is used for different types of channel instructions";
+      }
+    }
+  }
+
   return Status::OK();
 }
 
@@ -1680,7 +1699,7 @@
 
   TF_RETURN_IF_ERROR(VerifyHloStructure(module));
   TF_RETURN_IF_ERROR(VerifyAsynchronousCopies(*module));
-  TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
+  TF_RETURN_IF_ERROR(VerifyChannels(*module));
 
   std::unique_ptr<ShapeVerifier> shape_verifier =
       target_metadata_->GetVerifier();
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 1b27390..c174af6 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -1013,5 +1013,56 @@
       HasSubstr("mix of layout constrained and unconstrained AllReduce"));
 }
 
+TEST_F(HloVerifierTest, ChannelVerifier) {
+  const char* const kModuleStr = R"(
+  HloModule test
+
+  add {
+    lhs = f32[] parameter(0)
+    rhs = f32[] parameter(1)
+    ROOT add = f32[] add(lhs, rhs)
+  }
+
+  ENTRY entry {
+    %input = f32[8,12] parameter(0)
+    %token0 = token[] after-all()
+    %send = (f32[8,12], u32[], token[]) send(%input, %token0), channel_id=1
+    %send-done = token[] send-done(%send), channel_id=1
+    %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
+      channel_id=1
+    ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%input, %crs)
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(kModuleStr));
+  EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+              HasSubstr("used for different types of channel instructions"));
+}
+
+TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
+  const char* const kModuleStr = R"(
+  HloModule test
+
+  add {
+    lhs = f32[] parameter(0)
+    rhs = f32[] parameter(1)
+    ROOT add = f32[] add(lhs, rhs)
+  }
+
+  ENTRY entry {
+    %input = f32[8,12] parameter(0)
+    %permute = f32[8,12] collective-permute(%input),
+      source_target_pairs={{0,1},{1,0}}, channel_id=1
+    %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
+      channel_id=1
+    ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%permute, %crs)
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(kModuleStr));
+  EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+              HasSubstr("used for different types of channel instructions"));
+}
+
 }  // namespace
 }  // namespace xla