[XLA] Make sure that when rematerializing channel instructions we give them a different channel_id.
PiperOrigin-RevId: 412091687
Change-Id: If5499c76f282f92370b9c26ff72d4c4fd9c8e57d
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 9e79384..6b922f9 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -4432,6 +4432,7 @@
":hlo_dce",
":hlo_memory_scheduler",
":hlo_ordering",
+ ":hlo_query",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index dc218e6..193466e 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -41,6 +41,7 @@
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -1534,7 +1535,8 @@
StatusOr<int64_t> RematerializeInstructions(
MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
- InstructionList* instruction_list) {
+ InstructionList* instruction_list,
+ HloRematerialization* rematerialization) {
int64_t net_instructions_added = 0;
int64_t total_memory_saved =
memory_tracker->MemoryReducedIfRematerialized(*best_items);
@@ -1556,6 +1558,11 @@
HloInstruction* remat =
computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
+ // Increment channel_id on channel instructions.
+ if (HloChannelInstruction* channel_instr =
+ DynCast<HloChannelInstruction>(remat)) {
+ remat->set_channel_id(rematerialization->NextChannelId());
+ }
// Add control dependencies to the new operation.
for (auto successor : best->control_successors()) {
@@ -1762,7 +1769,8 @@
int min_block_size, int max_block_size, MemoryUsageTracker* memory_tracker,
InstructionList* instruction_list, int64_t memory_limit_bytes,
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
- absl::flat_hash_set<const HloInstruction*>* remat_move_instructions) {
+ absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
+ HloRematerialization* rematerialization) {
CHECK(min_block_size > 0) << "Negative block size.";
std::vector<Item*> best_items;
@@ -1797,7 +1805,8 @@
TF_ASSIGN_OR_RETURN(
num_instructions_added.net_instructions_added,
RematerializeInstructions(memory_tracker, &best_items,
- remat_move_instructions, instruction_list));
+ remat_move_instructions, instruction_list,
+ rematerialization));
}
return num_instructions_added;
}
@@ -1926,7 +1935,7 @@
RematerializeBestBlock(min_block_size, max_block_size,
&memory_tracker, &instruction_list,
memory_limit_bytes, &rematerializable_map,
- &remat_move_instructions));
+ &remat_move_instructions, this));
net_instructions_added += instructions_added.net_instructions_added;
remat_count += instructions_added.remat_count;
if (is_first_phase) {
@@ -2044,6 +2053,7 @@
TF_RET_CHECK(module->has_schedule());
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
+ next_channel_id_ = hlo_query::NextChannelId(*module);
// Adjust memory limit to account for the output of the entry
// computation. This is necessary because the per-computation accounting in
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 94fb062..2619a44 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -102,6 +102,9 @@
absl::string_view name() const override { return "rematerialization"; }
+ // Get the next available channel id and increment count.
+ int64_t NextChannelId() { return next_channel_id_++; }
+
// Runs rematerialization on the given module. Returns whether the module was
// changed. Requires that the module has a schedule set
// (HloModule::has_schedule() is true) before running. Returns whether any
@@ -197,6 +200,10 @@
RematerializationMode mode_;
int64_t min_remat_size_;
+
+ // Tracking available channel id numbers to use to apply to rematerialized
+ // channel instructions
+ int64_t next_channel_id_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 1fbc031..6a7dde6 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -1039,6 +1039,62 @@
::testing::Ne(fusion))),
op::Add()));
}
+
+// Make sure when rematerializing all-gathers we increment channel_ids properly.
+TEST_F(HloRematerializationTest, AllGatherChannelId) {
+ const string& hlo_string = R"(
+HloModule fusion, is_scheduled=true
+
+ENTRY %mycomp (param: f32[1]) -> f32[1] {
+ %param = f32[1]{0} parameter(0)
+ %reshape = f32[] reshape(f32[1]{0} %param)
+ %broadcast = f32[256,1]{1,0} broadcast(f32[] %reshape), dimensions={}
+ %ag = f32[1024,1]{1,0} all-gather(f32[256,1]{1,0} %broadcast), dimensions={0},
+ channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true
+ %bitcast = f32[1024]{0} bitcast(f32[1024,1]{1,0} %ag)
+ %negate = f32[1024,1]{1,0} negate(f32[1024,1]{1,0} %ag)
+ %concatenate = f32[2048,1]{1,0} concatenate(f32[1024,1]{1,0} %negate,
+ f32[1024,1]{1,0} %negate), dimensions={0}
+ %slice = f32[1,1]{1,0} slice(f32[2048,1]{1,0} %concatenate),
+ slice={[0:1], [0:1]}
+ %bitcast.1 = f32[1]{0} bitcast(f32[1,1]{1,0} %slice)
+ %concatenate.1 = f32[1025]{0} concatenate(f32[1024]{0} %bitcast,
+ f32[1]{0} %bitcast.1), dimensions={0}
+ ROOT %slice.1 = f32[1]{0} slice(f32[1025]{0} %concatenate.1), slice={[0:1]}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ auto* computation = module->entry_computation();
+ // Find and save the original broadcast instruction which should be
+ // rematerialized.
+ const HloInstruction* slice = computation->root_instruction();
+ ASSERT_THAT(slice, op::Slice(op::Concatenate(
+ op::Bitcast(op::AllGather(op::Broadcast(_))), _)));
+
+ // Computation requires 16KB without rematerialization, but uses only 12KB
+ // with rematerialization so pick a memory limit between these values (14KB).
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/14 * 1024, module.get()));
+ EXPECT_TRUE(changed);
+
+ // Root should not have changed.
+ EXPECT_EQ(computation->root_instruction(), slice);
+
+ // Original all-gather.
+ const HloInstruction* original_ag = FindInstruction(module.get(), "ag");
+ // The all-gather should have been rematerialized
+ const HloInstruction* remat_ag = FindInstruction(module.get(), "ag.remat");
+
+ EXPECT_NE(remat_ag, nullptr);
+ EXPECT_TRUE(original_ag->channel_id().has_value());
+ EXPECT_TRUE(remat_ag->channel_id().has_value());
+ EXPECT_EQ(*remat_ag->channel_id(), *original_ag->channel_id() + 1);
+}
+
} // namespace
} // namespace xla