[XLA] Make sure that when rematerializing channel instructions we give them a different channel_id.

PiperOrigin-RevId: 412091687
Change-Id: If5499c76f282f92370b9c26ff72d4c4fd9c8e57d
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 9e79384..6b922f9 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -4432,6 +4432,7 @@
         ":hlo_dce",
         ":hlo_memory_scheduler",
         ":hlo_ordering",
+        ":hlo_query",
         ":logical_buffer",
         ":tuple_points_to_analysis",
         "//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index dc218e6..193466e 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -41,6 +41,7 @@
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_query.h"
 #include "tensorflow/compiler/xla/service/logical_buffer.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/statusor.h"
@@ -1534,7 +1535,8 @@
 StatusOr<int64_t> RematerializeInstructions(
     MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
     absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
-    InstructionList* instruction_list) {
+    InstructionList* instruction_list,
+    HloRematerialization* rematerialization) {
   int64_t net_instructions_added = 0;
   int64_t total_memory_saved =
       memory_tracker->MemoryReducedIfRematerialized(*best_items);
@@ -1556,6 +1558,11 @@
 
     HloInstruction* remat =
         computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
+    // Increment channel_id on channel instructions.
+    if (HloChannelInstruction* channel_instr =
+            DynCast<HloChannelInstruction>(remat)) {
+      remat->set_channel_id(rematerialization->NextChannelId());
+    }
 
     // Add control dependencies to the new operation.
     for (auto successor : best->control_successors()) {
@@ -1762,7 +1769,8 @@
     int min_block_size, int max_block_size, MemoryUsageTracker* memory_tracker,
     InstructionList* instruction_list, int64_t memory_limit_bytes,
     absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
-    absl::flat_hash_set<const HloInstruction*>* remat_move_instructions) {
+    absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
+    HloRematerialization* rematerialization) {
   CHECK(min_block_size > 0) << "Negative block size.";
 
   std::vector<Item*> best_items;
@@ -1797,7 +1805,8 @@
     TF_ASSIGN_OR_RETURN(
         num_instructions_added.net_instructions_added,
         RematerializeInstructions(memory_tracker, &best_items,
-                                  remat_move_instructions, instruction_list));
+                                  remat_move_instructions, instruction_list,
+                                  rematerialization));
   }
   return num_instructions_added;
 }
@@ -1926,7 +1935,7 @@
             RematerializeBestBlock(min_block_size, max_block_size,
                                    &memory_tracker, &instruction_list,
                                    memory_limit_bytes, &rematerializable_map,
-                                   &remat_move_instructions));
+                                   &remat_move_instructions, this));
         net_instructions_added += instructions_added.net_instructions_added;
         remat_count += instructions_added.remat_count;
         if (is_first_phase) {
@@ -2044,6 +2053,7 @@
 
   TF_RET_CHECK(module->has_schedule());
   TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
+  next_channel_id_ = hlo_query::NextChannelId(*module);
 
   // Adjust memory limit to account for the output of the entry
   // computation. This is necessary because the per-computation accounting in
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 94fb062..2619a44 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -102,6 +102,9 @@
 
   absl::string_view name() const override { return "rematerialization"; }
 
+  // Get the next available channel id and increment count.
+  int64_t NextChannelId() { return next_channel_id_++; }
+
   // Runs rematerialization on the given module. Returns whether the module was
   // changed. Requires that the module has a schedule set
   // (HloModule::has_schedule() is true) before running. Returns whether any
@@ -197,6 +200,10 @@
   RematerializationMode mode_;
 
   int64_t min_remat_size_;
+
+  // Tracking available channel id numbers to use to apply to rematerialized
+  // channel instructions
+  int64_t next_channel_id_;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 1fbc031..6a7dde6 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -1039,6 +1039,62 @@
                                              ::testing::Ne(fusion))),
                    op::Add()));
 }
+
+// Make sure when rematerializing all-gathers we increment channel_ids properly.
+TEST_F(HloRematerializationTest, AllGatherChannelId) {
+  const string& hlo_string = R"(
+HloModule fusion, is_scheduled=true
+
+ENTRY %mycomp (param: f32[1]) -> f32[1] {
+  %param = f32[1]{0} parameter(0)
+  %reshape = f32[] reshape(f32[1]{0} %param)
+  %broadcast = f32[256,1]{1,0} broadcast(f32[] %reshape), dimensions={}
+  %ag = f32[1024,1]{1,0} all-gather(f32[256,1]{1,0} %broadcast), dimensions={0},
+    channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true
+  %bitcast = f32[1024]{0} bitcast(f32[1024,1]{1,0} %ag)
+  %negate = f32[1024,1]{1,0} negate(f32[1024,1]{1,0} %ag)
+  %concatenate = f32[2048,1]{1,0} concatenate(f32[1024,1]{1,0} %negate,
+    f32[1024,1]{1,0} %negate), dimensions={0}
+  %slice = f32[1,1]{1,0} slice(f32[2048,1]{1,0} %concatenate),
+    slice={[0:1], [0:1]}
+  %bitcast.1 = f32[1]{0} bitcast(f32[1,1]{1,0} %slice)
+  %concatenate.1 = f32[1025]{0} concatenate(f32[1024]{0} %bitcast,
+    f32[1]{0} %bitcast.1), dimensions={0}
+  ROOT %slice.1 = f32[1]{0} slice(f32[1025]{0} %concatenate.1), slice={[0:1]}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  auto* computation = module->entry_computation();
+  // Find and save the original broadcast instruction which should be
+  // rematerialized.
+  const HloInstruction* slice = computation->root_instruction();
+  ASSERT_THAT(slice, op::Slice(op::Concatenate(
+                         op::Bitcast(op::AllGather(op::Broadcast(_))), _)));
+
+  // Computation requires 16KB without rematerialization, but uses only 12KB
+  // with rematerialization so pick a memory limit between these values (14KB).
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          RunHloRematerialization(
+                              /*memory_limit_bytes=*/14 * 1024, module.get()));
+  EXPECT_TRUE(changed);
+
+  // Root should not have changed.
+  EXPECT_EQ(computation->root_instruction(), slice);
+
+  // Original all-gather.
+  const HloInstruction* original_ag = FindInstruction(module.get(), "ag");
+  // The all-gather should have been rematerialized
+  const HloInstruction* remat_ag = FindInstruction(module.get(), "ag.remat");
+
+  EXPECT_NE(remat_ag, nullptr);
+  EXPECT_TRUE(original_ag->channel_id().has_value());
+  EXPECT_TRUE(remat_ag->channel_id().has_value());
+  EXPECT_EQ(*remat_ag->channel_id(), *original_ag->channel_id() + 1);
+}
+
 }  // namespace
 
 }  // namespace xla