Use SymbolUserOpInterface for verification instead

WhileOp can't query other functions during general verification and should instead be doing so by using SymbolUserOpInterface.

PiperOrigin-RevId: 360809385
Change-Id: I7c331f7a38297220c42bc0071574ef5d6b75473e
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index abcca67..379c9c4 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -44,6 +44,7 @@
         "ir/tfrt_ops.td",
         "@llvm-project//mlir:OpBaseTdFiles",
         "@llvm-project//mlir:SideEffectTdFiles",
+        "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
         "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
         "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
         "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index 99de2cb..4d48d61 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -35,6 +35,7 @@
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
+include "mlir/IR/SymbolInterfaces.td"
 
 class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
   let results = (outs
@@ -70,7 +71,7 @@
   }];
 }
 
-def TF_CaseOp : TF_Op<"Case", []> {
+def TF_CaseOp : TF_Op<"Case", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = [{
 An n-way switch statement which calls a single branch function.
   }];
@@ -295,7 +296,7 @@
   TF_DerivedResultTypeAttr T = TF_DerivedResultTypeAttr<0>;
 }
 
-def TF_IfOp : TF_Op<"If", []> {
+def TF_IfOp : TF_Op<"If", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = "output = cond ? then_branch(input) : else_branch(input)";
 
   let description = [{
@@ -334,10 +335,6 @@
   TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
   TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
 
-  let verifier = [{
-    return Verify(*this);
-  }];
-
   let hasCanonicalizer = 1;
 
   let extraClassDeclaration = [{
@@ -665,7 +662,7 @@
   let verifier = [{ return VerifyPartitionedCall(*this); }];
 }
 
-def TF_WhileOp : TF_Op<"While", []> {
+def TF_WhileOp : TF_Op<"While", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = [{
 output = input; While (Cond(output)) { output = Body(output) }
   }];
@@ -718,10 +715,6 @@
   TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
   TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
 
-  let verifier = [{
-    return Verify(*this);
-  }];
-
   let extraClassDeclaration = [{
     // Get the condition function.
     FuncOp cond_function() {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index 281d3da..92a186b 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -761,7 +761,8 @@
 }
 
 static LogicalResult VerifyCaseOrIfOpBranchFunctions(
-    Operation *op, ArrayRef<Attribute> branches,
+    SymbolTableCollection &symbol_table, Operation *op,
+    ArrayRef<Attribute> branches,
     llvm::function_ref<std::string(unsigned branch_index)> branch_name) {
   SmallVector<FunctionType, 2> branch_types;
   branch_types.reserve(branches.size());
@@ -772,7 +773,7 @@
   TypeRangeWithDesc result{op->getResultTypes(), "result"};
 
   for (auto branch : llvm::enumerate(branches)) {
-    auto branch_func = SymbolTable::lookupNearestSymbolFrom<FuncOp>(
+    auto branch_func = symbol_table.lookupNearestSymbolFrom<FuncOp>(
         op, branch.value().cast<SymbolRefAttr>());
     if (!branch_func)
       return op->emitOpError()
@@ -816,12 +817,17 @@
 }
 
 static LogicalResult Verify(CaseOp op) {
-  if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure();
+  return VerifyCaseOpBase(op, op.branch_index());
+}
+
+LogicalResult CaseOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
   auto branch_name = [](unsigned index) {
     return llvm::formatv("branch #{0}", index).str();
   };
-  return VerifyCaseOrIfOpBranchFunctions(op, op.branches().getValue(),
-                                         branch_name);
+  // TODO(jpienaar): Remove.
+  if (failed(CaseOpAdaptor(*this).verify(getLoc()))) return failure();
+  return VerifyCaseOrIfOpBranchFunctions(symbol_table, *this,
+                                         branches().getValue(), branch_name);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2459,12 +2465,14 @@
 // IfOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult Verify(IfOp op) {
+LogicalResult IfOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
   auto branch_name = [](unsigned index) -> std::string {
     return index == 0 ? "'then_branch'" : "'else_branch'";
   };
+  // TODO(jpienaar): Remove.
+  if (failed(IfOpAdaptor(*this).verify(getLoc()))) return failure();
   return VerifyCaseOrIfOpBranchFunctions(
-      op, {op.then_branchAttr(), op.else_branchAttr()}, branch_name);
+      symbol_table, *this, {then_branchAttr(), else_branchAttr()}, branch_name);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index 6aa3e20..5377d04 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -2923,16 +2923,17 @@
   return success();
 }
 
-static LogicalResult Verify(WhileOp op) {
-  auto cond_fn = op.cond_function();
-  auto body_fn = op.body_function();
+LogicalResult WhileOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
+  // TODO(jpienaar): Remove.
+  if (failed(WhileOpAdaptor(*this).verify(getLoc()))) return failure();
+
+  auto cond_fn = symbol_table.lookupNearestSymbolFrom<FuncOp>(*this, cond());
+  auto body_fn = symbol_table.lookupNearestSymbolFrom<FuncOp>(*this, body());
   if (!cond_fn) {
-    return op.emitOpError("cond refers to an undefined function : ")
-           << op.cond();
+    return emitOpError("cond refers to an undefined function : ") << cond();
   }
   if (!body_fn) {
-    return op.emitOpError("body refers to an undefined function : ")
-           << op.body();
+    return emitOpError("body refers to an undefined function : ") << body();
   }
 
   auto cond_fn_type = cond_fn.getType();
@@ -2940,14 +2941,12 @@
 
   // Verify that the cond function has exactly one result.
   if (cond_fn_type.getNumResults() != 1)
-    return op.emitOpError("requires cond function to have exactly one result");
+    return emitOpError("requires cond function to have exactly one result");
 
-  if (failed(VerifyWhileTypes(op, /*cond_input=*/cond_fn_type.getInputs(),
-                              /*body_input=*/body_fn_type.getInputs(),
-                              /*body_result=*/body_fn_type.getResults(),
-                              op.shape_invariant())))
-    return failure();
-  return success();
+  return VerifyWhileTypes(*this, /*cond_input=*/cond_fn_type.getInputs(),
+                          /*body_input=*/body_fn_type.getInputs(),
+                          /*body_result=*/body_fn_type.getResults(),
+                          shape_invariant());
 }
 
 //===----------------------------------------------------------------------===//