blob: af5fb599dca34636ad06a1ed7db7782437e406b2 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering HLO dialect to LHLO dialect.
#include "absl/memory/memory.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace {
constexpr StringRef kTempBufferAttr = "temp";
Value* GetTensorStoreOrReturnMemRef(Value* value) {
for (const auto& user : value->getUsers()) {
if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
if (tensor_store.getOperand(0) == value) {
return tensor_store.getOperand(1);
}
}
if (auto return_op = dyn_cast<xla_hlo::ReturnOp>(user)) {
if (return_op.getOperand(0) == value) {
auto block = return_op.getOperation()->getBlock();
return *block->args_rbegin();
}
}
}
return nullptr;
}
Operation* GetLastUse(Value* value) {
Operation* last = value->getDefiningOp();
for (auto& user : value->getUses()) {
Operation* user_op = user.getOwner();
if (!user_op->isBeforeInBlock(last)) {
last = user_op;
}
}
return last;
}
Value* InsertAllocAndDealloc(Location loc, Value* result,
ConversionPatternRewriter* rewriter) {
auto result_type = result->getType().dyn_cast<ShapedType>();
if (!result_type || !result_type.hasStaticShape()) {
emitError(loc,
"tensor to buffer conversion expects statically shaped results");
}
auto memref_type =
MemRefType::get(result_type.getShape(), result_type.getElementType());
Operation* last = GetLastUse(result);
Operation* op = result->getDefiningOp();
OpBuilder allocBuilder(op);
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type);
alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true));
allocBuilder.setInsertionPoint(op->getBlock(),
std::next(Block::iterator(last)));
allocBuilder.create<DeallocOp>(loc, alloc);
return alloc;
}
/// For every tensor-type value that is produced in the original function,
/// this function returns the buffer that can be used in the converted
/// function to store that values held in the tensor.
Value* GetBufferForResultValue(Location loc, Value* result,
ConversionPatternRewriter* rewriter) {
if (auto existing_memref = GetTensorStoreOrReturnMemRef(result)) {
return existing_memref;
}
return InsertAllocAndDealloc(loc, result, rewriter);
}
template <typename HloOpTy, typename LhloOpTy>
class HloToLhloOpConverter : public ConversionPattern {
public:
explicit HloToLhloOpConverter(MLIRContext* context)
: ConversionPattern(HloOpTy::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation* op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
if (op->getParentRegion()->getBlocks().size() != 1) {
emitError(op->getLoc(),
"tensor to buffer conversion expects a single block in the "
"region containing the operation");
}
const auto& original_results = op->getResults();
SmallVector<Value*, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
buffer_args.push_back(
GetBufferForResultValue(op->getLoc(), result, &rewriter));
}
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
op->getAttrs());
rewriter.replaceOp(op, ArrayRef<Value*>(buffer_args).slice(operands.size()),
llvm::to_vector<4>(original_results));
return matchSuccess();
}
};
struct HloToLHloReduceConverter
: public OpConversionPattern<xla_hlo::ReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
xla_hlo::ReduceOp op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
// TODO(b/137624192) Implement variadic reduce.
if (op.getNumResults() != 1) return matchFailure();
if (op.getParentRegion()->getBlocks().size() != 1) {
emitError(loc,
"tensor to buffer conversion expects a single block in the "
"region containing the operation");
}
const auto& original_results = op.getResults();
SmallVector<Value*, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
buffer_args.push_back(GetBufferForResultValue(loc, result, &rewriter));
}
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
loc, llvm::None, buffer_args, op.getAttrs());
// Copy over the operations inside the region.
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
// Create new block arguments with correct type.
auto& entry_block = new_op.body().front();
int original_arg_count = entry_block.getNumArguments();
for (int i = 0; i < original_arg_count; ++i) {
auto old_arg = entry_block.getArgument(i);
auto old_type = old_arg->getType().cast<TensorType>();
auto new_type =
MemRefType::get(old_type.getShape(), old_type.getElementType());
auto new_arg = entry_block.addArgument(new_type);
rewriter.replaceUsesOfBlockArgument(old_arg, new_arg);
}
// Add an argument for the result.
entry_block.addArgument(
entry_block.getArgument(original_arg_count)->getType());
// Remove the old arguments.
for (int i = original_arg_count - 1; i >= 0; --i) {
entry_block.eraseArgument(i);
}
// Insert terminator at the end.
rewriter.setInsertionPointToEnd(&entry_block);
rewriter.create<xla_lhlo::TerminatorOp>(loc);
rewriter.replaceOp(op, ArrayRef<Value*>(buffer_args).slice(operands.size()),
llvm::to_vector<4>(original_results));
return matchSuccess();
}
};
class HloToLhloTensorLoadConverter : public ConversionPattern {
public:
explicit HloToLhloTensorLoadConverter(MLIRContext* context)
: ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation* op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOp(op, operands, llvm::to_vector<4>(op->getResults()));
return matchSuccess();
}
};
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
class HloToLhloTensorStoreConverter : public ConversionPattern {
public:
explicit HloToLhloTensorStoreConverter(MLIRContext* context)
: ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation* op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.eraseOp(op);
return matchSuccess();
}
};
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
class HloToLhloReturnConverter : public OpConversionPattern<xla_hlo::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
xla_hlo::ReturnOp op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.eraseOp(op);
return matchSuccess();
}
};
// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary.
//
// Example fusion with HLO ops.
//
// func @fusion(%arg0: memref<2x2xf32>,
// %arg1: memref<2x2xf32>,
// %arg2: memref<2x2xf32>,
// %arg3: memref<2x2xf32>) {
// "xla_lhlo.fusion"() ({
// %0 = tensor_load %arg1 : memref<2x2xf32>
// %1 = tensor_load %arg2 : memref<2x2xf32>
// %2 = "xla_hlo.add"(%0, %1) {name = "add"} :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// %3 = tensor_load %arg0 : memref<2x2xf32>
// %4 = "xla_hlo.mul"(%2, %3) {name = "multiply"} :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// tensor_store %4, %arg3 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> ()
// }) {name = "fusion"} : () -> ()
// return
// }
//
// Transformed fusion with LHLO ops.
// func @fusion(%arg0: memref<2x2xf32>,
// %arg1: memref<2x2xf32>,
// %arg2: memref<2x2xf32>,
// %arg3: memref<2x2xf32>) {
// "xla_lhlo.fusion"() ( {
// %0 = alloc() {temp = true} : memref<2x2xf32>
// "xla_lhlo.add"(%arg1, %arg2, %0) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// "xla_lhlo.mul"(%0, %arg0, %arg3) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// dealloc %0 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> ()
// }) {name = "fusion"} : () -> ()
// return
// }
// }
struct HloLegalizeToLhlo : public FunctionPass<HloLegalizeToLhlo> {
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<xla_lhlo::XlaLhloDialect>();
auto func = getFunction();
populateHLOToLHLOConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
signalPassFailure();
}
}
};
} // namespace
void populateHLOToLHLOConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
HloToLhloOpConverter<xla_hlo::AbsOp, xla_lhlo::AbsOp>,
HloToLhloOpConverter<xla_hlo::AddOp, xla_lhlo::AddOp>,
HloToLhloOpConverter<xla_hlo::AndOp, xla_lhlo::AndOp>,
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp,
xla_lhlo::BroadcastInDimOp>,
HloToLhloOpConverter<xla_hlo::CeilOp, xla_lhlo::CeilOp>,
HloToLhloOpConverter<xla_hlo::CompareOp, xla_lhlo::CompareOp>,
HloToLhloOpConverter<xla_hlo::ConstOp, xla_lhlo::ConstOp>,
HloToLhloOpConverter<xla_hlo::ConvertOp, xla_lhlo::ConvertOp>,
HloToLhloOpConverter<xla_hlo::CosOp, xla_lhlo::CosOp>,
HloToLhloOpConverter<xla_hlo::DivOp, xla_lhlo::DivOp>,
HloToLhloOpConverter<xla_hlo::ExpOp, xla_lhlo::ExpOp>,
HloToLhloOpConverter<xla_hlo::IotaOp, xla_lhlo::IotaOp>,
HloToLhloOpConverter<xla_hlo::MaxOp, xla_lhlo::MaxOp>,
HloToLhloOpConverter<xla_hlo::MinOp, xla_lhlo::MinOp>,
HloToLhloOpConverter<xla_hlo::MulOp, xla_lhlo::MulOp>,
HloToLhloOpConverter<xla_hlo::NegOp, xla_lhlo::NegOp>,
HloToLhloOpConverter<xla_hlo::RemOp, xla_lhlo::RemOp>,
HloToLhloOpConverter<xla_hlo::SelectOp, xla_lhlo::SelectOp>,
HloToLhloOpConverter<xla_hlo::SignOp, xla_lhlo::SignOp>,
HloToLhloOpConverter<xla_hlo::SubOp, xla_lhlo::SubOp>,
HloToLhloOpConverter<xla_hlo::TanhOp, xla_lhlo::TanhOp>,
HloToLHloReduceConverter, HloToLhloReturnConverter,
HloToLhloTensorLoadConverter, HloToLhloTensorStoreConverter
>(context);
// clang-format on
}
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToLhloPass() {
return absl::make_unique<HloLegalizeToLhlo>();
}
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
} // namespace xla_hlo
} // namespace mlir