[XLA:GPU] Fix all-reduce aliasing for single operand all-reduce.

Slight clean-up of `CanShareBufferHint`.

PiperOrigin-RevId: 378501752
Change-Id: I9e309579393fb6b97cf2214448d91b48f011cee7
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 296bbfd..c9662e8 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -195,23 +195,25 @@
 absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
                                         const HloInstruction* operand,
                                         const ShapeIndex& user_index) {
-  // Share the bias buffer with the parent instruction.
-  if (IsCublasGemm(*user)) {
-    if (user->operand_count() == 3 && user->operand(2) == operand) {
-      return true;
-    }
+  switch (user->opcode()) {
+    case HloOpcode::kAllReduce:
+      // NCCL all-reduce can be performed in-place.
+      return user->operand_count() == 1 ||
+             (user_index.size() == 1 &&
+              user->operand(user_index[0]) == operand);
+    case HloOpcode::kCustomCall:
+      // Share the bias buffer with the parent instruction.
+      if (user->custom_call_target() == kGemmCallTarget) {
+        return user->operand_count() == 3 && user->operand(2) == operand;
+      }
+      // The operand of cholesky can be shared with the first output.
+      if (user->custom_call_target() == kCusolverCholeskyCallTarget) {
+        return user_index.size() == 1 && user_index[0] == 0;
+      }
+      return false;
+    default:
+      return absl::nullopt;
   }
-  // The operand of cholesky can be shared with the first output.
-  if (user->opcode() == HloOpcode::kCustomCall &&
-      user->custom_call_target() == kCusolverCholeskyCallTarget) {
-    return user_index.size() == 1 && user_index[0] == 0;
-  }
-  // NCCL all-reduce can be performed in-place.
-  if (user->opcode() == HloOpcode::kAllReduce && user_index.size() == 1 &&
-      user->operand(user_index[0]) == operand) {
-    return true;
-  }
-  return absl::nullopt;
 }
 
 // Try to load ptx from files defined in the FLAGS. If successful, return true.
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler_test.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler_test.cc
index c565f27..f3f0711 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler_test.cc
@@ -28,7 +28,46 @@
 using NVPTXCompilerTest = HloTestBase;
 
 TEST_F(NVPTXCompilerTest, AllReducePerformedInplace) {
-  const char* const hlo_string = R"(
+  const absl::string_view hlo_string = R"(
+HloModule Module
+
+summit {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+  param0 = f32[128] parameter(0)
+  param1 = f32[128] parameter(1)
+  add = f32[128] add(param0, param1)
+  ROOT allreduce = f32[128] all-reduce(add), replica_groups={}, to_apply=summit
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  NVPTXCompiler compiler;
+  Compiler::CompileOptions compile_options;
+  TF_ASSERT_OK_AND_ASSIGN(auto module_and_buffer_assignment,
+                          compiler.RunHloPassesAndBufferAssignement(
+                              std::move(module),
+                              /*executor=*/nullptr,
+                              /*optimize=*/false, compile_options));
+
+  module = std::move(std::get<0>(module_and_buffer_assignment));
+  std::unique_ptr<BufferAssignment> buffer_assignment =
+      std::move(std::get<1>(module_and_buffer_assignment));
+
+  HloInstruction* all_reduce = module->entry_computation()->root_instruction();
+
+  ASSERT_EQ(
+      buffer_assignment->GetInstructionAllocation(all_reduce, {}),
+      buffer_assignment->GetInstructionAllocation(all_reduce->operand(0), {}));
+}
+
+TEST_F(NVPTXCompilerTest, AllReducePerformedInplaceTwoOperands) {
+  const absl::string_view hlo_string = R"(
 HloModule Module
 
 summit {