[XLA] Adapting HLO-to-LHLO-legalization to use Buffer Assignment
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
index 262533b..53296b2 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
@@ -1,4 +1,4 @@
-// RUN: xla-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure
+// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement %s -o - | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@@ -13,33 +13,42 @@
// -----
+func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+ return %arg0 : tensor<4xf32>
+}
+// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
+// CHECK-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
+// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
+
+// -----
+
// CHECK-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
- // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
- // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
- // CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
- // CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
- // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
- // CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
%1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
- // CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
- // CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
%3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
- // CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
%4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32>
- // CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
%5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
- // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
- // CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
- // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
- // CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
- // CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
- // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
- // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
return %5 : tensor<4xf32>
- // CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
}
+// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
+// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
+// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
+// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
+// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
+// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
+// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
+// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
+// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
+// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
+// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
// -----
@@ -47,20 +56,20 @@
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
- // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32>
- // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32>
+ // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
+ // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
+ // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
tensor_store %tensor_result, %result : memref<2x2xf32>
- // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
index aa29241..756a38f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
@@ -27,6 +27,7 @@
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
@@ -39,16 +40,10 @@
namespace {
constexpr StringRef kTempBufferAttr = "temp";
-
-/// Returns DeallocOp to ensure that CopyOp is not inserted after dealloc.
-Operation* FindInsertionPointForCopy(Value value) {
- for (const auto& user : value.getUsers()) {
- if (auto dealloc = dyn_cast<DeallocOp>(user)) {
- return user;
- }
- }
- return nullptr;
-}
+template <typename T>
+using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
+using StdReturnOpConverter = NonVoidToVoidReturnOpConverter<
+ mlir::ReturnOp, xla_lhlo::TerminatorOp, xla_lhlo::CopyOp>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand,
@@ -92,8 +87,9 @@
return alloc;
}
-Value InsertAllocAndDealloc(Location loc, Value result,
- ConversionPatternRewriter* rewriter) {
+Value InsertAlloc(Location loc, OpResult result,
+ BufferAssignmentPlacer* bufferAssignment,
+ ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
if (!result_type || !result_type.hasStaticShape()) {
result.getDefiningOp()->emitOpError()
@@ -101,31 +97,21 @@
}
auto memref_type =
MemRefType::get(result_type.getShape(), result_type.getElementType());
-
- Operation* op = result.getDefiningOp();
- auto block = op->getBlock();
-
- OpBuilder allocBuilder(op);
- allocBuilder.setInsertionPointToStart(block); // Inserting at the beginning
- auto alloc = allocBuilder.create<AllocOp>(loc, memref_type);
-
- alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true));
-
- allocBuilder.setInsertionPoint(block, std::prev(block->end()));
- allocBuilder.create<DeallocOp>(loc, alloc);
-
+ OpBuilder::InsertionGuard guard(*rewriter);
+ rewriter->restoreInsertionPoint(
+ bufferAssignment->computeAllocPosition(result));
+ auto alloc = rewriter->create<AllocOp>(loc, memref_type);
return alloc;
}
template <typename HloOpTy>
-class HloToLhloOpConverter : public ConversionPattern {
+class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
public:
- explicit HloToLhloOpConverter(MLIRContext* context)
- : ConversionPattern(HloOpTy::getOperationName(), 1, context) {}
-
+ using BaseOpConversion<HloOpTy>::BaseOpConversion;
LogicalResult matchAndRewrite(
- Operation* op, ArrayRef<Value> operands,
+ HloOpTy hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
+ Operation* op = hloOp.getOperation();
const auto& original_results = op->getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : llvm::enumerate(original_results)) {
@@ -135,8 +121,8 @@
return failure();
}
if (resultType.hasStaticShape()) {
- buffer_args.push_back(
- InsertAllocAndDealloc(op->getLoc(), result.value(), &rewriter));
+ buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(),
+ this->bufferAssignment, &rewriter));
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
@@ -156,9 +142,9 @@
};
struct HloToLhloDynamicBroadcastInDimOpConverter
- : public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
+: public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
@@ -175,10 +161,9 @@
}
};
-struct HloToLhloReduceOpConverter
- : public OpConversionPattern<xla_hlo::ReduceOp> {
+struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
xla_hlo::ReduceOp op, ArrayRef<Value> operands,
@@ -194,7 +179,8 @@
const auto& original_results = op.getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
- buffer_args.push_back(InsertAllocAndDealloc(loc, result, &rewriter));
+ buffer_args.push_back(
+ InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
}
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
loc, llvm::None, buffer_args, op.getAttrs());
@@ -230,12 +216,12 @@
}
};
-class HloToLhloTensorLoadOpConverter : public ConversionPattern {
+class HloToLhloTensorLoadOpConverter
+ : public BaseOpConversion<mlir::TensorLoadOp> {
public:
- explicit HloToLhloTensorLoadOpConverter(MLIRContext* context)
- : ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {}
+ using BaseOpConversion<mlir::TensorLoadOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
- Operation* op, ArrayRef<Value> operands,
+ mlir::TensorLoadOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOp(op, operands);
return success();
@@ -243,13 +229,13 @@
};
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
-class HloToLhloTensorStoreOpConverter : public ConversionPattern {
+class HloToLhloTensorStoreOpConverter
+ : public BaseOpConversion<mlir::TensorStoreOp> {
public:
- explicit HloToLhloTensorStoreOpConverter(MLIRContext* context)
- : ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {}
+ using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
- Operation* op, ArrayRef<Value> operands,
+ mlir::TensorStoreOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
op, llvm::None, operands.front(), operands.back());
@@ -291,7 +277,6 @@
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// "xla_lhlo.multiply"(%0, %arg0, %arg3) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
-// dealloc %0 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> ()
// }) : () -> ()
// return
@@ -313,14 +298,13 @@
// %arg1: memref<4xf32>,
// %arg2: memref<4xf32>) {
// %0 = alloc() : memref<4xf32>
-// %1 = alloc() : memref<4xf32>
+
// "xla_lhlo.maximum"(%arg0, %arg1, %0) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
+// %1 = alloc() : memref<4xf32>
// "xla_lhlo.add"(%arg0, %0, %1) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
-// dealloc %0 : memref<4xf32>
-// dealloc %1 : memref<4xf32>
// "xla_lhlo.terminator"() : () -> ()
// }
@@ -346,101 +330,25 @@
});
auto module = getOperation();
- populateHLOToLHLOConversionPattern(module.getContext(), &patterns);
-
- // Do partial conversion so we can have unknown ops in tests.
- if (failed(applyPartialConversion(module, target, patterns, nullptr))) {
- signalPassFailure();
- }
+ BufferAssignmentTypeConverter converter;
+ module.walk([&](FuncOp func) {
+ BufferAssignmentPlacer bufferAssignment(func);
+ OwningRewritePatternList patterns;
+ populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
+ &converter, &patterns);
+ return WalkResult(
+ applyPartialConversion(func, target, patterns, &converter));
+ });
}
};
-
-Type ConvertType(Type t) {
- if (auto tensorType = t.dyn_cast<RankedTensorType>()) {
- return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- }
- return t;
-}
-
} // namespace
-/// Transforms FuncOp arguments and results from tensors to buffers. Tensor
-/// results are converted to memrefs and appended to the argument list.
-class HloToLhloFuncOpConverter : public OpConversionPattern<FuncOp> {
- public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- FuncOp funcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter& rewriter) const final {
- if (funcOp.getBody().getBlocks().size() > 1) {
- funcOp.emitOpError() << "tensor to buffer conversion expects a single "
- "block in the region containing the operation";
- return failure();
- }
-
- auto funcType = funcOp.getType();
-
- TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
- for (auto argType : llvm::enumerate(funcType.getInputs())) {
- conversion.addInputs(argType.index(), ConvertType(argType.value()));
- }
- for (auto resType : funcType.getResults()) {
- conversion.addInputs(ConvertType(resType));
- }
- rewriter.updateRootInPlace(funcOp, [&] {
- funcOp.setType(
- rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None));
- rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
- });
- return success();
- }
-};
-
-/// Transforms ReturnOp to LhloTerminator. CopyOp is inserted to copy each
-/// result to the corresponding buffer argument.
-class StdToLhloReturnOpConverter : public OpConversionPattern<mlir::ReturnOp> {
- public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mlir::ReturnOp returnOp, ArrayRef<Value> operands,
- ConversionPatternRewriter& rewriter) const final {
- auto numReturnValues = returnOp.getNumOperands();
- auto funcOp = returnOp.getParentOfType<FuncOp>();
- auto numFuncArgs = funcOp.getNumArguments();
- auto loc = returnOp.getLoc();
-
- for (auto operand : llvm::enumerate(operands)) {
- auto returnArgNumber = numFuncArgs - numReturnValues + operand.index();
- auto dstBuffer = funcOp.getArgument(returnArgNumber);
- if (dstBuffer == operand.value()) {
- continue;
- }
-
- auto dealloc = FindInsertionPointForCopy(operand.value());
-
- if (dealloc == nullptr) {
- returnOp.emitOpError()
- << "Missing dealloc for operand " << operand.index();
- return failure();
- }
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(dealloc);
- rewriter.create<xla_lhlo::CopyOp>(loc, llvm::None, operand.value(),
- funcOp.getArgument(returnArgNumber));
- }
- rewriter.replaceOpWithNewOp<xla_lhlo::TerminatorOp>(returnOp);
- return success();
- }
-};
-
-void populateHLOToLHLOConversionPattern(MLIRContext* context,
- OwningRewritePatternList* patterns) {
+void populateHLOToLHLOConversionPattern(
+ MLIRContext* context, BufferAssignmentPlacer* bufferAssignment,
+ TypeConverter* converter, OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
HloToLhloDynamicBroadcastInDimOpConverter,
- HloToLhloFuncOpConverter,
HloToLhloOpConverter<xla_hlo::AbsOp>,
HloToLhloOpConverter<xla_hlo::AddOp>,
HloToLhloOpConverter<xla_hlo::AndOp>,
@@ -472,8 +380,9 @@
HloToLhloReduceOpConverter,
HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter,
- StdToLhloReturnOpConverter
- >(context);
+ FunctionAndBlockSignatureConverter,
+ StdReturnOpConverter
+ >(context, bufferAssignment, converter);
// clang-format on
}
diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h
index ad81cda..e4f5c93 100644
--- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h
@@ -23,6 +23,7 @@
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
namespace mlir {
+class BufferAssignmentPlacer;
namespace xla_hlo {
// Collection of rewrite patterns for lowering a general dot product.
@@ -38,9 +39,9 @@
MLIRContext *ctx);
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
-void populateHLOToLHLOConversionPattern(MLIRContext *context,
- OwningRewritePatternList *patterns);
-
+void populateHLOToLHLOConversionPattern(
+ MLIRContext* context, BufferAssignmentPlacer* bufferAssignment,
+ TypeConverter* converter, OwningRewritePatternList* patterns);
// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
void populateHLOToLinalgConversionPattern(MLIRContext *context,
OwningRewritePatternList *patterns);
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
index 33d3690..c806e95 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
@@ -45,6 +45,7 @@
#include "mlir/IR/Region.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
+#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
@@ -60,34 +61,6 @@
using ::mlir::xla_lhlo::FusionOp;
-// Following are some small transformations that are required to clean up code
-// after lowering from linalg to loops.
-
-// A simple pass that applies lowering of HLO to LHLO only within LHLO ops that
-// contain regions with HLO ops, e.g. FusionOp, ReduceOp, SelectAndScatterOp.
-// This is needed, as these ops are not closed from above and hence nested pass
-// managers can not be applied.
-struct NestedHloRegionsConverter
- : public mlir::PassWrapper<NestedHloRegionsConverter,
- ::mlir::FunctionPass> {
- void runOnFunction() override {
- auto& ctx = getContext();
- mlir::OwningRewritePatternList patterns;
- mlir::ConversionTarget target(ctx);
- target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>();
- ::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns);
-
- getFunction().walk([&](mlir::Operation* op) {
- if (op->getNumRegions() == 0) {
- return;
- }
- if (failed(applyPartialConversion(op, target, patterns, nullptr))) {
- signalPassFailure();
- }
- });
- }
-};
-
// Replaces a FusionOp by the operations contained in its region.
struct FusionOpRemover
: public mlir::PassWrapper<FusionOpRemover, ::mlir::FunctionPass> {
@@ -436,8 +409,10 @@
tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end());
}
- // First, lower bodies of LHLO operations that contain HLO ops.
- pm.addPass(absl::make_unique<NestedHloRegionsConverter>());
+ // Legalize from HLO to LHLO.
+ pm.addPass(::mlir::xla_hlo::createLegalizeToLhloPass());
+ // Moving `AllocOp`s and inserting missing `DeallocOp`s
+ pm.addPass(::mlir::createBufferPlacementPass());
// Next, we can strip the outer fusion operation.
pm.addPass(absl::make_unique<FusionOpRemover>());
// Remove unnecessary LHLO copies.