[XLA:GPU] Fix all-reduce aliasing for single operand all-reduce.
Slight clean-up of `CanShareBufferHint`.
PiperOrigin-RevId: 378501752
Change-Id: I9e309579393fb6b97cf2214448d91b48f011cee7
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 296bbfd..c9662e8 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -195,23 +195,25 @@
absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
const HloInstruction* operand,
const ShapeIndex& user_index) {
- // Share the bias buffer with the parent instruction.
- if (IsCublasGemm(*user)) {
- if (user->operand_count() == 3 && user->operand(2) == operand) {
- return true;
- }
+ switch (user->opcode()) {
+ case HloOpcode::kAllReduce:
+ // NCCL all-reduce can be performed in-place.
+ return user->operand_count() == 1 ||
+ (user_index.size() == 1 &&
+ user->operand(user_index[0]) == operand);
+ case HloOpcode::kCustomCall:
+ // Share the bias buffer with the parent instruction.
+ if (user->custom_call_target() == kGemmCallTarget) {
+ return user->operand_count() == 3 && user->operand(2) == operand;
+ }
+ // The operand of cholesky can be shared with the first output.
+ if (user->custom_call_target() == kCusolverCholeskyCallTarget) {
+ return user_index.size() == 1 && user_index[0] == 0;
+ }
+ return false;
+ default:
+ return absl::nullopt;
}
- // The operand of cholesky can be shared with the first output.
- if (user->opcode() == HloOpcode::kCustomCall &&
- user->custom_call_target() == kCusolverCholeskyCallTarget) {
- return user_index.size() == 1 && user_index[0] == 0;
- }
- // NCCL all-reduce can be performed in-place.
- if (user->opcode() == HloOpcode::kAllReduce && user_index.size() == 1 &&
- user->operand(user_index[0]) == operand) {
- return true;
- }
- return absl::nullopt;
}
// Try to load ptx from files defined in the FLAGS. If successful, return true.
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler_test.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler_test.cc
index c565f27..f3f0711 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler_test.cc
@@ -28,7 +28,46 @@
using NVPTXCompilerTest = HloTestBase;
TEST_F(NVPTXCompilerTest, AllReducePerformedInplace) {
- const char* const hlo_string = R"(
+ const absl::string_view hlo_string = R"(
+HloModule Module
+
+summit {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+ param0 = f32[128] parameter(0)
+ param1 = f32[128] parameter(1)
+ add = f32[128] add(param0, param1)
+ ROOT allreduce = f32[128] all-reduce(add), replica_groups={}, to_apply=summit
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ NVPTXCompiler compiler;
+ Compiler::CompileOptions compile_options;
+ TF_ASSERT_OK_AND_ASSIGN(auto module_and_buffer_assignment,
+ compiler.RunHloPassesAndBufferAssignement(
+ std::move(module),
+ /*executor=*/nullptr,
+ /*optimize=*/false, compile_options));
+
+ module = std::move(std::get<0>(module_and_buffer_assignment));
+ std::unique_ptr<BufferAssignment> buffer_assignment =
+ std::move(std::get<1>(module_and_buffer_assignment));
+
+ HloInstruction* all_reduce = module->entry_computation()->root_instruction();
+
+ ASSERT_EQ(
+ buffer_assignment->GetInstructionAllocation(all_reduce, {}),
+ buffer_assignment->GetInstructionAllocation(all_reduce->operand(0), {}));
+}
+
+TEST_F(NVPTXCompilerTest, AllReducePerformedInplaceTwoOperands) {
+ const absl::string_view hlo_string = R"(
HloModule Module
summit {