Add legalization of HLO reduce to LHLO reduce.

PiperOrigin-RevId: 283928453
Change-Id: Ib4d878e41473fe41c1ef20f269542aa0f248b723
diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
index 58d5b7a..af5fb59 100644
--- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
@@ -18,6 +18,7 @@
 #include "absl/memory/memory.h"
 #include "mlir/Dialect/StandardOps/Ops.h"  // TF:local_config_mlir
 #include "mlir/IR/Attributes.h"  // TF:local_config_mlir
+#include "mlir/IR/BlockAndValueMapping.h"  // TF:local_config_mlir
 #include "mlir/IR/Builders.h"  // TF:local_config_mlir
 #include "mlir/IR/Function.h"  // TF:local_config_mlir
 #include "mlir/IR/Location.h"  // TF:local_config_mlir
@@ -38,13 +39,19 @@
 
 constexpr StringRef kTempBufferAttr = "temp";
 
-Value* GetTensorStoreMemRef(Value* value) {
+Value* GetTensorStoreOrReturnMemRef(Value* value) {
   for (const auto& user : value->getUsers()) {
     if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
       if (tensor_store.getOperand(0) == value) {
         return tensor_store.getOperand(1);
       }
     }
+    if (auto return_op = dyn_cast<xla_hlo::ReturnOp>(user)) {
+      if (return_op.getOperand(0) == value) {
+        auto block = return_op.getOperation()->getBlock();
+        return *block->args_rbegin();
+      }
+    }
   }
   return nullptr;
 }
@@ -88,8 +95,8 @@
 /// function to store that values held in the tensor.
 Value* GetBufferForResultValue(Location loc, Value* result,
                                ConversionPatternRewriter* rewriter) {
-  if (auto tensor_store_memref = GetTensorStoreMemRef(result)) {
-    return tensor_store_memref;
+  if (auto existing_memref = GetTensorStoreOrReturnMemRef(result)) {
+    return existing_memref;
   }
   return InsertAllocAndDealloc(loc, result, rewriter);
 }
@@ -122,6 +129,62 @@
   }
 };
 
+struct HloToLHloReduceConverter
+    : public OpConversionPattern<xla_hlo::ReduceOp> {
+ public:
+  using OpConversionPattern::OpConversionPattern;
+
+  PatternMatchResult matchAndRewrite(
+      xla_hlo::ReduceOp op, ArrayRef<Value*> operands,
+      ConversionPatternRewriter& rewriter) const final {
+    auto loc = op.getLoc();
+    // TODO(b/137624192) Implement variadic reduce.
+    if (op.getNumResults() != 1) return matchFailure();
+    if (op.getParentRegion()->getBlocks().size() != 1) {
+      emitError(loc,
+                "tensor to buffer conversion expects a single block in the "
+                "region containing the operation");
+    }
+    const auto& original_results = op.getResults();
+    SmallVector<Value*, 4> buffer_args(operands.begin(), operands.end());
+    for (auto result : original_results) {
+      buffer_args.push_back(GetBufferForResultValue(loc, result, &rewriter));
+    }
+    auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
+        loc, llvm::None, buffer_args, op.getAttrs());
+
+    // Copy over the operations inside the region.
+    rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
+
+    // Create new block arguments with correct type.
+    auto& entry_block = new_op.body().front();
+    int original_arg_count = entry_block.getNumArguments();
+    for (int i = 0; i < original_arg_count; ++i) {
+      auto old_arg = entry_block.getArgument(i);
+      auto old_type = old_arg->getType().cast<TensorType>();
+      auto new_type =
+          MemRefType::get(old_type.getShape(), old_type.getElementType());
+      auto new_arg = entry_block.addArgument(new_type);
+      rewriter.replaceUsesOfBlockArgument(old_arg, new_arg);
+    }
+    // Add an argument for the result.
+    entry_block.addArgument(
+        entry_block.getArgument(original_arg_count)->getType());
+    // Remove the old arguments.
+    for (int i = original_arg_count - 1; i >= 0; --i) {
+      entry_block.eraseArgument(i);
+    }
+    // Insert terminator at the end.
+    rewriter.setInsertionPointToEnd(&entry_block);
+    rewriter.create<xla_lhlo::TerminatorOp>(loc);
+
+    rewriter.replaceOp(op, ArrayRef<Value*>(buffer_args).slice(operands.size()),
+                       llvm::to_vector<4>(original_results));
+
+    return matchSuccess();
+  }
+};
+
 class HloToLhloTensorLoadConverter : public ConversionPattern {
  public:
   explicit HloToLhloTensorLoadConverter(MLIRContext* context)
@@ -135,6 +198,7 @@
   }
 };
 
+// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
 class HloToLhloTensorStoreConverter : public ConversionPattern {
  public:
   explicit HloToLhloTensorStoreConverter(MLIRContext* context)
@@ -148,6 +212,19 @@
   }
 };
 
+// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
+class HloToLhloReturnConverter : public OpConversionPattern<xla_hlo::ReturnOp> {
+ public:
+  using OpConversionPattern::OpConversionPattern;
+
+  PatternMatchResult matchAndRewrite(
+      xla_hlo::ReturnOp op, ArrayRef<Value*> operands,
+      ConversionPatternRewriter& rewriter) const final {
+    rewriter.eraseOp(op);
+    return matchSuccess();
+  }
+};
+
 // Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
 // buffers if necessary.
 //
@@ -215,6 +292,7 @@
                            xla_lhlo::BroadcastInDimOp>,
       HloToLhloOpConverter<xla_hlo::CeilOp, xla_lhlo::CeilOp>,
       HloToLhloOpConverter<xla_hlo::CompareOp, xla_lhlo::CompareOp>,
+      HloToLhloOpConverter<xla_hlo::ConstOp, xla_lhlo::ConstOp>,
       HloToLhloOpConverter<xla_hlo::ConvertOp, xla_lhlo::ConvertOp>,
       HloToLhloOpConverter<xla_hlo::CosOp, xla_lhlo::CosOp>,
       HloToLhloOpConverter<xla_hlo::DivOp, xla_lhlo::DivOp>,
@@ -229,6 +307,7 @@
       HloToLhloOpConverter<xla_hlo::SignOp, xla_lhlo::SignOp>,
       HloToLhloOpConverter<xla_hlo::SubOp, xla_lhlo::SubOp>,
       HloToLhloOpConverter<xla_hlo::TanhOp, xla_lhlo::TanhOp>,
+      HloToLHloReduceConverter, HloToLhloReturnConverter,
       HloToLhloTensorLoadConverter, HloToLhloTensorStoreConverter
   >(context);
   // clang-format on
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
index 87042f5..c749af3 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
@@ -65,8 +65,8 @@
     mlir::OwningRewritePatternList patterns;
     mlir::ConversionTarget target(ctx);
     target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>();
-
     ::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns);
+
     getFunction().walk([&](FusionOp op) {
       if (failed(applyPartialConversion(op, target, patterns, nullptr))) {
         signalPassFailure();
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc
index e63f948..e3b736d 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc
@@ -255,45 +255,44 @@
                      LoweringStage::GPU);
 }
 
-// TODO(herhut): Re-enable once we can lower hlo_reduce to proper lhlo_reduce.
-// TEST_F(LhloGenTest, FusedReduce) {
-//   CompileAndVerifyIr(R"(
-// HloModule FusedReduce
-//
-// %add (x: f32[], y: f32[]) -> f32[] {
-//   %x = f32[] parameter(0)
-//   %y = f32[] parameter(1)
-//   ROOT %add = f32[] add(f32[] %x, f32[] %y)
-// }
-//
-// %fused_computation (param: f32[100,10]) -> f32[10] {
-//   %param = f32[100,10] parameter(0)
-//   %constant = f32[] constant(0)
-//   ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant),
-//       dimensions={0}, to_apply=%add
-// }
-//
-// ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] {
-//   %x = f32[100,10] parameter(0)
-//   ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput,
-//       calls=%fused_computation
-// }
-// )",
-//                      R"(
-// ;CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]])
-// ;CHECK: "xla_lhlo.fusion"() ( {
-// ;CHECK:   %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]]
-// ;CHECK:   %[[CT0:.*]] = xla_hlo.constant dense<0.000000e+00>
-// ;CHECK:   %[[RED:.*]] = "xla_hlo.reduce"(%0, %1) ( {
-// ;CHECK:     ^bb0(%[[BARG0:.*]]: [[ETYPE:.*]], %[[BARG1:.*]]: [[ETYPE]])
-// ;CHECK:       %[[ADD:.*]] = xla_hlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]]
-// ;CHECK:       "xla_hlo.return"(%[[ADD]])
-// ;CHECK:     })
-// ;CHECK:   tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]]
-// ;CHECK:   "xla_lhlo.terminator"()
-// ;CHECK-NEXT: })
-//       )");
-// }
+TEST_F(LhloGenTest, FusedReduce) {
+  CompileAndVerifyIr(R"(
+HloModule FusedReduce
+
+%add (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(f32[] %x, f32[] %y)
+}
+
+%fused_computation (param: f32[100,10]) -> f32[10] {
+  %param = f32[100,10] parameter(0)
+  %constant = f32[] constant(0)
+  ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant),
+      dimensions={0}, to_apply=%add
+}
+
+ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] {
+  %x = f32[100,10] parameter(0)
+  ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput,
+      calls=%fused_computation
+}
+)",
+                     R"(
+;CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]])
+;CHECK: "xla_lhlo.fusion"() ( {
+;CHECK:   %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]]
+;CHECK:   %[[CT0:.*]] = xla_hlo.constant dense<0.000000e+00>
+;CHECK:   %[[RED:.*]] = "xla_hlo.reduce"(%0, %1) ( {
+;CHECK:     ^bb0(%[[BARG0:.*]]: [[ETYPE:.*]], %[[BARG1:.*]]: [[ETYPE]])
+;CHECK:       %[[ADD:.*]] = xla_hlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]]
+;CHECK:       "xla_hlo.return"(%[[ADD]])
+;CHECK:     })
+;CHECK:   tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]]
+;CHECK:   "xla_lhlo.terminator"()
+;CHECK-NEXT: })
+      )");
+}
 
 TEST_F(LhloGenTest, Broadcast) {
   CompileAndVerifyIr(R"(