Introduce python bindings for MLIR EDSCs
This CL also introduces a set of python bindings using pybind11. The bindings
are exercised using a `test_py2andpy3.py` test suite that works for both
python 2 and 3.
`test_py3.py` on the other hand uses the more idiomatic,
python 3 only "PEP 3132 -- Extended Iterable Unpacking" to implement a rank
and type-agnostic copy with transposition.
Because python assignment is by reference, we cannot easily make the
assignment operator use the same type of sugaring as in C++; i.e. the
following:
```cpp
Stmt block = edsc::Block({
For(ivs, zeros, shapeA, ones, {
C[ivs] = IA[ivs] + IB[ivs]
})});
```
has no equivalent in the native Python EDSCs at this time.
However, the sugaring can be built as a simple DSL in python and is left as
future work.
PiperOrigin-RevId: 231337667
diff --git a/bindings/python/pybind.cpp b/bindings/python/pybind.cpp
new file mode 100644
index 0000000..e99694b
--- /dev/null
+++ b/bindings/python/pybind.cpp
@@ -0,0 +1,559 @@
+#include "third_party/llvm/llvm/include/llvm/ADT/SmallVector.h"
+#include "third_party/llvm/llvm/include/llvm/ADT/StringRef.h"
+#include "third_party/llvm/llvm/include/llvm/IR/Module.h"
+#include "third_party/llvm/llvm/include/llvm/Support/TargetSelect.h"
+#include "third_party/llvm/llvm/include/llvm/Support/raw_ostream.h"
+#include <cstddef>
+
+#include "third_party/llvm/llvm/projects/google_mlir/include/mlir-c/Core.h"
+#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/MLIREmitter.h"
+#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Types.h"
+#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/ExecutionEngine/ExecutionEngine.h"
+#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/BuiltinOps.h"
+#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Module.h"
+#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass.h"
+#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Target/LLVMIR.h"
+#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Transforms/Passes.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/pytypes.h"
+#include "pybind11/stl.h"
+
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Types.h"
+
+static bool inited = [] {
+ llvm::InitializeNativeTarget();
+ llvm::InitializeNativeTargetAsmPrinter();
+ return true;
+}();
+
+namespace mlir {
+namespace edsc {
+namespace python {
+
+static std::vector<std::unique_ptr<mlir::Pass>> getDefaultPasses(
+ const std::vector<const mlir::PassInfo *> &mlirPassInfoList = {}) {
+ std::vector<std::unique_ptr<mlir::Pass>> passList;
+ passList.reserve(mlirPassInfoList.size() + 4);
+ // Run each of the passes that were selected.
+ for (const auto *passInfo : mlirPassInfoList) {
+ passList.emplace_back(passInfo->createPass());
+ }
+ // Append the extra passes for lowering to MLIR.
+ passList.emplace_back(mlir::createConstantFoldPass());
+ passList.emplace_back(mlir::createCSEPass());
+ passList.emplace_back(mlir::createCanonicalizerPass());
+ passList.emplace_back(mlir::createLowerAffinePass());
+ return passList;
+}
+
+// Run the passes sequentially on the given module.
+// Return `nullptr` immediately if any of the passes fails.
+static bool runPasses(const std::vector<std::unique_ptr<mlir::Pass>> &passes,
+ Module *module) {
+ for (const auto &pass : passes) {
+ mlir::PassResult result = pass->runOnModule(module);
+ if (result == mlir::PassResult::Failure || module->verify()) {
+ llvm::errs() << "Pass failed\n";
+ return true;
+ }
+ }
+ return false;
+}
+
+namespace py = pybind11;
+
+struct PythonBindable;
+struct PythonExpr;
+struct PythonStmt;
+
+struct PythonFunction {
+ PythonFunction() : function{nullptr} {}
+ PythonFunction(mlir_func_t f) : function{f} {}
+ PythonFunction(mlir::Function *f) : function{f} {}
+ operator mlir_func_t() { return function; }
+ std::string str() {
+ mlir::Function *f = reinterpret_cast<mlir::Function *>(function);
+ std::string res;
+ llvm::raw_string_ostream os(res);
+ f->print(os);
+ return res;
+ }
+ mlir_func_t function;
+};
+
+struct PythonType {
+ PythonType() : type{nullptr} {}
+ PythonType(mlir_type_t t) : type{t} {}
+ operator mlir_type_t() { return type; }
+ std::string str() {
+ mlir::Type f = mlir::Type::getFromOpaquePointer(type);
+ std::string res;
+ llvm::raw_string_ostream os(res);
+ f.print(os);
+ return res;
+ }
+ mlir_type_t type;
+};
+
+/// Trivial C++ wrappers make use of the EDSC C API.
+struct PythonMLIRModule {
+ PythonMLIRModule() : mlirContext(), module(new mlir::Module(&mlirContext)) {}
+
+ PythonType makeScalarType(const std::string &mlirElemType,
+ unsigned bitwidth) {
+ return ::makeScalarType(mlir_context_t{&mlirContext}, mlirElemType.c_str(),
+ bitwidth);
+ }
+ PythonType makeMemRefType(PythonType elemType, std::vector<int64_t> sizes) {
+ return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType,
+ int64_list_t{sizes.data(), sizes.size()});
+ }
+ PythonFunction makeFunction(const std::string &name,
+ std::vector<PythonType> &inputTypes,
+ std::vector<PythonType> &outputTypes) {
+ std::vector<mlir_type_t> ins(inputTypes.begin(), inputTypes.end());
+ std::vector<mlir_type_t> outs(outputTypes.begin(), outputTypes.end());
+ auto funcType = ::makeFunctionType(
+ mlir_context_t{&mlirContext}, mlir_type_list_t{ins.data(), ins.size()},
+ mlir_type_list_t{outs.data(), outs.size()});
+ auto *func = new mlir::Function(
+ UnknownLoc::get(&mlirContext), name,
+ mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>());
+ func->addEntryBlock();
+ module->getFunctions().push_back(func);
+ return mlir_func_t{func};
+ }
+
+ void compile() {
+ auto created = mlir::ExecutionEngine::create(module.get());
+ llvm::handleAllErrors(created.takeError(),
+ [](const llvm::ErrorInfoBase &b) {
+ b.log(llvm::errs());
+ assert(false);
+ });
+ engine = std::move(*created);
+ }
+
+ std::string getIR() {
+ std::string res;
+ llvm::raw_string_ostream os(res);
+ module->print(os);
+ return res;
+ }
+
+ uint64_t getEngineAddress() {
+ assert(engine && "module must be compiled into engine first");
+ return reinterpret_cast<uint64_t>(reinterpret_cast<void *>(engine.get()));
+ }
+
+private:
+ mlir::MLIRContext mlirContext;
+ // One single module in a python-exposed MLIRContext for now.
+ std::unique_ptr<mlir::Module> module;
+ std::unique_ptr<mlir::ExecutionEngine> engine;
+};
+
+struct ContextManager {
+ void enter() { context = new ScopedEDSCContext(); }
+ void exit(py::object, py::object, py::object) {
+ delete context;
+ context = nullptr;
+ }
+ mlir::edsc::ScopedEDSCContext *context;
+};
+
+struct PythonExpr {
+ PythonExpr() : expr{nullptr} {}
+ PythonExpr(const PythonBindable &bindable);
+ PythonExpr(const edsc_expr_t &expr) : expr{expr} {}
+ operator edsc_expr_t() { return expr; }
+ std::string str() {
+ assert(expr && "unexpected empty expr");
+ return Expr(*this).str();
+ }
+ edsc_expr_t expr;
+};
+
+struct PythonBindable : public PythonExpr {
+ PythonBindable() : PythonExpr(edsc_expr_t{makeBindable()}) {}
+ PythonBindable(PythonExpr expr) : PythonExpr(expr) {
+ assert(Expr(expr).isa<Bindable>() && "Expected Bindable");
+ }
+ std::string str() {
+ assert(expr && "unexpected empty expr");
+ return Expr(expr).str();
+ }
+};
+
+struct PythonStmt {
+ PythonStmt() : stmt{nullptr} {}
+ PythonStmt(const edsc_stmt_t &stmt) : stmt{stmt} {}
+ PythonStmt(const PythonExpr &e) : stmt{makeStmt(e.expr)} {}
+ operator edsc_stmt_t() { return stmt; }
+ std::string str() {
+ assert(stmt && "unexpected empty stmt");
+ return Stmt(stmt).str();
+ }
+ edsc_stmt_t stmt;
+};
+
+struct PythonIndexed : public edsc_indexed_t {
+ PythonIndexed() : edsc_indexed_t{makeIndexed(PythonBindable())} {}
+ PythonIndexed(PythonExpr e) : edsc_indexed_t{makeIndexed(e)} {}
+ PythonIndexed(PythonBindable b) : edsc_indexed_t{makeIndexed(b)} {}
+ operator PythonExpr() { return PythonExpr(base); }
+};
+
+struct MLIRFunctionEmitter {
+ MLIRFunctionEmitter(PythonFunction f)
+ : currentFunction(reinterpret_cast<mlir::Function *>(f.function)),
+ currentBuilder(currentFunction),
+ emitter(¤tBuilder, currentFunction->getLoc()) {}
+
+ PythonExpr bindConstantBF16(double value);
+ PythonExpr bindConstantF16(float value);
+ PythonExpr bindConstantF32(float value);
+ PythonExpr bindConstantF64(double value);
+ PythonExpr bindConstantInt(int64_t value, unsigned bitwidth);
+ PythonExpr bindConstantIndex(int64_t value);
+ PythonExpr bindFunctionArgument(unsigned pos);
+ py::list bindFunctionArguments();
+ py::list bindFunctionArgumentView(unsigned pos);
+ py::list bindMemRefShape(PythonExpr boundMemRef);
+ py::list bindIndexedMemRefShape(PythonIndexed boundMemRef) {
+ return bindMemRefShape(boundMemRef.base);
+ }
+ py::list bindMemRefView(PythonExpr boundMemRef);
+ py::list bindIndexedMemRefView(PythonIndexed boundMemRef) {
+ return bindMemRefView(boundMemRef.base);
+ }
+ void emit(PythonStmt stmt);
+
+private:
+ mlir::Function *currentFunction;
+ mlir::FuncBuilder currentBuilder;
+ mlir::edsc::MLIREmitter emitter;
+ edsc_mlir_emitter_t c_emitter;
+};
+
+static edsc_stmt_list_t makeCStmts(llvm::SmallVectorImpl<edsc_stmt_t> &owning,
+ const py::list &stmts) {
+ for (auto &inp : stmts) {
+ owning.push_back(edsc_stmt_t{inp.cast<PythonStmt>()});
+ }
+ return edsc_stmt_list_t{owning.data(), owning.size()};
+}
+
+static edsc_expr_list_t makeCExprs(llvm::SmallVectorImpl<edsc_expr_t> &owning,
+ const py::list &exprs) {
+ for (auto &inp : exprs) {
+ owning.push_back(edsc_expr_t{inp.cast<PythonExpr>()});
+ }
+ return edsc_expr_list_t{owning.data(), owning.size()};
+}
+
+PythonExpr::PythonExpr(const PythonBindable &bindable) : expr{bindable.expr} {}
+
+PythonExpr MLIRFunctionEmitter::bindConstantBF16(double value) {
+ return ::bindConstantBF16(edsc_mlir_emitter_t{&emitter}, value);
+}
+
+PythonExpr MLIRFunctionEmitter::bindConstantF16(float value) {
+ return ::bindConstantF16(edsc_mlir_emitter_t{&emitter}, value);
+}
+
+PythonExpr MLIRFunctionEmitter::bindConstantF32(float value) {
+ return ::bindConstantF32(edsc_mlir_emitter_t{&emitter}, value);
+}
+
+PythonExpr MLIRFunctionEmitter::bindConstantF64(double value) {
+ return ::bindConstantF64(edsc_mlir_emitter_t{&emitter}, value);
+}
+
+PythonExpr MLIRFunctionEmitter::bindConstantInt(int64_t value,
+ unsigned bitwidth) {
+ return ::bindConstantInt(edsc_mlir_emitter_t{&emitter}, value, bitwidth);
+}
+
+PythonExpr MLIRFunctionEmitter::bindConstantIndex(int64_t value) {
+ return ::bindConstantIndex(edsc_mlir_emitter_t{&emitter}, value);
+}
+
+PythonExpr MLIRFunctionEmitter::bindFunctionArgument(unsigned pos) {
+ return ::bindFunctionArgument(edsc_mlir_emitter_t{&emitter},
+ mlir_func_t{currentFunction}, pos);
+}
+
+PythonExpr getPythonType(edsc_expr_t e) { return PythonExpr(e); }
+
+template <typename T> py::list makePyList(llvm::ArrayRef<T> owningResults) {
+ py::list res;
+ for (auto e : owningResults) {
+ res.append(getPythonType(e));
+ }
+ return res;
+}
+
+py::list MLIRFunctionEmitter::bindFunctionArguments() {
+ auto arity = getFunctionArity(mlir_func_t{currentFunction});
+ llvm::SmallVector<edsc_expr_t, 8> owningResults(arity);
+ edsc_expr_list_t results{owningResults.data(), owningResults.size()};
+ ::bindFunctionArguments(edsc_mlir_emitter_t{&emitter},
+ mlir_func_t{currentFunction}, &results);
+ return makePyList(ArrayRef<edsc_expr_t>{owningResults});
+}
+
+py::list MLIRFunctionEmitter::bindMemRefShape(PythonExpr boundMemRef) {
+ auto rank = getBoundMemRefRank(edsc_mlir_emitter_t{&emitter}, boundMemRef);
+ llvm::SmallVector<edsc_expr_t, 8> owningShapes(rank);
+ edsc_expr_list_t resultShapes{owningShapes.data(), owningShapes.size()};
+ ::bindMemRefShape(edsc_mlir_emitter_t{&emitter}, boundMemRef, &resultShapes);
+ return makePyList(ArrayRef<edsc_expr_t>{owningShapes});
+}
+
+py::list MLIRFunctionEmitter::bindMemRefView(PythonExpr boundMemRef) {
+ auto rank = getBoundMemRefRank(edsc_mlir_emitter_t{&emitter}, boundMemRef);
+ // Own the PythonExpr for the arg as well as all its dims.
+ llvm::SmallVector<edsc_expr_t, 8> owningLbs(rank);
+ llvm::SmallVector<edsc_expr_t, 8> owningUbs(rank);
+ llvm::SmallVector<edsc_expr_t, 8> owningSteps(rank);
+ edsc_expr_list_t resultLbs{owningLbs.data(), owningLbs.size()};
+ edsc_expr_list_t resultUbs{owningUbs.data(), owningUbs.size()};
+ edsc_expr_list_t resultSteps{owningSteps.data(), owningSteps.size()};
+ ::bindMemRefView(edsc_mlir_emitter_t{&emitter}, boundMemRef, &resultLbs,
+ &resultUbs, &resultSteps);
+ py::list res;
+ res.append(makePyList(ArrayRef<edsc_expr_t>{owningLbs}));
+ res.append(makePyList(ArrayRef<edsc_expr_t>{owningUbs}));
+ res.append(makePyList(ArrayRef<edsc_expr_t>{owningSteps}));
+ return res;
+}
+
+void MLIRFunctionEmitter::emit(PythonStmt stmt) {
+ emitter.emitStmt(Stmt(stmt));
+}
+
+PYBIND11_MODULE(pybind, m) {
+ m.doc() =
+ "Python bindings for MLIR Embedded Domain-Specific Components (EDSCs)";
+ m.def("version", []() { return "EDSC Python extensions v0.0"; });
+ m.def("initContext",
+ []() { return static_cast<void *>(new ScopedEDSCContext()); });
+ m.def("deleteContext",
+ [](void *ctx) { delete reinterpret_cast<ScopedEDSCContext *>(ctx); });
+
+ m.def("Block", [](const py::list &stmts) {
+ SmallVector<edsc_stmt_t, 8> owning;
+ return PythonStmt(::Block(makeCStmts(owning, stmts)));
+ });
+ m.def("For", [](const py::list &ivs, const py::list &lbs, const py::list &ubs,
+ const py::list &steps, const py::list &stmts) {
+ SmallVector<edsc_expr_t, 8> owningIVs;
+ SmallVector<edsc_expr_t, 8> owningLBs;
+ SmallVector<edsc_expr_t, 8> owningUBs;
+ SmallVector<edsc_expr_t, 8> owningSteps;
+ SmallVector<edsc_stmt_t, 8> owningStmts;
+ return PythonStmt(
+ ::ForNest(makeCExprs(owningIVs, ivs), makeCExprs(owningLBs, lbs),
+ makeCExprs(owningUBs, ubs), makeCExprs(owningSteps, steps),
+ makeCStmts(owningStmts, stmts)));
+ });
+ m.def("For", [](PythonExpr iv, PythonExpr lb, PythonExpr ub, PythonExpr step,
+ const py::list &stmts) {
+ SmallVector<edsc_stmt_t, 8> owning;
+ return PythonStmt(::For(iv, lb, ub, step, makeCStmts(owning, stmts)));
+ });
+ m.def("Select", [](PythonExpr cond, PythonExpr e1, PythonExpr e2) {
+ return PythonExpr(::Select(cond, e1, e2));
+ });
+ m.def("Return", []() {
+ return PythonStmt(::Return(edsc_expr_list_t{nullptr, 0}));
+ });
+ m.def("Return", [](const py::list &returns) {
+ SmallVector<edsc_expr_t, 8> owningExprs;
+ return PythonStmt(::Return(makeCExprs(owningExprs, returns)));
+ });
+
+#define DEFINE_PYBIND_BINARY_OP(PYTHON_NAME, C_NAME) \
+ m.def(PYTHON_NAME, [](PythonExpr e1, PythonExpr e2) { \
+ return PythonExpr(::C_NAME(e1, e2)); \
+ });
+
+ DEFINE_PYBIND_BINARY_OP("Add", Add);
+ DEFINE_PYBIND_BINARY_OP("Mul", Mul);
+ DEFINE_PYBIND_BINARY_OP("Sub", Sub);
+ // DEFINE_PYBIND_BINARY_OP("Div", Div);
+ DEFINE_PYBIND_BINARY_OP("LT", LT);
+ DEFINE_PYBIND_BINARY_OP("LE", LE);
+ DEFINE_PYBIND_BINARY_OP("GT", GT);
+ DEFINE_PYBIND_BINARY_OP("GE", GE);
+
+#undef DEFINE_PYBIND_BINARY_OP
+
+ py::class_<PythonFunction>(m, "Function",
+ "Wrapping class for mlir::Function.")
+ .def(py::init<PythonFunction>())
+ .def("__str__", &PythonFunction::str);
+
+ py::class_<PythonType>(m, "Type", "Wrapping class for mlir::Type.")
+ .def(py::init<PythonType>())
+ .def("__str__", &PythonType::str);
+
+ py::class_<PythonMLIRModule>(
+ m, "MLIRModule",
+ "An MLIRModule is the abstraction that owns the allocations to support "
+ "compilation of a single mlir::Module into an ExecutionEngine backed by "
+ "the LLVM ORC JIT. A typical flow consists in creating an MLIRModule, "
+ "adding functions, compiling the module to obtain an ExecutionEngine on "
+ "which named functions may be called. For now the only means to retrieve "
+ "the ExecutionEngine is by calling `get_engine_address`. This mode of "
+ "execution is limited to passing the pointer to C++ where the function "
+ "is called. Extending the API to allow calling JIT compiled functions "
+ "directly require integration with a tensor library (e.g. numpy). This "
+ "is left as the prerogative of libraries and frameworks for now.")
+ .def(py::init<>())
+ .def("make_function", &PythonMLIRModule::makeFunction,
+ "Creates a new mlir::Function in the current mlir::Module.")
+ .def(
+ "make_scalar_type",
+ [](PythonMLIRModule &instance, const std::string &type,
+ unsigned bitwidth) {
+ return instance.makeScalarType(type, bitwidth);
+ },
+ py::arg("type"), py::arg("bitwidth") = 0,
+ "Returns a scalar mlir::Type using the following convention:\n"
+ " - makeScalarType(c, \"bf16\") return an `mlir::Type::getBF16`\n"
+ " - makeScalarType(c, \"f16\") return an `mlir::Type::getF16`\n"
+ " - makeScalarType(c, \"f32\") return an `mlir::Type::getF32`\n"
+ " - makeScalarType(c, \"f64\") return an `mlir::Type::getF64`\n"
+ " - makeScalarType(c, \"index\") return an `mlir::Type::getIndex`\n"
+ " - makeScalarType(c, \"i\", bitwidth) return an "
+ "`mlir::Type::getInteger(bitwidth)`\n\n"
+ " No other combinations are currently supported.")
+ .def("make_memref_type", &PythonMLIRModule::makeMemRefType,
+ "Returns an mlir::MemRefType of an elemental scalar. -1 is used to "
+ "denote symbolic dimensions in the resulting memref shape.")
+ .def("compile", &PythonMLIRModule::compile,
+ "Compiles the mlir::Module to LLVMIR a creates new opaque "
+ "ExecutionEngine backed by the ORC JIT.")
+ .def("get_ir", &PythonMLIRModule::getIR,
+ "Returns a dump of the MLIR representation of the module. This is "
+ "used for serde to support out-of-process execution as well as "
+ "debugging purposes.")
+ .def("get_engine_address", &PythonMLIRModule::getEngineAddress,
+ "Returns the address of the compiled ExecutionEngine. This is used "
+ "for in-process execution.");
+
+ py::class_<ContextManager>(
+ m, "ContextManager",
+ "An EDSC context manager is the memory arena containing all the EDSC "
+ "allocations.\nUsage:\n\n"
+ "with E.ContextManager() as _:\n i = E.Expr(E.Bindable())\n ...")
+ .def(py::init<>())
+ .def("__enter__", &ContextManager::enter)
+ .def("__exit__", &ContextManager::exit);
+
+ py::class_<MLIRFunctionEmitter>(
+ m, "MLIRFunctionEmitter",
+ "An MLIRFunctionEmitter is used to fill an empty function body. This is "
+ "a staged process:\n"
+ " 1. create or retrieve an mlir::Function `f` with an empty body;\n"
+ " 2. make an `MLIRFunctionEmitter(f)` to build the current function;\n"
+ " 3. create leaf Expr that are either Bindable or already Expr that are"
+ " bound to constants and function arguments by using methods of "
+ " `MLIRFunctionEmitter`;\n"
+ " 4. build the function body using Expr, Indexed and Stmt;\n"
+ " 5. emit the MLIR to implement the function body.")
+ .def(py::init<PythonFunction>())
+ .def("bind_constant_bf16", &MLIRFunctionEmitter::bindConstantBF16)
+ .def("bind_constant_f16", &MLIRFunctionEmitter::bindConstantF16)
+ .def("bind_constant_f32", &MLIRFunctionEmitter::bindConstantF32)
+ .def("bind_constant_f64", &MLIRFunctionEmitter::bindConstantF64)
+ .def("bind_constant_int", &MLIRFunctionEmitter::bindConstantInt)
+ .def("bind_constant_index", &MLIRFunctionEmitter::bindConstantIndex)
+ .def("bind_function_argument", &MLIRFunctionEmitter::bindFunctionArgument,
+ "Returns an Expr that has been bound to a positional argument in "
+ "the current Function.")
+ .def("bind_function_arguments",
+ &MLIRFunctionEmitter::bindFunctionArguments,
+ "Returns a list of Expr where each Expr has been bound to the "
+ "corresponding positional argument in the current Function.")
+ .def("bind_memref_shape", &MLIRFunctionEmitter::bindMemRefShape,
+ "Returns a list of Expr where each Expr has been bound to the "
+ "corresponding dimension of the memref.")
+ .def("bind_memref_view", &MLIRFunctionEmitter::bindMemRefView,
+ "Returns three lists (lower bound, upper bound and step) of Expr "
+ "where each triplet of Expr has been bound to the minimal offset, "
+ "extent and stride of the corresponding dimension of the memref.")
+ .def("bind_indexed_shape", &MLIRFunctionEmitter::bindIndexedMemRefShape,
+ "Same as bind_memref_shape but returns a list of `Indexed` that "
+ "support load and store operations")
+ .def("bind_indexed_view", &MLIRFunctionEmitter::bindIndexedMemRefView,
+ "Same as bind_memref_view but returns lists of `Indexed` that "
+ "support load and store operations")
+ .def("emit", &MLIRFunctionEmitter::emit,
+ "Emits the MLIR for the EDSC expressions and statements in the "
+ "current function body.");
+
+ py::class_<PythonExpr>(m, "Expr", "Wrapping class for mlir::edsc::Expr")
+ .def(py::init<PythonBindable>())
+ .def("__add__", [](PythonExpr e1,
+ PythonExpr e2) { return PythonExpr(::Add(e1, e2)); })
+ .def("__sub__", [](PythonExpr e1,
+ PythonExpr e2) { return PythonExpr(::Sub(e1, e2)); })
+ .def("__mul__", [](PythonExpr e1,
+ PythonExpr e2) { return PythonExpr(::Mul(e1, e2)); })
+ // .def("__div__", [](PythonExpr e1, PythonExpr e2) { return
+ // PythonExpr(::Div(e1, e2)); })
+ .def("__lt__", [](PythonExpr e1,
+ PythonExpr e2) { return PythonExpr(::LT(e1, e2)); })
+ .def("__le__", [](PythonExpr e1,
+ PythonExpr e2) { return PythonExpr(::LE(e1, e2)); })
+ .def("__gt__", [](PythonExpr e1,
+ PythonExpr e2) { return PythonExpr(::GT(e1, e2)); })
+ .def("__ge__", [](PythonExpr e1,
+ PythonExpr e2) { return PythonExpr(::GE(e1, e2)); })
+ .def("__str__", &PythonExpr::str,
+ R"DOC(Returns the string value for the Expr)DOC");
+
+ py::class_<PythonBindable>(
+ m, "Bindable",
+ "Wrapping class for mlir::edsc::Bindable.\nA Bindable is a special Expr "
+ "that can be bound manually to specific MLIR SSA Values.")
+ .def(py::init<>())
+ .def("__str__", &PythonBindable::str);
+
+ py::class_<PythonStmt>(m, "Stmt", "Wrapping class for mlir::edsc::Stmt.")
+ .def(py::init<PythonExpr>())
+ .def("__str__", &PythonStmt::str,
+ R"DOC(Returns the string value for the Expr)DOC");
+
+ py::class_<PythonIndexed>(
+ m, "Indexed",
+ "Wrapping class for mlir::edsc::Indexed.\nAn Indexed is a wrapper class "
+ "that support load and store operations.")
+ .def(py::init<>(), R"DOC(Build from fresh Bindable)DOC")
+ .def(py::init<PythonExpr>(), R"DOC(Build from existing Expr)DOC")
+ .def(py::init<PythonBindable>(), R"DOC(Build from existing Bindable)DOC")
+ .def(
+ "load",
+ [](PythonIndexed &instance, const py::list &indices) {
+ SmallVector<edsc_expr_t, 8> owning;
+ return PythonExpr(Load(instance, makeCExprs(owning, indices)));
+ },
+ R"DOC(Returns an Expr that loads from an Indexed)DOC")
+ .def(
+ "store",
+ [](PythonIndexed &instance, const py::list &indices,
+ PythonExpr value) {
+ SmallVector<edsc_expr_t, 8> owning;
+ return PythonStmt(
+ Store(value, instance, makeCExprs(owning, indices)));
+ },
+ R"DOC(Returns the Stmt that stores into an Indexed)DOC");
+}
+
+} // namespace python
+} // namespace edsc
+} // namespace mlir
diff --git a/bindings/python/test/test_py2and3.py b/bindings/python/test/test_py2and3.py
new file mode 100644
index 0000000..7779849
--- /dev/null
+++ b/bindings/python/test/test_py2and3.py
@@ -0,0 +1,208 @@
+"""Python2 and 3 test for the MLIR EDSC C API and Python bindings"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+import google_mlir.bindings.python.pybind as E
+
+help(E)
+
+
+class EdscTest(unittest.TestCase):
+
+ def testBindables(self):
+ with E.ContextManager():
+ i = E.Expr(E.Bindable())
+ self.assertIn("$1", i.__str__())
+
+ def testOneExpr(self):
+ with E.ContextManager():
+ i, lb, ub = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
+ expr = E.Mul(i, E.Add(lb, ub))
+ str = expr.__str__()
+ self.assertIn("($1 * ($2 + $3))", str)
+
+ def testOneLoop(self):
+ with E.ContextManager():
+ i, lb, ub, step = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
+ loop = E.For(i, lb, ub, step, [E.Stmt(E.Add(lb, ub))])
+ str = loop.__str__()
+ self.assertIn("for($1 = $2 to $3 step $4) {", str)
+ self.assertIn("$5 = ($2 + $3)", str)
+
+ def testTwoLoops(self):
+ with E.ContextManager():
+ i, lb, ub, step = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
+ loop = E.For(i, lb, ub, step, [E.For(i, lb, ub, step, [E.Stmt(i)])])
+ str = loop.__str__()
+ self.assertIn("for($1 = $2 to $3 step $4) {", str)
+ self.assertIn("for($1 = $2 to $3 step $4) {", str)
+ self.assertIn("$5 = $1;", str)
+
+ def testNestedLoops(self):
+ with E.ContextManager():
+ i, lb, ub, step = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
+ ivs = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
+ lbs = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
+ ubs = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
+ steps = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
+ loop = E.For(ivs, lbs, ubs, steps, [
+ E.For(i, lb, ub, step, [E.Stmt(ub * step - lb)]),
+ ])
+ str = loop.__str__()
+ self.assertIn("for($5 = $9 to $13 step $17) {", str)
+ self.assertIn("for($6 = $10 to $14 step $18) {", str)
+ self.assertIn("for($7 = $11 to $15 step $19) {", str)
+ self.assertIn("for($8 = $12 to $16 step $20) {", str)
+ self.assertIn("for($1 = $2 to $3 step $4) {", str)
+ self.assertIn("= (($3 * $4) - $2);", str)
+
+ def testIndexed(self):
+ with E.ContextManager():
+ i, j, k = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
+ A, B, C = list(map(E.Indexed, [E.Bindable() for _ in range(3)]))
+ stmt = C.store([i, j], A.load([i, k]) * B.load([k, j]))
+ str = stmt.__str__()
+ self.assertIn(" = store( ... )", str)
+
+ def testMatmul(self):
+ with E.ContextManager():
+ ivs = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
+ lbs = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
+ ubs = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
+ steps = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
+ i, j, k = ivs[0], ivs[1], ivs[2]
+ A, B, C = list(map(E.Indexed, [E.Bindable() for _ in range(3)]))
+ loop = E.For(
+ ivs, lbs, ubs, steps,
+ [C.store([i, j],
+ C.load([i, j]) + A.load([i, k]) * B.load([k, j]))])
+ str = loop.__str__()
+ self.assertIn("for($1 = $4 to $7 step $10) {", str)
+ self.assertIn("for($2 = $5 to $8 step $11) {", str)
+ self.assertIn("for($3 = $6 to $9 step $12) {", str)
+ self.assertIn(" = store( ... )", str)
+
+ def testArithmetic(self):
+ with E.ContextManager():
+ i, j, k, l = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
+ stmt = i + j * k - l
+ str = stmt.__str__()
+ self.assertIn("(($1 + ($2 * $3)) - $4)", str)
+
+ def testSelect(self):
+ with E.ContextManager():
+ i, j, k = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
+ stmt = E.Select(i > j, i, j)
+ str = stmt.__str__()
+ self.assertIn("select(($1 > $2), $1, $2)", str)
+
+ def testBlock(self):
+ with E.ContextManager():
+ i, j = list(map(E.Expr, [E.Bindable() for _ in range(2)]))
+ stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)])
+ str = stmt.__str__()
+ self.assertIn("block {", str)
+ self.assertIn("$3 = ($1 + $2)", str)
+ self.assertIn("$4 = ($1 - $2)", str)
+ self.assertIn("}", str)
+
+ def testMLIRScalarTypes(self):
+ module = E.MLIRModule()
+ t = module.make_scalar_type("bf16")
+ self.assertIn("bf16", t.__str__())
+ t = module.make_scalar_type("f16")
+ self.assertIn("f16", t.__str__())
+ t = module.make_scalar_type("f32")
+ self.assertIn("f32", t.__str__())
+ t = module.make_scalar_type("f64")
+ self.assertIn("f64", t.__str__())
+ t = module.make_scalar_type("i", 1)
+ self.assertIn("i1", t.__str__())
+ t = module.make_scalar_type("i", 8)
+ self.assertIn("i8", t.__str__())
+ t = module.make_scalar_type("i", 32)
+ self.assertIn("i32", t.__str__())
+ t = module.make_scalar_type("i", 123)
+ self.assertIn("i123", t.__str__())
+ t = module.make_scalar_type("index")
+ self.assertIn("index", t.__str__())
+
+ def testMLIRFunctionCreation(self):
+ module = E.MLIRModule()
+ t = module.make_scalar_type("f32")
+ self.assertIn("f32", t.__str__())
+ m = module.make_memref_type(t, [3, 4, -1, 5])
+ self.assertIn("memref<3x4x?x5xf32>", m.__str__())
+ f = module.make_function("copy", [m, m], [])
+ self.assertIn(
+ "func @copy(%arg0: memref<3x4x?x5xf32>, %arg1: memref<3x4x?x5xf32>) {",
+ f.__str__())
+
+ def testMLIRConstantEmission(self):
+ module = E.MLIRModule()
+ f = module.make_function("constants", [], [])
+ with E.ContextManager():
+ emitter = E.MLIRFunctionEmitter(f)
+ emitter.bind_constant_bf16(1.23)
+ emitter.bind_constant_f16(1.23)
+ emitter.bind_constant_f32(1.23)
+ emitter.bind_constant_f64(1.23)
+ emitter.bind_constant_int(1, 1)
+ emitter.bind_constant_int(123, 8)
+ emitter.bind_constant_int(123, 16)
+ emitter.bind_constant_int(123, 32)
+ emitter.bind_constant_index(123)
+ str = f.__str__()
+ self.assertIn("constant 1.230000e+00 : bf16", str)
+ self.assertIn("constant 1.230470e+00 : f16", str)
+ self.assertIn("constant 1.230000e+00 : f32", str)
+ self.assertIn("constant 1.230000e+00 : f64", str)
+ self.assertIn("constant 1 : i1", str)
+ self.assertIn("constant 123 : i8", str)
+ self.assertIn("constant 123 : i16", str)
+ self.assertIn("constant 123 : i32", str)
+ self.assertIn("constant 123 : index", str)
+
+ # TODO(ntv): support symbolic For bounds with EDSCs
+ def testMLIREmission(self):
+ shape = [3, 4, 5]
+ module = E.MLIRModule()
+ index = module.make_scalar_type("index")
+ t = module.make_scalar_type("f32")
+ m = module.make_memref_type(t, shape)
+ f = module.make_function("copy", [m, m], [])
+
+ with E.ContextManager():
+ emitter = E.MLIRFunctionEmitter(f)
+ zero = emitter.bind_constant_index(0)
+ one = emitter.bind_constant_index(1)
+ input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
+ M, N, O = emitter.bind_indexed_shape(input)
+
+ ivs = list(map(E.Expr, [E.Bindable() for _ in range(len(shape))]))
+ lbs = [zero, zero, zero]
+ ubs = [M, N, O]
+ steps = [one, one, one]
+
+ # TODO(ntv): emitter.assertEqual(M, oM) etc
+ loop = E.Block([
+ E.For(ivs, lbs, ubs, steps, [output.store(ivs, input.load(ivs))]),
+ E.Return()
+ ])
+ emitter.emit(loop)
+
+ # print(f) # uncomment to see the emitted IR
+ str = f.__str__()
+ self.assertIn("""store %0, %arg1[%i0, %i1, %i2] : memref<3x4x5xf32>""",
+ str)
+
+ module.compile()
+ self.assertNotEqual(module.get_engine_address(), 0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/bindings/python/test/test_py3.py b/bindings/python/test/test_py3.py
new file mode 100644
index 0000000..427add5
--- /dev/null
+++ b/bindings/python/test/test_py3.py
@@ -0,0 +1,47 @@
+"""Python3 test for the MLIR EDSC C API and Python bindings"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+import google_mlir.bindings.python.pybind as E
+
+
+class EdscTest(unittest.TestCase):
+
+ def testSugaredMLIREmission(self):
+ shape = [3, 4, 5, 6, 7]
+ shape_t = [7, 4, 5, 6, 3]
+ module = E.MLIRModule()
+ t = module.make_scalar_type("f32")
+ m = module.make_memref_type(t, shape)
+ m_t = module.make_memref_type(t, shape_t)
+ f = module.make_function("copy", [m, m_t], [])
+
+ with E.ContextManager():
+ emitter = E.MLIRFunctionEmitter(f)
+ input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
+ lbs, ubs, steps = emitter.bind_indexed_view(input)
+ i, *ivs, j = list(map(E.Expr, [E.Bindable() for _ in range(len(shape))]))
+
+ # n-D type and rank agnostic copy-transpose-first-last (where n >= 2).
+ loop = E.Block([
+ E.For([i, *ivs, j], lbs, ubs, steps,
+ [output.store([i, *ivs, j], input.load([j, *ivs, i]))]),
+ E.Return()
+ ])
+ emitter.emit(loop)
+
+ # print(f) # uncomment to see the emitted IR
+ str = f.__str__()
+ self.assertIn("load %arg0[%i4, %i1, %i2, %i3, %i0]", str)
+ self.assertIn("store %0, %arg1[%i0, %i1, %i2, %i3, %i4]", str)
+
+ module.compile()
+ self.assertNotEqual(module.get_engine_address(), 0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/lib/EDSC/MLIREmitter.cpp b/lib/EDSC/MLIREmitter.cpp
index f0efdb5..72d30ff 100644
--- a/lib/EDSC/MLIREmitter.cpp
+++ b/lib/EDSC/MLIREmitter.cpp
@@ -459,7 +459,8 @@
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
e->bindConstant<mlir::ConstantIntOp>(
- b, value, e->getBuilder()->getIntegerType(bitwidth));
+ b, // mlir::APInt(bitwidth, value),
+ value, e->getBuilder()->getIntegerType(bitwidth));
return b;
}
diff --git a/lib/EDSC/Types.cpp b/lib/EDSC/Types.cpp
index e346093..e261cc8 100644
--- a/lib/EDSC/Types.cpp
+++ b/lib/EDSC/Types.cpp
@@ -23,7 +23,6 @@
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@@ -624,8 +623,9 @@
return (*this)[llvm::ArrayRef<Expr>{indices.begin(), indices.end()}];
}
-// NOLINTNEXTLINE: unconventional-assign-operator
-Stmt mlir::edsc::Indexed::operator=(Expr expr) {
+// clang-format off
+Stmt mlir::edsc::Indexed::operator=(Expr expr) { // NOLINT: unconventional-assign-operator
+ // clang-format on
assert(!indices.empty() && "Expected attached indices to Indexed");
assert(base);
Stmt stmt(store(expr, base, indices));
@@ -644,23 +644,27 @@
mlir_type_t makeScalarType(mlir_context_t context, const char *name,
unsigned bitwidth) {
mlir::MLIRContext *c = reinterpret_cast<mlir::MLIRContext *>(context);
- mlir_type_t res =
- llvm::StringSwitch<mlir_type_t>(name)
- .Case("bf16",
- mlir_type_t{mlir::Type::getBF16(c).getAsOpaquePointer()})
- .Case("f16", mlir_type_t{mlir::Type::getF16(c).getAsOpaquePointer()})
- .Case("f32", mlir_type_t{mlir::Type::getF32(c).getAsOpaquePointer()})
- .Case("f64", mlir_type_t{mlir::Type::getF64(c).getAsOpaquePointer()})
- .Case("index",
- mlir_type_t{mlir::Type::getIndex(c).getAsOpaquePointer()})
- .Case("i",
- mlir_type_t{
- mlir::Type::getInteger(bitwidth, c).getAsOpaquePointer()})
- .Default(mlir_type_t{nullptr});
- if (!res) {
- llvm_unreachable("Invalid type specifier");
+ if (llvm::StringRef(name) == "bf16") {
+ return mlir_type_t{mlir::Type::getBF16(c).getAsOpaquePointer()};
}
- return res;
+ if (llvm::StringRef(name) == "f16") {
+ return mlir_type_t{mlir::Type::getF16(c).getAsOpaquePointer()};
+ }
+ if (llvm::StringRef(name) == "f32") {
+ return mlir_type_t{mlir::Type::getF32(c).getAsOpaquePointer()};
+ }
+ if (llvm::StringRef(name) == "f64") {
+ return mlir_type_t{mlir::Type::getF64(c).getAsOpaquePointer()};
+ }
+ if (llvm::StringRef(name) == "index") {
+ return mlir_type_t{mlir::Type::getIndex(c).getAsOpaquePointer()};
+ }
+ if (llvm::StringRef(name) == "i") {
+ return mlir_type_t{
+ mlir::Type::getInteger(bitwidth, c).getAsOpaquePointer()};
+ }
+ assert(false && "unknown scalar type");
+ return mlir_type_t{nullptr};
}
mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,