blob: 1ea1d81d9d184df631f4916f9c3a3e487f37be4a [file] [log] [blame]
//===- SDBMExpr.h - MLIR SDBM Expression implementation -------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// A striped difference-bound matrix (SDBM) expression is a constant expression,
// an identifier, a binary expression with constant RHS and +, stripe operators
// or a difference expression between two identifiers.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/SDBMExpr.h"
#include "SDBMExprDetail.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
namespace {
/// A simple compositional matcher for AffineExpr
///
/// Example usage:
///
/// ```c++
/// AffineExprMatcher x, C, m;
/// AffineExprMatcher pattern1 = ((x % C) * m) + x;
/// AffineExprMatcher pattern2 = x + ((x % C) * m);
/// if (pattern1.match(expr) || pattern2.match(expr)) {
/// ...
/// }
/// ```
class AffineExprMatcherStorage;
class AffineExprMatcher {
public:
AffineExprMatcher();
AffineExprMatcher(const AffineExprMatcher &other);
AffineExprMatcher operator+(AffineExprMatcher other) {
return AffineExprMatcher(AffineExprKind::Add, *this, other);
}
AffineExprMatcher operator*(AffineExprMatcher other) {
return AffineExprMatcher(AffineExprKind::Mul, *this, other);
}
AffineExprMatcher floorDiv(AffineExprMatcher other) {
return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other);
}
AffineExprMatcher ceilDiv(AffineExprMatcher other) {
return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other);
}
AffineExprMatcher operator%(AffineExprMatcher other) {
return AffineExprMatcher(AffineExprKind::Mod, *this, other);
}
AffineExpr match(AffineExpr expr);
AffineExpr matched();
Optional<int> getMatchedConstantValue();
private:
AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b);
AffineExprKind kind; // only used to match in binary op cases.
// A shared_ptr allows multiple references to same matcher storage without
// worrying about ownership or dealing with an arena. To be cleaned up if we
// go with this.
std::shared_ptr<AffineExprMatcherStorage> storage;
};
class AffineExprMatcherStorage {
public:
AffineExprMatcherStorage() {}
AffineExprMatcherStorage(const AffineExprMatcherStorage &other)
: subExprs(other.subExprs.begin(), other.subExprs.end()),
matched(other.matched) {}
AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs)
: subExprs(exprs.begin(), exprs.end()) {}
AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b)
: subExprs({a, b}) {}
llvm::SmallVector<AffineExprMatcher, 0> subExprs;
AffineExpr matched;
};
} // namespace
AffineExprMatcher::AffineExprMatcher()
: kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {}
AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other)
: kind(other.kind), storage(other.storage) {}
Optional<int> AffineExprMatcher::getMatchedConstantValue() {
if (auto cst = storage->matched.dyn_cast<AffineConstantExpr>())
return cst.getValue();
return None;
}
AffineExpr AffineExprMatcher::match(AffineExpr expr) {
if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) {
if (storage->matched)
if (storage->matched != expr)
return AffineExpr();
storage->matched = expr;
return storage->matched;
}
if (kind != expr.getKind()) {
return AffineExpr();
}
if (auto bin = expr.dyn_cast<AffineBinaryOpExpr>()) {
if (!storage->subExprs.empty() &&
!storage->subExprs[0].match(bin.getLHS())) {
return AffineExpr();
}
if (!storage->subExprs.empty() &&
!storage->subExprs[1].match(bin.getRHS())) {
return AffineExpr();
}
if (storage->matched)
if (storage->matched != expr)
return AffineExpr();
storage->matched = expr;
return storage->matched;
}
llvm_unreachable("binary expected");
}
AffineExpr AffineExprMatcher::matched() { return storage->matched; }
AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a,
AffineExprMatcher b)
: kind(k), storage(new AffineExprMatcherStorage(a, b)) {
storage->subExprs.push_back(a);
storage->subExprs.push_back(b);
}
//===----------------------------------------------------------------------===//
// SDBMExpr
//===----------------------------------------------------------------------===//
SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
MLIRContext *SDBMExpr::getContext() const { return impl->getContext(); }
template <typename Derived, typename Result = void> class SDBMVisitor {
public:
/// Visit the given SDBM expression, dispatching to kind-specific functions.
Result visit(SDBMExpr expr) {
auto *derived = static_cast<Derived *>(this);
switch (expr.getKind()) {
case SDBMExprKind::Add:
case SDBMExprKind::Diff:
case SDBMExprKind::DimId:
case SDBMExprKind::SymbolId:
case SDBMExprKind::Neg:
case SDBMExprKind::Stripe:
return derived->visitVarying(expr.cast<SDBMVaryingExpr>());
case SDBMExprKind::Constant:
return derived->visitConstant(expr.cast<SDBMConstantExpr>());
}
llvm_unreachable("Unknown SDBMExpr Kind");
}
protected:
/// Default visitors do nothing.
void visitSum(SDBMSumExpr) {}
void visitDiff(SDBMDiffExpr) {}
void visitStripe(SDBMStripeExpr) {}
void visitDim(SDBMDimExpr) {}
void visitSymbol(SDBMSymbolExpr) {}
void visitNeg(SDBMNegExpr) {}
void visitConstant(SDBMConstantExpr) {}
/// Default implementation of visitPositive dispatches to the special
/// functions for stripes and other variables. Concrete visitors can override
/// it.
Result visitPositive(SDBMPositiveExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (expr.getKind() == SDBMExprKind::Stripe)
return derived->visitStripe(expr.cast<SDBMStripeExpr>());
else
return derived->visitInput(expr.cast<SDBMInputExpr>());
}
/// Default implementation of visitInput dispatches to the special
/// functions for dimensions or symbols. Concrete visitors can override it to
/// visit all variables instead.
Result visitInput(SDBMInputExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (expr.getKind() == SDBMExprKind::DimId)
return derived->visitDim(expr.cast<SDBMDimExpr>());
else
return derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
}
/// Default implementation of visitVarying dispatches to the special
/// functions for variables and negations thereof. Concerete visitors can
/// override it to visit all variables and negations isntead.
Result visitVarying(SDBMVaryingExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (auto var = expr.dyn_cast<SDBMPositiveExpr>())
return derived->visitPositive(var);
else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
return derived->visitNeg(neg);
else if (auto sum = expr.dyn_cast<SDBMSumExpr>())
return derived->visitSum(sum);
else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
return derived->visitDiff(diff);
llvm_unreachable("unhandled subtype of varying SDBM expression");
}
};
void SDBMExpr::print(raw_ostream &os) const {
struct Printer : public SDBMVisitor<Printer> {
Printer(raw_ostream &ostream) : prn(ostream) {}
void visitSum(SDBMSumExpr expr) {
visitVarying(expr.getLHS());
prn << " + ";
visitConstant(expr.getRHS());
}
void visitDiff(SDBMDiffExpr expr) {
visitPositive(expr.getLHS());
prn << " - ";
visitPositive(expr.getRHS());
}
void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
void visitStripe(SDBMStripeExpr expr) {
visitPositive(expr.getVar());
prn << " # ";
visitConstant(expr.getStripeFactor());
}
void visitNeg(SDBMNegExpr expr) {
prn << '-';
visitPositive(expr.getVar());
}
void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
raw_ostream &prn;
};
Printer printer(os);
printer.visit(*this);
}
void SDBMExpr::dump() const {
print(llvm::errs());
llvm::errs() << '\n';
}
//===----------------------------------------------------------------------===//
// SDBMSumExpr
//===----------------------------------------------------------------------===//
SDBMVaryingExpr SDBMSumExpr::getLHS() const {
return static_cast<ImplType *>(impl)->lhs;
}
SDBMConstantExpr SDBMSumExpr::getRHS() const {
return static_cast<ImplType *>(impl)->rhs;
}
AffineExpr SDBMExpr::getAsAffineExpr() const {
struct Converter : public SDBMVisitor<Converter, AffineExpr> {
AffineExpr visitSum(SDBMSumExpr expr) {
AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
return lhs + rhs;
}
AffineExpr visitStripe(SDBMStripeExpr expr) {
AffineExpr lhs = visit(expr.getVar()),
rhs = visit(expr.getStripeFactor());
return lhs - (lhs % rhs);
}
AffineExpr visitDiff(SDBMDiffExpr expr) {
AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
return lhs - rhs;
}
AffineExpr visitDim(SDBMDimExpr expr) {
return getAffineDimExpr(expr.getPosition(), expr.getContext());
}
AffineExpr visitSymbol(SDBMSymbolExpr expr) {
return getAffineSymbolExpr(expr.getPosition(), expr.getContext());
}
AffineExpr visitNeg(SDBMNegExpr expr) {
return getAffineBinaryOpExpr(AffineExprKind::Mul,
getAffineConstantExpr(-1, expr.getContext()),
visit(expr.getVar()));
}
AffineExpr visitConstant(SDBMConstantExpr expr) {
return getAffineConstantExpr(expr.getValue(), expr.getContext());
}
} converter;
return converter.visit(*this);
}
Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
// Attempt to recover a stripe expression. Because AffineExprs don't have
// a first-class difference kind, we check for both x + -1 * (x mod C) and
// -1 * (x mod C) + x cases.
AffineExprMatcher x, C, m;
AffineExprMatcher pattern1 = ((x % C) * m) + x;
AffineExprMatcher pattern2 = x + ((x % C) * m);
if ((pattern1.match(expr) && m.getMatchedConstantValue() == -1) ||
(pattern2.match(expr) && m.getMatchedConstantValue() == -1)) {
if (auto convertedLHS = visit(x.matched())) {
// TODO(ntv): return convertedLHS.stripe(C);
return SDBMStripeExpr::get(
convertedLHS.cast<SDBMPositiveExpr>(),
visit(C.matched()).cast<SDBMConstantExpr>());
}
}
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return {};
// In a "add" AffineExpr, the constant always appears on the right. If
// there were two constants, they would have been folded away.
assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
// SDBM accepts LHS variables and RHS constants in a sum.
auto lhsVar = lhs.dyn_cast<SDBMVaryingExpr>();
auto rhsVar = rhs.dyn_cast<SDBMVaryingExpr>();
if (rhsConstant && lhsVar)
return SDBMSumExpr::get(lhsVar, rhsConstant);
// The sum of a negated variable and a non-negated variable is a
// difference, supported as a special kind in SDBM. Because AffineExprs
// don't have first-class difference kind, check both LHS and RHS for
// negation.
auto lhsPos = lhs.dyn_cast<SDBMPositiveExpr>();
auto rhsPos = rhs.dyn_cast<SDBMPositiveExpr>();
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
if (lhsNeg && rhsVar)
return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar());
if (rhsNeg && lhsVar)
return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar());
// Other cases don't fit into SDBM.
return {};
}
SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
// Attempt to recover a stripe expression "x # C = (x floordiv C) * C".
AffineExprMatcher x, C;
AffineExprMatcher pattern = (x.floorDiv(C)) * C;
if (pattern.match(expr)) {
if (SDBMExpr converted = visit(x.matched())) {
if (auto varConverted = converted.dyn_cast<SDBMPositiveExpr>())
// TODO(ntv): return varConverted.stripe(C.getConstantValue());
return SDBMStripeExpr::get(
varConverted,
SDBMConstantExpr::get(varConverted.getContext(),
C.getMatchedConstantValue().getValue()));
}
}
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return {};
// In a "mul" AffineExpr, the constant always appears on the right. If
// there were two constants, they would have been folded away.
assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
if (!rhsConstant)
return {};
// The only supported "multiplication" expression is an SDBM is dimension
// negation, that is a product of dimension and constant -1.
auto lhsVar = lhs.dyn_cast<SDBMPositiveExpr>();
if (lhsVar && rhsConstant.getValue() == -1)
return SDBMNegExpr::get(lhsVar);
// Other multiplications are not allowed in SDBM.
return {};
}
SDBMExpr visitModExpr(AffineBinaryOpExpr expr) {
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return {};
// 'mod' can only be converted to SDBM if its LHS is a variable
// and its RHS is a constant. Then it `x mod c = x - x stripe c`.
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
auto lhsVar = rhs.dyn_cast<SDBMPositiveExpr>();
if (!lhsVar || !rhsConstant)
return {};
return SDBMDiffExpr::get(lhsVar,
SDBMStripeExpr::get(lhsVar, rhsConstant));
}
// `a floordiv b = (a stripe b) / b`, but we have no division in SDBM
SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; }
SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; }
// Dimensions, symbols and constants are converted trivially.
SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
return SDBMConstantExpr::get(expr.getContext(), expr.getValue());
}
SDBMExpr visitDimExpr(AffineDimExpr expr) {
return SDBMDimExpr::get(expr.getContext(), expr.getPosition());
}
SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
return SDBMSymbolExpr::get(expr.getContext(), expr.getPosition());
}
} converter;
if (auto result = converter.visit(affine))
return result;
return None;
}
//===----------------------------------------------------------------------===//
// SDBMDiffExpr
//===----------------------------------------------------------------------===//
SDBMPositiveExpr SDBMDiffExpr::getLHS() const {
return static_cast<ImplType *>(impl)->lhs;
}
SDBMPositiveExpr SDBMDiffExpr::getRHS() const {
return static_cast<ImplType *>(impl)->rhs;
}
//===----------------------------------------------------------------------===//
// SDBMStripeExpr
//===----------------------------------------------------------------------===//
SDBMPositiveExpr SDBMStripeExpr::getVar() const {
if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
return lhs.cast<SDBMPositiveExpr>();
return {};
}
SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
return static_cast<ImplType *>(impl)->rhs;
}
//===----------------------------------------------------------------------===//
// SDBMInputExpr
//===----------------------------------------------------------------------===//
unsigned SDBMInputExpr::getPosition() const {
return static_cast<ImplType *>(impl)->position;
}
//===----------------------------------------------------------------------===//
// SDBMConstantExpr
//===----------------------------------------------------------------------===//
int64_t SDBMConstantExpr::getValue() const {
return static_cast<ImplType *>(impl)->constant;
}
//===----------------------------------------------------------------------===//
// SDBMNegExpr
//===----------------------------------------------------------------------===//
SDBMPositiveExpr SDBMNegExpr::getVar() const {
return static_cast<ImplType *>(impl)->dim;
}