Change RedundantCopiesRemoval to a pass
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
index 28c5af3..538b392 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
@@ -45,6 +45,19 @@
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
}
+// CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store
+func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: memref<f32>) {
+ %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tensor_store %0, %arg2 : memref<f32>
+ return
+}
+// CHECK: (%[[NEW_ARG0:.*]]: memref<f32>, %[[NEW_ARG1:.*]]: memref<f32>, %[[RESULT:.*]]: memref<f32>)
+// CHECK-NOT: %[[ALLOC_OPERAND:.*]] = alloc() {temp = true} : memref<f32>
+// CHECK: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]]) : (memref<f32>, memref<f32>, memref<f32>) -> ()
+// CHECK-NOT: "xla_lhlo.copy"(%[[ALLOC_OPERAND]], %[[RESULT]]) : (memref<f32>, memref<f32>) -> ()
+// CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref<f32>
+// CHECK: "xla_lhlo.terminator"() : () -> ()
+
// CHECK-LABEL: func @fusion
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
@@ -208,4 +221,4 @@
// CHECK-NEXT: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
-}
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
index 2baa06a..835f87f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
@@ -27,6 +27,7 @@
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
+#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
@@ -49,17 +50,6 @@
return nullptr;
}
-Value GetTensorStore(Value value) {
- for (const auto& user : value.getUsers()) {
- if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
- if (tensor_store.getOperand(0) == value) {
- return tensor_store.getOperand(1);
- }
- }
- }
- return nullptr;
-}
-
Value InsertAllocAndDealloc(Location loc, Value result,
ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
@@ -85,17 +75,6 @@
return alloc;
}
-/// For every tensor-type value that is produced in the original function,
-/// this function returns the buffer that can be used in the converted
-/// function to store that values held in the tensor.
-Value GetBufferForResultValue(Location loc, Value result,
- ConversionPatternRewriter* rewriter) {
- if (auto existing_memref = GetTensorStore(result)) {
- return existing_memref;
- }
- return InsertAllocAndDealloc(loc, result, rewriter);
-}
-
template <typename HloOpTy, typename LhloOpTy>
class HloToLhloOpConverter : public ConversionPattern {
public:
@@ -137,7 +116,7 @@
const auto& original_results = op.getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
- buffer_args.push_back(GetBufferForResultValue(loc, result, &rewriter));
+ buffer_args.push_back(InsertAllocAndDealloc(loc, result, &rewriter));
}
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
loc, llvm::None, buffer_args, op.getAttrs());
@@ -200,38 +179,6 @@
}
};
-/// Removes Lhlo.CopyOp that copies from an allocated buffer to the block
-/// argument. All uses of the buffer are replaced with the block argument.
-void RemoveRedundantCopies(ModuleOp module) {
- llvm::SmallVector<Operation*, 2> eraseList;
- module.walk([&](xla_lhlo::CopyOp copyOp) {
- auto arguments = copyOp.getOperation()->getBlock()->getArguments();
- if (std::any_of(
- arguments.begin(), arguments.end(),
- [&](BlockArgument arg) { return copyOp.output() == arg; }) &&
- std::none_of(
- arguments.begin(), arguments.end(),
- [&](BlockArgument arg) { return copyOp.operand() == arg; })) {
- Value operand = copyOp.operand();
- Value output = copyOp.output();
- copyOp.erase();
- for (auto op : operand.getUsers()) {
- if (!isa<DeallocOp>(op)) {
- op->replaceUsesOfWith(operand, output);
- }
- }
- auto allocOp = operand.getDefiningOp();
- if (auto deallocOp = dyn_cast<DeallocOp>(*allocOp->getUsers().begin())) {
- eraseList.push_back(deallocOp);
- eraseList.push_back(allocOp);
- }
- }
- });
- for (auto op : eraseList) {
- op->erase();
- }
-}
-
// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary.
//
@@ -321,8 +268,6 @@
if (failed(applyFullConversion(module, target, patterns, nullptr))) {
signalPassFailure();
}
-
- RemoveRedundantCopies(module);
}
};
@@ -442,12 +387,57 @@
// clang-format on
}
+/// Removes Lhlo.CopyOp that copies from an allocated buffer to the block
+/// argument. All uses of the buffer are replaced with the block argument.
+struct RedundantCopiesRemoval : mlir::FunctionPass<RedundantCopiesRemoval> {
+ void runOnFunction() override {
+ llvm::SmallVector<mlir::Operation*, 2> eraseList;
+ getFunction().walk([&](mlir::xla_lhlo::CopyOp copyOp) {
+ auto arguments = copyOp.getOperation()->getBlock()->getArguments();
+ if (std::any_of(arguments.begin(), arguments.end(),
+ [&](mlir::BlockArgument arg) {
+ return copyOp.output() == arg;
+ }) &&
+ std::none_of(arguments.begin(), arguments.end(),
+ [&](mlir::BlockArgument arg) {
+ return copyOp.operand() == arg;
+ })) {
+ mlir::Value operand = copyOp.operand();
+ mlir::Value output = copyOp.output();
+ copyOp.erase();
+ for (auto op : operand.getUsers()) {
+ if (!mlir::isa<mlir::DeallocOp>(op)) {
+ op->replaceUsesOfWith(operand, output);
+ }
+ }
+ auto allocOp = operand.getDefiningOp();
+ if (auto deallocOp =
+ mlir::dyn_cast<mlir::DeallocOp>(*allocOp->getUsers().begin())) {
+ eraseList.push_back(deallocOp);
+ eraseList.push_back(allocOp);
+ }
+ }
+ });
+ for (auto op : eraseList) {
+ op->erase();
+ }
+ };
+};
+
std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass() {
return absl::make_unique<HloLegalizeToLhlo>();
}
-static PassRegistration<HloLegalizeToLhlo> legalize_pass(
- "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
+std::unique_ptr<OpPassBase<FuncOp>> createLhloCopyRemovalPass() {
+ return absl::make_unique<RedundantCopiesRemoval>();
+}
+
+static PassPipelineRegistration<> legalize_pass(
+ "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect",
+ [](mlir::OpPassManager& pm) {
+ pm.addPass(createLegalizeToLhloPass());
+ pm.addPass(createLhloCopyRemovalPass());
+ });
} // namespace xla_hlo
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h
index c890a81..14a2e33 100644
--- a/tensorflow/compiler/mlir/xla/transforms/passes.h
+++ b/tensorflow/compiler/mlir/xla/transforms/passes.h
@@ -58,6 +58,11 @@
// Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeHloToLinalgPass();
+// Removes unnecessary LHLO copies which copy from the allocated buffers to the
+// block arguments. These copies have been created by replacing TensorStoreOp
+// with LHLO.CopyOp in HLO to LHLO lowering.
+std::unique_ptr<OpPassBase<FuncOp>> createLhloCopyRemovalPass();
+
} // namespace xla_hlo
namespace xla_lhlo {
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
index 76ad2b3..79e56e4 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
@@ -276,6 +276,8 @@
pm.addPass(absl::make_unique<FusionToLhloConverter>());
// Next, we can strip the outer fusion operation.
pm.addPass(absl::make_unique<FusionOpRemover>());
+ // Remove unnecessary Lhlo copies.
+ pm.addPass(::mlir::xla_hlo::createLhloCopyRemovalPass());
// Transform lhlo operations to LinAlg.
pm.addPass(::mlir::xla_lhlo::createLegalizeLhloToLinalgPass());
// Fuse linalg operations. This will yield a single tiled loop nest where