[tf][tfg] Use SymbolUserOpInterface to verify TFG functional ops
Inspecting other functions in `Op::verify` can lead to threading errors. Verify referenced functions through `SymbolUserOpInterface`.
PiperOrigin-RevId: 421200088
Change-Id: I5aa28e93568f2c50067c5e2b8b96bf26390ec108
diff --git a/tensorflow/core/ir/ops.cc b/tensorflow/core/ir/ops.cc
index 1516eae..05ef8be 100644
--- a/tensorflow/core/ir/ops.cc
+++ b/tensorflow/core/ir/ops.cc
@@ -798,27 +798,27 @@
// If-Like Ops
template <typename IfLikeOp>
-static LogicalResult VerifyIfLikeOp(IfLikeOp op) {
+static LogicalResult VerifyIfLikeOp(IfLikeOp op,
+ SymbolTableCollection &symbol_table) {
+ if (failed(op.verify())) return failure();
FailureOr<TypeRange> ins = VerifyOperands(op);
if (failed(ins)) return failure();
FailureOr<TypeRange> outs = VerifyResults(op);
if (failed(outs)) return failure();
- Operation *table_op = SymbolTable::getNearestSymbolTable(op);
- if (!table_op) return op.emitOpError("is not contained in a symbol table");
SymbolRefAttr then_name = op.then_branch().getName();
SymbolRefAttr else_name = op.else_branch().getName();
// The first operand is the condition and is not passed to the functions.
TypeRange func_args = llvm::drop_begin(*ins);
- auto then_func = dyn_cast_or_null<GraphFuncOp>(
- SymbolTable::lookupSymbolIn(table_op, then_name));
+ auto then_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
+ op, op.then_branch().getName());
if (then_func &&
failed(VerifySignature(then_func, op, func_args, *outs, "then")))
return failure();
- auto else_func = dyn_cast_or_null<GraphFuncOp>(
- SymbolTable::lookupSymbolIn(table_op, else_name));
+ auto else_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
+ op, op.else_branch().getName());
if (else_func &&
failed(VerifySignature(else_func, op, func_args, *outs, "else")))
return failure();
@@ -830,21 +830,21 @@
// Case-Like Ops
template <typename CaseLikeOp>
-static LogicalResult VerifyCaseLikeOp(CaseLikeOp op) {
+static LogicalResult VerifyCaseLikeOp(CaseLikeOp op,
+ SymbolTableCollection &symbol_table) {
+ if (failed(op.verify())) return failure();
FailureOr<TypeRange> ins = VerifyOperands(op);
if (failed(ins)) return failure();
FailureOr<TypeRange> outs = VerifyResults(op);
if (failed(outs)) return failure();
- Operation *table_op = SymbolTable::getNearestSymbolTable(op);
- if (!table_op) return op.emitOpError("is not contained in a symbol table");
// The first operand is the branch index and is not passed to the functions.
TypeRange func_args = llvm::drop_begin(*ins);
for (auto &it : llvm::enumerate(op.branches())) {
SymbolRefAttr func_name = it.value().template cast<FuncAttr>().getName();
- auto func = dyn_cast_or_null<GraphFuncOp>(
- SymbolTable::lookupSymbolIn(table_op, func_name));
+ auto func =
+ symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(op, func_name);
if (func && failed(VerifySignature(func, op, func_args, *outs,
"branch #" + Twine(it.index()))))
return failure();
@@ -856,26 +856,25 @@
// While-Like Ops
template <typename WhileLikeOp>
-static LogicalResult VerifyWhileLikeOp(WhileLikeOp op) {
+static LogicalResult VerifyWhileLikeOp(WhileLikeOp op,
+ SymbolTableCollection &symbol_table) {
+ if (failed(op.verify())) return failure();
FailureOr<TypeRange> ins = VerifyOperands(op);
if (failed(ins)) return failure();
FailureOr<TypeRange> outs = VerifyResults(op);
if (failed(outs)) return failure();
- Operation *table_op = SymbolTable::getNearestSymbolTable(op);
- if (!table_op) return op.emitOpError("is not contained in a symbol table");
- SymbolRefAttr cond_name = op.cond().getName();
SymbolRefAttr body_name = op.body().getName();
- auto cond_func = dyn_cast_or_null<GraphFuncOp>(
- SymbolTable::lookupSymbolIn(table_op, cond_name));
+ auto cond_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
+ op, op.cond().getName());
auto i1_type = Builder(op.getContext()).getI1Type();
if (cond_func &&
failed(VerifySignature(cond_func, op, *ins, i1_type, "cond")))
return failure();
- auto body_func = dyn_cast_or_null<GraphFuncOp>(
- SymbolTable::lookupSymbolIn(table_op, body_name));
+ auto body_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
+ op, op.body().getName());
if (body_func && failed(VerifySignature(body_func, op, *ins, *outs, "body")))
return failure();
@@ -885,20 +884,20 @@
//===----------------------------------------------------------------------===//
// ForOp
-static LogicalResult VerifyForOp(ForOp op) {
- FailureOr<TypeRange> ins = VerifyOperands(op);
+LogicalResult ForOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
+ if (failed(verify())) return failure();
+ FailureOr<TypeRange> ins = VerifyOperands(*this);
if (failed(ins)) return failure();
- FailureOr<TypeRange> outs = VerifyResults(op);
+ FailureOr<TypeRange> outs = VerifyResults(*this);
if (failed(outs)) return failure();
- SymbolRefAttr body_name = op.body().getName();
- auto body_func =
- SymbolTable::lookupNearestSymbolFrom<GraphFuncOp>(op, body_name);
+ auto body_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
+ *this, body().getName());
// The first three arguments are the for-loop indices, but the current loop
// index is passed in.
TypeRange func_args = llvm::drop_begin(*ins, /*N=*/2);
if (body_func &&
- failed(VerifySignature(body_func, op, func_args, *outs, "body")))
+ failed(VerifySignature(body_func, *this, func_args, *outs, "body")))
return failure();
return success();
}
diff --git a/tensorflow/core/ir/ops.td b/tensorflow/core/ir/ops.td
index 8ac8ad0..5e9f0b5 100644
--- a/tensorflow/core/ir/ops.td
+++ b/tensorflow/core/ir/ops.td
@@ -27,6 +27,7 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
// TFGraph op definitions
@@ -254,7 +255,9 @@
TFGraph_Op<mnemonic, traits>;
// Base class for TFGraph if-like operations.
-class TFGraph_IfLikeOp<string mnemonic> : TFGraph_ConcreteOp<mnemonic> {
+class TFGraph_IfLikeOp<string mnemonic>
+ : TFGraph_ConcreteOp<mnemonic,
+ [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let arguments = (ins
// Op operands.
I1Tensor:$cond,
@@ -269,7 +272,12 @@
);
let results = (outs Variadic<TFGraph_Tensor>:$outs, ControlType:$ctl);
- let verifier = [{ return VerifyIfLikeOp(*this); }];
+ let extraClassDefinition = [{
+ LogicalResult $cppClass::verifySymbolUses(
+ SymbolTableCollection &symbol_table) {
+ return VerifyIfLikeOp(*this, symbol_table);
+ }
+ }];
}
def TFGraph_IfOp : TFGraph_IfLikeOp<"If"> {
@@ -283,7 +291,9 @@
}
// Base class for TFGraph case-like operations.
-class TFGraph_CaseLikeOp<string mnemonic> : TFGraph_ConcreteOp<mnemonic> {
+class TFGraph_CaseLikeOp<string mnemonic>
+ : TFGraph_ConcreteOp<mnemonic,
+ [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let arguments = (ins
// Op operands.
I32Tensor:$branch_index,
@@ -297,7 +307,12 @@
let results = (outs Variadic<TFGraph_Tensor>:$outs, ControlType:$ctl);
- let verifier = [{ return VerifyCaseLikeOp(*this); }];
+ let extraClassDefinition = [{
+ LogicalResult $cppClass::verifySymbolUses(
+ SymbolTableCollection &symbol_table) {
+ return VerifyCaseLikeOp(*this, symbol_table);
+ }
+ }];
}
def TFGraph_CaseOp : TFGraph_CaseLikeOp<"Case"> {
@@ -311,7 +326,9 @@
}
// Base class for TFGraph while-like operations.
-class TFGraph_WhileLikeOp<string mnemonic> : TFGraph_ConcreteOp<mnemonic> {
+class TFGraph_WhileLikeOp<string mnemonic>
+ : TFGraph_ConcreteOp<mnemonic,
+ [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let arguments = (ins
// Op operands.
Variadic<TFGraph_TensorOrControlType>:$args,
@@ -324,7 +341,12 @@
);
let results = (outs Variadic<TFGraph_Tensor>:$outs, ControlType:$ctl);
- let verifier = [{ return VerifyWhileLikeOp(*this); }];
+ let extraClassDefinition = [{
+ LogicalResult $cppClass::verifySymbolUses(
+ SymbolTableCollection &symbol_table) {
+ return VerifyWhileLikeOp(*this, symbol_table);
+ }
+ }];
}
def TFGraph_WhileOp : TFGraph_WhileLikeOp<"While"> {
@@ -338,7 +360,9 @@
}
// A functional for loop operation.
-def TFGraph_ForOp : TFGraph_ConcreteOp<"For"> {
+def TFGraph_ForOp
+ : TFGraph_ConcreteOp<"For",
+ [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "A functional for-loop operation.";
let arguments = (ins
@@ -353,8 +377,6 @@
);
let results = (outs Variadic<TFGraph_Tensor>:$outs, ControlType:$ctl);
-
- let verifier = [{ return VerifyForOp(*this); }];
}
#endif // TFG_OPS