[XLA] Adapting HLO-to-LHLO-legalization to use Buffer Assignment
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
index 262533b..53296b2 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
@@ -1,4 +1,4 @@
-// RUN: xla-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure
+// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement %s -o - | FileCheck %s --dump-input-on-failure
 
 // CHECK-LABEL: func @attrs
 func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@@ -13,33 +13,42 @@
 
 // -----
 
+func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+  return %arg0 : tensor<4xf32>
+}
+//      CHECK: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
+// CHECK-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
+// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
+
+// -----
+
 // CHECK-LABEL: func @func_op_long
 func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
-  // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
-  // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
-  // CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
-  // CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
-  // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
-  // CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
   %1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
-  // CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
   %2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
-  // CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
   %3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
-  // CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
   %4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32>
-  // CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
   %5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
-  // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
-  // CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
-  // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
-  // CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
-  // CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
-  // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
-  // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
   return %5 : tensor<4xf32>
-  // CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
 }
+//      CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
+// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
+// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
+// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
+// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
+// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
+// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
+// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
+// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
+// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
+// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
 
 // -----
 
@@ -47,20 +56,20 @@
 func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
              %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
   // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
-  // CHECK-NEXT:  %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32>
-  // CHECK-NEXT:  %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32>
+  // CHECK-NEXT:  %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
   %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
   %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
   %sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
       : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
   // CHECK-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
+  // CHECK-NEXT:  %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
   %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
   %tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
       : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
   // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
+  // CHECK-NEXT:  dealloc %[[ADD_RESULT]] : memref<2x2xf32>
   // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
   tensor_store %tensor_result, %result : memref<2x2xf32>
-  // CHECK-NEXT:  dealloc %[[ADD_RESULT]] : memref<2x2xf32>
   // CHECK-NEXT:  dealloc %[[MUL_RESULT]] : memref<2x2xf32>
   // CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
   "xla_lhlo.terminator"() : () -> ()
diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
index aa29241..756a38f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
@@ -27,6 +27,7 @@
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Transforms/BufferPlacement.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
@@ -39,16 +40,10 @@
 namespace {
 
 constexpr StringRef kTempBufferAttr = "temp";
-
-/// Returns DeallocOp to ensure that CopyOp is not inserted after dealloc.
-Operation* FindInsertionPointForCopy(Value value) {
-  for (const auto& user : value.getUsers()) {
-    if (auto dealloc = dyn_cast<DeallocOp>(user)) {
-      return user;
-    }
-  }
-  return nullptr;
-}
+template <typename T>
+using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
+using StdReturnOpConverter = NonVoidToVoidReturnOpConverter<
+        mlir::ReturnOp, xla_lhlo::TerminatorOp, xla_lhlo::CopyOp>;
 
 Value InsertDynamicAllocAndDealloc(Location loc, Value result,
                                    Value shape_operand,
@@ -92,8 +87,9 @@
   return alloc;
 }
 
-Value InsertAllocAndDealloc(Location loc, Value result,
-                            ConversionPatternRewriter* rewriter) {
+Value InsertAlloc(Location loc, OpResult result,
+                  BufferAssignmentPlacer* bufferAssignment,
+                  ConversionPatternRewriter* rewriter) {
   auto result_type = result.getType().dyn_cast<ShapedType>();
   if (!result_type || !result_type.hasStaticShape()) {
     result.getDefiningOp()->emitOpError()
@@ -101,31 +97,21 @@
   }
   auto memref_type =
       MemRefType::get(result_type.getShape(), result_type.getElementType());
-
-  Operation* op = result.getDefiningOp();
-  auto block = op->getBlock();
-
-  OpBuilder allocBuilder(op);
-  allocBuilder.setInsertionPointToStart(block);  // Inserting at the beginning
-  auto alloc = allocBuilder.create<AllocOp>(loc, memref_type);
-
-  alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true));
-
-  allocBuilder.setInsertionPoint(block, std::prev(block->end()));
-  allocBuilder.create<DeallocOp>(loc, alloc);
-
+  OpBuilder::InsertionGuard guard(*rewriter);
+  rewriter->restoreInsertionPoint(
+      bufferAssignment->computeAllocPosition(result));
+  auto alloc = rewriter->create<AllocOp>(loc, memref_type);
   return alloc;
 }
 
 template <typename HloOpTy>
-class HloToLhloOpConverter : public ConversionPattern {
+class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
  public:
-  explicit HloToLhloOpConverter(MLIRContext* context)
-      : ConversionPattern(HloOpTy::getOperationName(), 1, context) {}
-
+  using BaseOpConversion<HloOpTy>::BaseOpConversion;
   LogicalResult matchAndRewrite(
-      Operation* op, ArrayRef<Value> operands,
+      HloOpTy hloOp, ArrayRef<Value> operands,
       ConversionPatternRewriter& rewriter) const final {
+    Operation* op = hloOp.getOperation();
     const auto& original_results = op->getResults();
     SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
     for (auto result : llvm::enumerate(original_results)) {
@@ -135,8 +121,8 @@
         return failure();
       }
       if (resultType.hasStaticShape()) {
-        buffer_args.push_back(
-            InsertAllocAndDealloc(op->getLoc(), result.value(), &rewriter));
+        buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(),
+                                          this->bufferAssignment, &rewriter));
       } else {
         SmallVector<Value, 1> results_shape;
         auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
@@ -156,9 +142,9 @@
 };
 
 struct HloToLhloDynamicBroadcastInDimOpConverter
-    : public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
+: public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> {
  public:
-  using OpConversionPattern::OpConversionPattern;
+  using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion;
 
   LogicalResult matchAndRewrite(
       xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
@@ -175,10 +161,9 @@
   }
 };
 
-struct HloToLhloReduceOpConverter
-    : public OpConversionPattern<xla_hlo::ReduceOp> {
+struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
  public:
-  using OpConversionPattern::OpConversionPattern;
+  using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion;
 
   LogicalResult matchAndRewrite(
       xla_hlo::ReduceOp op, ArrayRef<Value> operands,
@@ -194,7 +179,8 @@
     const auto& original_results = op.getResults();
     SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
     for (auto result : original_results) {
-      buffer_args.push_back(InsertAllocAndDealloc(loc, result, &rewriter));
+      buffer_args.push_back(
+          InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
     }
     auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
         loc, llvm::None, buffer_args, op.getAttrs());
@@ -230,12 +216,12 @@
   }
 };
 
-class HloToLhloTensorLoadOpConverter : public ConversionPattern {
+class HloToLhloTensorLoadOpConverter
+    : public BaseOpConversion<mlir::TensorLoadOp> {
  public:
-  explicit HloToLhloTensorLoadOpConverter(MLIRContext* context)
-      : ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {}
+  using BaseOpConversion<mlir::TensorLoadOp>::BaseOpConversion;
   LogicalResult matchAndRewrite(
-      Operation* op, ArrayRef<Value> operands,
+      mlir::TensorLoadOp op, ArrayRef<Value> operands,
       ConversionPatternRewriter& rewriter) const final {
     rewriter.replaceOp(op, operands);
     return success();
@@ -243,13 +229,13 @@
 };
 
 // TODO(b/137624192): Rewrite into a copy and elide copy if possible.
-class HloToLhloTensorStoreOpConverter : public ConversionPattern {
+class HloToLhloTensorStoreOpConverter
+    : public BaseOpConversion<mlir::TensorStoreOp> {
  public:
-  explicit HloToLhloTensorStoreOpConverter(MLIRContext* context)
-      : ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {}
+  using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
 
   LogicalResult matchAndRewrite(
-      Operation* op, ArrayRef<Value> operands,
+      mlir::TensorStoreOp op, ArrayRef<Value> operands,
       ConversionPatternRewriter& rewriter) const final {
     rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
         op, llvm::None, operands.front(), operands.back());
@@ -291,7 +277,6 @@
 //         (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
 //     "xla_lhlo.multiply"(%0, %arg0, %arg3) :
 //         (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
-//     dealloc %0 : memref<2x2xf32>
 //     "xla_lhlo.terminator"() : () -> ()
 //   }) : () -> ()
 //   return
@@ -313,14 +298,13 @@
 //               %arg1: memref<4xf32>,
 //               %arg2: memref<4xf32>) {
 //   %0 = alloc() : memref<4xf32>
-//   %1 = alloc() : memref<4xf32>
+
 //   "xla_lhlo.maximum"(%arg0, %arg1, %0) :
 //         (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
+//   %1 = alloc() : memref<4xf32>
 //   "xla_lhlo.add"(%arg0, %0, %1) :
 //         (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
 //   "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
-//   dealloc %0 : memref<4xf32>
-//   dealloc %1 : memref<4xf32>
 //   "xla_lhlo.terminator"() : () -> ()
 // }
 
@@ -346,101 +330,25 @@
     });
 
     auto module = getOperation();
-    populateHLOToLHLOConversionPattern(module.getContext(), &patterns);
-
-    // Do partial conversion so we can have unknown ops in tests.
-    if (failed(applyPartialConversion(module, target, patterns, nullptr))) {
-      signalPassFailure();
-    }
+    BufferAssignmentTypeConverter converter;
+    module.walk([&](FuncOp func) {
+      BufferAssignmentPlacer bufferAssignment(func);
+      OwningRewritePatternList patterns;
+      populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
+                                         &converter, &patterns);
+      return WalkResult(
+          applyPartialConversion(func, target, patterns, &converter));
+    });
   }
 };
-
-Type ConvertType(Type t) {
-  if (auto tensorType = t.dyn_cast<RankedTensorType>()) {
-    return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
-  }
-  return t;
-}
-
 }  // namespace
 
-/// Transforms FuncOp arguments and results from tensors to buffers. Tensor
-/// results are converted to memrefs and appended to the argument list.
-class HloToLhloFuncOpConverter : public OpConversionPattern<FuncOp> {
- public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult matchAndRewrite(
-      FuncOp funcOp, ArrayRef<Value> operands,
-      ConversionPatternRewriter& rewriter) const final {
-    if (funcOp.getBody().getBlocks().size() > 1) {
-      funcOp.emitOpError() << "tensor to buffer conversion expects a single "
-                              "block in the region containing the operation";
-      return failure();
-    }
-
-    auto funcType = funcOp.getType();
-
-    TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
-    for (auto argType : llvm::enumerate(funcType.getInputs())) {
-      conversion.addInputs(argType.index(), ConvertType(argType.value()));
-    }
-    for (auto resType : funcType.getResults()) {
-      conversion.addInputs(ConvertType(resType));
-    }
-    rewriter.updateRootInPlace(funcOp, [&] {
-      funcOp.setType(
-          rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None));
-      rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
-    });
-    return success();
-  }
-};
-
-/// Transforms ReturnOp to LhloTerminator. CopyOp is inserted to copy each
-/// result to the corresponding buffer argument.
-class StdToLhloReturnOpConverter : public OpConversionPattern<mlir::ReturnOp> {
- public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult matchAndRewrite(
-      mlir::ReturnOp returnOp, ArrayRef<Value> operands,
-      ConversionPatternRewriter& rewriter) const final {
-    auto numReturnValues = returnOp.getNumOperands();
-    auto funcOp = returnOp.getParentOfType<FuncOp>();
-    auto numFuncArgs = funcOp.getNumArguments();
-    auto loc = returnOp.getLoc();
-
-    for (auto operand : llvm::enumerate(operands)) {
-      auto returnArgNumber = numFuncArgs - numReturnValues + operand.index();
-      auto dstBuffer = funcOp.getArgument(returnArgNumber);
-      if (dstBuffer == operand.value()) {
-        continue;
-      }
-
-      auto dealloc = FindInsertionPointForCopy(operand.value());
-
-      if (dealloc == nullptr) {
-        returnOp.emitOpError()
-            << "Missing dealloc for operand " << operand.index();
-        return failure();
-      }
-      OpBuilder::InsertionGuard guard(rewriter);
-      rewriter.setInsertionPoint(dealloc);
-      rewriter.create<xla_lhlo::CopyOp>(loc, llvm::None, operand.value(),
-                                        funcOp.getArgument(returnArgNumber));
-    }
-    rewriter.replaceOpWithNewOp<xla_lhlo::TerminatorOp>(returnOp);
-    return success();
-  }
-};
-
-void populateHLOToLHLOConversionPattern(MLIRContext* context,
-                                        OwningRewritePatternList* patterns) {
+void populateHLOToLHLOConversionPattern(
+    MLIRContext* context, BufferAssignmentPlacer* bufferAssignment,
+    TypeConverter* converter, OwningRewritePatternList* patterns) {
   // clang-format off
   patterns->insert<
       HloToLhloDynamicBroadcastInDimOpConverter,
-      HloToLhloFuncOpConverter,
       HloToLhloOpConverter<xla_hlo::AbsOp>,
       HloToLhloOpConverter<xla_hlo::AddOp>,
       HloToLhloOpConverter<xla_hlo::AndOp>,
@@ -472,8 +380,9 @@
       HloToLhloReduceOpConverter,
       HloToLhloTensorLoadOpConverter,
       HloToLhloTensorStoreOpConverter,
-      StdToLhloReturnOpConverter
-  >(context);
+      FunctionAndBlockSignatureConverter,
+      StdReturnOpConverter
+  >(context, bufferAssignment, converter);
   // clang-format on
 }
 
diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h
index ad81cda..e4f5c93 100644
--- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h
@@ -23,6 +23,7 @@
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 
 namespace mlir {
+class BufferAssignmentPlacer;
 namespace xla_hlo {
 
 // Collection of rewrite patterns for lowering a general dot product.
@@ -38,9 +39,9 @@
                               MLIRContext *ctx);
 
 // Collection of rewrite patterns for lowering of HLO to LHLO dialect.
-void populateHLOToLHLOConversionPattern(MLIRContext *context,
-                                        OwningRewritePatternList *patterns);
-
+void populateHLOToLHLOConversionPattern(
+    MLIRContext* context, BufferAssignmentPlacer* bufferAssignment,
+    TypeConverter* converter, OwningRewritePatternList* patterns);
 // Collection of rewrite patterns for lowering of HLO to Linalg dialect.
 void populateHLOToLinalgConversionPattern(MLIRContext *context,
                                           OwningRewritePatternList *patterns);
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
index 33d3690..c806e95 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
@@ -45,6 +45,7 @@
 #include "mlir/IR/Region.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Transforms/BufferPlacement.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 #include "mlir/Transforms/LoopUtils.h"  // from @llvm-project
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
@@ -60,34 +61,6 @@
 
 using ::mlir::xla_lhlo::FusionOp;
 
-// Following are some small transformations that are required to clean up code
-// after lowering from linalg to loops.
-
-// A simple pass that applies lowering of HLO to LHLO only within LHLO ops that
-// contain regions with HLO ops, e.g. FusionOp, ReduceOp, SelectAndScatterOp.
-// This is needed, as these ops are not closed from above and hence nested pass
-// managers can not be applied.
-struct NestedHloRegionsConverter
-    : public mlir::PassWrapper<NestedHloRegionsConverter,
-                               ::mlir::FunctionPass> {
-  void runOnFunction() override {
-    auto& ctx = getContext();
-    mlir::OwningRewritePatternList patterns;
-    mlir::ConversionTarget target(ctx);
-    target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>();
-    ::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns);
-
-    getFunction().walk([&](mlir::Operation* op) {
-      if (op->getNumRegions() == 0) {
-        return;
-      }
-      if (failed(applyPartialConversion(op, target, patterns, nullptr))) {
-        signalPassFailure();
-      }
-    });
-  }
-};
-
 // Replaces a FusionOp by the operations contained in its region.
 struct FusionOpRemover
     : public mlir::PassWrapper<FusionOpRemover, ::mlir::FunctionPass> {
@@ -436,8 +409,10 @@
     tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end());
   }
 
-  // First, lower bodies of LHLO operations that contain HLO ops.
-  pm.addPass(absl::make_unique<NestedHloRegionsConverter>());
+  // Legalize from HLO to LHLO.
+  pm.addPass(::mlir::xla_hlo::createLegalizeToLhloPass());
+  // Moving `AllocOp`s and inserting missing `DeallocOp`s
+  pm.addPass(::mlir::createBufferPlacementPass());
   // Next, we can strip the outer fusion operation.
   pm.addPass(absl::make_unique<FusionOpRemover>());
   // Remove unnecessary LHLO copies.