| //===- Types.h - MLIR EDSC Type System Implementation -----------*- C++ -*-===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| |
| #include "mlir/EDSC/Types.h" |
| #include "mlir-c/Core.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/Function.h" |
| #include "mlir/IR/StandardTypes.h" |
| #include "mlir/Support/STLExtras.h" |
| |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/StringSwitch.h" |
| #include "llvm/ADT/Twine.h" |
| #include "llvm/Support/raw_ostream.h" |
| |
| using llvm::errs; |
| using llvm::Twine; |
| |
| using namespace mlir; |
| using namespace mlir::edsc; |
| using namespace mlir::edsc::detail; |
| |
| namespace mlir { |
| namespace edsc { |
| namespace detail { |
| |
| struct ExprStorage { |
| // Note: this structure is similar to OperationState, but stores lists in a |
| // EDSC bump allocator. |
| ExprKind kind; |
| unsigned id; |
| |
| ArrayRef<Expr> operands; |
| ArrayRef<Type> resultTypes; |
| ArrayRef<NamedAttribute> attributes; |
| |
| ExprStorage(ExprKind kind, ArrayRef<Type> results, ArrayRef<Expr> children, |
| ArrayRef<NamedAttribute> attrs, unsigned exprId = Expr::newId()) |
| : kind(kind), id(exprId) { |
| if (!children.empty()) { |
| auto exprStorage = |
| Expr::globalAllocator()->Allocate<Expr>(children.size()); |
| std::uninitialized_copy(children.begin(), children.end(), exprStorage); |
| operands = llvm::makeArrayRef(exprStorage, children.size()); |
| } else { |
| operands = ArrayRef<Expr>(); |
| } |
| |
| if (!results.empty()) { |
| auto typeStorage = |
| Expr::globalAllocator()->Allocate<Type>(results.size()); |
| std::uninitialized_copy(results.begin(), results.end(), typeStorage); |
| resultTypes = llvm::makeArrayRef(typeStorage, results.size()); |
| } else { |
| resultTypes = ArrayRef<Type>(); |
| } |
| |
| if (!attrs.empty()) { |
| auto attrStorage = |
| Expr::globalAllocator()->Allocate<NamedAttribute>(attrs.size()); |
| std::uninitialized_copy(attrs.begin(), attrs.end(), attrStorage); |
| attributes = llvm::makeArrayRef(attrStorage, attrs.size()); |
| } else { |
| attributes = ArrayRef<NamedAttribute>(); |
| } |
| } |
| }; |
| |
| struct StmtStorage { |
| StmtStorage(Bindable lhs, Expr rhs, llvm::ArrayRef<Stmt> enclosedStmts) |
| : lhs(lhs), rhs(rhs), enclosedStmts(enclosedStmts) {} |
| Bindable lhs; |
| Expr rhs; |
| ArrayRef<Stmt> enclosedStmts; |
| }; |
| |
| } // namespace detail |
| } // namespace edsc |
| } // namespace mlir |
| |
| mlir::edsc::ScopedEDSCContext::ScopedEDSCContext() { |
| Expr::globalAllocator() = &allocator; |
| Bindable::resetIds(); |
| } |
| |
| mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() { |
| Expr::globalAllocator() = nullptr; |
| } |
| |
| mlir::edsc::Expr::Expr() { |
| // Initialize with placement new. |
| storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>(); |
| new (storage) detail::ExprStorage(ExprKind::Unbound, {}, {}, {}); |
| } |
| |
| ExprKind mlir::edsc::Expr::getKind() const { return storage->kind; } |
| |
| unsigned mlir::edsc::Expr::getId() const { |
| return static_cast<ImplType *>(storage)->id; |
| } |
| |
| unsigned &mlir::edsc::Expr::newId() { |
| static thread_local unsigned id = 0; |
| return ++id; |
| } |
| |
| Expr mlir::edsc::op::operator+(Expr lhs, Expr rhs) { |
| return BinaryExpr(ExprKind::Add, lhs, rhs); |
| } |
| Expr mlir::edsc::op::operator-(Expr lhs, Expr rhs) { |
| return BinaryExpr(ExprKind::Sub, lhs, rhs); |
| } |
| Expr mlir::edsc::op::operator*(Expr lhs, Expr rhs) { |
| return BinaryExpr(ExprKind::Mul, lhs, rhs); |
| } |
| |
| Expr mlir::edsc::op::operator==(Expr lhs, Expr rhs) { |
| return BinaryExpr(ExprKind::EQ, lhs, rhs); |
| } |
| Expr mlir::edsc::op::operator!=(Expr lhs, Expr rhs) { |
| return BinaryExpr(ExprKind::NE, lhs, rhs); |
| } |
| Expr mlir::edsc::op::operator<(Expr lhs, Expr rhs) { |
| return BinaryExpr(ExprKind::LT, lhs, rhs); |
| } |
| Expr mlir::edsc::op::operator<=(Expr lhs, Expr rhs) { |
| return BinaryExpr(ExprKind::LE, lhs, rhs); |
| } |
| Expr mlir::edsc::op::operator>(Expr lhs, Expr rhs) { |
| return BinaryExpr(ExprKind::GT, lhs, rhs); |
| } |
| Expr mlir::edsc::op::operator>=(Expr lhs, Expr rhs) { |
| return BinaryExpr(ExprKind::GE, lhs, rhs); |
| } |
| Expr mlir::edsc::op::operator&&(Expr lhs, Expr rhs) { |
| return BinaryExpr(ExprKind::And, lhs, rhs); |
| } |
| Expr mlir::edsc::op::operator||(Expr lhs, Expr rhs) { |
| return BinaryExpr(ExprKind::Or, lhs, rhs); |
| } |
| Expr mlir::edsc::op::operator!(Expr expr) { |
| return UnaryExpr(ExprKind::Negate, expr); |
| } |
| |
| llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n) { |
| llvm::SmallVector<Expr, 8> res; |
| res.reserve(n); |
| for (auto i = 0; i < n; ++i) { |
| res.push_back(Expr()); |
| } |
| return res; |
| } |
| |
| static llvm::SmallVector<Expr, 8> makeExprs(edsc_expr_list_t exprList) { |
| llvm::SmallVector<Expr, 8> exprs; |
| exprs.reserve(exprList.n); |
| for (unsigned i = 0; i < exprList.n; ++i) { |
| exprs.push_back(Expr(exprList.exprs[i])); |
| } |
| return exprs; |
| } |
| |
| static void fillStmts(edsc_stmt_list_t enclosedStmts, |
| llvm::SmallVector<Stmt, 8> *stmts) { |
| stmts->reserve(enclosedStmts.n); |
| for (unsigned i = 0; i < enclosedStmts.n; ++i) { |
| stmts->push_back(Stmt(enclosedStmts.stmts[i])); |
| } |
| } |
| |
| Expr mlir::edsc::alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) { |
| return VariadicExpr(ExprKind::Alloc, sizes, memrefType); |
| } |
| |
| Stmt mlir::edsc::StmtList(ArrayRef<Stmt> stmts) { |
| return Stmt(StmtBlockLikeExpr(ExprKind::StmtList, {}), stmts); |
| } |
| |
| edsc_stmt_t StmtList(edsc_stmt_list_t enclosedStmts) { |
| llvm::SmallVector<Stmt, 8> stmts; |
| fillStmts(enclosedStmts, &stmts); |
| return Stmt(mlir::edsc::StmtList(stmts)); |
| } |
| |
| Expr mlir::edsc::dealloc(Expr memref) { |
| return UnaryExpr(ExprKind::Dealloc, memref); |
| } |
| |
| Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) { |
| Expr idx; |
| return For(Bindable(idx), lb, ub, step, stmts); |
| } |
| |
| Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step, |
| ArrayRef<Stmt> stmts) { |
| return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, {lb, ub, step}), stmts); |
| } |
| |
| Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs, |
| ArrayRef<Expr> ubs, ArrayRef<Expr> steps, |
| ArrayRef<Stmt> enclosedStmts) { |
| assert(!indices.empty()); |
| assert(indices.size() == lbs.size()); |
| assert(indices.size() == ubs.size()); |
| assert(indices.size() == steps.size()); |
| Expr iv = indices.back(); |
| Stmt curStmt = |
| For(Bindable(iv), lbs.back(), ubs.back(), steps.back(), enclosedStmts); |
| for (int64_t i = indices.size() - 2; i >= 0; --i) { |
| Expr iiv = indices[i]; |
| curStmt.set(For(Bindable(iiv), lbs[i], ubs[i], steps[i], |
| llvm::ArrayRef<Stmt>{&curStmt, 1})); |
| } |
| return curStmt; |
| } |
| |
| edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub, |
| edsc_expr_t step, edsc_stmt_list_t enclosedStmts) { |
| llvm::SmallVector<Stmt, 8> stmts; |
| fillStmts(enclosedStmts, &stmts); |
| return Stmt( |
| For(Expr(iv).cast<Bindable>(), Expr(lb), Expr(ub), Expr(step), stmts)); |
| } |
| |
| edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs, |
| edsc_expr_list_t ubs, edsc_expr_list_t steps, |
| edsc_stmt_list_t enclosedStmts) { |
| llvm::SmallVector<Stmt, 8> stmts; |
| fillStmts(enclosedStmts, &stmts); |
| return Stmt(For(makeExprs(ivs), makeExprs(lbs), makeExprs(ubs), |
| makeExprs(steps), stmts)); |
| } |
| |
| Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) { |
| SmallVector<Expr, 8> exprs; |
| exprs.push_back(m); |
| exprs.append(indices.begin(), indices.end()); |
| return VariadicExpr(ExprKind::Load, exprs); |
| } |
| |
| edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices) { |
| Indexed i(Expr(indexed.base).cast<Bindable>()); |
| auto exprs = makeExprs(indices); |
| Expr res = i(exprs); |
| return res; |
| } |
| |
| Expr mlir::edsc::store(Expr val, Expr m, ArrayRef<Expr> indices) { |
| SmallVector<Expr, 8> exprs; |
| exprs.push_back(val); |
| exprs.push_back(m); |
| exprs.append(indices.begin(), indices.end()); |
| return VariadicExpr(ExprKind::Store, exprs); |
| } |
| |
| edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed, |
| edsc_expr_list_t indices) { |
| Indexed i(Expr(indexed.base).cast<Bindable>()); |
| auto exprs = makeExprs(indices); |
| Indexed loc = i(exprs); |
| return Stmt(loc = Expr(value)); |
| } |
| |
| Expr mlir::edsc::select(Expr cond, Expr lhs, Expr rhs) { |
| return TernaryExpr(ExprKind::Select, cond, lhs, rhs); |
| } |
| |
| edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs) { |
| return select(Expr(cond), Expr(lhs), Expr(rhs)); |
| } |
| |
| Expr mlir::edsc::vector_type_cast(Expr memrefExpr, Type memrefType) { |
| return VariadicExpr(ExprKind::VectorTypeCast, {memrefExpr}, {memrefType}); |
| } |
| |
| Stmt mlir::edsc::Return(ArrayRef<Expr> values) { |
| return VariadicExpr(ExprKind::Return, values); |
| } |
| |
| edsc_stmt_t Return(edsc_expr_list_t values) { |
| return Stmt(Return(makeExprs(values))); |
| } |
| |
| void mlir::edsc::Expr::print(raw_ostream &os) const { |
| if (auto unbound = this->dyn_cast<Bindable>()) { |
| os << "$" << unbound.getId(); |
| return; |
| } else if (auto un = this->dyn_cast<UnaryExpr>()) { |
| switch (un.getKind()) { |
| case ExprKind::Negate: |
| os << "~"; |
| break; |
| default: { |
| os << "unknown_unary"; |
| } |
| } |
| os << un.getExpr(); |
| } else if (auto bin = this->dyn_cast<BinaryExpr>()) { |
| os << "(" << bin.getLHS(); |
| switch (bin.getKind()) { |
| case ExprKind::Add: |
| os << " + "; |
| break; |
| case ExprKind::Sub: |
| os << " - "; |
| break; |
| case ExprKind::Mul: |
| os << " * "; |
| break; |
| case ExprKind::Div: |
| os << " / "; |
| break; |
| case ExprKind::LT: |
| os << " < "; |
| break; |
| case ExprKind::LE: |
| os << " <= "; |
| break; |
| case ExprKind::GT: |
| os << " > "; |
| break; |
| case ExprKind::GE: |
| os << " >= "; |
| break; |
| case ExprKind::EQ: |
| os << " == "; |
| break; |
| case ExprKind::NE: |
| os << " != "; |
| break; |
| case ExprKind::And: |
| os << " && "; |
| break; |
| case ExprKind::Or: |
| os << " || "; |
| break; |
| default: { |
| os << "unknown_binary"; |
| } |
| } |
| os << bin.getRHS() << ")"; |
| return; |
| } else if (auto ter = this->dyn_cast<TernaryExpr>()) { |
| switch (ter.getKind()) { |
| case ExprKind::Select: |
| os << "select(" << ter.getCond() << ", " << ter.getLHS() << ", " |
| << ter.getRHS() << ")"; |
| return; |
| default: { |
| os << "unknown_ternary"; |
| } |
| } |
| } else if (auto nar = this->dyn_cast<VariadicExpr>()) { |
| auto exprs = nar.getExprs(); |
| switch (nar.getKind()) { |
| case ExprKind::Load: |
| os << "load(" << exprs[0] << "["; |
| interleaveComma(ArrayRef<Expr>(exprs.begin() + 1, exprs.size() - 1), os); |
| os << "])"; |
| return; |
| case ExprKind::Store: |
| os << "store(" << exprs[0] << ", " << exprs[1] << "["; |
| interleaveComma(ArrayRef<Expr>(exprs.begin() + 2, exprs.size() - 2), os); |
| os << "])"; |
| return; |
| case ExprKind::Return: |
| interleaveComma(exprs, os); |
| return; |
| default: { |
| os << "unknown_variadic"; |
| } |
| } |
| } else if (auto stmtLikeExpr = this->dyn_cast<StmtBlockLikeExpr>()) { |
| auto exprs = stmtLikeExpr.getExprs(); |
| switch (stmtLikeExpr.getKind()) { |
| // We only print the lb, ub and step here, which are the StmtBlockLike |
| // part of the `for` StmtBlockLikeExpr. |
| case ExprKind::For: |
| assert(exprs.size() == 3 && "For StmtBlockLikeExpr expected 3 exprs"); |
| os << exprs[0] << " to " << exprs[1] << " step " << exprs[2]; |
| return; |
| default: { |
| os << "unknown_stmt"; |
| } |
| } |
| } |
| os << "unknown_kind(" << static_cast<int>(getKind()) << ")"; |
| } |
| |
| void mlir::edsc::Expr::dump() const { this->print(llvm::errs()); } |
| |
| std::string mlir::edsc::Expr::str() const { |
| std::string res; |
| llvm::raw_string_ostream os(res); |
| this->print(os); |
| return res; |
| } |
| |
| llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os, |
| const Expr &expr) { |
| expr.print(os); |
| return os; |
| } |
| |
| edsc_expr_t makeBindable() { return Bindable(Expr()); } |
| |
| mlir::edsc::UnaryExpr::UnaryExpr(ExprKind kind, Expr expr) |
| : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) { |
| // Initialize with placement new. |
| new (storage) detail::ExprStorage(kind, {}, {expr}, {}); |
| } |
| Expr mlir::edsc::UnaryExpr::getExpr() const { |
| return static_cast<ImplType *>(storage)->operands.front(); |
| } |
| |
| mlir::edsc::BinaryExpr::BinaryExpr(ExprKind kind, Expr lhs, Expr rhs) |
| : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) { |
| // Initialize with placement new. |
| new (storage) detail::ExprStorage(kind, {}, {lhs, rhs}, {}); |
| } |
| Expr mlir::edsc::BinaryExpr::getLHS() const { |
| return static_cast<ImplType *>(storage)->operands.front(); |
| } |
| Expr mlir::edsc::BinaryExpr::getRHS() const { |
| return static_cast<ImplType *>(storage)->operands.back(); |
| } |
| |
| mlir::edsc::TernaryExpr::TernaryExpr(ExprKind kind, Expr cond, Expr lhs, |
| Expr rhs) |
| : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) { |
| // Initialize with placement new. |
| new (storage) detail::ExprStorage(kind, {}, {cond, lhs, rhs}, {}); |
| } |
| Expr mlir::edsc::TernaryExpr::getCond() const { |
| return static_cast<ImplType *>(storage)->operands[0]; |
| } |
| Expr mlir::edsc::TernaryExpr::getLHS() const { |
| return static_cast<ImplType *>(storage)->operands[1]; |
| } |
| Expr mlir::edsc::TernaryExpr::getRHS() const { |
| return static_cast<ImplType *>(storage)->operands[2]; |
| } |
| |
| mlir::edsc::VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef<Expr> exprs, |
| ArrayRef<Type> types) |
| : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) { |
| // Initialize with placement new. |
| new (storage) detail::ExprStorage(kind, types, exprs, {}); |
| } |
| ArrayRef<Expr> mlir::edsc::VariadicExpr::getExprs() const { |
| return static_cast<ImplType *>(storage)->operands; |
| } |
| ArrayRef<Type> mlir::edsc::VariadicExpr::getTypes() const { |
| return static_cast<ImplType *>(storage)->resultTypes; |
| } |
| |
| mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, |
| ArrayRef<Expr> exprs, |
| ArrayRef<Type> types) |
| : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) { |
| // Initialize with placement new. |
| new (storage) detail::ExprStorage(kind, types, exprs, {}); |
| } |
| ArrayRef<Expr> mlir::edsc::StmtBlockLikeExpr::getExprs() const { |
| return static_cast<ImplType *>(storage)->operands; |
| } |
| |
| mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs, |
| llvm::ArrayRef<Stmt> enclosedStmts) { |
| storage = Expr::globalAllocator()->Allocate<detail::StmtStorage>(); |
| // Initialize with placement new. |
| auto enclosedStmtStorage = |
| Expr::globalAllocator()->Allocate<Stmt>(enclosedStmts.size()); |
| std::uninitialized_copy(enclosedStmts.begin(), enclosedStmts.end(), |
| enclosedStmtStorage); |
| new (storage) detail::StmtStorage{ |
| lhs, rhs, ArrayRef<Stmt>(enclosedStmtStorage, enclosedStmts.size())}; |
| } |
| |
| mlir::edsc::Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts) |
| : Stmt(Bindable(Expr()), rhs, enclosedStmts) {} |
| |
| edsc_stmt_t makeStmt(edsc_expr_t e) { |
| assert(e && "unexpected empty expression"); |
| return Stmt(Expr(e)); |
| } |
| |
| Stmt &mlir::edsc::Stmt::operator=(const Expr &expr) { |
| Stmt res(Bindable(Expr()), expr, {}); |
| std::swap(res.storage, this->storage); |
| return *this; |
| } |
| |
| Expr mlir::edsc::Stmt::getLHS() const { |
| return static_cast<ImplType *>(storage)->lhs; |
| } |
| |
| Expr mlir::edsc::Stmt::getRHS() const { |
| return static_cast<ImplType *>(storage)->rhs; |
| } |
| |
| llvm::ArrayRef<Stmt> mlir::edsc::Stmt::getEnclosedStmts() const { |
| return storage->enclosedStmts; |
| } |
| |
| void mlir::edsc::Stmt::print(raw_ostream &os, Twine indent) const { |
| if (!storage) { |
| os << "null_storage"; |
| return; |
| } |
| auto lhs = getLHS(); |
| auto rhs = getRHS(); |
| |
| if (auto stmtExpr = rhs.dyn_cast<StmtBlockLikeExpr>()) { |
| switch (stmtExpr.getKind()) { |
| case ExprKind::For: |
| os << indent << "for(" << lhs << " = " << stmtExpr << ") {"; |
| os << "\n"; |
| for (const auto &s : getEnclosedStmts()) { |
| if (!s.getRHS().isa<StmtBlockLikeExpr>()) { |
| os << indent << " "; |
| } |
| s.print(os, indent + " "); |
| os << ";\n"; |
| } |
| os << indent << "}"; |
| return; |
| case ExprKind::StmtList: |
| os << indent << "stmt_list {"; |
| for (auto &s : getEnclosedStmts()) { |
| os << "\n"; |
| s.print(os, indent + " "); |
| } |
| os << "\n" << indent << "}"; |
| return; |
| default: { |
| // TODO(ntv): print more statement cases. |
| os << "TODO"; |
| } |
| } |
| } else { |
| os << lhs << " = " << rhs; |
| } |
| } |
| |
| void mlir::edsc::Stmt::dump() const { this->print(llvm::errs()); } |
| |
| std::string mlir::edsc::Stmt::str() const { |
| std::string res; |
| llvm::raw_string_ostream os(res); |
| this->print(os); |
| return res; |
| } |
| |
| llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os, |
| const Stmt &stmt) { |
| stmt.print(os); |
| return os; |
| } |
| |
| Indexed mlir::edsc::Indexed::operator()(llvm::ArrayRef<Expr> indices) { |
| Indexed res(base); |
| res.indices = llvm::SmallVector<Expr, 4>(indices.begin(), indices.end()); |
| return res; |
| } |
| |
| // NOLINTNEXTLINE: unconventional-assign-operator |
| Stmt mlir::edsc::Indexed::operator=(Expr expr) { |
| return Stmt(store(expr, base, indices)); |
| } |
| |
| edsc_indexed_t makeIndexed(edsc_expr_t expr) { |
| return edsc_indexed_t{expr, edsc_expr_list_t{nullptr, 0}}; |
| } |
| |
| edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices) { |
| return edsc_indexed_t{indexed.base, indices}; |
| } |
| |
| mlir_type_t makeScalarType(mlir_context_t context, const char *name, |
| unsigned bitwidth) { |
| mlir::MLIRContext *c = reinterpret_cast<mlir::MLIRContext *>(context); |
| mlir_type_t res = |
| llvm::StringSwitch<mlir_type_t>(name) |
| .Case("bf16", |
| mlir_type_t{mlir::FloatType::getBF16(c).getAsOpaquePointer()}) |
| .Case("f16", |
| mlir_type_t{mlir::FloatType::getF16(c).getAsOpaquePointer()}) |
| .Case("f32", |
| mlir_type_t{mlir::FloatType::getF32(c).getAsOpaquePointer()}) |
| .Case("f64", |
| mlir_type_t{mlir::FloatType::getF64(c).getAsOpaquePointer()}) |
| .Case("index", |
| mlir_type_t{mlir::IndexType::get(c).getAsOpaquePointer()}) |
| .Case("i", |
| mlir_type_t{ |
| mlir::IntegerType::get(bitwidth, c).getAsOpaquePointer()}) |
| .Default(mlir_type_t{nullptr}); |
| if (!res) { |
| llvm_unreachable("Invalid type specifier"); |
| } |
| return res; |
| } |
| |
| mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType, |
| int64_list_t sizes) { |
| auto t = mlir::MemRefType::get( |
| llvm::ArrayRef<int64_t>(sizes.values, sizes.n), |
| mlir::Type::getFromOpaquePointer(elemType), |
| {mlir::AffineMap::getMultiDimIdentityMap( |
| sizes.n, reinterpret_cast<mlir::MLIRContext *>(context))}, |
| 0); |
| return mlir_type_t{t.getAsOpaquePointer()}; |
| } |
| |
| mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs, |
| mlir_type_list_t outputs) { |
| llvm::SmallVector<mlir::Type, 8> ins(inputs.n), outs(outputs.n); |
| for (unsigned i = 0; i < inputs.n; ++i) { |
| ins[i] = mlir::Type::getFromOpaquePointer(inputs.types[i]); |
| } |
| for (unsigned i = 0; i < outputs.n; ++i) { |
| outs[i] = mlir::Type::getFromOpaquePointer(outputs.types[i]); |
| } |
| auto ft = mlir::FunctionType::get( |
| ins, outs, reinterpret_cast<mlir::MLIRContext *>(context)); |
| return mlir_type_t{ft.getAsOpaquePointer()}; |
| } |
| |
| unsigned getFunctionArity(mlir_func_t function) { |
| auto *f = reinterpret_cast<mlir::Function *>(function); |
| return f->getNumArguments(); |
| } |