Avoid reusing channel_id during spatial_partitioner
PiperOrigin-RevId: 288056247
Change-Id: I309f9044fb3b25f7f3f31dc61d767a15128253e4
diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc
index defd6ab..1b6494b 100644
--- a/tensorflow/compiler/xla/service/hlo_query.cc
+++ b/tensorflow/compiler/xla/service/hlo_query.cc
@@ -133,5 +133,17 @@
return false;
}
+int64 NextChannelId(const HloModule& module) {
+ int64 next_channel_id = 1;
+ for (const HloComputation* comp : module.computations()) {
+ for (const HloInstruction* hlo : comp->instructions()) {
+ if (DynCast<HloChannelInstruction>(hlo)) {
+ next_channel_id = std::max(next_channel_id, *hlo->channel_id() + 1);
+ }
+ }
+ }
+ return next_channel_id;
+}
+
} // namespace hlo_query
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h
index 0ea36ae..b7fbc46 100644
--- a/tensorflow/compiler/xla/service/hlo_query.h
+++ b/tensorflow/compiler/xla/service/hlo_query.h
@@ -77,6 +77,10 @@
// layout.
bool ContainsLayoutConstrainedAllReduce(const HloModule& module);
+// Returns the next available channel id that can be used in the given module
+// (for HloChannelInstructions).
+int64 NextChannelId(const HloModule& module);
+
} // namespace hlo_query
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 0b7d2ec..b2beb9d 100755
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1317,37 +1317,24 @@
return Status::OK();
}
-// Checks various invariants of send and recv instructions.
-Status VerifySendsAndRecvs(const HloModule& module) {
- absl::flat_hash_map<int64, const HloInstruction*> host_channels;
- // Host send/recv instructions must have their own unique channel.
- auto check_unique_host_channel = [&](const HloInstruction* instruction) {
- const HloSendRecvInstruction* sendrecv =
- DynCast<const HloSendRecvInstruction>(instruction);
- if (sendrecv->is_host_transfer()) {
- auto it_inserted =
- host_channels.insert({*sendrecv->channel_id(), sendrecv});
- if (!it_inserted.second) {
- return FailedPrecondition(
- "Channel %d is used for multiple host send/recv instructions: "
- "%s "
- "and "
- "%s",
- *sendrecv->channel_id(), sendrecv->ToString(),
- it_inserted.first->second->ToString());
- }
- }
-
- return Status::OK();
- };
+// Checks various invariants of channel instructions (send/recv and
+// collectives).
+Status VerifyChannels(const HloModule& module) {
+ absl::flat_hash_map<int64, std::vector<const HloInstruction*>>
+ channel_instructions;
// Send/Recv instruction must have a single user: the corresponding
// SendDone/RecvDone. with matching channel.
for (const HloComputation* computation : module.computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
+ auto channel_instr = DynCast<HloChannelInstruction>(instruction);
+ if (!channel_instr || !channel_instr->channel_id()) {
+ continue;
+ }
+ channel_instructions[*channel_instr->channel_id()].push_back(instruction);
+
switch (instruction->opcode()) {
case HloOpcode::kSend: {
- TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
TF_RET_CHECK(instruction->users().size() == 1);
const HloInstruction* send_done = instruction->users().front();
TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
@@ -1356,7 +1343,6 @@
break;
}
case HloOpcode::kRecv: {
- TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
TF_RET_CHECK(instruction->users().size() == 1);
const HloInstruction* recv_done = instruction->users().front();
TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
@@ -1377,6 +1363,39 @@
}
}
}
+
+ // Iterate over each channel to check invariants.
+ for (auto& pair : channel_instructions) {
+ auto& instructions = pair.second;
+ const HloInstruction* first = instructions[0];
+ auto sendrecv = DynCast<HloSendRecvInstruction>(first);
+ if (sendrecv) {
+ absl::flat_hash_set<HloOpcode> opcodes;
+ for (const HloInstruction* instr : instructions) {
+ opcodes.insert(instr->opcode());
+ auto cast = DynCast<HloSendRecvInstruction>(instr);
+ TF_RET_CHECK(cast != nullptr)
+ << "channel " << pair.first
+ << " is used for different types of channel instructions";
+ }
+ if (sendrecv->is_host_transfer()) {
+ TF_RET_CHECK(instructions.size() == 2)
+ << "channel " << pair.first
+ << " is used for multiple host send/recv instructions";
+ } else {
+ TF_RET_CHECK(instructions.size() == opcodes.size())
+ << "channel " << pair.first
+ << " is used for multiple send/recv instructions";
+ }
+ } else {
+ for (const HloInstruction* instr : instructions) {
+ TF_RET_CHECK(first->opcode() == instr->opcode())
+ << "channel " << pair.first
+ << " is used for different types of channel instructions";
+ }
+ }
+ }
+
return Status::OK();
}
@@ -1680,7 +1699,7 @@
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifyAsynchronousCopies(*module));
- TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
+ TF_RETURN_IF_ERROR(VerifyChannels(*module));
std::unique_ptr<ShapeVerifier> shape_verifier =
target_metadata_->GetVerifier();
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 1b27390..c174af6 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -1013,5 +1013,56 @@
HasSubstr("mix of layout constrained and unconstrained AllReduce"));
}
+TEST_F(HloVerifierTest, ChannelVerifier) {
+ const char* const kModuleStr = R"(
+ HloModule test
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY entry {
+ %input = f32[8,12] parameter(0)
+ %token0 = token[] after-all()
+ %send = (f32[8,12], u32[], token[]) send(%input, %token0), channel_id=1
+ %send-done = token[] send-done(%send), channel_id=1
+ %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
+ channel_id=1
+ ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%input, %crs)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnUnverifiedModule(kModuleStr));
+ EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+ HasSubstr("used for different types of channel instructions"));
+}
+
+TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
+ const char* const kModuleStr = R"(
+ HloModule test
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY entry {
+ %input = f32[8,12] parameter(0)
+ %permute = f32[8,12] collective-permute(%input),
+ source_target_pairs={{0,1},{1,0}}, channel_id=1
+ %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
+ channel_id=1
+ ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%permute, %crs)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnUnverifiedModule(kModuleStr));
+ EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+ HasSubstr("used for different types of channel instructions"));
+}
+
} // namespace
} // namespace xla