Use rewriter in tflite's lower_static_tensor_list
This is to correct 2 issues in regards to lowering the operation:
1. Functions were created that would never be deleted if dialect conversion failed
2. The created functions would not be removed from the SymbolTable if dialect conversion failed leading to dangling pointers that were never removed.
PiperOrigin-RevId: 351794561
Change-Id: Ibeb84bad69a6b117929e4f817a7be0f5cba7ade3
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index 43f9834..a6aa29e 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -550,25 +550,31 @@
auto func_type = FunctionType::get(rewriter.getContext(), branch_args_type,
branch_result_type);
+ // Create functions in a higher scope before restoring the insertion point.
+ // Additionally, create the SymbolTable before further modifying the module.
+ auto original_point = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPointAfter(op->getParentOfType<FuncOp>());
+ SymbolTable manager(op->getParentOfType<ModuleOp>());
+
// Constructs `then_branch`, which is executed when `if_cond` evaluates to
// true.
- FuncOp then_branch_op = FuncOp::create(loc, "cond_true", func_type);
+ auto then_branch_op = rewriter.create<FuncOp>(loc, "cond_true", func_type);
CreateCondTrueBranch(op, shape_dtype, result_type, then_branch_op,
&rewriter);
// Constructs `else_branch`, which is executed when `if_cond` evaluates to
// false.
- FuncOp else_branch_op = FuncOp::create(loc, "cond_false", func_type);
+ auto else_branch_op = rewriter.create<FuncOp>(loc, "cond_false", func_type);
CreateCondFalseBranch(loc, shape_dtype, result_type, else_branch_op,
&rewriter);
// Inserts the two blocks' names into the symbol table held by the module.
// Using SymbolTable will ensure that the inserted symbol names are
// unique.
- SymbolTable manager(op->getParentOfType<ModuleOp>());
manager.insert(then_branch_op);
manager.insert(else_branch_op);
+ rewriter.restoreInsertionPoint(original_point);
rewriter.replaceOpWithNewOp<TF::IfOp>(
op, result_type, if_cond,
/*input=*/
@@ -588,8 +594,9 @@
Type result_type, FuncOp branch_func,
ConversionPatternRewriter *rewriter) const {
auto guard = OpBuilder::InsertionGuard(*rewriter);
- Block *block = branch_func.addEntryBlock();
- rewriter->setInsertionPointToStart(block);
+ Block *block =
+ rewriter->createBlock(&branch_func.getBody(), branch_func.begin(),
+ branch_func.getType().getInputs());
auto input_shape = block->getArgument(1);
auto size_diff = block->getArgument(2);
@@ -627,8 +634,9 @@
// size, the else branch is executed.
// Slice the first 'size' rows from the input tensorlist.
auto guard = OpBuilder::InsertionGuard(*rewriter);
- Block *block = branch_func.addEntryBlock();
- rewriter->setInsertionPointToStart(block);
+ Block *block =
+ rewriter->createBlock(&branch_func.getBody(), branch_func.begin(),
+ branch_func.getType().getInputs());
Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
Value vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);