Add support for saving and restoring the insertion point of a FuncBuilder. This also updates the edsc::ScopedContext to use a single builder that saves/restores insertion points. This is necessary for using edscs within RewritePatterns.

--

PiperOrigin-RevId: 248812645
diff --git a/bindings/python/pybind.cpp b/bindings/python/pybind.cpp
index 720f381..8f0cd63 100644
--- a/bindings/python/pybind.cpp
+++ b/bindings/python/pybind.cpp
@@ -235,18 +235,22 @@
 
   PythonFunction enter() {
     assert(function.function && "function is not set up");
-    context = new mlir::edsc::ScopedContext(
-        static_cast<mlir::Function *>(function.function));
+    auto *mlirFunc = static_cast<mlir::Function *>(function.function);
+    contextBuilder.emplace(mlirFunc);
+    context =
+        new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc->getLoc());
     return function;
   }
 
   void exit(py::object, py::object, py::object) {
     delete context;
     context = nullptr;
+    contextBuilder.reset();
   }
 
   PythonFunction function;
   mlir::edsc::ScopedContext *context;
+  llvm::Optional<FuncBuilder> contextBuilder;
 };
 
 PythonFunctionContext PythonMLIRModule::makeFunctionContext(
diff --git a/examples/Linalg/Linalg2/Example.cpp b/examples/Linalg/Linalg2/Example.cpp
index a8f9e9a..0de8a90 100644
--- a/examples/Linalg/Linalg2/Example.cpp
+++ b/examples/Linalg/Linalg2/Example.cpp
@@ -39,7 +39,8 @@
   mlir::Function *f =
       makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {});
 
-  ScopedContext scope(f);
+  FuncBuilder builder(f);
+  ScopedContext scope(builder, f->getLoc());
 
   // clang-format off
   ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)),
@@ -77,7 +78,8 @@
   mlir::Function *f = makeFunction(module, "linalg_ops_folded_slices",
                                    {indexType, indexType, indexType}, {});
 
-  ScopedContext scope(f);
+  FuncBuilder builder(f);
+  ScopedContext scope(builder, f->getLoc());
 
   // clang-format off
   ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)),
diff --git a/examples/Linalg/Linalg2/lib/Transforms.cpp b/examples/Linalg/Linalg2/lib/Transforms.cpp
index d78d6aa..4523830 100644
--- a/examples/Linalg/Linalg2/lib/Transforms.cpp
+++ b/examples/Linalg/Linalg2/lib/Transforms.cpp
@@ -101,8 +101,8 @@
 }
 
 ViewOp linalg::emitAndReturnFullyComposedView(Value *v) {
-  ScopedContext scope(FuncBuilder(v->getDefiningOp()),
-                      v->getDefiningOp()->getLoc());
+  FuncBuilder builder(v->getDefiningOp());
+  ScopedContext scope(builder, v->getDefiningOp()->getLoc());
   assert(v->getType().isa<ViewType>() && "must be a ViewType");
   auto *memRef = getViewSupportingMemRef(v);
   auto chain = getViewChain(v);
diff --git a/examples/Linalg/Linalg3/Conversion.cpp b/examples/Linalg/Linalg3/Conversion.cpp
index ba7b7eb..0d7b22b 100644
--- a/examples/Linalg/Linalg3/Conversion.cpp
+++ b/examples/Linalg/Linalg3/Conversion.cpp
@@ -44,7 +44,8 @@
       module, name,
       {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
 
-  ScopedContext scope(f);
+  FuncBuilder builder(f);
+  ScopedContext scope(builder, f->getLoc());
   // clang-format off
   ValueHandle
     M = dim(f->getArgument(0), 0),
diff --git a/examples/Linalg/Linalg3/Example.cpp b/examples/Linalg/Linalg3/Example.cpp
index 32b5c9e..69717e8 100644
--- a/examples/Linalg/Linalg3/Example.cpp
+++ b/examples/Linalg/Linalg3/Example.cpp
@@ -41,7 +41,8 @@
       module, name,
       {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
 
-  ScopedContext scope(f);
+  mlir::FuncBuilder builder(f);
+  ScopedContext scope(builder, f->getLoc());
   // clang-format off
   ValueHandle
     M = dim(f->getArgument(0), 0),
diff --git a/examples/Linalg/Linalg3/Execution.cpp b/examples/Linalg/Linalg3/Execution.cpp
index 9ea3a69..a3cd7f7 100644
--- a/examples/Linalg/Linalg3/Execution.cpp
+++ b/examples/Linalg/Linalg3/Execution.cpp
@@ -44,7 +44,8 @@
       module, name,
       {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
 
-  ScopedContext scope(f);
+  mlir::FuncBuilder builder(f);
+  ScopedContext scope(builder, f->getLoc());
   // clang-format off
   ValueHandle
     M = dim(f->getArgument(0), 0),
diff --git a/examples/Linalg/Linalg3/lib/TensorOps.cpp b/examples/Linalg/Linalg3/lib/TensorOps.cpp
index 2209e9d..5d59b82 100644
--- a/examples/Linalg/Linalg3/lib/TensorOps.cpp
+++ b/examples/Linalg/Linalg3/lib/TensorOps.cpp
@@ -62,8 +62,10 @@
   using edsc::op::operator*;
   using edsc::op::operator==;
   using edsc::intrinsics::select;
-  ScopedContext scope( // account for affine.terminator in loop.
-      FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
+
+  // Account for affine.terminator in loop.
+  FuncBuilder builder(body, std::prev(body->end(), 1));
+  ScopedContext scope(builder, innermostLoop.getLoc());
   FloatType fTy = getOperand(0)
                       ->getType()
                       .cast<ViewType>()
@@ -106,7 +108,8 @@
   assert(
       llvm::isa_and_nonnull<RangeOp>(indexingPosPair.first->getDefiningOp()));
   // clang-format off
-  ScopedContext scope(FuncBuilder(op), op->getLoc());
+  FuncBuilder builder(op);
+  ScopedContext scope(builder, op->getLoc());
   IndexHandle i;
   using linalg::common::LoopNestRangeBuilder;
   LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
@@ -132,8 +135,9 @@
   using edsc::op::operator*;
   using edsc::op::operator==;
   using edsc::intrinsics::select;
-  ScopedContext scope( // account for affine.terminator in loop.
-      FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
+  // Account for affine.terminator in loop.
+  FuncBuilder builder(body, std::prev(body->end(), 1));
+  ScopedContext scope(builder, innermostLoop.getLoc());
   FloatType fTy = getOperand(0)
                       ->getType()
                       .cast<ViewType>()
@@ -181,7 +185,8 @@
       llvm::isa_and_nonnull<RangeOp>(indexingPosPair.first->getDefiningOp()));
   using linalg::common::LoopNestRangeBuilder;
   // clang-format off
-  ScopedContext scope(FuncBuilder(op), op->getLoc());
+  FuncBuilder builder(op);
+  ScopedContext scope(builder, op->getLoc());
   IndexHandle j;
   LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
     [&j, &vA, &vB, &vC]() {
@@ -206,8 +211,9 @@
   using edsc::op::operator*;
   using edsc::op::operator==;
   using edsc::intrinsics::select;
-  ScopedContext scope( // account for affine.terminator in loop.
-      FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
+  // Account for affine.terminator in loop.
+  FuncBuilder builder(body, std::prev(body->end(), 1));
+  ScopedContext scope(builder, innermostLoop.getLoc());
   FloatType fTy = getOperand(0)
                       ->getType()
                       .cast<ViewType>()
diff --git a/examples/Linalg/Linalg3/lib/Transforms.cpp b/examples/Linalg/Linalg3/lib/Transforms.cpp
index 6309300..de19200 100644
--- a/examples/Linalg/Linalg3/lib/Transforms.cpp
+++ b/examples/Linalg/Linalg3/lib/Transforms.cpp
@@ -162,8 +162,8 @@
 template <class ContractionOp>
 static SmallVector<mlir::AffineForOp, 4>
 writeContractionAsLoops(ContractionOp contraction) {
-  ScopedContext scope(FuncBuilder(contraction.getOperation()),
-                      contraction.getLoc());
+  FuncBuilder builder(contraction.getOperation());
+  ScopedContext scope(builder, contraction.getLoc());
   auto allRanges = getRanges(contraction);
   auto loopRanges =
       makeGenericLoopRanges(operandRangesToLoopsMap(contraction), allRanges);
@@ -279,7 +279,8 @@
   SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp());
   ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
                       : cast<ViewOp>(load.getView()->getDefiningOp());
-  ScopedContext scope(FuncBuilder(load), load.getLoc());
+  FuncBuilder builder(load);
+  ScopedContext scope(builder, load.getLoc());
   auto *memRef = view.getSupportingMemRef();
   auto operands = emitAndReturnLoadStoreOperands(load, view);
   rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, memRef, operands);
@@ -294,7 +295,8 @@
   SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp());
   ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
                       : cast<ViewOp>(store.getView()->getDefiningOp());
-  ScopedContext scope(FuncBuilder(store), store.getLoc());
+  FuncBuilder builder(store);
+  ScopedContext scope(builder, store.getLoc());
   auto *valueToStore = store.getValueToStore();
   auto *memRef = view.getSupportingMemRef();
   auto operands = emitAndReturnLoadStoreOperands(store, view);
diff --git a/examples/Linalg/Linalg4/Example.cpp b/examples/Linalg/Linalg4/Example.cpp
index cdd87d7..bb32758 100644
--- a/examples/Linalg/Linalg4/Example.cpp
+++ b/examples/Linalg/Linalg4/Example.cpp
@@ -41,7 +41,9 @@
       module, name,
       {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
 
-  ScopedContext scope(f);
+  FuncBuilder builder(f);
+  ScopedContext scope(builder, f->getLoc());
+
   // clang-format off
   ValueHandle
     M = dim(f->getArgument(0), 0),
diff --git a/examples/Linalg/Linalg4/lib/Transforms.cpp b/examples/Linalg/Linalg4/lib/Transforms.cpp
index 6771257..0e34189 100644
--- a/examples/Linalg/Linalg4/lib/Transforms.cpp
+++ b/examples/Linalg/Linalg4/lib/Transforms.cpp
@@ -149,7 +149,8 @@
          contraction.getNumParallelDims() + contraction.getNumReductionDims());
 
   auto *op = static_cast<ConcreteOp *>(&contraction);
-  ScopedContext scope(mlir::FuncBuilder(op->getOperation()), op->getLoc());
+  mlir::FuncBuilder builder(op->getOperation());
+  ScopedContext scope(builder, op->getLoc());
   SmallVector<IndexHandle, 4> ivs(tileSizes.size());
   auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
 
diff --git a/include/mlir/EDSC/Builders.h b/include/mlir/EDSC/Builders.h
index 39302f6..124e637 100644
--- a/include/mlir/EDSC/Builders.h
+++ b/include/mlir/EDSC/Builders.h
@@ -50,9 +50,13 @@
 /// setting and restoring of insertion points.
 class ScopedContext {
 public:
-  /// Sets location to fun->getLoc() in case the provided Loction* is null.
-  ScopedContext(Function *fun, Location *loc = nullptr);
-  ScopedContext(FuncBuilder builder, Location location);
+  ScopedContext(FuncBuilder &builder, Location location);
+
+  /// Sets the insertion point of the builder to 'newInsertPt' for the duration
+  /// of the scope. The existing insertion point of the builder is restored on
+  /// destruction.
+  ScopedContext(FuncBuilder &builder, FuncBuilder::InsertPoint newInsertPt,
+                Location location);
   ~ScopedContext();
 
   static MLIRContext *getContext();
@@ -70,8 +74,10 @@
 
   static ScopedContext *&getCurrentScopedContext();
 
-  /// Current FuncBuilder.
-  FuncBuilder builder;
+  /// Top level FuncBuilder.
+  FuncBuilder &builder;
+  /// The previous insertion point of the builder.
+  llvm::Optional<FuncBuilder::InsertPoint> prevBuilderInsertPoint;
   /// Current location.
   Location location;
   /// Parent context we return into.
@@ -115,9 +121,10 @@
   /// Step back "prev" times from the end of the block to set up the insertion
   /// point, which is useful for non-empty blocks.
   void enter(mlir::Block *block, int prev = 0) {
-    bodyScope =
-        new ScopedContext(FuncBuilder(block, std::prev(block->end(), prev)),
-                          ScopedContext::getLocation());
+    bodyScope = new ScopedContext(
+        *ScopedContext::getBuilder(),
+        FuncBuilder::InsertPoint(block, std::prev(block->end(), prev)),
+        ScopedContext::getLocation());
     bodyScope->nestedBuilder = this;
   }
 
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index ca12f39..b8c3819 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -215,6 +215,27 @@
   /// Return the function this builder is referring to.
   Function *getFunction() const { return function; }
 
+  /// This class represents a saved insertion point.
+  class InsertPoint {
+  public:
+    /// Creates a new insertion point which doesn't point to anything.
+    InsertPoint() = default;
+
+    /// Creates a new insertion point at the given location.
+    InsertPoint(Block *insertBlock, Block::iterator insertPt)
+        : block(insertBlock), point(insertPt) {}
+
+    /// Returns true if this insert point is set.
+    bool isSet() const { return (block != nullptr); }
+
+    Block *getBlock() const { return block; }
+    Block::iterator getPoint() const { return point; }
+
+  private:
+    Block *block = nullptr;
+    Block::iterator point;
+  };
+
   /// Reset the insertion point to no location.  Creating an operation without a
   /// set insertion point is an error, but this can still be useful when the
   /// current insertion point a builder refers to is being removed.
@@ -223,6 +244,19 @@
     insertPoint = Block::iterator();
   }
 
+  /// Return a saved insertion point.
+  InsertPoint saveInsertionPoint() const {
+    return InsertPoint(getInsertionBlock(), getInsertionPoint());
+  }
+
+  /// Restore the insert point to a previously saved point.
+  void restoreInsertionPoint(InsertPoint ip) {
+    if (ip.isSet())
+      setInsertionPoint(ip.getBlock(), ip.getPoint());
+    else
+      clearInsertionPoint();
+  }
+
   /// Set the insertion point to the specified location.
   void setInsertionPoint(Block *block, Block::iterator insertPoint) {
     // TODO: check that insertPoint is in this rather than some other block.
diff --git a/lib/EDSC/Builders.cpp b/lib/EDSC/Builders.cpp
index fa9f739..087b819 100644
--- a/lib/EDSC/Builders.cpp
+++ b/lib/EDSC/Builders.cpp
@@ -24,23 +24,33 @@
 using namespace mlir;
 using namespace mlir::edsc;
 
-mlir::edsc::ScopedContext::ScopedContext(Function *fun, Location *loc)
-    : builder(fun), location(loc ? *loc : fun->getLoc()),
-      enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
-      nestedBuilder(nullptr) {
-  getCurrentScopedContext() = this;
-}
-
-mlir::edsc::ScopedContext::ScopedContext(FuncBuilder builder, Location location)
+mlir::edsc::ScopedContext::ScopedContext(FuncBuilder &builder,
+                                         Location location)
     : builder(builder), location(location),
       enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
       nestedBuilder(nullptr) {
   getCurrentScopedContext() = this;
 }
 
+/// Sets the insertion point of the builder to 'newInsertPt' for the duration
+/// of the scope. The existing insertion point of the builder is restored on
+/// destruction.
+mlir::edsc::ScopedContext::ScopedContext(FuncBuilder &builder,
+                                         FuncBuilder::InsertPoint newInsertPt,
+                                         Location location)
+    : builder(builder), prevBuilderInsertPoint(builder.saveInsertionPoint()),
+      location(location),
+      enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
+      nestedBuilder(nullptr) {
+  getCurrentScopedContext() = this;
+  builder.restoreInsertionPoint(newInsertPt);
+}
+
 mlir::edsc::ScopedContext::~ScopedContext() {
   assert(!nestedBuilder &&
          "Active NestedBuilder must have been exited at this point!");
+  if (prevBuilderInsertPoint)
+    builder.restoreInsertionPoint(*prevBuilderInsertPoint);
   getCurrentScopedContext() = enclosingScopedContext;
 }
 
diff --git a/lib/Linalg/Transforms/Tiling.cpp b/lib/Linalg/Transforms/Tiling.cpp
index f330cf4..b2cff60 100644
--- a/lib/Linalg/Transforms/Tiling.cpp
+++ b/lib/Linalg/Transforms/Tiling.cpp
@@ -56,7 +56,8 @@
     auto it = map.find(v);
     if (it != map.end())
       return it->second;
-    edsc::ScopedContext s(&f);
+    FuncBuilder builder(f);
+    edsc::ScopedContext s(builder, f.getLoc());
     return map.insert(std::make_pair(v, edsc::intrinsics::constant_index(v)))
         .first->getSecond();
   }
@@ -258,7 +259,8 @@
              tileSizes.size() &&
          "expected matching number of tile sizes and loops");
 
-  ScopedContext scope(FuncBuilder(op.getOperation()), op.getLoc());
+  FuncBuilder builder(op.getOperation());
+  ScopedContext scope(builder, op.getLoc());
   auto loopRanges = makeTiledLoopRanges(
       scope.getBuilder(), scope.getLocation(),
       // The flattened loopToOperandRangesMaps is expected to be an invertible
diff --git a/lib/Transforms/LowerVectorTransfers.cpp b/lib/Transforms/LowerVectorTransfers.cpp
index 53c97cf..c2c899e 100644
--- a/lib/Transforms/LowerVectorTransfers.cpp
+++ b/lib/Transforms/LowerVectorTransfers.cpp
@@ -267,7 +267,7 @@
   VectorTransferReadOp transfer = cast<VectorTransferReadOp>(op);
 
   // 1. Setup all the captures.
-  ScopedContext scope(FuncBuilder(op), transfer.getLoc());
+  ScopedContext scope(rewriter, transfer.getLoc());
   IndexedValue remote(transfer.getMemRef());
   MemRefView view(transfer.getMemRef());
   VectorView vectorView(transfer.getVector());
@@ -326,7 +326,7 @@
   VectorTransferWriteOp transfer = cast<VectorTransferWriteOp>(op);
 
   // 1. Setup all the captures.
-  ScopedContext scope(FuncBuilder(op), transfer.getLoc());
+  ScopedContext scope(rewriter, transfer.getLoc());
   IndexedValue remote(transfer.getMemRef());
   MemRefView view(transfer.getMemRef());
   ValueHandle vectorValue(transfer.getVector());
diff --git a/test/EDSC/builder-api-test.cpp b/test/EDSC/builder-api-test.cpp
index 07bba2e..92efd80 100644
--- a/test/EDSC/builder-api-test.cpp
+++ b/test/EDSC/builder-api-test.cpp
@@ -62,7 +62,8 @@
   auto f =
       makeFunction("builder_dynamic_for_func_args", {}, {indexType, indexType});
 
-  ScopedContext scope(f.get());
+  FuncBuilder builder(*f);
+  ScopedContext scope(builder, f->getLoc());
   ValueHandle i(indexType), j(indexType), lb(f->getArgument(0)),
       ub(f->getArgument(1));
   ValueHandle f7(constant_float(llvm::APFloat(7.0f), f32Type));
@@ -112,7 +113,8 @@
   auto f = makeFunction("builder_dynamic_for", {},
                         {indexType, indexType, indexType, indexType});
 
-  ScopedContext scope(f.get());
+  FuncBuilder builder(*f);
+  ScopedContext scope(builder, f->getLoc());
   ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)),
       c(f->getArgument(2)), d(f->getArgument(3));
   LoopBuilder(&i, a - b, c + d, 2)({});
@@ -134,7 +136,8 @@
   auto f = makeFunction("builder_max_min_for", {},
                         {indexType, indexType, indexType, indexType});
 
-  ScopedContext scope(f.get());
+  FuncBuilder builder(*f);
+  ScopedContext scope(builder, f->getLoc());
   ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)),
       ub1(f->getArgument(2)), ub2(f->getArgument(3));
   LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)({});
@@ -154,7 +157,8 @@
   using namespace edsc::op;
   auto f = makeFunction("builder_blocks");
 
-  ScopedContext scope(f.get());
+  FuncBuilder builder(*f);
+  ScopedContext scope(builder, f->getLoc());
   ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
       c2(ValueHandle::create<ConstantIntOp>(1234, 32));
   ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
@@ -201,7 +205,8 @@
   using namespace edsc::op;
   auto f = makeFunction("builder_blocks_eager");
 
-  ScopedContext scope(f.get());
+  FuncBuilder builder(*f);
+  ScopedContext scope(builder, f->getLoc());
   ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
       c2(ValueHandle::create<ConstantIntOp>(1234, 32));
   ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
@@ -243,7 +248,8 @@
   auto f = makeFunction("builder_cond_branch", {},
                         {IntegerType::get(1, &globalContext())});
 
-  ScopedContext scope(f.get());
+  FuncBuilder builder(*f);
+  ScopedContext scope(builder, f->getLoc());
   ValueHandle funcArg(f->getArgument(0));
   ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
       c64(ValueHandle::create<ConstantIntOp>(64, 64)),
@@ -283,7 +289,8 @@
   auto f = makeFunction("builder_cond_branch_eager", {},
                         {IntegerType::get(1, &globalContext())});
 
-  ScopedContext scope(f.get());
+  FuncBuilder builder(*f);
+  ScopedContext scope(builder, f->getLoc());
   ValueHandle funcArg(f->getArgument(0));
   ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
       c64(ValueHandle::create<ConstantIntOp>(64, 64)),
@@ -322,7 +329,8 @@
   auto f =
       makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType});
 
-  ScopedContext scope(f.get());
+  FuncBuilder builder(*f);
+  ScopedContext scope(builder, f->getLoc());
   // clang-format off
   ValueHandle f7(
       ValueHandle::create<ConstantFloatOp>(llvm::APFloat(7.0f), f32Type));
@@ -373,7 +381,8 @@
   auto indexType = IndexType::get(&globalContext());
   auto f = makeFunction("custom_ops", {}, {indexType, indexType});
 
-  ScopedContext scope(f.get());
+  FuncBuilder builder(*f);
+  ScopedContext scope(builder, f->getLoc());
   CustomOperation<ValueHandle> MY_CUSTOM_OP("my_custom_op");
   CustomOperation<OperationHandle> MY_CUSTOM_OP_0("my_custom_op_0");
   CustomOperation<OperationHandle> MY_CUSTOM_OP_2("my_custom_op_2");
@@ -410,7 +419,9 @@
   using namespace edsc::op;
   auto indexType = IndexType::get(&globalContext());
   auto f = makeFunction("insertion_in_block", {}, {indexType, indexType});
-  ScopedContext scope(f.get());
+
+  FuncBuilder builder(*f);
+  ScopedContext scope(builder, f->getLoc());
   BlockHandle b1;
   // clang-format off
   ValueHandle::create<ConstantIntOp>(0, 32);
@@ -435,7 +446,8 @@
   auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0);
   auto f = makeFunction("select_op", {}, {memrefType});
 
-  ScopedContext scope(f.get());
+  FuncBuilder builder(*f);
+  ScopedContext scope(builder, f->getLoc());
   // clang-format off
   ValueHandle zero = constant_index(0), one = constant_index(1);
   MemRefView vA(f->getArgument(0));
@@ -469,9 +481,9 @@
   auto memrefType =
       MemRefType::get({-1, -1, -1}, FloatType::getF32(&globalContext()), {}, 0);
   auto f = makeFunction("tile_2d", {}, {memrefType, memrefType, memrefType});
-  FuncBuilder builder(f.get());
 
-  ScopedContext scope(f.get());
+  FuncBuilder builder(*f);
+  ScopedContext scope(builder, f->getLoc());
   ValueHandle zero = constant_index(0);
   MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
       vC(f->getArgument(2));
@@ -544,7 +556,8 @@
   mlir::Module module(&globalContext());
   module.getFunctions().push_back(f);
 
-  ScopedContext scope(f);
+  FuncBuilder builder(f);
+  ScopedContext scope(builder, f->getLoc());
   ValueHandle zero = constant_index(0);
   MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
       vC(f->getArgument(2));
diff --git a/test/Transforms/Vectorize/lower_vector_transfers.mlir b/test/Transforms/Vectorize/lower_vector_transfers.mlir
index a5717f5..aed9b1f 100644
--- a/test/Transforms/Vectorize/lower_vector_transfers.mlir
+++ b/test/Transforms/Vectorize/lower_vector_transfers.mlir
@@ -126,8 +126,8 @@
 
 // CHECK-LABEL:func @materialize_write(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
 func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
-  // CHECK-NEXT:  %[[C0:.*]] = constant 0 : index
   // CHECK-NEXT:  %cst = constant splat<vector<5x4x3xf32>, 1.000000e+00>
+  // CHECK-NEXT:  %[[C0:.*]] = constant 0 : index
   // CHECK-NEXT:  %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
   // CHECK-NEXT:  affine.for %[[I0:.*]] = 0 to %arg0 step 3 {
   // CHECK-NEXT:    affine.for %[[I1:.*]] = 0 to %arg1 step 4 {