blob: 5455eba135055dee90e02a390c1b66a729fccd50 [file] [log] [blame]
#include "third_party/llvm/llvm/include/llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm/include/llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm/include/llvm/IR/Module.h"
#include "third_party/llvm/llvm/include/llvm/Support/TargetSelect.h"
#include "third_party/llvm/llvm/include/llvm/Support/raw_ostream.h"
#include <cstddef>
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir-c/Core.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/MLIREmitter.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Types.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/ExecutionEngine/ExecutionEngine.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/BuiltinOps.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Module.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Target/LLVMIR.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Transforms/Passes.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Types.h"
static bool inited = [] {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
return true;
}();
namespace mlir {
namespace edsc {
namespace python {
static std::vector<std::unique_ptr<mlir::Pass>> getDefaultPasses(
const std::vector<const mlir::PassInfo *> &mlirPassInfoList = {}) {
std::vector<std::unique_ptr<mlir::Pass>> passList;
passList.reserve(mlirPassInfoList.size() + 4);
// Run each of the passes that were selected.
for (const auto *passInfo : mlirPassInfoList) {
passList.emplace_back(passInfo->createPass());
}
// Append the extra passes for lowering to MLIR.
passList.emplace_back(mlir::createConstantFoldPass());
passList.emplace_back(mlir::createCSEPass());
passList.emplace_back(mlir::createCanonicalizerPass());
passList.emplace_back(mlir::createLowerAffinePass());
return passList;
}
// Run the passes sequentially on the given module.
// Return `nullptr` immediately if any of the passes fails.
static bool runPasses(const std::vector<std::unique_ptr<mlir::Pass>> &passes,
Module *module) {
for (const auto &pass : passes) {
mlir::PassResult result = pass->runOnModule(module);
if (result == mlir::PassResult::Failure || module->verify()) {
llvm::errs() << "Pass failed\n";
return true;
}
}
return false;
}
namespace py = pybind11;
struct PythonBindable;
struct PythonExpr;
struct PythonStmt;
struct PythonBlock;
struct PythonFunction {
PythonFunction() : function{nullptr} {}
PythonFunction(mlir_func_t f) : function{f} {}
PythonFunction(mlir::Function *f) : function{f} {}
operator mlir_func_t() { return function; }
std::string str() {
mlir::Function *f = reinterpret_cast<mlir::Function *>(function);
std::string res;
llvm::raw_string_ostream os(res);
f->print(os);
return res;
}
mlir_func_t function;
};
struct PythonType {
PythonType() : type{nullptr} {}
PythonType(mlir_type_t t) : type{t} {}
operator mlir_type_t() { return type; }
std::string str() {
mlir::Type f = mlir::Type::getFromOpaquePointer(type);
std::string res;
llvm::raw_string_ostream os(res);
f.print(os);
return res;
}
mlir_type_t type;
};
/// Trivial C++ wrappers make use of the EDSC C API.
struct PythonMLIRModule {
PythonMLIRModule() : mlirContext(), module(new mlir::Module(&mlirContext)) {}
PythonType makeScalarType(const std::string &mlirElemType,
unsigned bitwidth) {
return ::makeScalarType(mlir_context_t{&mlirContext}, mlirElemType.c_str(),
bitwidth);
}
PythonType makeMemRefType(PythonType elemType, std::vector<int64_t> sizes) {
return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType,
int64_list_t{sizes.data(), sizes.size()});
}
PythonType makeIndexType() {
return ::makeIndexType(mlir_context_t{&mlirContext});
}
PythonFunction makeFunction(const std::string &name,
std::vector<PythonType> &inputTypes,
std::vector<PythonType> &outputTypes) {
std::vector<mlir_type_t> ins(inputTypes.begin(), inputTypes.end());
std::vector<mlir_type_t> outs(outputTypes.begin(), outputTypes.end());
auto funcType = ::makeFunctionType(
mlir_context_t{&mlirContext}, mlir_type_list_t{ins.data(), ins.size()},
mlir_type_list_t{outs.data(), outs.size()});
auto *func = new mlir::Function(
UnknownLoc::get(&mlirContext), name,
mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>());
func->addEntryBlock();
module->getFunctions().push_back(func);
return mlir_func_t{func};
}
void compile() {
auto created = mlir::ExecutionEngine::create(module.get());
llvm::handleAllErrors(created.takeError(),
[](const llvm::ErrorInfoBase &b) {
b.log(llvm::errs());
assert(false);
});
engine = std::move(*created);
}
std::string getIR() {
std::string res;
llvm::raw_string_ostream os(res);
module->print(os);
return res;
}
uint64_t getEngineAddress() {
assert(engine && "module must be compiled into engine first");
return reinterpret_cast<uint64_t>(reinterpret_cast<void *>(engine.get()));
}
private:
mlir::MLIRContext mlirContext;
// One single module in a python-exposed MLIRContext for now.
std::unique_ptr<mlir::Module> module;
std::unique_ptr<mlir::ExecutionEngine> engine;
};
struct ContextManager {
void enter() { context = new ScopedEDSCContext(); }
void exit(py::object, py::object, py::object) {
delete context;
context = nullptr;
}
mlir::edsc::ScopedEDSCContext *context;
};
struct PythonExpr {
PythonExpr() : expr{nullptr} {}
PythonExpr(const PythonBindable &bindable);
PythonExpr(const edsc_expr_t &expr) : expr{expr} {}
operator edsc_expr_t() { return expr; }
std::string str() {
assert(expr && "unexpected empty expr");
return Expr(*this).str();
}
edsc_expr_t expr;
};
struct PythonBindable : public PythonExpr {
explicit PythonBindable(const PythonType &type)
: PythonExpr(edsc_expr_t{makeBindable(type.type)}) {}
PythonBindable(PythonExpr expr) : PythonExpr(expr) {
assert(Expr(expr).isa<Bindable>() && "Expected Bindable");
}
std::string str() {
assert(expr && "unexpected empty expr");
return Expr(expr).str();
}
};
struct PythonStmt {
PythonStmt() : stmt{nullptr} {}
PythonStmt(const edsc_stmt_t &stmt) : stmt{stmt} {}
PythonStmt(const PythonExpr &e) : stmt{makeStmt(e.expr)} {}
operator edsc_stmt_t() { return stmt; }
std::string str() {
assert(stmt && "unexpected empty stmt");
return Stmt(stmt).str();
}
edsc_stmt_t stmt;
};
struct PythonBlock {
PythonBlock() : blk{nullptr} {}
PythonBlock(const edsc_block_t &other) : blk{other} {}
PythonBlock(const PythonBlock &other) = default;
operator edsc_block_t() { return blk; }
std::string str() {
assert(blk && "unexpected empty block");
return StmtBlock(blk).str();
}
edsc_block_t blk;
};
struct PythonIndexed : public edsc_indexed_t {
PythonIndexed(PythonExpr e) : edsc_indexed_t{makeIndexed(e)} {}
PythonIndexed(PythonBindable b) : edsc_indexed_t{makeIndexed(b)} {}
operator PythonExpr() { return PythonExpr(base); }
};
struct MLIRFunctionEmitter {
MLIRFunctionEmitter(PythonFunction f)
: currentFunction(reinterpret_cast<mlir::Function *>(f.function)),
currentBuilder(currentFunction),
emitter(&currentBuilder, currentFunction->getLoc()) {}
PythonExpr bindConstantBF16(double value);
PythonExpr bindConstantF16(float value);
PythonExpr bindConstantF32(float value);
PythonExpr bindConstantF64(double value);
PythonExpr bindConstantInt(int64_t value, unsigned bitwidth);
PythonExpr bindConstantIndex(int64_t value);
PythonExpr bindFunctionArgument(unsigned pos);
py::list bindFunctionArguments();
py::list bindFunctionArgumentView(unsigned pos);
py::list bindMemRefShape(PythonExpr boundMemRef);
py::list bindIndexedMemRefShape(PythonIndexed boundMemRef) {
return bindMemRefShape(boundMemRef.base);
}
py::list bindMemRefView(PythonExpr boundMemRef);
py::list bindIndexedMemRefView(PythonIndexed boundMemRef) {
return bindMemRefView(boundMemRef.base);
}
void emit(PythonStmt stmt);
void emitBlock(PythonBlock block);
void emitBlockBody(PythonBlock block);
private:
mlir::Function *currentFunction;
mlir::FuncBuilder currentBuilder;
mlir::edsc::MLIREmitter emitter;
edsc_mlir_emitter_t c_emitter;
};
static edsc_stmt_list_t makeCStmts(llvm::SmallVectorImpl<edsc_stmt_t> &owning,
const py::list &stmts) {
for (auto &inp : stmts) {
owning.push_back(edsc_stmt_t{inp.cast<PythonStmt>()});
}
return edsc_stmt_list_t{owning.data(), owning.size()};
}
static edsc_expr_list_t makeCExprs(llvm::SmallVectorImpl<edsc_expr_t> &owning,
const py::list &exprs) {
for (auto &inp : exprs) {
owning.push_back(edsc_expr_t{inp.cast<PythonExpr>()});
}
return edsc_expr_list_t{owning.data(), owning.size()};
}
PythonExpr::PythonExpr(const PythonBindable &bindable) : expr{bindable.expr} {}
PythonExpr MLIRFunctionEmitter::bindConstantBF16(double value) {
return ::bindConstantBF16(edsc_mlir_emitter_t{&emitter}, value);
}
PythonExpr MLIRFunctionEmitter::bindConstantF16(float value) {
return ::bindConstantF16(edsc_mlir_emitter_t{&emitter}, value);
}
PythonExpr MLIRFunctionEmitter::bindConstantF32(float value) {
return ::bindConstantF32(edsc_mlir_emitter_t{&emitter}, value);
}
PythonExpr MLIRFunctionEmitter::bindConstantF64(double value) {
return ::bindConstantF64(edsc_mlir_emitter_t{&emitter}, value);
}
PythonExpr MLIRFunctionEmitter::bindConstantInt(int64_t value,
unsigned bitwidth) {
return ::bindConstantInt(edsc_mlir_emitter_t{&emitter}, value, bitwidth);
}
PythonExpr MLIRFunctionEmitter::bindConstantIndex(int64_t value) {
return ::bindConstantIndex(edsc_mlir_emitter_t{&emitter}, value);
}
PythonExpr MLIRFunctionEmitter::bindFunctionArgument(unsigned pos) {
return ::bindFunctionArgument(edsc_mlir_emitter_t{&emitter},
mlir_func_t{currentFunction}, pos);
}
PythonExpr getPythonType(edsc_expr_t e) { return PythonExpr(e); }
template <typename T> py::list makePyList(llvm::ArrayRef<T> owningResults) {
py::list res;
for (auto e : owningResults) {
res.append(getPythonType(e));
}
return res;
}
py::list MLIRFunctionEmitter::bindFunctionArguments() {
auto arity = getFunctionArity(mlir_func_t{currentFunction});
llvm::SmallVector<edsc_expr_t, 8> owningResults(arity);
edsc_expr_list_t results{owningResults.data(), owningResults.size()};
::bindFunctionArguments(edsc_mlir_emitter_t{&emitter},
mlir_func_t{currentFunction}, &results);
return makePyList(ArrayRef<edsc_expr_t>{owningResults});
}
py::list MLIRFunctionEmitter::bindMemRefShape(PythonExpr boundMemRef) {
auto rank = getBoundMemRefRank(edsc_mlir_emitter_t{&emitter}, boundMemRef);
llvm::SmallVector<edsc_expr_t, 8> owningShapes(rank);
edsc_expr_list_t resultShapes{owningShapes.data(), owningShapes.size()};
::bindMemRefShape(edsc_mlir_emitter_t{&emitter}, boundMemRef, &resultShapes);
return makePyList(ArrayRef<edsc_expr_t>{owningShapes});
}
py::list MLIRFunctionEmitter::bindMemRefView(PythonExpr boundMemRef) {
auto rank = getBoundMemRefRank(edsc_mlir_emitter_t{&emitter}, boundMemRef);
// Own the PythonExpr for the arg as well as all its dims.
llvm::SmallVector<edsc_expr_t, 8> owningLbs(rank);
llvm::SmallVector<edsc_expr_t, 8> owningUbs(rank);
llvm::SmallVector<edsc_expr_t, 8> owningSteps(rank);
edsc_expr_list_t resultLbs{owningLbs.data(), owningLbs.size()};
edsc_expr_list_t resultUbs{owningUbs.data(), owningUbs.size()};
edsc_expr_list_t resultSteps{owningSteps.data(), owningSteps.size()};
::bindMemRefView(edsc_mlir_emitter_t{&emitter}, boundMemRef, &resultLbs,
&resultUbs, &resultSteps);
py::list res;
res.append(makePyList(ArrayRef<edsc_expr_t>{owningLbs}));
res.append(makePyList(ArrayRef<edsc_expr_t>{owningUbs}));
res.append(makePyList(ArrayRef<edsc_expr_t>{owningSteps}));
return res;
}
void MLIRFunctionEmitter::emit(PythonStmt stmt) {
emitter.emitStmt(Stmt(stmt));
}
void MLIRFunctionEmitter::emitBlock(PythonBlock block) {
emitter.emitBlock(StmtBlock(block));
}
void MLIRFunctionEmitter::emitBlockBody(PythonBlock block) {
emitter.emitStmts(StmtBlock(block).getBody());
}
PYBIND11_MODULE(pybind, m) {
m.doc() =
"Python bindings for MLIR Embedded Domain-Specific Components (EDSCs)";
m.def("version", []() { return "EDSC Python extensions v0.0"; });
m.def("initContext",
[]() { return static_cast<void *>(new ScopedEDSCContext()); });
m.def("deleteContext",
[](void *ctx) { delete reinterpret_cast<ScopedEDSCContext *>(ctx); });
m.def("Block", [](const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
return PythonBlock(::Block(makeCStmts(owning, stmts)));
});
m.def("For", [](const py::list &ivs, const py::list &lbs, const py::list &ubs,
const py::list &steps, const py::list &stmts) {
SmallVector<edsc_expr_t, 8> owningIVs;
SmallVector<edsc_expr_t, 8> owningLBs;
SmallVector<edsc_expr_t, 8> owningUBs;
SmallVector<edsc_expr_t, 8> owningSteps;
SmallVector<edsc_stmt_t, 8> owningStmts;
return PythonStmt(
::ForNest(makeCExprs(owningIVs, ivs), makeCExprs(owningLBs, lbs),
makeCExprs(owningUBs, ubs), makeCExprs(owningSteps, steps),
makeCStmts(owningStmts, stmts)));
});
m.def("For", [](PythonExpr iv, PythonExpr lb, PythonExpr ub, PythonExpr step,
const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
return PythonStmt(::For(iv, lb, ub, step, makeCStmts(owning, stmts)));
});
m.def("Select", [](PythonExpr cond, PythonExpr e1, PythonExpr e2) {
return PythonExpr(::Select(cond, e1, e2));
});
m.def("Return", []() {
return PythonStmt(::Return(edsc_expr_list_t{nullptr, 0}));
});
m.def("Return", [](const py::list &returns) {
SmallVector<edsc_expr_t, 8> owningExprs;
return PythonStmt(::Return(makeCExprs(owningExprs, returns)));
});
#define DEFINE_PYBIND_BINARY_OP(PYTHON_NAME, C_NAME) \
m.def(PYTHON_NAME, [](PythonExpr e1, PythonExpr e2) { \
return PythonExpr(::C_NAME(e1, e2)); \
});
DEFINE_PYBIND_BINARY_OP("Add", Add);
DEFINE_PYBIND_BINARY_OP("Mul", Mul);
DEFINE_PYBIND_BINARY_OP("Sub", Sub);
// DEFINE_PYBIND_BINARY_OP("Div", Div);
DEFINE_PYBIND_BINARY_OP("LT", LT);
DEFINE_PYBIND_BINARY_OP("LE", LE);
DEFINE_PYBIND_BINARY_OP("GT", GT);
DEFINE_PYBIND_BINARY_OP("GE", GE);
DEFINE_PYBIND_BINARY_OP("EQ", EQ);
DEFINE_PYBIND_BINARY_OP("NE", NE);
DEFINE_PYBIND_BINARY_OP("And", And);
DEFINE_PYBIND_BINARY_OP("Or", Or);
#undef DEFINE_PYBIND_BINARY_OP
#define DEFINE_PYBIND_UNARY_OP(PYTHON_NAME, C_NAME) \
m.def(PYTHON_NAME, [](PythonExpr e1) { return PythonExpr(::C_NAME(e1)); });
DEFINE_PYBIND_UNARY_OP("Negate", Negate);
#undef DEFINE_PYBIND_UNARY_OP
py::class_<PythonFunction>(m, "Function",
"Wrapping class for mlir::Function.")
.def(py::init<PythonFunction>())
.def("__str__", &PythonFunction::str);
py::class_<PythonBlock>(m, "StmtBlock",
"Wrapping class for mlir::edsc::StmtBlock")
.def(py::init<PythonBlock>())
.def("__str__", &PythonBlock::str);
py::class_<PythonType>(m, "Type", "Wrapping class for mlir::Type.")
.def(py::init<PythonType>())
.def("__str__", &PythonType::str);
py::class_<PythonMLIRModule>(
m, "MLIRModule",
"An MLIRModule is the abstraction that owns the allocations to support "
"compilation of a single mlir::Module into an ExecutionEngine backed by "
"the LLVM ORC JIT. A typical flow consists in creating an MLIRModule, "
"adding functions, compiling the module to obtain an ExecutionEngine on "
"which named functions may be called. For now the only means to retrieve "
"the ExecutionEngine is by calling `get_engine_address`. This mode of "
"execution is limited to passing the pointer to C++ where the function "
"is called. Extending the API to allow calling JIT compiled functions "
"directly require integration with a tensor library (e.g. numpy). This "
"is left as the prerogative of libraries and frameworks for now.")
.def(py::init<>())
.def("make_function", &PythonMLIRModule::makeFunction,
"Creates a new mlir::Function in the current mlir::Module.")
.def(
"make_scalar_type",
[](PythonMLIRModule &instance, const std::string &type,
unsigned bitwidth) {
return instance.makeScalarType(type, bitwidth);
},
py::arg("type"), py::arg("bitwidth") = 0,
"Returns a scalar mlir::Type using the following convention:\n"
" - makeScalarType(c, \"bf16\") return an "
"`mlir::FloatType::getBF16`\n"
" - makeScalarType(c, \"f16\") return an `mlir::FloatType::getF16`\n"
" - makeScalarType(c, \"f32\") return an `mlir::FloatType::getF32`\n"
" - makeScalarType(c, \"f64\") return an `mlir::FloatType::getF64`\n"
" - makeScalarType(c, \"index\") return an `mlir::IndexType::get`\n"
" - makeScalarType(c, \"i\", bitwidth) return an "
"`mlir::IntegerType::get(bitwidth)`\n\n"
" No other combinations are currently supported.")
.def("make_memref_type", &PythonMLIRModule::makeMemRefType,
"Returns an mlir::MemRefType of an elemental scalar. -1 is used to "
"denote symbolic dimensions in the resulting memref shape.")
.def("make_index_type", &PythonMLIRModule::makeIndexType,
"Returns an mlir::IndexType")
.def("compile", &PythonMLIRModule::compile,
"Compiles the mlir::Module to LLVMIR a creates new opaque "
"ExecutionEngine backed by the ORC JIT.")
.def("get_ir", &PythonMLIRModule::getIR,
"Returns a dump of the MLIR representation of the module. This is "
"used for serde to support out-of-process execution as well as "
"debugging purposes.")
.def("get_engine_address", &PythonMLIRModule::getEngineAddress,
"Returns the address of the compiled ExecutionEngine. This is used "
"for in-process execution.");
py::class_<ContextManager>(
m, "ContextManager",
"An EDSC context manager is the memory arena containing all the EDSC "
"allocations.\nUsage:\n\n"
"with E.ContextManager() as _:\n i = E.Expr(E.Bindable())\n ...")
.def(py::init<>())
.def("__enter__", &ContextManager::enter)
.def("__exit__", &ContextManager::exit);
py::class_<MLIRFunctionEmitter>(
m, "MLIRFunctionEmitter",
"An MLIRFunctionEmitter is used to fill an empty function body. This is "
"a staged process:\n"
" 1. create or retrieve an mlir::Function `f` with an empty body;\n"
" 2. make an `MLIRFunctionEmitter(f)` to build the current function;\n"
" 3. create leaf Expr that are either Bindable or already Expr that are"
" bound to constants and function arguments by using methods of "
" `MLIRFunctionEmitter`;\n"
" 4. build the function body using Expr, Indexed and Stmt;\n"
" 5. emit the MLIR to implement the function body.")
.def(py::init<PythonFunction>())
.def("bind_constant_bf16", &MLIRFunctionEmitter::bindConstantBF16)
.def("bind_constant_f16", &MLIRFunctionEmitter::bindConstantF16)
.def("bind_constant_f32", &MLIRFunctionEmitter::bindConstantF32)
.def("bind_constant_f64", &MLIRFunctionEmitter::bindConstantF64)
.def("bind_constant_int", &MLIRFunctionEmitter::bindConstantInt)
.def("bind_constant_index", &MLIRFunctionEmitter::bindConstantIndex)
.def("bind_function_argument", &MLIRFunctionEmitter::bindFunctionArgument,
"Returns an Expr that has been bound to a positional argument in "
"the current Function.")
.def("bind_function_arguments",
&MLIRFunctionEmitter::bindFunctionArguments,
"Returns a list of Expr where each Expr has been bound to the "
"corresponding positional argument in the current Function.")
.def("bind_memref_shape", &MLIRFunctionEmitter::bindMemRefShape,
"Returns a list of Expr where each Expr has been bound to the "
"corresponding dimension of the memref.")
.def("bind_memref_view", &MLIRFunctionEmitter::bindMemRefView,
"Returns three lists (lower bound, upper bound and step) of Expr "
"where each triplet of Expr has been bound to the minimal offset, "
"extent and stride of the corresponding dimension of the memref.")
.def("bind_indexed_shape", &MLIRFunctionEmitter::bindIndexedMemRefShape,
"Same as bind_memref_shape but returns a list of `Indexed` that "
"support load and store operations")
.def("bind_indexed_view", &MLIRFunctionEmitter::bindIndexedMemRefView,
"Same as bind_memref_view but returns lists of `Indexed` that "
"support load and store operations")
.def("emit", &MLIRFunctionEmitter::emit,
"Emits the MLIR for the EDSC expressions and statements in the "
"current function body.")
.def("emit", &MLIRFunctionEmitter::emitBlock,
"Emits the MLIR for the EDSC statements into a new block")
.def("emit_inplace", &MLIRFunctionEmitter::emitBlockBody,
"Emits the MLIR for the EDSC statements contained in a EDSC block "
"into the current function body without creating a new block");
py::class_<PythonExpr>(m, "Expr", "Wrapping class for mlir::edsc::Expr")
.def(py::init<PythonBindable>())
.def("__add__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Add(e1, e2)); })
.def("__sub__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Sub(e1, e2)); })
.def("__mul__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Mul(e1, e2)); })
// .def("__div__", [](PythonExpr e1, PythonExpr e2) { return
// PythonExpr(::Div(e1, e2)); })
.def("__lt__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::LT(e1, e2)); })
.def("__le__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::LE(e1, e2)); })
.def("__gt__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::GT(e1, e2)); })
.def("__ge__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::GE(e1, e2)); })
.def("__eq__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::EQ(e1, e2)); })
.def("__ne__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::NE(e1, e2)); })
.def("__and__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::And(e1, e2)); })
.def("__or__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Or(e1, e2)); })
.def("__invert__", [](PythonExpr e) { return PythonExpr(::Negate(e)); })
.def("__str__", &PythonExpr::str,
R"DOC(Returns the string value for the Expr)DOC");
py::class_<PythonBindable>(
m, "Bindable",
"Wrapping class for mlir::edsc::Bindable.\nA Bindable is a special Expr "
"that can be bound manually to specific MLIR SSA Values.")
.def(py::init<PythonType>())
.def("__str__", &PythonBindable::str);
py::class_<PythonStmt>(m, "Stmt", "Wrapping class for mlir::edsc::Stmt.")
.def(py::init<PythonExpr>())
.def("__str__", &PythonStmt::str,
R"DOC(Returns the string value for the Expr)DOC");
py::class_<PythonIndexed>(
m, "Indexed",
"Wrapping class for mlir::edsc::Indexed.\nAn Indexed is a wrapper class "
"that support load and store operations.")
.def(py::init<PythonExpr>(), R"DOC(Build from existing Expr)DOC")
.def(py::init<PythonBindable>(), R"DOC(Build from existing Bindable)DOC")
.def(
"load",
[](PythonIndexed &instance, const py::list &indices) {
SmallVector<edsc_expr_t, 8> owning;
return PythonExpr(Load(instance, makeCExprs(owning, indices)));
},
R"DOC(Returns an Expr that loads from an Indexed)DOC")
.def(
"store",
[](PythonIndexed &instance, const py::list &indices,
PythonExpr value) {
SmallVector<edsc_expr_t, 8> owning;
return PythonStmt(
Store(value, instance, makeCExprs(owning, indices)));
},
R"DOC(Returns the Stmt that stores into an Indexed)DOC");
}
} // namespace python
} // namespace edsc
} // namespace mlir