blob: d50181ef3da28a7b3c140be451bbe95978c2ca20 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements conversion of `gml_st.loop` to buffer form.
#include <utility>
#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
#include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h"
#include "mlir-hlo/Dialect/gml_st/transforms/passes.h"
#include "mlir-hlo/Dialect/gml_st/transforms/rewriters.h"
#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace {
using bufferization::ToMemrefOp;
using bufferization::ToTensorOp;
using gml_st::LoopOp;
using linalg::InitTensorOp;
using memref::SubViewOp;
using tensor::ExtractSliceOp;
using tensor::InsertSliceOp;
using vector::TransferReadOp;
using vector::TransferWriteOp;
static Value materializeToTensor(OpBuilder &builder, TensorType type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<BaseMemRefType>());
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
}
// TODO(pifon): Remove as soon as https://reviews.llvm.org/D93126 is landed.
class CustomBufferizeTypeConverter
: public bufferization::BufferizeTypeConverter {
public:
CustomBufferizeTypeConverter() {
// Keep all types unchanged.
addConversion([](Type type) { return type; });
// Convert RankedTensorType to MemRefType.
addConversion([](RankedTensorType type) -> Type {
return MemRefType::get(type.getShape(), type.getElementType());
});
// Convert UnrankedTensorType to UnrankedMemRefType.
addConversion([](UnrankedTensorType type) -> Type {
return UnrankedMemRefType::get(type.getElementType(), 0);
});
addArgumentMaterialization(materializeToTensor);
addSourceMaterialization(materializeToTensor);
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
// Target materialization is invoked if the new operand type does not
// match the expected type. A special case is when the new operand type is
// a memref with a specified layout, i.e. non-empty affine map.
// TODO(pifon) : Change how target materialization is invoked in dialect
// conversion.
if (auto memrefType = inputs[0].getType().dyn_cast<MemRefType>()) {
assert(!memrefType.getLayout().isIdentity());
return inputs[0];
}
assert(inputs[0].getType().isa<TensorType>());
return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
});
}
};
/// Convert `tensor.extract_slice` to `memref.subview` in-place.
struct BufferizeExtractSliceOp : public OpConversionPattern<ExtractSliceOp> {
using OpConversionPattern<ExtractSliceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ExtractSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (!op->getParentOfType<LoopOp>()) return failure();
rewriter.replaceOpWithNewOp<SubViewOp>(
op, adaptor.getSource(), op.getMixedOffsets(), op.getMixedSizes(),
op.getMixedStrides());
return success();
}
};
/// Convert `linalg.init_tensor` of `memref.alloc`.
struct BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
using OpConversionPattern<InitTensorOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
InitTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (!op->getParentOfType<LoopOp>()) return failure();
rewriter.replaceOpWithNewOp<memref::AllocOp>(
op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
adaptor.sizes());
return success();
}
};
bool isBlockArgOfTiledLoop(Value value) {
if (auto blockArg = value.dyn_cast<BlockArgument>())
return isa<LoopOp>(blockArg.getOwner()->getParentOp());
return false;
}
// Attempts to find an existing `memref.subview` of `destMemRef` in the tiled
// loop. The assumption is that in `gml_st.loop` the tile of the output
// tensor that we read and the tile that we write to are the same.
Value findExistingSubview(Value destMemRef) {
if (auto toMemref = destMemRef.getDefiningOp<ToMemrefOp>()) {
if (auto toTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>()) {
if (!isBlockArgOfTiledLoop(toTensor.getMemref())) return Value{};
// Scan through users of the block argument to find `subview` op.
for (Operation *tensorUser : toMemref.getTensor().getUsers()) {
if (auto anotherCast = mlir::dyn_cast<ToMemrefOp>(tensorUser)) {
for (Operation *memrefUser : anotherCast.getMemref().getUsers()) {
if (auto subview = mlir::dyn_cast<SubViewOp>(memrefUser)) {
if (subview.getSource() == destMemRef) return subview;
}
}
}
}
}
}
return Value{};
}
/// Convert `tensor.insert_slice` to `memref.subview` in-place.
struct BufferizeInsertSliceOp : public OpConversionPattern<InsertSliceOp> {
public:
using OpConversionPattern<InsertSliceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
InsertSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Value sourceMemRef = adaptor.getSource();
assert(sourceMemRef.getType().isa<MemRefType>());
Value destMemRef = adaptor.getDest();
assert(destMemRef.getType().isa<MemRefType>());
if (!op->getParentOfType<LoopOp>()) return failure();
Value subview = findExistingSubview(destMemRef);
if (!subview) {
subview = rewriter.create<SubViewOp>(
op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
op.getMixedStrides());
}
rewriter.create<memref::CopyOp>(op.getLoc(), sourceMemRef, subview);
rewriter.replaceOp(op, destMemRef);
return success();
}
};
/// Create linalg op on buffers given the original tensor-based operation and
/// the buffers for the outputs.
linalg::LinalgOp createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
linalg::LinalgOp linalgOp,
ValueRange inputs,
ValueRange outputs) {
SmallVector<Value, 8> newOperands = inputs;
newOperands.append(outputs.begin(), outputs.end());
auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(),
/*resultTypes=*/ArrayRef<Type>{},
newOperands);
for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) {
auto &oldRegion = std::get<0>(regions);
auto &newRegion = std::get<1>(regions);
rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
}
return newOp;
}
/// Get a variadic operand segment.
ValueRange getVariadicOperands(DenseIntElementsAttr sizeAttr,
ValueRange operands, unsigned index) {
const uint32_t *sizeIt = &*sizeAttr.value_begin<uint32_t>();
if (sizeAttr.isSplat()) return operands.slice(*sizeIt * index, *sizeIt);
unsigned start = 0;
for (unsigned i = 0; i < index; ++i) start += sizeIt[i];
return operands.slice(start, sizeIt[index]);
}
// Bufferize LinalgOps in-place.
struct BufferizeLinalgOp
: public OpInterfaceConversionPattern<linalg::LinalgOp> {
using OpInterfaceConversionPattern<
linalg::LinalgOp>::OpInterfaceConversionPattern;
LogicalResult matchAndRewrite(
linalg::LinalgOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
if (!op->getParentOfType<LoopOp>()) return failure();
// An op with two variadic operand groups expects a segment size attribute.
auto operandSegments =
op->getAttrOfType<DenseIntElementsAttr>("operand_segment_sizes");
if (!operandSegments) return failure();
const auto getOperands = [&](unsigned index) {
return getVariadicOperands(operandSegments, operands, index);
};
createLinalgOpOnBuffers(rewriter, op, getOperands(0), getOperands(1));
rewriter.replaceOp(op, getOperands(1));
return success();
}
};
// Convert `gml_st.yield` terminator of `gml_st.loop` to `gml_st.yield` with no
// arguments.
struct BufferizeLinalgYieldOp : public OpConversionPattern<gml_st::YieldOp> {
using OpConversionPattern<gml_st::YieldOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
gml_st::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (!mlir::dyn_cast<LoopOp>(op->getParentOp()) ||
adaptor.getOperands().empty())
return failure();
rewriter.replaceOpWithNewOp<gml_st::YieldOp>(op);
return success();
}
};
// FuncOp-like bufferization pattern for `gml_st.loop` that inserts
// `memref.tensor_load` ops for every memref block argument.
//
// TODO(b/230082413): This code has to go away if we migrate to one-shot
// bufferization.
struct BufferizeLoopOp : public OpConversionPattern<LoopOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
LoopOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getNumResults() == 0) return failure();
// Allocate new buffers for results if it is used by multiple uses.
SmallVector<Value, 4> operands = adaptor.getOperands();
for (auto &en : llvm::enumerate(op.outputs())) {
Value output = en.value();
auto toTensor = output.getDefiningOp<bufferization::ToTensorOp>();
if (!toTensor || toTensor->hasOneUse()) continue;
auto alloc = toTensor.getMemref().getDefiningOp<memref::AllocOp>();
if (!alloc) continue;
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
auto *newAlloc = rewriter.clone(*alloc.getOperation());
operands[op.getNumControlOperands() + op.getNumInputs() + en.index()] =
newAlloc->getResult(0);
}
SmallVector<NamedAttribute> attrList;
for (auto &item : adaptor.getAttributes()) {
attrList.push_back(item);
}
auto newOp = rewriter.create<LoopOp>(op.getLoc(), mlir::TypeRange{},
operands, attrList);
// Take the region from the old op and put it in the new op.
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
newOp.getLoopBody().end());
// Convert the type of the entry block of the LoopOp's body.
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
*getTypeConverter()))) {
return rewriter.notifyMatchFailure(op, "could not convert body types");
}
rewriter.replaceOp(op, newOp.outputs());
return success();
}
};
// TODO(b/199045477): The pattern for vector.transfer_read/write have to be
// moved out of Linalg bufferization to a VectorOps bufferization pass.
struct BufferizeVectorTransferReadOp
: public OpConversionPattern<vector::TransferReadOp> {
using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
vector::TransferReadOp readOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (readOp.getShapedType().isa<MemRefType>()) return failure();
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
readOp, readOp.getType(), adaptor.getSource(), adaptor.getIndices(),
adaptor.getPermutationMapAttr(), adaptor.getPadding(),
adaptor.getMask(),
adaptor.getInBounds() ? adaptor.getInBoundsAttr() : ArrayAttr());
return success();
}
};
struct BufferizeVectorTransferWriteOp
: public OpConversionPattern<vector::TransferWriteOp> {
using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
vector::TransferWriteOp writeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (writeOp.getShapedType().isa<MemRefType>()) return failure();
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), adaptor.getVector(), adaptor.getSource(),
adaptor.getIndices(), adaptor.getPermutationMapAttr(),
adaptor.getInBounds() ? adaptor.getInBoundsAttr() : ArrayAttr());
rewriter.replaceOp(writeOp, adaptor.getSource());
return success();
}
};
} // namespace
namespace gml_st {
struct TiledLoopBufferizePass
: public TiledLoopBufferizePassBase<TiledLoopBufferizePass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnOperation() override {
// Bufferize ops using BufferizableOpInterface. This could be switched to
// One-Shot Bufferize in the future.
mlir::RewritePatternSet patterns(&getContext());
mlir::bufferization::BufferizationOptions options =
mlir::bufferization::getPartialBufferizationOptions();
// TODO(springerm): Add dialects to this filter as more and more dialects
// will be migrated to BufferizableOpInterface-based bufferization.
options.opFilter.allowDialect<shape::ShapeDialect>();
if (failed(mlir::bufferization::bufferizeOp(getOperation(), options))) {
signalPassFailure();
return;
}
// Bufferize the remaining IR with dialect conversion. This will disappear
// eventually once all bufferization is done via BufferizableOpInterface.
if (failed(runDialectConversionBasedBufferization())) signalPassFailure();
}
private:
LogicalResult runDialectConversionBasedBufferization() {
mlir::RewritePatternSet patterns(&getContext());
auto &context = getContext();
ConversionTarget target(context);
target.addLegalDialect<
mlir::arith::ArithmeticDialect,
mlir::bufferization::BufferizationDialect,
mlir::complex::ComplexDialect, mlir::lmhlo::LmhloDialect,
mlir::AffineDialect, mlir::vector::VectorDialect,
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
mlir::tensor::TensorDialect, mlir::math::MathDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
target.addIllegalDialect<mhlo::MhloDialect>();
target.addIllegalOp<tensor::ExtractSliceOp, tensor::InsertSliceOp>();
CustomBufferizeTypeConverter converter;
mlir::mhlo::RemoveSignTypeConverter removeSignConverter;
// Configure bufferize pattern.
populateCallOpTypeConversionPattern(patterns, converter);
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
mlir::bufferization::populateBufferizeMaterializationLegality(target);
populateTiledLoopBufferizePattern(&getContext(), &converter, &patterns);
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
converter, patterns, target);
// Configure legality.
auto isLegalOp = [&](Operation *op) { return converter.isLegal(op); };
target.addDynamicallyLegalDialect<mlir::linalg::LinalgDialect>(isLegalOp);
target.addDynamicallyLegalOp<mlir::func::CallOp, gml_st::LoopOp,
gml_st::YieldOp, mlir::LLVM::InlineAsmOp,
mlir::vector::TransferWriteOp,
mlir::vector::TransferReadOp>(isLegalOp);
return applyPartialConversion(getOperation(), target, std::move(patterns));
}
};
void populateTiledLoopBufferizePattern(
mlir::MLIRContext *context,
mlir::bufferization::BufferizeTypeConverter *converter,
mlir::RewritePatternSet *patterns) {
// clang-format off
patterns->add<
BufferizeExtractSliceOp,
BufferizeInitTensorOp,
BufferizeInsertSliceOp,
BufferizeLinalgOp,
BufferizeLinalgYieldOp,
BufferizeLoopOp,
BufferizeVectorTransferReadOp,
BufferizeVectorTransferWriteOp
>(*converter, context);
// clang-format on
}
std::unique_ptr<OperationPass<func::FuncOp>> CreateTiledLoopBufferizePass() {
return std::make_unique<TiledLoopBufferizePass>();
}
} // namespace gml_st
} // namespace mlir