Use SymbolUserOpInterface for verification instead
WhileOp can't query other functions during general verification and should instead be doing so by using SymbolUserOpInterface.
PiperOrigin-RevId: 360809385
Change-Id: I7c331f7a38297220c42bc0071574ef5d6b75473e
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index abcca67..379c9c4 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -44,6 +44,7 @@
"ir/tfrt_ops.td",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
+ "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index 99de2cb..4d48d61 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -35,6 +35,7 @@
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
+include "mlir/IR/SymbolInterfaces.td"
class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
let results = (outs
@@ -70,7 +71,7 @@
}];
}
-def TF_CaseOp : TF_Op<"Case", []> {
+def TF_CaseOp : TF_Op<"Case", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = [{
An n-way switch statement which calls a single branch function.
}];
@@ -295,7 +296,7 @@
TF_DerivedResultTypeAttr T = TF_DerivedResultTypeAttr<0>;
}
-def TF_IfOp : TF_Op<"If", []> {
+def TF_IfOp : TF_Op<"If", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "output = cond ? then_branch(input) : else_branch(input)";
let description = [{
@@ -334,10 +335,6 @@
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
- let verifier = [{
- return Verify(*this);
- }];
-
let hasCanonicalizer = 1;
let extraClassDeclaration = [{
@@ -665,7 +662,7 @@
let verifier = [{ return VerifyPartitionedCall(*this); }];
}
-def TF_WhileOp : TF_Op<"While", []> {
+def TF_WhileOp : TF_Op<"While", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = [{
output = input; While (Cond(output)) { output = Body(output) }
}];
@@ -718,10 +715,6 @@
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
- let verifier = [{
- return Verify(*this);
- }];
-
let extraClassDeclaration = [{
// Get the condition function.
FuncOp cond_function() {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index 281d3da..92a186b 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -761,7 +761,8 @@
}
static LogicalResult VerifyCaseOrIfOpBranchFunctions(
- Operation *op, ArrayRef<Attribute> branches,
+ SymbolTableCollection &symbol_table, Operation *op,
+ ArrayRef<Attribute> branches,
llvm::function_ref<std::string(unsigned branch_index)> branch_name) {
SmallVector<FunctionType, 2> branch_types;
branch_types.reserve(branches.size());
@@ -772,7 +773,7 @@
TypeRangeWithDesc result{op->getResultTypes(), "result"};
for (auto branch : llvm::enumerate(branches)) {
- auto branch_func = SymbolTable::lookupNearestSymbolFrom<FuncOp>(
+ auto branch_func = symbol_table.lookupNearestSymbolFrom<FuncOp>(
op, branch.value().cast<SymbolRefAttr>());
if (!branch_func)
return op->emitOpError()
@@ -816,12 +817,17 @@
}
static LogicalResult Verify(CaseOp op) {
- if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure();
+ return VerifyCaseOpBase(op, op.branch_index());
+}
+
+LogicalResult CaseOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
auto branch_name = [](unsigned index) {
return llvm::formatv("branch #{0}", index).str();
};
- return VerifyCaseOrIfOpBranchFunctions(op, op.branches().getValue(),
- branch_name);
+ // TODO(jpienaar): Remove.
+ if (failed(CaseOpAdaptor(*this).verify(getLoc()))) return failure();
+ return VerifyCaseOrIfOpBranchFunctions(symbol_table, *this,
+ branches().getValue(), branch_name);
}
//===----------------------------------------------------------------------===//
@@ -2459,12 +2465,14 @@
// IfOp
//===----------------------------------------------------------------------===//
-static LogicalResult Verify(IfOp op) {
+LogicalResult IfOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
auto branch_name = [](unsigned index) -> std::string {
return index == 0 ? "'then_branch'" : "'else_branch'";
};
+ // TODO(jpienaar): Remove.
+ if (failed(IfOpAdaptor(*this).verify(getLoc()))) return failure();
return VerifyCaseOrIfOpBranchFunctions(
- op, {op.then_branchAttr(), op.else_branchAttr()}, branch_name);
+ symbol_table, *this, {then_branchAttr(), else_branchAttr()}, branch_name);
}
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index 6aa3e20..5377d04 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -2923,16 +2923,17 @@
return success();
}
-static LogicalResult Verify(WhileOp op) {
- auto cond_fn = op.cond_function();
- auto body_fn = op.body_function();
+LogicalResult WhileOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
+ // TODO(jpienaar): Remove.
+ if (failed(WhileOpAdaptor(*this).verify(getLoc()))) return failure();
+
+ auto cond_fn = symbol_table.lookupNearestSymbolFrom<FuncOp>(*this, cond());
+ auto body_fn = symbol_table.lookupNearestSymbolFrom<FuncOp>(*this, body());
if (!cond_fn) {
- return op.emitOpError("cond refers to an undefined function : ")
- << op.cond();
+ return emitOpError("cond refers to an undefined function : ") << cond();
}
if (!body_fn) {
- return op.emitOpError("body refers to an undefined function : ")
- << op.body();
+ return emitOpError("body refers to an undefined function : ") << body();
}
auto cond_fn_type = cond_fn.getType();
@@ -2940,14 +2941,12 @@
// Verify that the cond function has exactly one result.
if (cond_fn_type.getNumResults() != 1)
- return op.emitOpError("requires cond function to have exactly one result");
+ return emitOpError("requires cond function to have exactly one result");
- if (failed(VerifyWhileTypes(op, /*cond_input=*/cond_fn_type.getInputs(),
- /*body_input=*/body_fn_type.getInputs(),
- /*body_result=*/body_fn_type.getResults(),
- op.shape_invariant())))
- return failure();
- return success();
+ return VerifyWhileTypes(*this, /*cond_input=*/cond_fn_type.getInputs(),
+ /*body_input=*/body_fn_type.getInputs(),
+ /*body_result=*/body_fn_type.getResults(),
+ shape_invariant());
}
//===----------------------------------------------------------------------===//