Allowing replacing non-root operations in DialectConversion.

When dealing with regions, or other patterns that need to generate temporary operations, it is useful to be able to replace other operations than the root op being matched. Before this PR, these operations would still be considered for legalization meaning that the conversion would either fail, erroneously need to mark these ops as legal, or add unnecessary patterns.

PiperOrigin-RevId: 274598513
diff --git a/lib/Transforms/DialectConversion.cpp b/lib/Transforms/DialectConversion.cpp
index 4266368..0007feb 100644
--- a/lib/Transforms/DialectConversion.cpp
+++ b/lib/Transforms/DialectConversion.cpp
@@ -558,7 +558,7 @@
     case BlockActionKind::Split: {
       action.originalBlock->getOperations().splice(
           action.originalBlock->end(), action.block->getOperations());
-      action.block->dropAllUses();
+      action.block->dropAllDefinedValueUses();
       action.block->erase();
       break;
     }
@@ -990,6 +990,21 @@
     }
   }
 
+  // Check all of the replacements to ensure that the pattern actually replaced
+  // the root operation. We also mark any other replaced ops as 'dead' so that
+  // we don't try to legalize them later.
+  bool replacedRoot = false;
+  for (unsigned i = curState.numReplacements,
+                e = rewriterImpl.replacements.size();
+       i != e; ++i) {
+    Operation *replacedOp = rewriterImpl.replacements[i].op;
+    if (replacedOp == op)
+      replacedRoot = true;
+    else
+      rewriterImpl.deadOps.insert(replacedOp);
+  }
+  assert(replacedRoot && "expected pattern to replace the root operation");
+
   // Recursively legalize each of the new operations.
   for (unsigned i = curState.numCreatedOperations,
                 e = rewriterImpl.createdOps.size();
diff --git a/test/Transforms/test-legalizer-full.mlir b/test/Transforms/test-legalizer-full.mlir
index 79494c7..2cf981b 100644
--- a/test/Transforms/test-legalizer-full.mlir
+++ b/test/Transforms/test-legalizer-full.mlir
@@ -19,6 +19,13 @@
   }) : () -> ()
   "test.return"() : () -> ()
 }
+// CHECK-LABEL: func @replace_non_root_illegal_op
+func @replace_non_root_illegal_op() {
+  // CHECK-NEXT: "test.legal_op_b"
+  // CHECK-NEXT: test.return
+  %result = "test.replace_non_root"() : () -> (i32)
+  "test.return"() : () -> ()
+}
 
 // -----
 
diff --git a/test/lib/TestDialect/TestOps.td b/test/lib/TestDialect/TestOps.td
index b26360f..73769e7 100644
--- a/test/lib/TestDialect/TestOps.td
+++ b/test/lib/TestDialect/TestOps.td
@@ -808,6 +808,7 @@
 def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>;
 def LegalOpA : TEST_Op<"legal_op_a">,
   Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>;
+def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;
 
 // Check that smaller pattern depths are chosen, i.e. prioritize more direct
 // mappings.
diff --git a/test/lib/TestDialect/TestPatterns.cpp b/test/lib/TestDialect/TestPatterns.cpp
index 83814ee..2dde6a3 100644
--- a/test/lib/TestDialect/TestPatterns.cpp
+++ b/test/lib/TestDialect/TestPatterns.cpp
@@ -249,6 +249,26 @@
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Non-Root Replacement Rewrite Testing
+/// This pattern generates an invalid operation, but replaces it before the
+/// pattern is finished. This checks that we don't need to legalize the
+/// temporary op.
+struct TestNonRootReplacement : public RewritePattern {
+  TestNonRootReplacement(MLIRContext *ctx)
+      : RewritePattern("test.replace_non_root", 1, ctx) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const final {
+    auto resultType = *op->result_type_begin();
+    auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
+    auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
+
+    rewriter.replaceOp(illegalOp, {legalOp});
+    rewriter.replaceOp(op, {illegalOp});
+    return matchSuccess();
+  }
+};
 } // namespace
 
 namespace {
@@ -301,15 +321,15 @@
         .insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
                 TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType,
                 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
-                TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType>(
-            &getContext());
+                TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+                TestNonRootReplacement>(&getContext());
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);
 
     // Define the conversion target used for the test.
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
-    target.addLegalOp<LegalOpA, TestCastOp, TestValidOp>();
+    target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>();
     target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
       // Don't allow F32 operands.