blob: afbeda15fe6884be52f1a435eaee39101d9141b4 [file] [log] [blame]
//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// A striped difference-bound matrix (SDBM) expression is a constant expression,
// an identifier, a binary expression with constant RHS and +, stripe operators
// or a difference expression between two identifiers.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_SDBMEXPR_H
#define MLIR_IR_SDBMEXPR_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
namespace mlir {
class AffineExpr;
class MLIRContext;
enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg };
namespace detail {
struct SDBMExprStorage;
struct SDBMBinaryExprStorage;
struct SDBMDiffExprStorage;
struct SDBMPositiveExprStorage;
struct SDBMConstantExprStorage;
struct SDBMNegExprStorage;
} // namespace detail
class SDBMConstantExpr;
class SDBMDialect;
class SDBMDimExpr;
class SDBMSymbolExpr;
/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side
/// expression for the SDBM framework. SDBM expressions are a subset of affine
/// expressions supporting low-complexity algorithms for the operations used in
/// loop transformations. In particular, are supported:
/// - constant expressions;
/// - single variables (dimensions and symbols) with +1 or -1 coefficient;
/// - stripe expressions: "x # C", where "x" is a single variable or another
/// stripe expression, "#" is the stripe operator, and "C" is a constant
/// expression; "#" is defined as x - x mod C.
/// - sum expressions between single variable/stripe expressions and constant
/// expressions;
/// - difference expressions between single variable/stripe expressions.
/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing
/// and operating on SDBM expressions. For example, it requires the LHS of a
/// sum expression to be a single variable or a stripe expression. These
/// restrictions are intended to force the caller to perform the necessary
/// simplifications to stay within the SDBM domain, because SDBM expressions do
/// not combine in more cases than they do. This choice may be reconsidered in
/// the future.
///
/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by
/// an MLIRContext, and should be used by-value. They are uniqued in the
/// MLIRContext and immortal.
class SDBMExpr {
public:
using ImplType = detail::SDBMExprStorage;
SDBMExpr() : impl(nullptr) {}
/* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {}
/// SDBM expressions are thin wrappers around a unique'ed immutable pointer,
/// which makes them trivially assignable and trivially copyable.
SDBMExpr(const SDBMExpr &) = default;
SDBMExpr &operator=(const SDBMExpr &) = default;
/// SDBM expressions can be compared straight-forwardly.
bool operator==(const SDBMExpr &other) const { return impl == other.impl; }
bool operator!=(const SDBMExpr &other) const { return !(*this == other); }
/// SDBM expressions are convertible to `bool`: null expressions are converted
/// to false, non-null expressions are converted to true.
explicit operator bool() const { return impl != nullptr; }
bool operator!() const { return !static_cast<bool>(*this); }
/// Negate the given SDBM expression.
SDBMExpr operator-();
/// Prints the SDBM expression.
void print(raw_ostream &os) const;
void dump() const;
/// LLVM-style casts.
template <typename U> bool isa() const { return U::isClassFor(*this); }
template <typename U> U dyn_cast() const {
if (!isa<U>())
return {};
return U(const_cast<SDBMExpr *>(this)->impl);
}
template <typename U> U cast() const {
assert(isa<U>() && "cast to incorrect subtype");
return U(const_cast<SDBMExpr *>(this)->impl);
}
/// Support for LLVM hashing.
::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); }
/// Returns the kind of the SDBM expression.
SDBMExprKind getKind() const;
/// Returns the MLIR context in which this expression lives.
MLIRContext *getContext() const;
/// Returns the SDBM dialect instance.
SDBMDialect *getDialect() const;
/// Convert the SDBM expression into an Affine expression. This always
/// succeeds because SDBM are a subset of affine.
AffineExpr getAsAffineExpr() const;
/// Try constructing an SDBM expression from the given affine expression.
/// This may fail if the affine expression is not representable as SDBM, in
/// which case llvm::None is returned. The conversion procedure recognizes
/// (nested) multiplicative ((x floordiv B) * B) and additive (x - x mod B)
/// patterns for the stripe expression.
static Optional<SDBMExpr> tryConvertAffineExpr(AffineExpr affine);
protected:
ImplType *impl;
};
/// SDBM constant expression, wraps a 64-bit integer.
class SDBMConstantExpr : public SDBMExpr {
public:
using ImplType = detail::SDBMConstantExprStorage;
using SDBMExpr::SDBMExpr;
/// Obtain or create a constant expression unique'ed in the given dialect
/// (which belongs to a context).
static SDBMConstantExpr get(SDBMDialect *dialect, int64_t value);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Constant;
}
int64_t getValue() const;
};
/// SDBM varying expression can be one of:
/// - input variable expression;
/// - stripe expression;
/// - negation (product with -1) of either of the above.
/// - sum of a varying and a constant expression
/// - difference between varying expressions
class SDBMVaryingExpr : public SDBMExpr {
public:
using ImplType = detail::SDBMExprStorage;
using SDBMExpr::SDBMExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
expr.getKind() == SDBMExprKind::SymbolId ||
expr.getKind() == SDBMExprKind::Neg ||
expr.getKind() == SDBMExprKind::Stripe ||
expr.getKind() == SDBMExprKind::Add ||
expr.getKind() == SDBMExprKind::Diff;
}
};
/// SDBM positive variable expression can be one of:
/// - single variable expression;
/// - stripe expression.
class SDBMPositiveExpr : public SDBMVaryingExpr {
public:
using SDBMVaryingExpr::SDBMVaryingExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
expr.getKind() == SDBMExprKind::SymbolId ||
expr.getKind() == SDBMExprKind::Stripe;
}
};
/// SDBM sum expression. LHS is a varying expression and RHS is always a
/// constant expression.
class SDBMSumExpr : public SDBMVaryingExpr {
public:
using ImplType = detail::SDBMBinaryExprStorage;
using SDBMVaryingExpr::SDBMVaryingExpr;
/// Obtain or create a sum expression unique'ed in the given context.
static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs);
static bool isClassFor(const SDBMExpr &expr) {
SDBMExprKind kind = expr.getKind();
return kind == SDBMExprKind::Add;
}
SDBMVaryingExpr getLHS() const;
SDBMConstantExpr getRHS() const;
};
/// SDBM difference expression. Both LHS and RHS are positive variable
/// expressions.
class SDBMDiffExpr : public SDBMVaryingExpr {
public:
using ImplType = detail::SDBMDiffExprStorage;
using SDBMVaryingExpr::SDBMVaryingExpr;
/// Obtain or create a difference expression unique'ed in the given context.
static SDBMDiffExpr get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Diff;
}
SDBMPositiveExpr getLHS() const;
SDBMPositiveExpr getRHS() const;
};
/// SDBM stripe expression "x # C" where "x" is a positive variable expression,
/// "C" is a constant expression and "#" is the stripe operator defined as:
/// x # C = x - x mod C.
class SDBMStripeExpr : public SDBMPositiveExpr {
public:
using ImplType = detail::SDBMBinaryExprStorage;
using SDBMPositiveExpr::SDBMPositiveExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Stripe;
}
static SDBMStripeExpr get(SDBMPositiveExpr var,
SDBMConstantExpr stripeFactor);
SDBMPositiveExpr getVar() const;
SDBMConstantExpr getStripeFactor() const;
};
/// SDBM "input" variable expression can be either a dimension identifier or
/// a symbol identifier. When used to define SDBM functions, dimensions are
/// interpreted as function arguments while symbols are treated as unknown but
/// constant values, hence the name.
class SDBMInputExpr : public SDBMPositiveExpr {
public:
using ImplType = detail::SDBMPositiveExprStorage;
using SDBMPositiveExpr::SDBMPositiveExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
expr.getKind() == SDBMExprKind::SymbolId;
}
unsigned getPosition() const;
};
/// SDBM dimension expression. Dimensions correspond to function arguments
/// when defining functions using SDBM expressions.
class SDBMDimExpr : public SDBMInputExpr {
public:
using ImplType = detail::SDBMPositiveExprStorage;
using SDBMInputExpr::SDBMInputExpr;
/// Obtain or create a dimension expression unique'ed in the given dialect
/// (which belongs to a context).
static SDBMDimExpr get(SDBMDialect *dialect, unsigned position);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId;
}
};
/// SDBM symbol expression. Symbols correspond to symbolic constants when
/// defining functions using SDBM expressions.
class SDBMSymbolExpr : public SDBMInputExpr {
public:
using ImplType = detail::SDBMPositiveExprStorage;
using SDBMInputExpr::SDBMInputExpr;
/// Obtain or create a symbol expression unique'ed in the given dialect (which
/// belongs to a context).
static SDBMSymbolExpr get(SDBMDialect *dialect, unsigned position);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::SymbolId;
}
};
/// Negation of an SDBM variable expression. Equivalent to multiplying the
/// expression with -1 (SDBM does not support other coefficients that 1 and -1).
class SDBMNegExpr : public SDBMVaryingExpr {
public:
using ImplType = detail::SDBMNegExprStorage;
using SDBMVaryingExpr::SDBMVaryingExpr;
/// Obtain or create a negation expression unique'ed in the given context.
static SDBMNegExpr get(SDBMPositiveExpr var);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Neg;
}
SDBMPositiveExpr getVar() const;
};
/// A visitor class for SDBM expressions. Calls the kind-specific function
/// depending on the kind of expression it visits.
template <typename Derived, typename Result = void> class SDBMVisitor {
public:
/// Visit the given SDBM expression, dispatching to kind-specific functions.
Result visit(SDBMExpr expr) {
auto *derived = static_cast<Derived *>(this);
switch (expr.getKind()) {
case SDBMExprKind::Add:
case SDBMExprKind::Diff:
case SDBMExprKind::DimId:
case SDBMExprKind::SymbolId:
case SDBMExprKind::Neg:
case SDBMExprKind::Stripe:
return derived->visitVarying(expr.cast<SDBMVaryingExpr>());
case SDBMExprKind::Constant:
return derived->visitConstant(expr.cast<SDBMConstantExpr>());
}
llvm_unreachable("unsupported SDBM expression kind");
}
/// Traverse the SDBM expression tree calling `visit` on each node
/// in depth-first preorder.
void walkPreorder(SDBMExpr expr) { return walk</*isPreorder=*/true>(expr); }
/// Traverse the SDBM expression tree calling `visit` on each node in
/// depth-first postorder.
void walkPostorder(SDBMExpr expr) { return walk</*isPreorder=*/false>(expr); }
protected:
/// Default visitors do nothing.
void visitSum(SDBMSumExpr) {}
void visitDiff(SDBMDiffExpr) {}
void visitStripe(SDBMStripeExpr) {}
void visitDim(SDBMDimExpr) {}
void visitSymbol(SDBMSymbolExpr) {}
void visitNeg(SDBMNegExpr) {}
void visitConstant(SDBMConstantExpr) {}
/// Default implementation of visitPositive dispatches to the special
/// functions for stripes and other variables. Concrete visitors can override
/// it.
Result visitPositive(SDBMPositiveExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (expr.getKind() == SDBMExprKind::Stripe)
return derived->visitStripe(expr.cast<SDBMStripeExpr>());
else
return derived->visitInput(expr.cast<SDBMInputExpr>());
}
/// Default implementation of visitInput dispatches to the special
/// functions for dimensions or symbols. Concrete visitors can override it to
/// visit all variables instead.
Result visitInput(SDBMInputExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (expr.getKind() == SDBMExprKind::DimId)
return derived->visitDim(expr.cast<SDBMDimExpr>());
else
return derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
}
/// Default implementation of visitVarying dispatches to the special
/// functions for variables and negations thereof. Concerete visitors can
/// override it to visit all variables and negations instead.
Result visitVarying(SDBMVaryingExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (auto var = expr.dyn_cast<SDBMPositiveExpr>())
return derived->visitPositive(var);
else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
return derived->visitNeg(neg);
else if (auto sum = expr.dyn_cast<SDBMSumExpr>())
return derived->visitSum(sum);
else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
return derived->visitDiff(diff);
llvm_unreachable("unhandled subtype of varying SDBM expression");
}
template <bool isPreorder> void walk(SDBMExpr expr) {
if (isPreorder)
visit(expr);
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
walk<isPreorder>(sumExpr.getLHS());
walk<isPreorder>(sumExpr.getRHS());
} else if (auto diffExpr = expr.dyn_cast<SDBMDiffExpr>()) {
walk<isPreorder>(diffExpr.getLHS());
walk<isPreorder>(diffExpr.getRHS());
} else if (auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>()) {
walk<isPreorder>(stripeExpr.getVar());
walk<isPreorder>(stripeExpr.getStripeFactor());
} else if (auto negExpr = expr.dyn_cast<SDBMNegExpr>()) {
walk<isPreorder>(negExpr.getVar());
}
if (!isPreorder)
visit(expr);
}
};
/// Overloaded arithmetic operators for SDBM expressions asserting that their
/// arguments have the proper SDBM expression subtype. Perform canonicalization
/// and constant folding on these expressions.
namespace ops_assertions {
/// Add two SDBM expressions. At least one of the expressions must be a
/// constant or a negation, but both expressions cannot be negations
/// simultaneously.
SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs);
inline SDBMExpr operator+(SDBMExpr lhs, int64_t rhs) {
return lhs + SDBMConstantExpr::get(lhs.getDialect(), rhs);
}
inline SDBMExpr operator+(int64_t lhs, SDBMExpr rhs) {
return SDBMConstantExpr::get(rhs.getDialect(), lhs) + rhs;
}
/// Subtract an SDBM expression from another SDBM expression. Both expressions
/// must not be difference expressions.
SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs);
inline SDBMExpr operator-(SDBMExpr lhs, int64_t rhs) {
return lhs - SDBMConstantExpr::get(lhs.getDialect(), rhs);
}
inline SDBMExpr operator-(int64_t lhs, SDBMExpr rhs) {
return SDBMConstantExpr::get(rhs.getDialect(), lhs) - rhs;
}
/// Construct a stripe expression from a positive expression and a positive
/// constant stripe factor.
SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor);
inline SDBMExpr stripe(SDBMExpr expr, int64_t factor) {
return stripe(expr, SDBMConstantExpr::get(expr.getDialect(), factor));
}
} // namespace ops_assertions
} // end namespace mlir
namespace llvm {
// SDBMExpr hash just like pointers.
template <> struct DenseMapInfo<mlir::SDBMExpr> {
static mlir::SDBMExpr getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::SDBMExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static mlir::SDBMExpr getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::SDBMExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::SDBMExpr expr) {
return expr.hash_value();
}
static bool isEqual(mlir::SDBMExpr lhs, mlir::SDBMExpr rhs) {
return lhs == rhs;
}
};
// SDBMVaryingExpr hash just like pointers.
template <> struct DenseMapInfo<mlir::SDBMVaryingExpr> {
static mlir::SDBMVaryingExpr getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::SDBMVaryingExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static mlir::SDBMVaryingExpr getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::SDBMVaryingExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::SDBMVaryingExpr expr) {
return expr.hash_value();
}
static bool isEqual(mlir::SDBMVaryingExpr lhs, mlir::SDBMVaryingExpr rhs) {
return lhs == rhs;
}
};
// SDBMPositiveExpr hash just like pointers.
template <> struct DenseMapInfo<mlir::SDBMPositiveExpr> {
static mlir::SDBMPositiveExpr getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::SDBMPositiveExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static mlir::SDBMPositiveExpr getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::SDBMPositiveExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::SDBMPositiveExpr expr) {
return expr.hash_value();
}
static bool isEqual(mlir::SDBMPositiveExpr lhs, mlir::SDBMPositiveExpr rhs) {
return lhs == rhs;
}
};
// SDBMConstantExpr hash just like pointers.
template <> struct DenseMapInfo<mlir::SDBMConstantExpr> {
static mlir::SDBMConstantExpr getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::SDBMConstantExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static mlir::SDBMConstantExpr getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::SDBMConstantExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::SDBMConstantExpr expr) {
return expr.hash_value();
}
static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) {
return lhs == rhs;
}
};
} // namespace llvm
#endif // MLIR_IR_SDBMEXPR_H