Add support for saving and restoring the insertion point of a FuncBuilder. This also updates the edsc::ScopedContext to use a single builder that saves/restores insertion points. This is necessary for using edscs within RewritePatterns.
--
PiperOrigin-RevId: 248812645
diff --git a/bindings/python/pybind.cpp b/bindings/python/pybind.cpp
index 720f381..8f0cd63 100644
--- a/bindings/python/pybind.cpp
+++ b/bindings/python/pybind.cpp
@@ -235,18 +235,22 @@
PythonFunction enter() {
assert(function.function && "function is not set up");
- context = new mlir::edsc::ScopedContext(
- static_cast<mlir::Function *>(function.function));
+ auto *mlirFunc = static_cast<mlir::Function *>(function.function);
+ contextBuilder.emplace(mlirFunc);
+ context =
+ new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc->getLoc());
return function;
}
void exit(py::object, py::object, py::object) {
delete context;
context = nullptr;
+ contextBuilder.reset();
}
PythonFunction function;
mlir::edsc::ScopedContext *context;
+ llvm::Optional<FuncBuilder> contextBuilder;
};
PythonFunctionContext PythonMLIRModule::makeFunctionContext(
diff --git a/examples/Linalg/Linalg2/Example.cpp b/examples/Linalg/Linalg2/Example.cpp
index a8f9e9a..0de8a90 100644
--- a/examples/Linalg/Linalg2/Example.cpp
+++ b/examples/Linalg/Linalg2/Example.cpp
@@ -39,7 +39,8 @@
mlir::Function *f =
makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {});
- ScopedContext scope(f);
+ FuncBuilder builder(f);
+ ScopedContext scope(builder, f->getLoc());
// clang-format off
ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)),
@@ -77,7 +78,8 @@
mlir::Function *f = makeFunction(module, "linalg_ops_folded_slices",
{indexType, indexType, indexType}, {});
- ScopedContext scope(f);
+ FuncBuilder builder(f);
+ ScopedContext scope(builder, f->getLoc());
// clang-format off
ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)),
diff --git a/examples/Linalg/Linalg2/lib/Transforms.cpp b/examples/Linalg/Linalg2/lib/Transforms.cpp
index d78d6aa..4523830 100644
--- a/examples/Linalg/Linalg2/lib/Transforms.cpp
+++ b/examples/Linalg/Linalg2/lib/Transforms.cpp
@@ -101,8 +101,8 @@
}
ViewOp linalg::emitAndReturnFullyComposedView(Value *v) {
- ScopedContext scope(FuncBuilder(v->getDefiningOp()),
- v->getDefiningOp()->getLoc());
+ FuncBuilder builder(v->getDefiningOp());
+ ScopedContext scope(builder, v->getDefiningOp()->getLoc());
assert(v->getType().isa<ViewType>() && "must be a ViewType");
auto *memRef = getViewSupportingMemRef(v);
auto chain = getViewChain(v);
diff --git a/examples/Linalg/Linalg3/Conversion.cpp b/examples/Linalg/Linalg3/Conversion.cpp
index ba7b7eb..0d7b22b 100644
--- a/examples/Linalg/Linalg3/Conversion.cpp
+++ b/examples/Linalg/Linalg3/Conversion.cpp
@@ -44,7 +44,8 @@
module, name,
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
- ScopedContext scope(f);
+ FuncBuilder builder(f);
+ ScopedContext scope(builder, f->getLoc());
// clang-format off
ValueHandle
M = dim(f->getArgument(0), 0),
diff --git a/examples/Linalg/Linalg3/Example.cpp b/examples/Linalg/Linalg3/Example.cpp
index 32b5c9e..69717e8 100644
--- a/examples/Linalg/Linalg3/Example.cpp
+++ b/examples/Linalg/Linalg3/Example.cpp
@@ -41,7 +41,8 @@
module, name,
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
- ScopedContext scope(f);
+ mlir::FuncBuilder builder(f);
+ ScopedContext scope(builder, f->getLoc());
// clang-format off
ValueHandle
M = dim(f->getArgument(0), 0),
diff --git a/examples/Linalg/Linalg3/Execution.cpp b/examples/Linalg/Linalg3/Execution.cpp
index 9ea3a69..a3cd7f7 100644
--- a/examples/Linalg/Linalg3/Execution.cpp
+++ b/examples/Linalg/Linalg3/Execution.cpp
@@ -44,7 +44,8 @@
module, name,
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
- ScopedContext scope(f);
+ mlir::FuncBuilder builder(f);
+ ScopedContext scope(builder, f->getLoc());
// clang-format off
ValueHandle
M = dim(f->getArgument(0), 0),
diff --git a/examples/Linalg/Linalg3/lib/TensorOps.cpp b/examples/Linalg/Linalg3/lib/TensorOps.cpp
index 2209e9d..5d59b82 100644
--- a/examples/Linalg/Linalg3/lib/TensorOps.cpp
+++ b/examples/Linalg/Linalg3/lib/TensorOps.cpp
@@ -62,8 +62,10 @@
using edsc::op::operator*;
using edsc::op::operator==;
using edsc::intrinsics::select;
- ScopedContext scope( // account for affine.terminator in loop.
- FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
+
+ // Account for affine.terminator in loop.
+ FuncBuilder builder(body, std::prev(body->end(), 1));
+ ScopedContext scope(builder, innermostLoop.getLoc());
FloatType fTy = getOperand(0)
->getType()
.cast<ViewType>()
@@ -106,7 +108,8 @@
assert(
llvm::isa_and_nonnull<RangeOp>(indexingPosPair.first->getDefiningOp()));
// clang-format off
- ScopedContext scope(FuncBuilder(op), op->getLoc());
+ FuncBuilder builder(op);
+ ScopedContext scope(builder, op->getLoc());
IndexHandle i;
using linalg::common::LoopNestRangeBuilder;
LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
@@ -132,8 +135,9 @@
using edsc::op::operator*;
using edsc::op::operator==;
using edsc::intrinsics::select;
- ScopedContext scope( // account for affine.terminator in loop.
- FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
+ // Account for affine.terminator in loop.
+ FuncBuilder builder(body, std::prev(body->end(), 1));
+ ScopedContext scope(builder, innermostLoop.getLoc());
FloatType fTy = getOperand(0)
->getType()
.cast<ViewType>()
@@ -181,7 +185,8 @@
llvm::isa_and_nonnull<RangeOp>(indexingPosPair.first->getDefiningOp()));
using linalg::common::LoopNestRangeBuilder;
// clang-format off
- ScopedContext scope(FuncBuilder(op), op->getLoc());
+ FuncBuilder builder(op);
+ ScopedContext scope(builder, op->getLoc());
IndexHandle j;
LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
[&j, &vA, &vB, &vC]() {
@@ -206,8 +211,9 @@
using edsc::op::operator*;
using edsc::op::operator==;
using edsc::intrinsics::select;
- ScopedContext scope( // account for affine.terminator in loop.
- FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
+ // Account for affine.terminator in loop.
+ FuncBuilder builder(body, std::prev(body->end(), 1));
+ ScopedContext scope(builder, innermostLoop.getLoc());
FloatType fTy = getOperand(0)
->getType()
.cast<ViewType>()
diff --git a/examples/Linalg/Linalg3/lib/Transforms.cpp b/examples/Linalg/Linalg3/lib/Transforms.cpp
index 6309300..de19200 100644
--- a/examples/Linalg/Linalg3/lib/Transforms.cpp
+++ b/examples/Linalg/Linalg3/lib/Transforms.cpp
@@ -162,8 +162,8 @@
template <class ContractionOp>
static SmallVector<mlir::AffineForOp, 4>
writeContractionAsLoops(ContractionOp contraction) {
- ScopedContext scope(FuncBuilder(contraction.getOperation()),
- contraction.getLoc());
+ FuncBuilder builder(contraction.getOperation());
+ ScopedContext scope(builder, contraction.getLoc());
auto allRanges = getRanges(contraction);
auto loopRanges =
makeGenericLoopRanges(operandRangesToLoopsMap(contraction), allRanges);
@@ -279,7 +279,8 @@
SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp());
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
: cast<ViewOp>(load.getView()->getDefiningOp());
- ScopedContext scope(FuncBuilder(load), load.getLoc());
+ FuncBuilder builder(load);
+ ScopedContext scope(builder, load.getLoc());
auto *memRef = view.getSupportingMemRef();
auto operands = emitAndReturnLoadStoreOperands(load, view);
rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, memRef, operands);
@@ -294,7 +295,8 @@
SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp());
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
: cast<ViewOp>(store.getView()->getDefiningOp());
- ScopedContext scope(FuncBuilder(store), store.getLoc());
+ FuncBuilder builder(store);
+ ScopedContext scope(builder, store.getLoc());
auto *valueToStore = store.getValueToStore();
auto *memRef = view.getSupportingMemRef();
auto operands = emitAndReturnLoadStoreOperands(store, view);
diff --git a/examples/Linalg/Linalg4/Example.cpp b/examples/Linalg/Linalg4/Example.cpp
index cdd87d7..bb32758 100644
--- a/examples/Linalg/Linalg4/Example.cpp
+++ b/examples/Linalg/Linalg4/Example.cpp
@@ -41,7 +41,9 @@
module, name,
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
- ScopedContext scope(f);
+ FuncBuilder builder(f);
+ ScopedContext scope(builder, f->getLoc());
+
// clang-format off
ValueHandle
M = dim(f->getArgument(0), 0),
diff --git a/examples/Linalg/Linalg4/lib/Transforms.cpp b/examples/Linalg/Linalg4/lib/Transforms.cpp
index 6771257..0e34189 100644
--- a/examples/Linalg/Linalg4/lib/Transforms.cpp
+++ b/examples/Linalg/Linalg4/lib/Transforms.cpp
@@ -149,7 +149,8 @@
contraction.getNumParallelDims() + contraction.getNumReductionDims());
auto *op = static_cast<ConcreteOp *>(&contraction);
- ScopedContext scope(mlir::FuncBuilder(op->getOperation()), op->getLoc());
+ mlir::FuncBuilder builder(op->getOperation());
+ ScopedContext scope(builder, op->getLoc());
SmallVector<IndexHandle, 4> ivs(tileSizes.size());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
diff --git a/include/mlir/EDSC/Builders.h b/include/mlir/EDSC/Builders.h
index 39302f6..124e637 100644
--- a/include/mlir/EDSC/Builders.h
+++ b/include/mlir/EDSC/Builders.h
@@ -50,9 +50,13 @@
/// setting and restoring of insertion points.
class ScopedContext {
public:
- /// Sets location to fun->getLoc() in case the provided Loction* is null.
- ScopedContext(Function *fun, Location *loc = nullptr);
- ScopedContext(FuncBuilder builder, Location location);
+ ScopedContext(FuncBuilder &builder, Location location);
+
+ /// Sets the insertion point of the builder to 'newInsertPt' for the duration
+ /// of the scope. The existing insertion point of the builder is restored on
+ /// destruction.
+ ScopedContext(FuncBuilder &builder, FuncBuilder::InsertPoint newInsertPt,
+ Location location);
~ScopedContext();
static MLIRContext *getContext();
@@ -70,8 +74,10 @@
static ScopedContext *&getCurrentScopedContext();
- /// Current FuncBuilder.
- FuncBuilder builder;
+ /// Top level FuncBuilder.
+ FuncBuilder &builder;
+ /// The previous insertion point of the builder.
+ llvm::Optional<FuncBuilder::InsertPoint> prevBuilderInsertPoint;
/// Current location.
Location location;
/// Parent context we return into.
@@ -115,9 +121,10 @@
/// Step back "prev" times from the end of the block to set up the insertion
/// point, which is useful for non-empty blocks.
void enter(mlir::Block *block, int prev = 0) {
- bodyScope =
- new ScopedContext(FuncBuilder(block, std::prev(block->end(), prev)),
- ScopedContext::getLocation());
+ bodyScope = new ScopedContext(
+ *ScopedContext::getBuilder(),
+ FuncBuilder::InsertPoint(block, std::prev(block->end(), prev)),
+ ScopedContext::getLocation());
bodyScope->nestedBuilder = this;
}
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index ca12f39..b8c3819 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -215,6 +215,27 @@
/// Return the function this builder is referring to.
Function *getFunction() const { return function; }
+ /// This class represents a saved insertion point.
+ class InsertPoint {
+ public:
+ /// Creates a new insertion point which doesn't point to anything.
+ InsertPoint() = default;
+
+ /// Creates a new insertion point at the given location.
+ InsertPoint(Block *insertBlock, Block::iterator insertPt)
+ : block(insertBlock), point(insertPt) {}
+
+ /// Returns true if this insert point is set.
+ bool isSet() const { return (block != nullptr); }
+
+ Block *getBlock() const { return block; }
+ Block::iterator getPoint() const { return point; }
+
+ private:
+ Block *block = nullptr;
+ Block::iterator point;
+ };
+
/// Reset the insertion point to no location. Creating an operation without a
/// set insertion point is an error, but this can still be useful when the
/// current insertion point a builder refers to is being removed.
@@ -223,6 +244,19 @@
insertPoint = Block::iterator();
}
+ /// Return a saved insertion point.
+ InsertPoint saveInsertionPoint() const {
+ return InsertPoint(getInsertionBlock(), getInsertionPoint());
+ }
+
+ /// Restore the insert point to a previously saved point.
+ void restoreInsertionPoint(InsertPoint ip) {
+ if (ip.isSet())
+ setInsertionPoint(ip.getBlock(), ip.getPoint());
+ else
+ clearInsertionPoint();
+ }
+
/// Set the insertion point to the specified location.
void setInsertionPoint(Block *block, Block::iterator insertPoint) {
// TODO: check that insertPoint is in this rather than some other block.
diff --git a/lib/EDSC/Builders.cpp b/lib/EDSC/Builders.cpp
index fa9f739..087b819 100644
--- a/lib/EDSC/Builders.cpp
+++ b/lib/EDSC/Builders.cpp
@@ -24,23 +24,33 @@
using namespace mlir;
using namespace mlir::edsc;
-mlir::edsc::ScopedContext::ScopedContext(Function *fun, Location *loc)
- : builder(fun), location(loc ? *loc : fun->getLoc()),
- enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
- nestedBuilder(nullptr) {
- getCurrentScopedContext() = this;
-}
-
-mlir::edsc::ScopedContext::ScopedContext(FuncBuilder builder, Location location)
+mlir::edsc::ScopedContext::ScopedContext(FuncBuilder &builder,
+ Location location)
: builder(builder), location(location),
enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
nestedBuilder(nullptr) {
getCurrentScopedContext() = this;
}
+/// Sets the insertion point of the builder to 'newInsertPt' for the duration
+/// of the scope. The existing insertion point of the builder is restored on
+/// destruction.
+mlir::edsc::ScopedContext::ScopedContext(FuncBuilder &builder,
+ FuncBuilder::InsertPoint newInsertPt,
+ Location location)
+ : builder(builder), prevBuilderInsertPoint(builder.saveInsertionPoint()),
+ location(location),
+ enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
+ nestedBuilder(nullptr) {
+ getCurrentScopedContext() = this;
+ builder.restoreInsertionPoint(newInsertPt);
+}
+
mlir::edsc::ScopedContext::~ScopedContext() {
assert(!nestedBuilder &&
"Active NestedBuilder must have been exited at this point!");
+ if (prevBuilderInsertPoint)
+ builder.restoreInsertionPoint(*prevBuilderInsertPoint);
getCurrentScopedContext() = enclosingScopedContext;
}
diff --git a/lib/Linalg/Transforms/Tiling.cpp b/lib/Linalg/Transforms/Tiling.cpp
index f330cf4..b2cff60 100644
--- a/lib/Linalg/Transforms/Tiling.cpp
+++ b/lib/Linalg/Transforms/Tiling.cpp
@@ -56,7 +56,8 @@
auto it = map.find(v);
if (it != map.end())
return it->second;
- edsc::ScopedContext s(&f);
+ FuncBuilder builder(f);
+ edsc::ScopedContext s(builder, f.getLoc());
return map.insert(std::make_pair(v, edsc::intrinsics::constant_index(v)))
.first->getSecond();
}
@@ -258,7 +259,8 @@
tileSizes.size() &&
"expected matching number of tile sizes and loops");
- ScopedContext scope(FuncBuilder(op.getOperation()), op.getLoc());
+ FuncBuilder builder(op.getOperation());
+ ScopedContext scope(builder, op.getLoc());
auto loopRanges = makeTiledLoopRanges(
scope.getBuilder(), scope.getLocation(),
// The flattened loopToOperandRangesMaps is expected to be an invertible
diff --git a/lib/Transforms/LowerVectorTransfers.cpp b/lib/Transforms/LowerVectorTransfers.cpp
index 53c97cf..c2c899e 100644
--- a/lib/Transforms/LowerVectorTransfers.cpp
+++ b/lib/Transforms/LowerVectorTransfers.cpp
@@ -267,7 +267,7 @@
VectorTransferReadOp transfer = cast<VectorTransferReadOp>(op);
// 1. Setup all the captures.
- ScopedContext scope(FuncBuilder(op), transfer.getLoc());
+ ScopedContext scope(rewriter, transfer.getLoc());
IndexedValue remote(transfer.getMemRef());
MemRefView view(transfer.getMemRef());
VectorView vectorView(transfer.getVector());
@@ -326,7 +326,7 @@
VectorTransferWriteOp transfer = cast<VectorTransferWriteOp>(op);
// 1. Setup all the captures.
- ScopedContext scope(FuncBuilder(op), transfer.getLoc());
+ ScopedContext scope(rewriter, transfer.getLoc());
IndexedValue remote(transfer.getMemRef());
MemRefView view(transfer.getMemRef());
ValueHandle vectorValue(transfer.getVector());
diff --git a/test/EDSC/builder-api-test.cpp b/test/EDSC/builder-api-test.cpp
index 07bba2e..92efd80 100644
--- a/test/EDSC/builder-api-test.cpp
+++ b/test/EDSC/builder-api-test.cpp
@@ -62,7 +62,8 @@
auto f =
makeFunction("builder_dynamic_for_func_args", {}, {indexType, indexType});
- ScopedContext scope(f.get());
+ FuncBuilder builder(*f);
+ ScopedContext scope(builder, f->getLoc());
ValueHandle i(indexType), j(indexType), lb(f->getArgument(0)),
ub(f->getArgument(1));
ValueHandle f7(constant_float(llvm::APFloat(7.0f), f32Type));
@@ -112,7 +113,8 @@
auto f = makeFunction("builder_dynamic_for", {},
{indexType, indexType, indexType, indexType});
- ScopedContext scope(f.get());
+ FuncBuilder builder(*f);
+ ScopedContext scope(builder, f->getLoc());
ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)),
c(f->getArgument(2)), d(f->getArgument(3));
LoopBuilder(&i, a - b, c + d, 2)({});
@@ -134,7 +136,8 @@
auto f = makeFunction("builder_max_min_for", {},
{indexType, indexType, indexType, indexType});
- ScopedContext scope(f.get());
+ FuncBuilder builder(*f);
+ ScopedContext scope(builder, f->getLoc());
ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)),
ub1(f->getArgument(2)), ub2(f->getArgument(3));
LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)({});
@@ -154,7 +157,8 @@
using namespace edsc::op;
auto f = makeFunction("builder_blocks");
- ScopedContext scope(f.get());
+ FuncBuilder builder(*f);
+ ScopedContext scope(builder, f->getLoc());
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
@@ -201,7 +205,8 @@
using namespace edsc::op;
auto f = makeFunction("builder_blocks_eager");
- ScopedContext scope(f.get());
+ FuncBuilder builder(*f);
+ ScopedContext scope(builder, f->getLoc());
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
@@ -243,7 +248,8 @@
auto f = makeFunction("builder_cond_branch", {},
{IntegerType::get(1, &globalContext())});
- ScopedContext scope(f.get());
+ FuncBuilder builder(*f);
+ ScopedContext scope(builder, f->getLoc());
ValueHandle funcArg(f->getArgument(0));
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
@@ -283,7 +289,8 @@
auto f = makeFunction("builder_cond_branch_eager", {},
{IntegerType::get(1, &globalContext())});
- ScopedContext scope(f.get());
+ FuncBuilder builder(*f);
+ ScopedContext scope(builder, f->getLoc());
ValueHandle funcArg(f->getArgument(0));
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
@@ -322,7 +329,8 @@
auto f =
makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType});
- ScopedContext scope(f.get());
+ FuncBuilder builder(*f);
+ ScopedContext scope(builder, f->getLoc());
// clang-format off
ValueHandle f7(
ValueHandle::create<ConstantFloatOp>(llvm::APFloat(7.0f), f32Type));
@@ -373,7 +381,8 @@
auto indexType = IndexType::get(&globalContext());
auto f = makeFunction("custom_ops", {}, {indexType, indexType});
- ScopedContext scope(f.get());
+ FuncBuilder builder(*f);
+ ScopedContext scope(builder, f->getLoc());
CustomOperation<ValueHandle> MY_CUSTOM_OP("my_custom_op");
CustomOperation<OperationHandle> MY_CUSTOM_OP_0("my_custom_op_0");
CustomOperation<OperationHandle> MY_CUSTOM_OP_2("my_custom_op_2");
@@ -410,7 +419,9 @@
using namespace edsc::op;
auto indexType = IndexType::get(&globalContext());
auto f = makeFunction("insertion_in_block", {}, {indexType, indexType});
- ScopedContext scope(f.get());
+
+ FuncBuilder builder(*f);
+ ScopedContext scope(builder, f->getLoc());
BlockHandle b1;
// clang-format off
ValueHandle::create<ConstantIntOp>(0, 32);
@@ -435,7 +446,8 @@
auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0);
auto f = makeFunction("select_op", {}, {memrefType});
- ScopedContext scope(f.get());
+ FuncBuilder builder(*f);
+ ScopedContext scope(builder, f->getLoc());
// clang-format off
ValueHandle zero = constant_index(0), one = constant_index(1);
MemRefView vA(f->getArgument(0));
@@ -469,9 +481,9 @@
auto memrefType =
MemRefType::get({-1, -1, -1}, FloatType::getF32(&globalContext()), {}, 0);
auto f = makeFunction("tile_2d", {}, {memrefType, memrefType, memrefType});
- FuncBuilder builder(f.get());
- ScopedContext scope(f.get());
+ FuncBuilder builder(*f);
+ ScopedContext scope(builder, f->getLoc());
ValueHandle zero = constant_index(0);
MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
vC(f->getArgument(2));
@@ -544,7 +556,8 @@
mlir::Module module(&globalContext());
module.getFunctions().push_back(f);
- ScopedContext scope(f);
+ FuncBuilder builder(f);
+ ScopedContext scope(builder, f->getLoc());
ValueHandle zero = constant_index(0);
MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
vC(f->getArgument(2));
diff --git a/test/Transforms/Vectorize/lower_vector_transfers.mlir b/test/Transforms/Vectorize/lower_vector_transfers.mlir
index a5717f5..aed9b1f 100644
--- a/test/Transforms/Vectorize/lower_vector_transfers.mlir
+++ b/test/Transforms/Vectorize/lower_vector_transfers.mlir
@@ -126,8 +126,8 @@
// CHECK-LABEL:func @materialize_write(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
- // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
// CHECK-NEXT: %cst = constant splat<vector<5x4x3xf32>, 1.000000e+00>
+ // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
// CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
// CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %arg0 step 3 {
// CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %arg1 step 4 {