[tf][tfg] Use SymbolUserOpInterface to verify TFG functional ops

Inspecting other functions in `Op::verify` can lead to threading errors. Verify referenced functions through `SymbolUserOpInterface`.

PiperOrigin-RevId: 421200088
Change-Id: I5aa28e93568f2c50067c5e2b8b96bf26390ec108
diff --git a/tensorflow/core/ir/ops.cc b/tensorflow/core/ir/ops.cc
index 1516eae..05ef8be 100644
--- a/tensorflow/core/ir/ops.cc
+++ b/tensorflow/core/ir/ops.cc
@@ -798,27 +798,27 @@
 // If-Like Ops
 
 template <typename IfLikeOp>
-static LogicalResult VerifyIfLikeOp(IfLikeOp op) {
+static LogicalResult VerifyIfLikeOp(IfLikeOp op,
+                                    SymbolTableCollection &symbol_table) {
+  if (failed(op.verify())) return failure();
   FailureOr<TypeRange> ins = VerifyOperands(op);
   if (failed(ins)) return failure();
   FailureOr<TypeRange> outs = VerifyResults(op);
   if (failed(outs)) return failure();
 
-  Operation *table_op = SymbolTable::getNearestSymbolTable(op);
-  if (!table_op) return op.emitOpError("is not contained in a symbol table");
   SymbolRefAttr then_name = op.then_branch().getName();
   SymbolRefAttr else_name = op.else_branch().getName();
   // The first operand is the condition and is not passed to the functions.
   TypeRange func_args = llvm::drop_begin(*ins);
 
-  auto then_func = dyn_cast_or_null<GraphFuncOp>(
-      SymbolTable::lookupSymbolIn(table_op, then_name));
+  auto then_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
+      op, op.then_branch().getName());
   if (then_func &&
       failed(VerifySignature(then_func, op, func_args, *outs, "then")))
     return failure();
 
-  auto else_func = dyn_cast_or_null<GraphFuncOp>(
-      SymbolTable::lookupSymbolIn(table_op, else_name));
+  auto else_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
+      op, op.else_branch().getName());
   if (else_func &&
       failed(VerifySignature(else_func, op, func_args, *outs, "else")))
     return failure();
@@ -830,21 +830,21 @@
 // Case-Like Ops
 
 template <typename CaseLikeOp>
-static LogicalResult VerifyCaseLikeOp(CaseLikeOp op) {
+static LogicalResult VerifyCaseLikeOp(CaseLikeOp op,
+                                      SymbolTableCollection &symbol_table) {
+  if (failed(op.verify())) return failure();
   FailureOr<TypeRange> ins = VerifyOperands(op);
   if (failed(ins)) return failure();
   FailureOr<TypeRange> outs = VerifyResults(op);
   if (failed(outs)) return failure();
 
-  Operation *table_op = SymbolTable::getNearestSymbolTable(op);
-  if (!table_op) return op.emitOpError("is not contained in a symbol table");
   // The first operand is the branch index and is not passed to the functions.
   TypeRange func_args = llvm::drop_begin(*ins);
 
   for (auto &it : llvm::enumerate(op.branches())) {
     SymbolRefAttr func_name = it.value().template cast<FuncAttr>().getName();
-    auto func = dyn_cast_or_null<GraphFuncOp>(
-        SymbolTable::lookupSymbolIn(table_op, func_name));
+    auto func =
+        symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(op, func_name);
     if (func && failed(VerifySignature(func, op, func_args, *outs,
                                        "branch #" + Twine(it.index()))))
       return failure();
@@ -856,26 +856,25 @@
 // While-Like Ops
 
 template <typename WhileLikeOp>
-static LogicalResult VerifyWhileLikeOp(WhileLikeOp op) {
+static LogicalResult VerifyWhileLikeOp(WhileLikeOp op,
+                                       SymbolTableCollection &symbol_table) {
+  if (failed(op.verify())) return failure();
   FailureOr<TypeRange> ins = VerifyOperands(op);
   if (failed(ins)) return failure();
   FailureOr<TypeRange> outs = VerifyResults(op);
   if (failed(outs)) return failure();
 
-  Operation *table_op = SymbolTable::getNearestSymbolTable(op);
-  if (!table_op) return op.emitOpError("is not contained in a symbol table");
-  SymbolRefAttr cond_name = op.cond().getName();
   SymbolRefAttr body_name = op.body().getName();
 
-  auto cond_func = dyn_cast_or_null<GraphFuncOp>(
-      SymbolTable::lookupSymbolIn(table_op, cond_name));
+  auto cond_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
+      op, op.cond().getName());
   auto i1_type = Builder(op.getContext()).getI1Type();
   if (cond_func &&
       failed(VerifySignature(cond_func, op, *ins, i1_type, "cond")))
     return failure();
 
-  auto body_func = dyn_cast_or_null<GraphFuncOp>(
-      SymbolTable::lookupSymbolIn(table_op, body_name));
+  auto body_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
+      op, op.body().getName());
   if (body_func && failed(VerifySignature(body_func, op, *ins, *outs, "body")))
     return failure();
 
@@ -885,20 +884,20 @@
 //===----------------------------------------------------------------------===//
 // ForOp
 
-static LogicalResult VerifyForOp(ForOp op) {
-  FailureOr<TypeRange> ins = VerifyOperands(op);
+LogicalResult ForOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
+  if (failed(verify())) return failure();
+  FailureOr<TypeRange> ins = VerifyOperands(*this);
   if (failed(ins)) return failure();
-  FailureOr<TypeRange> outs = VerifyResults(op);
+  FailureOr<TypeRange> outs = VerifyResults(*this);
   if (failed(outs)) return failure();
 
-  SymbolRefAttr body_name = op.body().getName();
-  auto body_func =
-      SymbolTable::lookupNearestSymbolFrom<GraphFuncOp>(op, body_name);
+  auto body_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
+      *this, body().getName());
   // The first three arguments are the for-loop indices, but the current loop
   // index is passed in.
   TypeRange func_args = llvm::drop_begin(*ins, /*N=*/2);
   if (body_func &&
-      failed(VerifySignature(body_func, op, func_args, *outs, "body")))
+      failed(VerifySignature(body_func, *this, func_args, *outs, "body")))
     return failure();
   return success();
 }
diff --git a/tensorflow/core/ir/ops.td b/tensorflow/core/ir/ops.td
index 8ac8ad0..5e9f0b5 100644
--- a/tensorflow/core/ir/ops.td
+++ b/tensorflow/core/ir/ops.td
@@ -27,6 +27,7 @@
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
 // TFGraph op definitions
@@ -254,7 +255,9 @@
     TFGraph_Op<mnemonic, traits>;
 
 // Base class for TFGraph if-like operations.
-class TFGraph_IfLikeOp<string mnemonic> : TFGraph_ConcreteOp<mnemonic> {
+class TFGraph_IfLikeOp<string mnemonic>
+    : TFGraph_ConcreteOp<mnemonic,
+                         [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let arguments = (ins
     // Op operands.
     I1Tensor:$cond,
@@ -269,7 +272,12 @@
   );
   let results = (outs Variadic<TFGraph_Tensor>:$outs, ControlType:$ctl);
 
-  let verifier = [{ return VerifyIfLikeOp(*this); }];
+  let extraClassDefinition = [{
+    LogicalResult $cppClass::verifySymbolUses(
+        SymbolTableCollection &symbol_table) {
+      return VerifyIfLikeOp(*this, symbol_table);
+    }
+  }];
 }
 
 def TFGraph_IfOp : TFGraph_IfLikeOp<"If"> {
@@ -283,7 +291,9 @@
 }
 
 // Base class for TFGraph case-like operations.
-class TFGraph_CaseLikeOp<string mnemonic> : TFGraph_ConcreteOp<mnemonic> {
+class TFGraph_CaseLikeOp<string mnemonic>
+    : TFGraph_ConcreteOp<mnemonic,
+                         [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let arguments = (ins
     // Op operands.
     I32Tensor:$branch_index,
@@ -297,7 +307,12 @@
 
   let results = (outs Variadic<TFGraph_Tensor>:$outs, ControlType:$ctl);
 
-  let verifier = [{ return VerifyCaseLikeOp(*this); }];
+  let extraClassDefinition = [{
+    LogicalResult $cppClass::verifySymbolUses(
+        SymbolTableCollection &symbol_table) {
+      return VerifyCaseLikeOp(*this, symbol_table);
+    }
+  }];
 }
 
 def TFGraph_CaseOp : TFGraph_CaseLikeOp<"Case"> {
@@ -311,7 +326,9 @@
 }
 
 // Base class for TFGraph while-like operations.
-class TFGraph_WhileLikeOp<string mnemonic> : TFGraph_ConcreteOp<mnemonic> {
+class TFGraph_WhileLikeOp<string mnemonic>
+    : TFGraph_ConcreteOp<mnemonic,
+                         [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let arguments = (ins
     // Op operands.
     Variadic<TFGraph_TensorOrControlType>:$args,
@@ -324,7 +341,12 @@
   );
   let results = (outs Variadic<TFGraph_Tensor>:$outs, ControlType:$ctl);
 
-  let verifier = [{ return VerifyWhileLikeOp(*this); }];
+  let extraClassDefinition = [{
+    LogicalResult $cppClass::verifySymbolUses(
+        SymbolTableCollection &symbol_table) {
+      return VerifyWhileLikeOp(*this, symbol_table);
+    }
+  }];
 }
 
 def TFGraph_WhileOp : TFGraph_WhileLikeOp<"While"> {
@@ -338,7 +360,9 @@
 }
 
 // A functional for loop operation.
-def TFGraph_ForOp : TFGraph_ConcreteOp<"For"> {
+def TFGraph_ForOp
+    : TFGraph_ConcreteOp<"For",
+                         [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = "A functional for-loop operation.";
 
   let arguments = (ins
@@ -353,8 +377,6 @@
   );
 
   let results = (outs Variadic<TFGraph_Tensor>:$outs, ControlType:$ctl);
-
-  let verifier = [{ return VerifyForOp(*this); }];
 }
 
 #endif // TFG_OPS