blob: b2adf89bb84d2e86848850cb21cfd075b1d624f3 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering HLO dialect to LHLO dialect.
#include <functional>
#include <utility>
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace mhlo {
namespace {
template <typename T>
class SignlessOpConversion : public OpConversionPattern<T> {
public:
SignlessOpConversion(TypeConverter& type_converter,
RemoveSignTypeConverter* remove_sign_converter,
MLIRContext* ctx)
: OpConversionPattern<T>(type_converter, ctx),
remove_sign_converter_(remove_sign_converter) {}
LogicalResult matchAndRewrite(
T op, typename T::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
// Sign-convert operands and result type.
SmallVector<Value> converted_operands;
for (auto operand : adaptor.getOperands()) {
Type original = operand.getType();
Type converted = remove_sign_converter_->convertType(original);
if (converted == original) {
converted_operands.push_back(operand);
} else {
converted_operands.push_back(
rewriter
.create<UnrealizedConversionCastOp>(loc, converted, operand)
->getResult(0));
}
}
Type op_result_type = remove_sign_converter_->convertType(op.getType());
// Perform actual rewrite.
Value result =
signlessRewrite(op, converted_operands, op_result_type, rewriter);
if (!result) return failure();
// If the element type of the original op and the returned value differ,
// do a conversion cast to fix it up.
auto expected_element_type =
op.getType().template cast<ShapedType>().getElementType();
auto result_type = result.getType().cast<BaseMemRefType>();
auto actual_element_type = result_type.getElementType();
if (expected_element_type != actual_element_type) {
assert(remove_sign_converter_->convertType(expected_element_type) ==
actual_element_type);
Type new_type;
if (auto ranked = result_type.dyn_cast<MemRefType>()) {
new_type = MemRefType::get(ranked.getShape(), expected_element_type,
ranked.getLayout(), ranked.getMemorySpace());
} else {
new_type = UnrankedMemRefType::get(expected_element_type,
result_type.getMemorySpace());
}
result =
rewriter.create<UnrealizedConversionCastOp>(loc, new_type, result)
.getResult(0);
}
rewriter.replaceOp(op, result);
return success();
}
protected:
virtual Value signlessRewrite(T op, ArrayRef<Value> operands,
Type result_type,
ConversionPatternRewriter& rewriter) const = 0;
private:
RemoveSignTypeConverter* remove_sign_converter_;
};
template <typename T>
using BaseOpConversion = SignlessOpConversion<T>;
class HloToMemrefReshapeUnrankedConverter
: public BaseOpConversion<mhlo::ReshapeOp> {
public:
using BaseOpConversion<mhlo::ReshapeOp>::BaseOpConversion;
Value signlessRewrite(mhlo::ReshapeOp op, ArrayRef<Value> operands,
Type op_result_type,
ConversionPatternRewriter& rewriter) const final {
mhlo::ReshapeOp::Adaptor adaptor(operands);
auto unranked_operand_type =
adaptor.operand().getType().dyn_cast<UnrankedMemRefType>();
if (unranked_operand_type == nullptr) return {};
auto loc = op->getLoc();
auto result_type = op_result_type.cast<RankedTensorType>();
auto cast = rewriter.create<memref::CastOp>(
loc, adaptor.operand(),
MemRefType::get(result_type.getShape(), result_type.getElementType()));
return cast;
}
};
class HloToMemrefDynamicReshapeConverter
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
public:
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
Value signlessRewrite(mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
Type op_result_type,
ConversionPatternRewriter& rewriter) const final {
ShapedType result_type;
if (auto ranked_type = op_result_type.dyn_cast<RankedTensorType>()) {
result_type =
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
} else if (auto unranked_type =
op_result_type.dyn_cast<UnrankedTensorType>()) {
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
} else {
return {};
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
auto reshape = rewriter.create<memref::ReshapeOp>(
op.getLoc(), result_type, adaptor.operand(), adaptor.output_shape());
return reshape;
}
};
// TODO(b/175670649) Fix this to no longer access original tensor operands.
class HloToMemrefDynamicBroadcastInDimOpConverter
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
public:
HloToMemrefDynamicBroadcastInDimOpConverter(
TypeConverter& converter, RemoveSignTypeConverter* sign_converter,
MLIRContext* ctx, std::function<bool(Operation*)> enforce_identity_maps)
: BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter,
sign_converter, ctx),
enforce_identity_maps_(std::move(enforce_identity_maps)) {}
Value signlessRewrite(mhlo::DynamicBroadcastInDimOp op,
ArrayRef<Value> operands, Type op_result_type,
ConversionPatternRewriter& rewriter) const final {
auto result_type = op_result_type.dyn_cast<RankedTensorType>();
if (!result_type) return {};
Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
if (enforce_identity_maps_(op)) {
result = CreateCopy(op, result, &rewriter);
}
return result;
}
private:
// Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is
// necessary.
memref::ReinterpretCastOp InsertDynamicMemrefCastOp(
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
auto operand_rank = operand_type.getRank();
auto result_type = op.getType().cast<RankedTensorType>();
auto result_rank = result_type.getRank();
Value zero = b->create<arith::ConstantIndexOp>(loc, 0);
Value one = b->create<arith::ConstantIndexOp>(loc, 1);
// Compute a reversed scan product. Compute the stride for the dimensions so
// far, working from minor to major dimensions. Additionally, save the
// operand shape Values to use in the next loop.
SmallVector<Value, 2> operand_strides(operand_rank, one);
SmallVector<Value, 2> operand_sizes(operand_rank, one);
Value stride_so_far = one;
for (int i = operand_rank - 1; i >= 0; --i) {
Value operand_dim_size =
ShapedType::isDynamic(operand_shape[i])
? b->create<memref::DimOp>(loc, operand, i).getResult()
: b->create<arith::ConstantIndexOp>(loc, operand_shape[i])
.getResult();
operand_sizes[i] = operand_dim_size;
operand_strides[i] = stride_so_far;
if (i > 0) {
stride_so_far =
b->create<arith::MulIOp>(loc, stride_so_far, operand_dim_size);
}
}
SmallVector<OpFoldResult, 2> sizes, strides;
sizes.reserve(result_rank);
strides.reserve(result_rank);
DenseMap<int, int> output_to_input_dim;
for (const auto& dim : llvm::enumerate(op.broadcast_dimensions())) {
output_to_input_dim[dim.value().getSExtValue()] = dim.index();
}
for (int i = 0; i < result_rank; ++i) {
Value i_val = b->create<arith::ConstantIndexOp>(loc, i);
Value result_dim_size =
b->create<tensor::ExtractOp>(loc, op.output_dimensions(), i_val);
if (!result_dim_size.getType().isIndex()) {
result_dim_size = b->create<arith::IndexCastOp>(loc, result_dim_size,
b->getIndexType());
}
if (result_type.isDynamicDim(i)) {
sizes.push_back(result_dim_size);
} else {
sizes.push_back(b->getIndexAttr(result_type.getDimSize(i)));
}
auto it = output_to_input_dim.find(i);
// If the rank of the output is greater than the rank of the input, i.e.
// there was no output dimension in the inverse broadcast_dimensions map
// we also set stride to 0 to emulate padding of the shape with 1s and the
// corresponding expansion.
if (it == output_to_input_dim.end()) {
strides.push_back(zero);
continue;
}
// There can be two cases:
// 1) Operand dim == result dim => expansion is not needed
// => stride flattened buffer stride
// 2) Operand dim < result dim => expansion is needed => stride := 0.
int dim = it->second;
Value is_expansion = b->create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, operand_sizes[dim], result_dim_size);
Value select = b->create<mlir::SelectOp>(loc, is_expansion, zero,
operand_strides[dim]);
strides.push_back(select);
}
// Type-erased memref type with static rank and dynamic strides.
SmallVector<int64_t, 2> dynamic_layout(result_rank,
MemRefType::kDynamicStrideOrOffset);
auto type_erased_memref_type = MemRefType::get(
result_type.getShape(), operand_type.getElementType(),
makeStridedLinearLayoutMap(dynamic_layout,
/*offset=*/0, b->getContext()));
auto transformed_operand = b->create<memref::ReinterpretCastOp>(
loc, type_erased_memref_type, operand,
/*offset=*/b->getI64IntegerAttr(0), sizes, strides);
return transformed_operand;
}
Value CreateCopy(mhlo::DynamicBroadcastInDimOp op, Value broadcasted,
OpBuilder* b) const {
MemRefType result_type = broadcasted.getType().cast<MemRefType>();
auto loc = op.getLoc();
SmallVector<Value, 4> dynamic_operands;
for (int i = 0; i < result_type.getRank(); ++i) {
if (!result_type.isDynamicDim(i)) continue;
auto index = b->createOrFold<arith::ConstantIndexOp>(loc, i);
Value size =
b->create<tensor::ExtractOp>(loc, op.output_dimensions(), index);
if (!size.getType().isIndex()) {
size = b->create<arith::IndexCastOp>(loc, size, b->getIndexType());
}
dynamic_operands.push_back(size);
}
auto identity_map_memref =
MemRefType::get(result_type.getShape(), result_type.getElementType());
auto copy = b->create<memref::AllocOp>(op.getLoc(), identity_map_memref,
dynamic_operands);
b->create<memref::CopyOp>(loc, broadcasted, copy);
return copy;
}
std::function<bool(Operation*)> enforce_identity_maps_;
};
struct HloLegalizeToMemrefPass
: public HloLegalizeToMemrefPassBase<HloLegalizeToMemrefPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
tensor::TensorDialect>();
}
public:
void runOnFunction() override {
auto& context = getContext();
OwningRewritePatternList patterns(&context);
ConversionTarget target(context);
bufferization::BufferizeTypeConverter converter;
RemoveSignTypeConverter sign_converter;
populateHLOToMemrefConversionPattern(&converter, &sign_converter,
&patterns);
target.addIllegalOp<DynamicReshapeOp, DynamicBroadcastInDimOp>();
target.addLegalDialect<arith::ArithmeticDialect,
bufferization::BufferizationDialect, BuiltinDialect,
memref::MemRefDialect, StandardOpsDialect,
tensor::TensorDialect>();
auto func = getFunction();
if (failed(applyPartialConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void populateHLOToMemrefConversionPattern(
bufferization::BufferizeTypeConverter* converter,
RemoveSignTypeConverter* sign_converter, OwningRewritePatternList* patterns,
const std::function<bool(Operation*)>& enforce_identity_maps) {
MLIRContext* context = patterns->getContext();
patterns->insert<HloToMemrefDynamicBroadcastInDimOpConverter>(
*converter, sign_converter, context, std::move(enforce_identity_maps));
patterns->insert<HloToMemrefDynamicReshapeConverter,
HloToMemrefReshapeUnrankedConverter>(
*converter, sign_converter, context);
}
std::unique_ptr<FunctionPass> createLegalizeToMemrefPass() {
return std::make_unique<HloLegalizeToMemrefPass>();
}
} // namespace mhlo
} // namespace mlir