blob: e261cc8c836127a0ec61d273b39cb6363085be3e [file] [log] [blame]
//===- Types.h - MLIR EDSC Type System Implementation -----------*- C++ -*-===//
// Copyright 2019 The MLIR Authors.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/EDSC/Types.h"
#include "mlir-c/Core.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
using llvm::errs;
using llvm::Twine;
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::detail;
namespace mlir {
namespace edsc {
namespace detail {
struct ExprStorage {
ExprStorage(ExprKind kind) : kind(kind) {}
ExprKind kind;
struct BindableStorage : public ExprStorage {
BindableStorage(unsigned id) : ExprStorage(ExprKind::Unbound), id(id) {}
unsigned id;
struct UnaryExprStorage : public ExprStorage {
UnaryExprStorage(ExprKind k, Expr expr) : ExprStorage(k), expr(expr) {}
Expr expr;
struct BinaryExprStorage : public ExprStorage {
BinaryExprStorage(ExprKind k, Expr lhs, Expr rhs)
: ExprStorage(k), lhs(lhs), rhs(rhs) {}
Expr lhs, rhs;
struct TernaryExprStorage : public ExprStorage {
TernaryExprStorage(ExprKind k, Expr cond, Expr lhs, Expr rhs)
: ExprStorage(k), cond(cond), lhs(lhs), rhs(rhs) {}
Expr cond, lhs, rhs;
struct VariadicExprStorage : public ExprStorage {
VariadicExprStorage(ExprKind k, ArrayRef<Expr> exprs, ArrayRef<Type> types)
: ExprStorage(k), exprs(exprs.begin(), exprs.end()),
types(types.begin(), types.end()) {}
ArrayRef<Expr> exprs;
ArrayRef<Type> types;
struct StmtStorage {
StmtStorage(Bindable lhs, Expr rhs, llvm::ArrayRef<Stmt> enclosedStmts)
: lhs(lhs), rhs(rhs), enclosedStmts(enclosedStmts) {}
Bindable lhs;
Expr rhs;
ArrayRef<Stmt> enclosedStmts;
} // namespace detail
} // namespace edsc
} // namespace mlir
mlir::edsc::ScopedEDSCContext::ScopedEDSCContext() {
Expr::globalAllocator() = &allocator;
mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() {
Expr::globalAllocator() = nullptr;
ExprKind mlir::edsc::Expr::getKind() const { return storage->kind; }
Expr mlir::edsc::Expr::operator+(Expr other) const {
return BinaryExpr(ExprKind::Add, *this, other);
Expr mlir::edsc::Expr::operator-(Expr other) const {
return BinaryExpr(ExprKind::Sub, *this, other);
Expr mlir::edsc::Expr::operator*(Expr other) const {
return BinaryExpr(ExprKind::Mul, *this, other);
Expr mlir::edsc::Expr::operator==(Expr other) const {
return BinaryExpr(ExprKind::EQ, *this, other);
Expr mlir::edsc::Expr::operator!=(Expr other) const {
return BinaryExpr(ExprKind::NE, *this, other);
Expr mlir::edsc::Expr::operator<(Expr other) const {
return BinaryExpr(ExprKind::LT, *this, other);
Expr mlir::edsc::Expr::operator<=(Expr other) const {
return BinaryExpr(ExprKind::LE, *this, other);
Expr mlir::edsc::Expr::operator>(Expr other) const {
return BinaryExpr(ExprKind::GT, *this, other);
Expr mlir::edsc::Expr::operator>=(Expr other) const {
return BinaryExpr(ExprKind::GE, *this, other);
Expr mlir::edsc::Expr::operator&&(Expr other) const {
return BinaryExpr(ExprKind::And, *this, other);
Expr mlir::edsc::Expr::operator||(Expr other) const {
return BinaryExpr(ExprKind::Or, *this, other);
// Free functions.
llvm::SmallVector<Bindable, 8> mlir::edsc::makeBindables(unsigned n) {
llvm::SmallVector<Bindable, 8> res;
for (auto i = 0; i < n; ++i) {
return res;
static llvm::SmallVector<Bindable, 8> makeBindables(edsc_expr_list_t exprList) {
llvm::SmallVector<Bindable, 8> exprs;
for (unsigned i = 0; i < exprList.n; ++i) {
return exprs;
llvm::SmallVector<Expr, 8> mlir::edsc::makeExprs(unsigned n) {
llvm::SmallVector<Expr, 8> res;
for (auto i = 0; i < n; ++i) {
return res;
llvm::SmallVector<Expr, 8> mlir::edsc::makeExprs(ArrayRef<Bindable> bindables) {
llvm::SmallVector<Expr, 8> res;
for (auto b : bindables) {
return res;
static llvm::SmallVector<Expr, 8> makeExprs(edsc_expr_list_t exprList) {
llvm::SmallVector<Expr, 8> exprs;
for (unsigned i = 0; i < exprList.n; ++i) {
return exprs;
static llvm::SmallVector<Stmt, 8> makeStmts(edsc_stmt_list_t enclosedStmts) {
llvm::SmallVector<Stmt, 8> stmts;
for (unsigned i = 0; i < enclosedStmts.n; ++i) {
return stmts;
Expr mlir::edsc::alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) {
return VariadicExpr(ExprKind::Alloc, sizes, memrefType);
Stmt mlir::edsc::Block(ArrayRef<Stmt> stmts) {
return Stmt(StmtBlockLikeExpr(ExprKind::Block, {}), stmts);
edsc_stmt_t Block(edsc_stmt_list_t enclosedStmts) {
return Stmt(mlir::edsc::Block(makeStmts(enclosedStmts)));
Expr mlir::edsc::dealloc(Expr memref) {
return UnaryExpr(ExprKind::Dealloc, memref);
Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
Bindable idx;
return For(idx, lb, ub, step, stmts);
Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step,
ArrayRef<Stmt> stmts) {
return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, {lb, ub, step}), stmts);
Stmt mlir::edsc::For(MutableArrayRef<Bindable> indices, ArrayRef<Expr> lbs,
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
ArrayRef<Stmt> enclosedStmts) {
assert(indices.size() == lbs.size());
assert(indices.size() == ubs.size());
assert(indices.size() == steps.size());
Stmt curStmt =
For(indices.back(), lbs.back(), ubs.back(), steps.back(), enclosedStmts);
for (int64_t i = indices.size() - 2; i >= 0; --i) {
curStmt = For(indices[i], lbs[i], ubs[i], steps[i], {curStmt});
return curStmt;
Stmt mlir::edsc::For(llvm::MutableArrayRef<Bindable> indices,
llvm::ArrayRef<Bindable> lbs, llvm::ArrayRef<Bindable> ubs,
llvm::ArrayRef<Bindable> steps,
llvm::ArrayRef<Stmt> enclosedStmts) {
return For(indices, SmallVector<Expr, 8>{lbs.begin(), lbs.end()},
SmallVector<Expr, 8>{ubs.begin(), ubs.end()},
SmallVector<Expr, 8>{steps.begin(), steps.end()}, enclosedStmts);
Stmt mlir::edsc::For(llvm::MutableArrayRef<Bindable> indices,
llvm::ArrayRef<Bindable> lbs, llvm::ArrayRef<Expr> ubs,
llvm::ArrayRef<Bindable> steps,
llvm::ArrayRef<Stmt> enclosedStmts) {
return For(indices, SmallVector<Expr, 8>{lbs.begin(), lbs.end()}, ubs,
SmallVector<Expr, 8>{steps.begin(), steps.end()}, enclosedStmts);
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
return Stmt(For(Expr(iv).cast<Bindable>(), Expr(lb), Expr(ub), Expr(step),
edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs,
edsc_expr_list_t ubs, edsc_expr_list_t steps,
edsc_stmt_list_t enclosedStmts) {
auto bindables = makeBindables(ivs);
return Stmt(For(bindables, makeExprs(lbs), makeExprs(ubs), makeExprs(steps),
template <typename BindableOrExpr>
static Expr loadBuilder(Expr m, ArrayRef<BindableOrExpr> indices) {
SmallVector<Expr, 8> exprs;
exprs.append(indices.begin(), indices.end());
return VariadicExpr(ExprKind::Load, exprs);
Expr mlir::edsc::load(Expr m, Expr index) {
return loadBuilder<Expr>(m, {index});
Expr mlir::edsc::load(Expr m, Bindable index) {
return loadBuilder<Bindable>(m, {index});
Expr mlir::edsc::load(Expr m, const llvm::SmallVectorImpl<Bindable> &indices) {
return loadBuilder(m, ArrayRef<Bindable>{indices.begin(), indices.end()});
Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) {
return loadBuilder(m, indices);
edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices) {
Indexed i(Expr(indexed.base).cast<Bindable>());
Expr res = i[makeExprs(indices)];
return res;
template <typename BindableOrExpr>
static Expr storeBuilder(Expr val, Expr m, ArrayRef<BindableOrExpr> indices) {
SmallVector<Expr, 8> exprs;
exprs.append(indices.begin(), indices.end());
return VariadicExpr(ExprKind::Store, exprs);
Expr mlir::edsc::store(Expr val, Expr m, Expr index) {
return storeBuilder<Expr>(val, m, {index});
Expr mlir::edsc::store(Expr val, Expr m, Bindable index) {
return storeBuilder<Bindable>(val, m, {index});
Expr mlir::edsc::store(Expr val, Expr m,
const llvm::SmallVectorImpl<Bindable> &indices) {
return storeBuilder(val, m,
ArrayRef<Bindable>{indices.begin(), indices.end()});
Expr mlir::edsc::store(Expr val, Expr m, ArrayRef<Expr> indices) {
return storeBuilder(val, m, indices);
edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
edsc_expr_list_t indices) {
Indexed i(Expr(indexed.base).cast<Bindable>());
Indexed loc = i[makeExprs(indices)];
return Stmt(loc = Expr(value));
Expr mlir::edsc::select(Expr cond, Expr lhs, Expr rhs) {
return TernaryExpr(ExprKind::Select, cond, lhs, rhs);
edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs) {
return select(Expr(cond), Expr(lhs), Expr(rhs));
Expr mlir::edsc::vector_type_cast(Expr memrefExpr, Type memrefType) {
return VariadicExpr(ExprKind::VectorTypeCast, {memrefExpr}, {memrefType});
Stmt mlir::edsc::Return(ArrayRef<Expr> values) {
return VariadicExpr(ExprKind::Return, values);
edsc_stmt_t Return(edsc_expr_list_t values) {
return Stmt(Return(makeExprs(values)));
void mlir::edsc::Expr::print(raw_ostream &os) const {
if (auto unbound = this->dyn_cast<Bindable>()) {
os << "$" << unbound.getId();
} else if (auto un = this->dyn_cast<UnaryExpr>()) {
os << "unknown_unary";
} else if (auto bin = this->dyn_cast<BinaryExpr>()) {
os << "(" << bin.getLHS();
switch (bin.getKind()) {
case ExprKind::Add:
os << " + ";
case ExprKind::Sub:
os << " - ";
case ExprKind::Mul:
os << " * ";
case ExprKind::Div:
os << " / ";
case ExprKind::LT:
os << " < ";
case ExprKind::LE:
os << " <= ";
case ExprKind::GT:
os << " > ";
case ExprKind::GE:
os << " >= ";
default: {
os << "unknown_binary";
os << bin.getRHS() << ")";
} else if (auto ter = this->dyn_cast<TernaryExpr>()) {
switch (ter.getKind()) {
case ExprKind::Select:
os << "select(" << ter.getCond() << ", " << ter.getLHS() << ", "
<< ter.getRHS() << ")";
default: {
os << "unknown_ternary";
} else if (auto nar = this->dyn_cast<VariadicExpr>()) {
switch (nar.getKind()) {
case ExprKind::Load:
os << "load( ... )";
case ExprKind::Store:
os << "store( ... )";
case ExprKind::Return:
interleaveComma(nar.getExprs(), os);
default: {
os << "unknown_variadic";
} else if (auto stmtLikeExpr = this->dyn_cast<StmtBlockLikeExpr>()) {
auto exprs = stmtLikeExpr.getExprs();
switch (stmtLikeExpr.getKind()) {
// We only print the lb, ub and step here, which are the StmtBlockLike
// part of the `for` StmtBlockLikeExpr.
case ExprKind::For:
assert(exprs.size() == 3 && "For StmtBlockLikeExpr expected 3 exprs");
os << exprs[0] << " to " << exprs[1] << " step " << exprs[2];
default: {
os << "unknown_stmt";
os << "unknown_kind(" << static_cast<int>(getKind()) << ")";
void mlir::edsc::Expr::dump() const { this->print(llvm::errs()); }
std::string mlir::edsc::Expr::str() const {
std::string res;
llvm::raw_string_ostream os(res);
return res;
llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
const Expr &expr) {
return os;
: Expr(Expr::globalAllocator()->Allocate<detail::BindableStorage>()) {
// Initialize with placement new.
new (storage) detail::BindableStorage{Bindable::newId()};
edsc_expr_t makeBindable() { return Bindable(); }
unsigned mlir::edsc::Bindable::getId() const {
return static_cast<ImplType *>(storage)->id;
unsigned &mlir::edsc::Bindable::newId() {
static thread_local unsigned id = 0;
return ++id;
mlir::edsc::UnaryExpr::UnaryExpr(ExprKind kind, Expr expr)
: Expr(Expr::globalAllocator()->Allocate<detail::UnaryExprStorage>()) {
// Initialize with placement new.
new (storage) detail::UnaryExprStorage{kind, expr};
Expr mlir::edsc::UnaryExpr::getExpr() const {
return static_cast<ImplType *>(storage)->expr;
mlir::edsc::BinaryExpr::BinaryExpr(ExprKind kind, Expr lhs, Expr rhs)
: Expr(Expr::globalAllocator()->Allocate<detail::BinaryExprStorage>()) {
// Initialize with placement new.
new (storage) detail::BinaryExprStorage{kind, lhs, rhs};
Expr mlir::edsc::BinaryExpr::getLHS() const {
return static_cast<ImplType *>(storage)->lhs;
Expr mlir::edsc::BinaryExpr::getRHS() const {
return static_cast<ImplType *>(storage)->rhs;
mlir::edsc::TernaryExpr::TernaryExpr(ExprKind kind, Expr cond, Expr lhs,
Expr rhs)
: Expr(Expr::globalAllocator()->Allocate<detail::TernaryExprStorage>()) {
// Initialize with placement new.
new (storage) detail::TernaryExprStorage{kind, cond, lhs, rhs};
Expr mlir::edsc::TernaryExpr::getCond() const {
return static_cast<ImplType *>(storage)->cond;
Expr mlir::edsc::TernaryExpr::getLHS() const {
return static_cast<ImplType *>(storage)->lhs;
Expr mlir::edsc::TernaryExpr::getRHS() const {
return static_cast<ImplType *>(storage)->rhs;
mlir::edsc::VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef<Expr> exprs,
ArrayRef<Type> types)
: Expr(Expr::globalAllocator()->Allocate<detail::VariadicExprStorage>()) {
// Initialize with placement new.
auto exprStorage = Expr::globalAllocator()->Allocate<Expr>(exprs.size());
std::uninitialized_copy(exprs.begin(), exprs.end(), exprStorage);
auto typeStorage = Expr::globalAllocator()->Allocate<Type>(types.size());
std::uninitialized_copy(types.begin(), types.end(), typeStorage);
new (storage) detail::VariadicExprStorage{
kind, ArrayRef<Expr>(exprStorage, exprs.size()),
ArrayRef<Type>(typeStorage, types.size())};
ArrayRef<Expr> mlir::edsc::VariadicExpr::getExprs() const {
return static_cast<ImplType *>(storage)->exprs;
ArrayRef<Type> mlir::edsc::VariadicExpr::getTypes() const {
return static_cast<ImplType *>(storage)->types;
mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind,
ArrayRef<Expr> exprs,
ArrayRef<Type> types)
: Expr(Expr::globalAllocator()->Allocate<detail::VariadicExprStorage>()) {
// Initialize with placement new.
auto exprStorage = Expr::globalAllocator()->Allocate<Expr>(exprs.size());
std::uninitialized_copy(exprs.begin(), exprs.end(), exprStorage);
auto typeStorage = Expr::globalAllocator()->Allocate<Type>(types.size());
std::uninitialized_copy(types.begin(), types.end(), typeStorage);
new (storage) detail::VariadicExprStorage{
kind, ArrayRef<Expr>(exprStorage, exprs.size()),
ArrayRef<Type>(typeStorage, types.size())};
ArrayRef<Expr> mlir::edsc::StmtBlockLikeExpr::getExprs() const {
return static_cast<ImplType *>(storage)->exprs;
ArrayRef<Type> mlir::edsc::StmtBlockLikeExpr::getTypes() const {
return static_cast<ImplType *>(storage)->types;
mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
llvm::ArrayRef<Stmt> enclosedStmts) {
storage = Expr::globalAllocator()->Allocate<detail::StmtStorage>();
// Initialize with placement new.
auto enclosedStmtStorage =
std::uninitialized_copy(enclosedStmts.begin(), enclosedStmts.end(),
new (storage) detail::StmtStorage{
lhs, rhs, ArrayRef<Stmt>(enclosedStmtStorage, enclosedStmts.size())};
mlir::edsc::Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts)
: Stmt(Bindable(), rhs, enclosedStmts) {}
edsc_stmt_t makeStmt(edsc_expr_t e) {
assert(e && "unexpected empty expression");
return Stmt(Expr(e));
Stmt &mlir::edsc::Stmt::operator=(const Expr &expr) {
Stmt res(Bindable(), expr, {});
std::swap(, this->storage);
return *this;
Bindable mlir::edsc::Stmt::getLHS() const {
return static_cast<ImplType *>(storage)->lhs;
Expr mlir::edsc::Stmt::getRHS() const {
return static_cast<ImplType *>(storage)->rhs;
llvm::ArrayRef<Stmt> mlir::edsc::Stmt::getEnclosedStmts() const {
return storage->enclosedStmts;
void mlir::edsc::Stmt::print(raw_ostream &os, Twine indent) const {
assert(storage && "Unexpected null storage,stmt must be bound to print");
auto lhs = getLHS();
auto rhs = getRHS();
if (auto stmtExpr = rhs.dyn_cast<StmtBlockLikeExpr>()) {
switch (stmtExpr.getKind()) {
case ExprKind::For:
os << indent << "for(" << lhs << " = " << stmtExpr << ") {";
os << "\n";
for (const auto &s : getEnclosedStmts()) {
if (!s.getRHS().isa<StmtBlockLikeExpr>()) {
os << indent << " ";
s.print(os, indent + " ");
os << ";\n";
os << indent << "}";
case ExprKind::Block:
os << indent << "block {";
for (auto &s : getEnclosedStmts()) {
os << "\n";
s.print(os, indent + " ");
os << "\n" << indent << "}";
default: {
// TODO(ntv): print more statement cases.
os << "TODO";
} else {
os << lhs << " = " << rhs;
void mlir::edsc::Stmt::dump() const { this->print(llvm::errs()); }
std::string mlir::edsc::Stmt::str() const {
std::string res;
llvm::raw_string_ostream os(res);
return res;
llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
const Stmt &stmt) {
return os;
Indexed mlir::edsc::Indexed::operator[](llvm::ArrayRef<Expr> indices) const {
Indexed res(base);
res.indices = llvm::SmallVector<Expr, 4>(indices.begin(), indices.end());
return res;
Indexed mlir::edsc::Indexed::
operator[](llvm::ArrayRef<Bindable> indices) const {
return (*this)[llvm::ArrayRef<Expr>{indices.begin(), indices.end()}];
// clang-format off
Stmt mlir::edsc::Indexed::operator=(Expr expr) { // NOLINT: unconventional-assign-operator
// clang-format on
assert(!indices.empty() && "Expected attached indices to Indexed");
Stmt stmt(store(expr, base, indices));
return stmt;
edsc_indexed_t makeIndexed(edsc_expr_t expr) {
return edsc_indexed_t{expr, edsc_expr_list_t{nullptr, 0}};
edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices) {
return edsc_indexed_t{indexed.base, indices};
mlir_type_t makeScalarType(mlir_context_t context, const char *name,
unsigned bitwidth) {
mlir::MLIRContext *c = reinterpret_cast<mlir::MLIRContext *>(context);
if (llvm::StringRef(name) == "bf16") {
return mlir_type_t{mlir::Type::getBF16(c).getAsOpaquePointer()};
if (llvm::StringRef(name) == "f16") {
return mlir_type_t{mlir::Type::getF16(c).getAsOpaquePointer()};
if (llvm::StringRef(name) == "f32") {
return mlir_type_t{mlir::Type::getF32(c).getAsOpaquePointer()};
if (llvm::StringRef(name) == "f64") {
return mlir_type_t{mlir::Type::getF64(c).getAsOpaquePointer()};
if (llvm::StringRef(name) == "index") {
return mlir_type_t{mlir::Type::getIndex(c).getAsOpaquePointer()};
if (llvm::StringRef(name) == "i") {
return mlir_type_t{
mlir::Type::getInteger(bitwidth, c).getAsOpaquePointer()};
assert(false && "unknown scalar type");
return mlir_type_t{nullptr};
mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,
int64_list_t sizes) {
auto t = mlir::MemRefType::get(
llvm::ArrayRef<int64_t>(sizes.values, sizes.n),
sizes.n, reinterpret_cast<mlir::MLIRContext *>(context))},
return mlir_type_t{t.getAsOpaquePointer()};
mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
mlir_type_list_t outputs) {
llvm::SmallVector<mlir::Type, 8> ins(inputs.n), outs(outputs.n);
for (unsigned i = 0; i < inputs.n; ++i) {
ins[i] = mlir::Type::getFromOpaquePointer(inputs.types[i]);
for (unsigned i = 0; i < outputs.n; ++i) {
ins[i] = mlir::Type::getFromOpaquePointer(outputs.types[i]);
auto ft = mlir::FunctionType::get(
ins, outs, reinterpret_cast<mlir::MLIRContext *>(context));
return mlir_type_t{ft.getAsOpaquePointer()};
unsigned getFunctionArity(mlir_func_t function) {
auto *f = reinterpret_cast<mlir::Function *>(function);
return f->getNumArguments();