[tf][tfg] Relax operand/result type requirements for `ForRegion` and `WhileRegion`
Now that `RegionBranchOpInterface` has been changed to allow relaxing the type
equality requirement along control edges, make it so that TFG region ops only
require that the types be compatible (same tensor element types).
Consequently, `ForRegion` and `WhileRegion` lose their return type inference
and require their operand types to be explicitly specified in the assembly
format.
PiperOrigin-RevId: 434518990
diff --git a/tensorflow/core/ir/ops.cc b/tensorflow/core/ir/ops.cc
index 3a00546..af407f7 100644
--- a/tensorflow/core/ir/ops.cc
+++ b/tensorflow/core/ir/ops.cc
@@ -1281,17 +1281,6 @@
return VerifyLoopRegionArgs(*this, body_region());
}
-LogicalResult ForRegionOp::inferReturnTypes(
- MLIRContext *context, Optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- TypeRange arg_types =
- ForRegionOp::Adaptor(operands, attributes).init().getTypes();
- inferredReturnTypes.assign(arg_types.begin(), arg_types.end());
- inferredReturnTypes.push_back(tf_type::ControlType::get(context));
- return success();
-}
-
OperandRange ForRegionOp::getSuccessorEntryOperands(unsigned index) {
return init();
}
diff --git a/tensorflow/core/ir/ops.td b/tensorflow/core/ir/ops.td
index c55bf42..d01b767 100644
--- a/tensorflow/core/ir/ops.td
+++ b/tensorflow/core/ir/ops.td
@@ -27,7 +27,6 @@
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
-include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -445,6 +444,10 @@
Block &else_block() { return else_region().front(); }
YieldOp then_yield();
YieldOp else_yield();
+
+ bool areTypesCompatible(Type lhs, Type rhs) {
+ return getElementTypeOrSelf(lhs) == getElementTypeOrSelf(rhs);
+ }
}];
let extraClassDefinition = [{
@@ -512,6 +515,10 @@
let extraClassDeclaration = [{
Block &branch_block(unsigned idx) { return branches()[idx].front(); }
YieldOp branch_yield(unsigned idx);
+
+ bool areTypesCompatible(Type lhs, Type rhs) {
+ return getElementTypeOrSelf(lhs) == getElementTypeOrSelf(rhs);
+ }
}];
let extraClassDefinition = [{
@@ -576,8 +583,7 @@
// TF graph while loop op with regions.
class TFGraph_WhileLikeRegionOp<string mnemonic> : TFGraph_RegionOp<
- mnemonic, [InferTypeOpInterface, AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<InferTypeOpInterface>,
+ mnemonic, [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getSuccessorEntryOperands"]>,
DeclareOpInterfaceMethods<ControlArgumentInterface,
@@ -600,6 +606,10 @@
Block &body_block() { return body_region().front(); }
ConditionOp cond_condition();
YieldOp body_yield();
+
+ bool areTypesCompatible(Type lhs, Type rhs) {
+ return getElementTypeOrSelf(lhs) == getElementTypeOrSelf(rhs);
+ }
}];
let extraClassDefinition = [{
@@ -615,17 +625,6 @@
return cast<YieldOp>(body_block().getTerminator());
}
- LogicalResult $cppClass::inferReturnTypes(
- MLIRContext *context, Optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- TypeRange arg_types = $cppClass::Adaptor(operands, attributes)
- .init().getTypes();
- inferredReturnTypes.assign(arg_types.begin(), arg_types.end());
- inferredReturnTypes.push_back(tf_type::ControlType::get(context));
- return success();
- }
-
OperandRange $cppClass::getSuccessorEntryOperands(unsigned index) {
return init();
}
@@ -657,7 +656,7 @@
$cond_region
`do`
$body_region
- attr-dict (`:` type($init)^)?
+ attr-dict (`:` functional-type($init, $outs)^)?
}];
let hasVerifier = 1;
@@ -673,25 +672,18 @@
let summary = "A stateful region-based while loop operation.";
}
-// The indices of a for-loop op will be scalar i32 tensors. Make this type a
-// buildable type.
-def I32ScalarTensor : 0DTensorOf<[I32]> {
- let builderCall = "RankedTensorType::get({}, $_builder.getI32Type())";
-}
-
// TF graph for loop op with region.
def TFGraph_ForRegionOp : TFGraph_RegionOp<
- "ForRegion", [InferTypeOpInterface, AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<InferTypeOpInterface>,
+ "ForRegion", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getSuccessorEntryOperands"]>,
DeclareOpInterfaceMethods<ControlArgumentInterface,
["getControlToken"]>]> {
let arguments = (ins
// Op Operands.
- I32ScalarTensor:$start,
- I32ScalarTensor:$limit,
- I32ScalarTensor:$delta,
+ I32Tensor:$start,
+ I32Tensor:$limit,
+ I32Tensor:$delta,
Variadic<TFGraph_Tensor>:$init,
Variadic<ControlType>:$ctls,
// Optional attributes.
@@ -703,6 +695,10 @@
let extraClassDeclaration = [{
Block &body_block() { return body_region().front(); }
YieldOp body_yield();
+
+ bool areTypesCompatible(Type lhs, Type rhs) {
+ return getElementTypeOrSelf(lhs) == getElementTypeOrSelf(rhs);
+ }
}];
let extraClassDefinition = [{
@@ -720,7 +716,8 @@
(`(` $init^ `)`)? (`[` $ctls^ `]`)?
`from` $start `to` $limit `by` $delta
$body_region
- attr-dict (`:` type($init)^)?
+ attr-dict `:` `(` type($start) `,` type($limit) `,` type($delta)
+ (`,` type($init)^)? `)` (`->` `(` type($outs)^ `)`)?
}];
let hasVerifier = 1;
diff --git a/tensorflow/core/ir/ops_test.cc b/tensorflow/core/ir/ops_test.cc
index c0b2bec..17959af 100644
--- a/tensorflow/core/ir/ops_test.cc
+++ b/tensorflow/core/ir/ops_test.cc
@@ -49,9 +49,7 @@
yield(%arg1) : tensor<f32>
} else {
yield(%arg1) : tensor<f32>
- } {Tcond = i1, Tout = [f32], output_shapes = [#tf_type.shape<>],
- then_attrs = {}, else_attrs = {}}
- : (tensor<i1>) -> (tensor<f32>)
+ } : (tensor<i1>) -> (tensor<f32>)
return(%IfRegion) : tensor<f32>
}
)mlir";
@@ -91,8 +89,7 @@
yield(%arg1) : tensor<f32>
}, {
yield(%arg1) : tensor<f32>
- } {Tout = [f32], output_shapes = [#tf_type.shape<>], branch_attrs = [{}, {}]}
- : (tensor<i32>) -> (tensor<f32>)
+ } : (tensor<i32>) -> (tensor<f32>)
return(%CaseRegion) : tensor<f32>
}
)mlir";
@@ -136,9 +133,7 @@
} do {
^bb0(%arg1: tensor<f32>, %arg2: !tf_type.control):
yield(%arg1) : tensor<f32>
- } {T = [f32], body_attrs = {}, cond_attrs = {},
- output_shapes = [#tf_type.shape<>], parallel_iterations = 10 : i64}
- : tensor<f32>
+ } {parallel_iterations = 10 : i64} : (tensor<f32>) -> (tensor<f32>)
return(%WhileRegion) : tensor<f32>
}
)mlir";
@@ -176,8 +171,7 @@
^bb0(%arg2: tensor<i32>, %arg3: tensor<f32>,
%arg4: !tf_type.control, %arg5: !tf_type.control):
yield(%arg3) : tensor<f32>
- } {T = [f32], body_attrs = {}, output_shapes = [#tf_type.shape<>]}
- : tensor<f32>
+ } : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<f32>) -> (tensor<f32>)
return(%ForRegion) : tensor<f32>
}
)mlir";
diff --git a/tensorflow/core/ir/tests/region-invalid-ops.mlir b/tensorflow/core/ir/tests/region-invalid-ops.mlir
index 96251fd..d13e133 100644
--- a/tensorflow/core/ir/tests/region-invalid-ops.mlir
+++ b/tensorflow/core/ir/tests/region-invalid-ops.mlir
@@ -69,7 +69,7 @@
} do {
^bb0(%arg0: tensor<*xi32>, %arg1: !tf_type.control):
yield(%arg0) [%arg1] : tensor<*xi32>
- } {parallel_iterations = 10 :i64} : tensor<*xi32>
+ } {parallel_iterations = 10 :i64} : (tensor<*xi32>) -> (tensor<*xi32>)
}
// -----
@@ -86,7 +86,7 @@
^bb0(%arg0: tensor<*xi32>, %arg1: !tf_type.control):
%Cond, %ctl_2 = Cond : () -> (tensor<*xi1>)
condition %Cond : tensor<*xi1> (%arg0) : tensor<*xi32>
- } {parallel_iterations = 10 :i64} : tensor<*xi32>
+ } {parallel_iterations = 10 :i64} : (tensor<*xi32>) -> (tensor<*xi32>)
}
// -----
@@ -98,7 +98,7 @@
%For, %ctl_1 = ForRegion(%Arg) [%ctl] from %Index to %Index by %Index {
^bb0(%arg0: tensor<i64>, %arg1: tensor<*xf32>, %arg2: !tf_type.control, %arg3: !tf_type.control):
yield(%arg1) [%arg3] : tensor<*xf32>
- } : tensor<*xf32>
+ } : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<*xf32>) -> (tensor<*xf32>)
}
// -----
@@ -110,7 +110,7 @@
%For, %ctl_1 = ForRegion(%Arg) [%ctl] from %Index to %Index by %Index {
^bb0(%arg0: tensor<i32>, %arg1: tensor<*xf32>, %arg2: !tf_type.control, %arg3: !tf_type.control):
return(%arg1) [%arg3] : tensor<*xf32>
- } : tensor<*xf32>
+ } : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<*xf32>) -> (tensor<*xf32>)
}
// -----
@@ -122,7 +122,7 @@
%ctl_1 = ForRegion [%ctl] from %Index to %Index by %Index {
^bb0:
yield
- }
+ } : (tensor<i32>, tensor<i32>, tensor<i32>)
}
// -----
@@ -134,7 +134,7 @@
%For, %ctl_1 = ForRegion(%Arg) [%ctl] from %Index to %Index by %Index {
^bb0(%arg0: tensor<i32>, %arg1: tensor<*xf32>, %arg2: !tf_type.control):
yield(%arg1) [%arg2] : tensor<*xf32>
- } : tensor<*xf32>
+ } : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<*xf32>) -> (tensor<*xf32>)
}
// -----
@@ -145,5 +145,5 @@
%For, %ctl_1 = ForRegion(%Index) [%ctl] from %Index to %Index by %Index {
^bb0(%arg0: tensor<i32>, %arg1: !tf_type.control, %arg2: !tf_type.control, %arg3: tensor<*xf32>):
yield(%arg0) [%arg2] : tensor<i32>
- } : tensor<i32>
+ } : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>)
}
diff --git a/tensorflow/core/ir/tests/region-ops.mlir b/tensorflow/core/ir/tests/region-ops.mlir
index 4cbbb7d..b88a7ea 100644
--- a/tensorflow/core/ir/tests/region-ops.mlir
+++ b/tensorflow/core/ir/tests/region-ops.mlir
@@ -69,8 +69,7 @@
// CHECK-NEXT: ^bb0(%[[$ARG0:.*]]: tensor<{{.*}}>, %[[$ARG1:.*]]: tensor<{{.*}}>, %[[$ARG2:.*]]: !tf_type.control, %[[$ARG3:.*]]: !tf_type.control):
// CHECK-NEXT: %[[$FWD:.*]]:2, %[[$CTL_0:.*]] = Fwd(%[[$ARG0]], %[[$ARG1]]) [%[[$ARG2]]]
// CHECK-NEXT: yield(%[[$FWD]]#0, %[[$FWD]]#1) [%[[$CTL_0]], %[[$ARG3]]] : tensor<{{.*}}>, tensor<{{.*}}>
- // CHECK-NEXT: }
- // CHECK-SAME: {parallel_iterations = 10 : i64} : tensor<*xf32>, tensor<*xi32>
+ // CHECK-NEXT: } {parallel_iterations = 10 : i64} : (tensor<*xf32>, tensor<*xi32>) -> (tensor<f32>, tensor<i32>)
%WhileRegion:2, %ctl_0 = WhileRegion(%Op#0, %Op#1) [%ctl] {
^bb0(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>,
%arg2: !tf_type.control, %arg3: !tf_type.control):
@@ -81,7 +80,7 @@
%arg2: !tf_type.control, %arg3: !tf_type.control):
%Fwd:2, %ctl_0 = Fwd(%arg0, %arg1) [%arg2] : (tensor<*xf32>, tensor<*xi32>) -> (tensor<*xf32>, tensor<*xi32>)
yield(%Fwd#0, %Fwd#1) [%ctl_0, %arg3] : tensor<*xf32>, tensor<*xi32>
- } {parallel_iterations = 10 : i64} : tensor<*xf32>, tensor<*xi32>
+ } {parallel_iterations = 10 : i64} : (tensor<*xf32>, tensor<*xi32>) -> (tensor<f32>, tensor<i32>)
}
//===----------------------------------------------------------------------===//
@@ -97,9 +96,9 @@
// CHECK-SAME: from %[[INDEX]]#0 to %[[INDEX]]#1 by %[[INDEX]]#2 {
// CHECK-NEXT: ^bb0(%[[ARG0:.*]]: tensor<i32>, %[[ARG1:.*]]: tensor<{{.*}}>, %[[ARG2:.*]]: !tf_type.control, %[[ARG3:.*]]: !tf_type.control):
// CHECK-NEXT: yield(%[[ARG1]]) [%[[ARG3]]] : tensor<{{.*}}>
- // CHECK-NEXT: } : tensor<{{.*}}>
+ // CHECK-NEXT: } : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<*xf32>) -> (tensor<f32>)
%For, %ctl_1 = ForRegion(%Arg) [%ctl] from %Index#0 to %Index#1 by %Index#2 {
^bb0(%arg0: tensor<i32>, %arg1: tensor<*xf32>, %arg2: !tf_type.control, %arg3: !tf_type.control):
yield(%arg1) [%arg3] : tensor<*xf32>
- } : tensor<*xf32>
+ } : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<*xf32>) -> (tensor<f32>)
}