blob: 903e13ba11e41bac9892695afbeacf767bfcecac [file] [log] [blame]
//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::detail;
#define DEBUG_TYPE "dialect-conversion"
//===----------------------------------------------------------------------===//
// ArgConverter
//===----------------------------------------------------------------------===//
namespace {
/// This class provides a simple interface for converting the types of block
/// arguments. This is done by inserting fake cast operations that map from the
/// illegal type to the original type to allow for undoing pending rewrites in
/// the case of failure.
struct ArgConverter {
ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter)
: castOpName(kCastName, rewriter.getContext()),
loc(rewriter.getUnknownLoc()), typeConverter(typeConverter),
rewriter(rewriter) {}
/// Erase any rewrites registered for arguments to blocks within the given
/// region. This function is called when the given region is to be destroyed.
void cancelPendingRewrites(Block *block);
/// Cleanup and undo any generated conversions for the arguments of block.
/// This method differs from 'cancelPendingRewrites' in that it returns the
/// block signature to its original state.
void discardPendingRewrites(Block *block);
/// Replace usages of the cast operations with the argument directly.
void applyRewrites();
/// Return if the signature of the given block has already been converted.
bool hasBeenConverted(Block *block) const { return argMapping.count(block); }
/// Attempt to convert the signature of the given block.
LogicalResult convertSignature(Block *block, BlockAndValueMapping &mapping);
/// Apply the given signature conversion on the given block.
void applySignatureConversion(
Block *block, TypeConverter::SignatureConversion &signatureConversion,
BlockAndValueMapping &mapping);
/// Convert the given block argument given the provided set of new argument
/// values that are to replace it. This function returns the operation used
/// to perform the conversion.
Operation *convertArgument(BlockArgument *origArg,
ArrayRef<Value *> newValues,
BlockAndValueMapping &mapping);
/// A utility function used to create a conversion cast operation with the
/// given input and result types.
Operation *createCast(ArrayRef<Value *> inputs, Type outputType);
/// This is an operation name for a fake operation that is inserted during the
/// conversion process. Operations of this type are guaranteed to never escape
/// the converter.
static constexpr StringLiteral kCastName = "__mlir_conversion.cast";
OperationName castOpName;
/// This is a collection of cast operations that were generated during the
/// conversion process when converting the types of block arguments.
llvm::MapVector<Block *, SmallVector<Operation *, 4>> argMapping;
/// An instance of the unknown location that is used when generating
/// producers.
Location loc;
/// The type converter to use when changing types.
TypeConverter *typeConverter;
/// The pattern rewriter to use when materializing conversions.
PatternRewriter &rewriter;
};
} // end anonymous namespace
constexpr StringLiteral ArgConverter::kCastName;
/// Erase any rewrites registered for arguments to the given block.
void ArgConverter::cancelPendingRewrites(Block *block) {
auto it = argMapping.find(block);
if (it == argMapping.end())
return;
for (auto *op : it->second) {
op->dropAllDefinedValueUses();
op->erase();
}
argMapping.erase(it);
}
/// Cleanup and undo any generated conversions for the arguments of block.
/// This method differs from 'cancelPendingRewrites' in that it returns the
/// block signature to its original state.
void ArgConverter::discardPendingRewrites(Block *block) {
auto it = argMapping.find(block);
if (it == argMapping.end())
return;
// Erase all of the new arguments.
for (int i = block->getNumArguments() - 1; i >= 0; --i) {
block->getArgument(i)->dropAllUses();
block->eraseArgument(i, /*updatePredTerms=*/false);
}
// Re-instate the old arguments.
auto &mapping = it->second;
for (unsigned i = 0, e = mapping.size(); i != e; ++i) {
auto *op = mapping[i];
auto *arg = block->addArgument(op->getResult(0)->getType());
op->getResult(0)->replaceAllUsesWith(arg);
// If this operation is within a block, it will be cleaned up automatically.
if (!op->getBlock())
op->erase();
}
argMapping.erase(it);
}
/// Replace usages of the cast operations with the argument directly.
void ArgConverter::applyRewrites() {
Block *block;
ArrayRef<Operation *> argOps;
for (auto &mapping : argMapping) {
std::tie(block, argOps) = mapping;
// Process the remapping for each of the original arguments.
for (unsigned i = 0, e = argOps.size(); i != e; ++i) {
auto *op = argOps[i];
// Handle the case of a 1->N value mapping.
if (op->getNumOperands() > 1) {
// If all of the uses were removed, we can drop this op. Otherwise,
// keep the operation alive and let the user handle any remaining
// usages.
if (op->use_empty())
op->erase();
continue;
}
// If mapping is 1-1, replace the remaining uses and drop the cast
// operation.
// FIXME(riverriddle) This should check that the result type and operand
// type are the same, otherwise it should force a conversion to be
// materialized. This works around a current limitation with regards to
// region entry argument type conversion.
if (op->getNumOperands() == 1) {
op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
op->destroy();
continue;
}
// Otherwise, if there are any dangling uses then replace the fake
// conversion operation with one generated by the type converter. This
// is necessary as the cast must persist in the IR after conversion.
auto *opResult = op->getResult(0);
if (!opResult->use_empty()) {
rewriter.setInsertionPointToStart(block);
SmallVector<Value *, 1> operands(op->getOperands());
auto *newOp = typeConverter->materializeConversion(
rewriter, opResult->getType(), operands, op->getLoc());
opResult->replaceAllUsesWith(newOp->getResult(0));
}
op->destroy();
}
}
}
/// Converts the signature of the given entry block.
LogicalResult ArgConverter::convertSignature(Block *block,
BlockAndValueMapping &mapping) {
if (auto conversion = typeConverter->convertBlockSignature(block))
return applySignatureConversion(block, *conversion, mapping), success();
return failure();
}
/// Apply the given signature conversion on the given block.
void ArgConverter::applySignatureConversion(
Block *block, TypeConverter::SignatureConversion &signatureConversion,
BlockAndValueMapping &mapping) {
unsigned origArgCount = block->getNumArguments();
auto convertedTypes = signatureConversion.getConvertedTypes();
if (origArgCount == 0 && convertedTypes.empty())
return;
SmallVector<Value *, 4> newArgRange(block->addArguments(convertedTypes));
ArrayRef<Value *> newArgRef(newArgRange);
// Remap each of the original arguments as determined by the signature
// conversion.
auto &newArgMapping = argMapping[block];
rewriter.setInsertionPointToStart(block);
for (unsigned i = 0; i != origArgCount; ++i) {
ArrayRef<Value *> remappedValues;
if (auto inputMap = signatureConversion.getInputMapping(i))
remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size);
BlockArgument *arg = block->getArgument(i);
newArgMapping.push_back(convertArgument(arg, remappedValues, mapping));
}
// Erase all of the original arguments.
for (unsigned i = 0; i != origArgCount; ++i)
block->eraseArgument(0, /*updatePredTerms=*/false);
}
/// Convert the given block argument given the provided set of new argument
/// values that are to replace it. This function returns the operation used
/// to perform the conversion.
Operation *ArgConverter::convertArgument(BlockArgument *origArg,
ArrayRef<Value *> newValues,
BlockAndValueMapping &mapping) {
// Handle the cases of 1->0 or 1->1 mappings.
if (newValues.size() < 2) {
// Create a temporary producer for the argument during the conversion
// process.
auto *cast = createCast(newValues, origArg->getType());
origArg->replaceAllUsesWith(cast->getResult(0));
// Insert a mapping between this argument and the one that is replacing
// it.
if (!newValues.empty())
mapping.map(cast->getResult(0), newValues[0]);
return cast;
}
// Otherwise, this is a 1->N mapping. Call into the provided type converter
// to pack the new values.
auto *cast = typeConverter->materializeConversion(
rewriter, origArg->getType(), newValues, loc);
assert(cast->getNumResults() == 1 &&
cast->getNumOperands() == newValues.size());
origArg->replaceAllUsesWith(cast->getResult(0));
return cast;
}
/// A utility function used to create a conversion cast operation with the
/// given input and result types.
Operation *ArgConverter::createCast(ArrayRef<Value *> inputs, Type outputType) {
return Operation::create(loc, castOpName, inputs, outputType, llvm::None,
llvm::None, 0, false);
}
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numCreatedOperations, unsigned numReplacements,
unsigned numBlockActions)
: numCreatedOperations(numCreatedOperations),
numReplacements(numReplacements), numBlockActions(numBlockActions) {}
/// The current number of created operations.
unsigned numCreatedOperations;
/// The current number of replacements queued.
unsigned numReplacements;
/// The current number of block actions performed.
unsigned numBlockActions;
};
} // end anonymous namespace
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl {
/// This class represents one requested operation replacement via 'replaceOp'.
struct OpReplacement {
OpReplacement() = default;
OpReplacement(Operation *op, ArrayRef<Value *> newValues)
: op(op), newValues(newValues.begin(), newValues.end()) {}
Operation *op;
SmallVector<Value *, 2> newValues;
};
/// The kind of the block action performed during the rewrite. Actions can be
/// undone if the conversion fails.
enum class BlockActionKind { Split, Move, TypeConversion };
/// Original position of the given block in its parent region. We cannot use
/// a region iterator because it could have been invalidated by other region
/// operations since the position was stored.
struct BlockPosition {
Region *region;
Region::iterator::difference_type position;
};
/// The storage class for an undoable block action (one of BlockActionKind),
/// contains the information necessary to undo this action.
struct BlockAction {
static BlockAction getSplit(Block *block, Block *originalBlock) {
BlockAction action{BlockActionKind::Split, block, {}};
action.originalBlock = originalBlock;
return action;
}
static BlockAction getMove(Block *block, BlockPosition originalPos) {
return {BlockActionKind::Move, block, {originalPos}};
}
static BlockAction getTypeConversion(Block *block) {
return BlockAction{BlockActionKind::TypeConversion, block, {}};
}
// The action kind.
BlockActionKind kind;
// A pointer to the block that was created by the action.
Block *block;
union {
// In use if kind == BlockActionKind::Move and contains a pointer to the
// region that originally contained the block as well as the position of
// the block in that region.
BlockPosition originalPosition;
// In use if kind == BlockActionKind::Split and contains a pointer to the
// block that was split into two parts.
Block *originalBlock;
};
};
ConversionPatternRewriterImpl(PatternRewriter &rewriter,
TypeConverter *converter)
: argConverter(converter, rewriter) {}
/// Return the current state of the rewriter.
RewriterState getCurrentState();
/// Reset the state of the rewriter to a previously saved point.
void resetState(RewriterState state);
/// Undo the block actions (motions, splits) one by one in reverse order until
/// "numActionsToKeep" actions remains.
void undoBlockActions(unsigned numActionsToKeep = 0);
/// Cleanup and destroy any generated rewrite operations. This method is
/// invoked when the conversion process fails.
void discardRewrites();
/// Apply all requested operation rewrites. This method is invoked when the
/// conversion process succeeds.
void applyRewrites();
/// Convert the signature of the given block.
LogicalResult convertBlockSignature(Block *block);
/// Apply a signature conversion on the given region.
void applySignatureConversion(Region *region,
TypeConverter::SignatureConversion &conversion);
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead);
/// Notifies that a block was split.
void notifySplitBlock(Block *block, Block *continuation);
/// Notifies that the blocks of a region are about to be moved.
void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
Region::iterator before);
/// Remap the given operands to those with potentially different types.
void remapValues(Operation::operand_range operands,
SmallVectorImpl<Value *> &remapped);
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
BlockAndValueMapping mapping;
/// Utility used to convert block arguments.
ArgConverter argConverter;
/// Ordered vector of all of the newly created operations during conversion.
SmallVector<Operation *, 4> createdOps;
/// Ordered vector of any requested operation replacements.
SmallVector<OpReplacement, 4> replacements;
/// Ordered list of block operations (creations, splits, motions).
SmallVector<BlockAction, 4> blockActions;
};
} // end namespace detail
} // end namespace mlir
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), replacements.size(),
blockActions.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
// Undo any block actions.
undoBlockActions(state.numBlockActions);
// Reset any replaced operations and undo any saved mappings.
for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
for (auto *result : repl.op->getResults())
mapping.erase(result);
replacements.resize(state.numReplacements);
// Pop all of the newly created operations.
while (createdOps.size() != state.numCreatedOperations)
createdOps.pop_back_val()->erase();
}
void ConversionPatternRewriterImpl::undoBlockActions(
unsigned numActionsToKeep) {
for (auto &action :
llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
switch (action.kind) {
// Merge back the block that was split out.
case BlockActionKind::Split: {
action.originalBlock->getOperations().splice(
action.originalBlock->end(), action.block->getOperations());
action.block->erase();
break;
}
// Move the block back to its original position.
case BlockActionKind::Move: {
Region *originalRegion = action.originalPosition.region;
originalRegion->getBlocks().splice(
std::next(originalRegion->begin(), action.originalPosition.position),
action.block->getParent()->getBlocks(), action.block);
break;
}
// Undo the type conversion.
case BlockActionKind::TypeConversion: {
argConverter.discardPendingRewrites(action.block);
break;
}
}
}
blockActions.resize(numActionsToKeep);
}
void ConversionPatternRewriterImpl::discardRewrites() {
undoBlockActions();
// Remove any newly created ops.
for (auto *op : createdOps) {
op->dropAllDefinedValueUses();
op->erase();
}
}
void ConversionPatternRewriterImpl::applyRewrites() {
// Apply all of the rewrites replacements requested during conversion.
for (auto &repl : replacements) {
for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i)
repl.op->getResult(i)->replaceAllUsesWith(
mapping.lookupOrDefault(repl.newValues[i]));
// If this operation defines any regions, drop any pending argument
// rewrites.
if (argConverter.typeConverter && repl.op->getNumRegions()) {
for (auto &region : repl.op->getRegions())
for (auto &block : region)
argConverter.cancelPendingRewrites(&block);
}
}
// In a second pass, erase all of the replaced operations in reverse. This
// allows processing nested operations before their parent region is
// destroyed.
for (auto &repl : llvm::reverse(replacements))
repl.op->erase();
argConverter.applyRewrites();
}
LogicalResult
ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
// Check to see if this block should not be converted:
// * There is no type converter.
// * The block has already been converted.
// * This is an entry block, these are converted explicitly via patterns.
if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) ||
block->isEntryBlock())
return success();
// Otherwise, try to convert the block signature.
if (failed(argConverter.convertSignature(block, mapping)))
return failure();
blockActions.push_back(BlockAction::getTypeConversion(block));
return success();
}
void ConversionPatternRewriterImpl::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion) {
if (!region->empty()) {
argConverter.applySignatureConversion(&region->front(), conversion,
mapping);
blockActions.push_back(BlockAction::getTypeConversion(&region->front()));
}
}
void ConversionPatternRewriterImpl::replaceOp(
Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead) {
assert(newValues.size() == op->getNumResults());
// Create mappings for each of the new result values.
for (unsigned i = 0, e = newValues.size(); i < e; ++i) {
assert((newValues[i] || op->getResult(i)->use_empty()) &&
"result value has remaining uses that must be replaced");
if (newValues[i])
mapping.map(op->getResult(i), newValues[i]);
}
// Record the requested operation replacement.
replacements.emplace_back(op, newValues);
}
void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
Block *continuation) {
blockActions.push_back(BlockAction::getSplit(continuation, block));
}
void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
Region &region, Region &parent, Region::iterator before) {
for (auto &pair : llvm::enumerate(region)) {
Block &block = pair.value();
unsigned position = pair.index();
blockActions.push_back(BlockAction::getMove(&block, {&region, position}));
}
}
void ConversionPatternRewriterImpl::remapValues(
Operation::operand_range operands, SmallVectorImpl<Value *> &remapped) {
remapped.reserve(llvm::size(operands));
for (Value *operand : operands)
remapped.push_back(mapping.lookupOrDefault(operand));
}
//===----------------------------------------------------------------------===//
// ConversionPatternRewriter
//===----------------------------------------------------------------------===//
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx,
TypeConverter *converter)
: PatternRewriter(ctx),
impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {}
ConversionPatternRewriter::~ConversionPatternRewriter() {}
/// PatternRewriter hook for replacing the results of an operation.
void ConversionPatternRewriter::replaceOp(
Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead) {
impl->replaceOp(op, newValues, valuesToRemoveIfDead);
}
/// Apply a signature conversion to the entry block of the given region.
void ConversionPatternRewriter::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion) {
impl->applySignatureConversion(region, conversion);
}
/// Clone the given operation without cloning its regions.
Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) {
Operation *newOp = OpBuilder::cloneWithoutRegions(*op);
impl->createdOps.push_back(newOp);
return newOp;
}
/// PatternRewriter hook for splitting a block into two parts.
Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {
auto *continuation = PatternRewriter::splitBlock(block, before);
impl->notifySplitBlock(block, continuation);
return continuation;
}
/// PatternRewriter hook for moving blocks out of a region.
void ConversionPatternRewriter::inlineRegionBefore(Region &region,
Region &parent,
Region::iterator before) {
impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
PatternRewriter::inlineRegionBefore(region, parent, before);
}
/// PatternRewriter hook for creating a new operation.
Operation *
ConversionPatternRewriter::createOperation(const OperationState &state) {
auto *result = OpBuilder::createOperation(state);
impl->createdOps.push_back(result);
return result;
}
/// PatternRewriter hook for updating the root operation in-place.
void ConversionPatternRewriter::notifyRootUpdated(Operation *op) {
// The rewriter caches changes to the IR to allow for operating in-place and
// backtracking. The rewriter is currently not capable of backtracking
// in-place modifications.
llvm_unreachable("in-place operation updates are not supported");
}
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}
//===----------------------------------------------------------------------===//
// Conversion Patterns
//===----------------------------------------------------------------------===//
/// Attempt to match and rewrite the IR root at the specified operation.
PatternMatchResult
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
SmallVector<Value *, 4> operands;
auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
dialectRewriter.getImpl().remapValues(op->getOperands(), operands);
// If this operation has no successors, invoke the rewrite directly.
if (op->getNumSuccessors() == 0)
return matchAndRewrite(op, operands, dialectRewriter);
// Otherwise, we need to remap the successors.
SmallVector<Block *, 2> destinations;
destinations.reserve(op->getNumSuccessors());
SmallVector<ArrayRef<Value *>, 2> operandsPerDestination;
unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0);
for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) {
destinations.push_back(op->getSuccessor(i));
// Lookup the successors operands.
unsigned n = op->getNumSuccessorOperands(i);
operandsPerDestination.push_back(
llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n));
seen += n;
}
// Rewrite the operation.
return matchAndRewrite(
op,
llvm::makeArrayRef(operands.data(),
operands.data() + firstSuccessorOperand),
destinations, operandsPerDestination, dialectRewriter);
}
//===----------------------------------------------------------------------===//
// OperationLegalizer
//===----------------------------------------------------------------------===//
namespace {
/// A set of rewrite patterns that can be used to legalize a given operation.
using LegalizationPatterns = SmallVector<RewritePattern *, 1>;
/// This class defines a recursive operation legalizer.
class OperationLegalizer {
public:
using LegalizationAction = ConversionTarget::LegalizationAction;
OperationLegalizer(ConversionTarget &targetInfo,
const OwningRewritePatternList &patterns)
: target(targetInfo) {
buildLegalizationGraph(patterns);
computeLegalizationGraphBenefit();
}
/// Returns if the given operation is known to be illegal on the target.
bool isIllegal(Operation *op) const;
/// Attempt to legalize the given operation. Returns success if the operation
/// was legalized, failure otherwise.
LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
private:
/// Attempt to legalize the given operation by applying the provided pattern.
/// Returns success if the operation was legalized, failure otherwise.
LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
ConversionPatternRewriter &rewriter);
/// Build an optimistic legalization graph given the provided patterns. This
/// function populates 'legalizerPatterns' with the operations that are not
/// directly legal, but may be transitively legal for the current target given
/// the provided patterns.
void buildLegalizationGraph(const OwningRewritePatternList &patterns);
/// Compute the benefit of each node within the computed legalization graph.
/// This orders the patterns within 'legalizerPatterns' based upon two
/// criteria:
/// 1) Prefer patterns that have the lowest legalization depth, i.e.
/// represent the more direct mapping to the target.
/// 2) When comparing patterns with the same legalization depth, prefer the
/// pattern with the highest PatternBenefit. This allows for users to
/// prefer specific legalizations over others.
void computeLegalizationGraphBenefit();
/// The current set of patterns that have been applied.
llvm::SmallPtrSet<RewritePattern *, 8> appliedPatterns;
/// The set of legality information for operations transitively supported by
/// the target.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
/// The legalization information provided by the target.
ConversionTarget &target;
};
} // namespace
bool OperationLegalizer::isIllegal(Operation *op) const {
// Check if the target explicitly marked this operation as illegal.
if (auto action = target.getOpAction(op->getName()))
return action == LegalizationAction::Illegal;
return false;
}
LogicalResult
OperationLegalizer::legalize(Operation *op,
ConversionPatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
<< "\n");
// Check if this operation is legal on the target.
if (target.isLegal(op)) {
LLVM_DEBUG(llvm::dbgs()
<< "-- Success : Operation marked legal by the target\n");
return success();
}
// Otherwise, we need to apply a legalization pattern to this operation.
auto it = legalizerPatterns.find(op->getName());
if (it == legalizerPatterns.end()) {
LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n");
return failure();
}
// The patterns are sorted by expected benefit, so try to apply each in-order.
for (auto *pattern : it->second)
if (succeeded(legalizePattern(op, pattern, rewriter)))
return success();
LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no matched legalization pattern.\n");
return failure();
}
LogicalResult
OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
ConversionPatternRewriter &rewriter) {
LLVM_DEBUG({
llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> (";
interleaveComma(pattern->getGeneratedOps(), llvm::dbgs());
llvm::dbgs() << ")'.\n";
});
// Ensure that we don't cycle by not allowing the same pattern to be
// applied twice in the same recursion stack.
// TODO(riverriddle) We could eventually converge, but that requires more
// complicated analysis.
if (!appliedPatterns.insert(pattern).second) {
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern was already applied.\n");
return failure();
}
auto &rewriterImpl = rewriter.getImpl();
RewriterState curState = rewriterImpl.getCurrentState();
auto cleanupFailure = [&] {
// Reset the rewriter state and pop this pattern.
rewriterImpl.resetState(curState);
appliedPatterns.erase(pattern);
return failure();
};
// Try to rewrite with the given pattern.
rewriter.setInsertionPoint(op);
if (!pattern->matchAndRewrite(op, rewriter)) {
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n");
return cleanupFailure();
}
// If the pattern moved any blocks, try to legalize their types. This ensures
// that the types of the block arguments are legal for the region they were
// moved into.
for (unsigned i = curState.numBlockActions,
e = rewriterImpl.blockActions.size();
i != e; ++i) {
auto &action = rewriterImpl.blockActions[i];
if (action.kind != ConversionPatternRewriterImpl::BlockActionKind::Move)
continue;
// Convert the block signature.
if (failed(rewriterImpl.convertBlockSignature(action.block))) {
LLVM_DEBUG(llvm::dbgs()
<< "-- FAIL: failed to convert types of moved block.\n");
return cleanupFailure();
}
}
// Recursively legalize each of the new operations.
for (unsigned i = curState.numCreatedOperations,
e = rewriterImpl.createdOps.size();
i != e; ++i) {
if (failed(legalize(rewriterImpl.createdOps[i], rewriter))) {
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
return cleanupFailure();
}
}
appliedPatterns.erase(pattern);
return success();
}
void OperationLegalizer::buildLegalizationGraph(
const OwningRewritePatternList &patterns) {
// A mapping between an operation and a set of operations that can be used to
// generate it.
DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
// A mapping between an operation and any currently invalid patterns it has.
DenseMap<OperationName, SmallPtrSet<RewritePattern *, 2>> invalidPatterns;
// A worklist of patterns to consider for legality.
llvm::SetVector<RewritePattern *> patternWorklist;
// Build the mapping from operations to the parent ops that may generate them.
for (auto &pattern : patterns) {
auto root = pattern->getRootKind();
// Skip operations that are always known to be legal.
if (target.getOpAction(root) == LegalizationAction::Legal)
continue;
// Add this pattern to the invalid set for the root op and record this root
// as a parent for any generated operations.
invalidPatterns[root].insert(pattern.get());
for (auto op : pattern->getGeneratedOps())
parentOps[op].insert(root);
// Add this pattern to the worklist.
patternWorklist.insert(pattern.get());
}
while (!patternWorklist.empty()) {
auto *pattern = patternWorklist.pop_back_val();
// Check to see if any of the generated operations are invalid.
if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
auto action = target.getOpAction(op);
return !legalizerPatterns.count(op) &&
(!action || action == LegalizationAction::Illegal);
}))
continue;
// Otherwise, if all of the generated operation are valid, this op is now
// legal so add all of the child patterns to the worklist.
legalizerPatterns[pattern->getRootKind()].push_back(pattern);
invalidPatterns[pattern->getRootKind()].erase(pattern);
// Add any invalid patterns of the parent operations to see if they have now
// become legal.
for (auto op : parentOps[pattern->getRootKind()])
patternWorklist.set_union(invalidPatterns[op]);
}
}
void OperationLegalizer::computeLegalizationGraphBenefit() {
// The smallest pattern depth, when legalizing an operation.
DenseMap<OperationName, unsigned> minPatternDepth;
// Compute the minimum legalization depth for a given operation.
std::function<unsigned(OperationName)> computeDepth = [&](OperationName op) {
// Check for existing depth.
auto depthIt = minPatternDepth.find(op);
if (depthIt != minPatternDepth.end())
return depthIt->second;
// If a mapping for this operation does not exist, then this operation
// is always legal. Return 0 as the depth for a directly legal operation.
auto opPatternsIt = legalizerPatterns.find(op);
if (opPatternsIt == legalizerPatterns.end())
return 0u;
auto &minDepth = minPatternDepth[op];
if (opPatternsIt->second.empty())
return minDepth;
// Initialize the depth to the maximum value.
minDepth = std::numeric_limits<unsigned>::max();
// Compute the depth for each pattern used to legalize this operation.
SmallVector<std::pair<RewritePattern *, unsigned>, 4> patternsByDepth;
patternsByDepth.reserve(opPatternsIt->second.size());
for (RewritePattern *pattern : opPatternsIt->second) {
unsigned depth = 0;
for (auto generatedOp : pattern->getGeneratedOps())
depth = std::max(depth, computeDepth(generatedOp) + 1);
patternsByDepth.emplace_back(pattern, depth);
// Update the min depth for this operation.
minDepth = std::min(minDepth, depth);
}
// If the operation only has one legalization pattern, there is no need to
// sort them.
if (patternsByDepth.size() == 1)
return minDepth;
// Sort the patterns by those likely to be the most beneficial.
llvm::array_pod_sort(
patternsByDepth.begin(), patternsByDepth.end(),
[](const std::pair<RewritePattern *, unsigned> *lhs,
const std::pair<RewritePattern *, unsigned> *rhs) {
// First sort by the smaller pattern legalization depth.
if (lhs->second != rhs->second)
return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,
&rhs->second);
// Then sort by the larger pattern benefit.
auto lhsBenefit = lhs->first->getBenefit();
auto rhsBenefit = rhs->first->getBenefit();
return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit,
&lhsBenefit);
});
// Update the legalization pattern to use the new sorted list.
opPatternsIt->second.clear();
for (auto &patternIt : patternsByDepth)
opPatternsIt->second.push_back(patternIt.first);
return minDepth;
};
// For each operation that is transitively legal, compute a cost for it.
for (auto &opIt : legalizerPatterns)
if (!minPatternDepth.count(opIt.first))
computeDepth(opIt.first);
}
//===----------------------------------------------------------------------===//
// OperationConverter
//===----------------------------------------------------------------------===//
namespace {
enum OpConversionMode {
// In this mode, the conversion will ignore failed conversions to allow
// illegal operations to co-exist in the IR.
Partial,
// In this mode, all operations must be legal for the given target for the
// conversion to succeeed.
Full,
// In this mode, operations are analyzed for legality. No actual rewrites are
// applied to the operations on success.
Analysis,
};
// This class converts operations to a given conversion target via a set of
// rewrite patterns. The conversion behaves differently depending on the
// conversion mode.
struct OperationConverter {
explicit OperationConverter(ConversionTarget &target,
const OwningRewritePatternList &patterns,
OpConversionMode mode,
DenseSet<Operation *> *legalizableOps = nullptr)
: opLegalizer(target, patterns), mode(mode),
legalizableOps(legalizableOps) {}
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops,
TypeConverter *typeConverter);
private:
/// Converts an operation with the given rewriter.
LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
/// Recursively collect all of the operations to convert from within 'region'.
LogicalResult computeConversionSet(Region &region,
std::vector<Operation *> &toConvert);
/// Converts the type signatures of the blocks nested within 'op'.
LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter,
Operation *op);
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
/// The conversion mode to use when legalizing operations.
OpConversionMode mode;
/// A set of pre-existing operations that were found to be legalizable to the
/// target. This field is only used when mode == OpConversionMode::Analysis.
DenseSet<Operation *> *legalizableOps;
};
} // end anonymous namespace
LogicalResult
OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter,
Operation *op) {
// Check to see if type signatures need to be converted.
if (!rewriter.getImpl().argConverter.typeConverter)
return success();
for (auto &region : op->getRegions()) {
for (auto &block : region)
if (failed(rewriter.getImpl().convertBlockSignature(&block)))
return failure();
}
return success();
}
LogicalResult
OperationConverter::computeConversionSet(Region &region,
std::vector<Operation *> &toConvert) {
if (region.empty())
return success();
// Traverse starting from the entry block.
SmallVector<Block *, 16> worklist(1, &region.front());
DenseSet<Block *> visitedBlocks;
visitedBlocks.insert(&region.front());
while (!worklist.empty()) {
auto *block = worklist.pop_back_val();
// Compute the conversion set of each of the nested operations.
for (auto &op : *block) {
toConvert.emplace_back(&op);
for (auto &region : op.getRegions())
computeConversionSet(region, toConvert);
}
// Recurse to children that haven't been visited.
for (Block *succ : block->getSuccessors())
if (visitedBlocks.insert(succ).second)
worklist.push_back(succ);
}
// Check that all blocks in the region were visited.
if (llvm::any_of(llvm::drop_begin(region.getBlocks(), 1),
[&](Block &block) { return !visitedBlocks.count(&block); }))
return emitError(region.getLoc(), "unreachable blocks were not converted");
return success();
}
LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
Operation *op) {
// Legalize the given operation.
if (failed(opLegalizer.legalize(op, rewriter))) {
// Handle the case of a failed conversion for each of the different modes.
/// Full conversions expect all operations to be converted.
if (mode == OpConversionMode::Full)
return op->emitError()
<< "failed to legalize operation '" << op->getName() << "'";
/// Partial conversions allow conversions to fail iff the operation was not
/// explicitly marked as illegal.
if (mode == OpConversionMode::Partial && opLegalizer.isIllegal(op))
return op->emitError()
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
} else {
/// Analysis conversions don't fail if any operations fail to legalize,
/// they are only interested in the operations that were successfully
/// legalized.
if (mode == OpConversionMode::Analysis)
legalizableOps->insert(op);
// If legalization succeeded, convert the types any of the blocks within
// this operation.
if (failed(convertBlockSignatures(rewriter, op)))
return failure();
}
return success();
}
LogicalResult
OperationConverter::convertOperations(ArrayRef<Operation *> ops,
TypeConverter *typeConverter) {
if (ops.empty())
return success();
/// Compute the set of operations and blocks to convert.
std::vector<Operation *> toConvert;
for (auto *op : ops) {
toConvert.emplace_back(op);
for (auto &region : op->getRegions())
if (failed(computeConversionSet(region, toConvert)))
return failure();
}
// Convert each operation and discard rewrites on failure.
ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter);
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
return rewriter.getImpl().discardRewrites(), failure();
// Otherwise, the body conversion succeeded. Apply rewrites if this is not an
// analysis conversion.
if (mode == OpConversionMode::Analysis)
rewriter.getImpl().discardRewrites();
else
rewriter.getImpl().applyRewrites();
return success();
}
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
/// Remap an input of the original signature with a new set of types. The
/// new types are appended to the new signature conversion.
void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
ArrayRef<Type> types) {
assert(!types.empty() && "expected valid types");
remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
addInputs(types);
}
/// Append new input types to the signature conversion, this should only be
/// used if the new types are not intended to remap an existing input.
void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
assert(!types.empty() &&
"1->0 type remappings don't need to be added explicitly");
argTypes.append(types.begin(), types.end());
}
/// Remap an input of the original signature with a range of types in the
/// new signature.
void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
unsigned newInputNo,
unsigned newInputCount) {
assert(!remappedInputs[origInputNo] && "input has already been remapped");
assert(newInputCount != 0 && "expected valid input count");
remappedInputs[origInputNo] = InputMapping{newInputNo, newInputCount};
}
/// This hooks allows for converting a type.
LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) {
if (auto newT = convertType(t)) {
results.push_back(newT);
return success();
}
return failure();
}
/// Convert the given set of types, filling 'results' as necessary. This
/// returns failure if the conversion of any of the types fails, success
/// otherwise.
LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
SmallVectorImpl<Type> &results) {
for (auto type : types)
if (failed(convertType(type, results)))
return failure();
return success();
}
/// Return true if the given type is legal for this type converter, i.e. the
/// type converts to itself.
bool TypeConverter::isLegal(Type type) {
SmallVector<Type, 1> results;
return succeeded(convertType(type, results)) && results.size() == 1 &&
results.front() == type;
}
/// Return true if the inputs and outputs of the given function type are
/// legal.
bool TypeConverter::isSignatureLegal(FunctionType funcType) {
return llvm::all_of(
llvm::concat<const Type>(funcType.getInputs(), funcType.getResults()),
[this](Type type) { return isLegal(type); });
}
/// This hook allows for converting a specific argument of a signature.
LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
SignatureConversion &result) {
// Try to convert the given input type.
SmallVector<Type, 1> convertedTypes;
if (failed(convertType(type, convertedTypes)))
return failure();
// If this argument is being dropped, there is nothing left to do.
if (convertedTypes.empty())
return success();
// Otherwise, add the new inputs.
result.addInputs(inputNo, convertedTypes);
return success();
}
/// Create a default conversion pattern that rewrites the type signature of a
/// FuncOp.
namespace {
struct FuncOpSignatureConversion : public ConversionPattern {
FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
: ConversionPattern(FuncOp::getOperationName(), 1, ctx),
converter(converter) {}
/// Hook for derived classes to implement combined matching and rewriting.
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
FunctionType type = funcOp.getType();
// Convert the original function arguments.
TypeConverter::SignatureConversion result(type.getNumInputs());
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
return matchFailure();
// Convert the original function results.
SmallVector<Type, 1> convertedResults;
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
return matchFailure();
// Create a new function with an updated signature.
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
newFuncOp.setType(FunctionType::get(result.getConvertedTypes(),
convertedResults, funcOp.getContext()));
// Tell the rewriter to convert the region signature.
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
/// The type converter to use when rewriting the signature.
TypeConverter &converter;
};
} // end anonymous namespace
void mlir::populateFuncOpTypeConversionPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx,
TypeConverter &converter) {
patterns.insert<FuncOpSignatureConversion>(ctx, converter);
}
/// This function converts the type signature of the given block, by invoking
/// 'convertSignatureArg' for each argument. This function should return a valid
/// conversion for the signature on success, None otherwise.
auto TypeConverter::convertBlockSignature(Block *block)
-> llvm::Optional<SignatureConversion> {
SignatureConversion conversion(block->getNumArguments());
for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i)
if (failed(convertSignatureArg(i, block->getArgument(i)->getType(),
conversion)))
return llvm::None;
return conversion;
}
//===----------------------------------------------------------------------===//
// ConversionTarget
//===----------------------------------------------------------------------===//
/// Register a legality action for the given operation.
void ConversionTarget::setOpAction(OperationName op,
LegalizationAction action) {
legalOperations[op] = action;
}
/// Register a legality action for the given dialects.
void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
LegalizationAction action) {
for (StringRef dialect : dialectNames)
legalDialects[dialect] = action;
}
/// Get the legality action for the given operation.
auto ConversionTarget::getOpAction(OperationName op) const
-> llvm::Optional<LegalizationAction> {
// Check for an action for this specific operation.
auto it = legalOperations.find(op);
if (it != legalOperations.end())
return it->second;
// Otherwise, default to checking for an action on the parent dialect.
auto dialectIt = legalDialects.find(op.getDialect());
if (dialectIt != legalDialects.end())
return dialectIt->second;
return llvm::None;
}
/// Return if the given operation instance is legal on this target.
bool ConversionTarget::isLegal(Operation *op) const {
auto action = getOpAction(op->getName());
// Handle dynamic legality.
if (action == LegalizationAction::Dynamic) {
// Check for callbacks on the operation or dialect.
auto opFn = opLegalityFns.find(op->getName());
if (opFn != opLegalityFns.end())
return opFn->second(op);
auto dialectFn = dialectLegalityFns.find(op->getName().getDialect());
if (dialectFn != dialectLegalityFns.end())
return dialectFn->second(op);
// Otherwise, invoke the hook on the derived instance.
return isDynamicallyLegal(op);
}
// Otherwise, the operation is only legal if it was marked 'Legal'.
return action == LegalizationAction::Legal;
}
/// Set the dynamic legality callback for the given operation.
void ConversionTarget::setLegalityCallback(
OperationName name, const DynamicLegalityCallbackFn &callback) {
assert(callback && "expected valid legality callback");
opLegalityFns[name] = callback;
}
/// Set the dynamic legality callback for the given dialects.
void ConversionTarget::setLegalityCallback(
ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
assert(callback && "expected valid legality callback");
for (StringRef dialect : dialects)
dialectLegalityFns[dialect] = callback;
}
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
/// Apply a partial conversion on the given operations, and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize.
LogicalResult mlir::applyPartialConversion(
ArrayRef<Operation *> ops, ConversionTarget &target,
const OwningRewritePatternList &patterns, TypeConverter *converter) {
OperationConverter opConverter(target, patterns, OpConversionMode::Partial);
return opConverter.convertOperations(ops, converter);
}
LogicalResult
mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
const OwningRewritePatternList &patterns,
TypeConverter *converter) {
return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
converter);
}
/// Apply a complete conversion on the given operations, and all nested
/// operations. This method will return failure if the conversion of any
/// operation fails.
LogicalResult
mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
const OwningRewritePatternList &patterns,
TypeConverter *converter) {
OperationConverter opConverter(target, patterns, OpConversionMode::Full);
return opConverter.convertOperations(ops, converter);
}
LogicalResult
mlir::applyFullConversion(Operation *op, ConversionTarget &target,
const OwningRewritePatternList &patterns,
TypeConverter *converter) {
return applyFullConversion(llvm::makeArrayRef(op), target, patterns,
converter);
}
/// Apply an analysis conversion on the given operations, and all nested
/// operations. This method analyzes which operations would be successfully
/// converted to the target if a conversion was applied. All operations that
/// were found to be legalizable to the given 'target' are placed within the
/// provided 'convertedOps' set; note that no actual rewrites are applied to the
/// operations on success and only pre-existing operations are added to the set.
LogicalResult mlir::applyAnalysisConversion(
ArrayRef<Operation *> ops, ConversionTarget &target,
const OwningRewritePatternList &patterns,
DenseSet<Operation *> &convertedOps, TypeConverter *converter) {
OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
&convertedOps);
return opConverter.convertOperations(ops, converter);
}
LogicalResult
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
const OwningRewritePatternList &patterns,
DenseSet<Operation *> &convertedOps,
TypeConverter *converter) {
return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
convertedOps, converter);
}