Use TF legalization patterns to determine if op should be marked for OutsideCompilation.
PiperOrigin-RevId: 325042393
Change-Id: I075db3bb540e9cbc682b699bdf468021ce5debdb
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index c6f0083..63908c8 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -827,6 +827,7 @@
":xla_sharding_util",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/mlir/lite:validators",
+ "//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla/client:sharding_builder",
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir
index 9b28b3b..afad117 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir
@@ -1,154 +1,159 @@
// RUN: tf-opt %s -tf-mark-ops-for-outside-compilation | FILECHECK_OPTS="" FileCheck %s
-
-// CHECK-LABEL: func @op_string_result
-func @op_string_result() -> tensor<?xi32> {
+// CHECK-LABEL: func @unsupported_op
+func @unsupported_op() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
- // CHECK: "tf.A"
- // CHECK-NOT: _xla_outside_compilation
- // CHECK: "tf.B"
+ // CHECK: "tf.UnsupportedOp"
// CHECK-SAME: _xla_outside_compilation
- // CHECK: "tf.C"
+ // CHECK: "tf.Identity"
// CHECK-NOT: _xla_outside_compilation
- %1 = "tf.A"() : () -> tensor<?xi32>
- %2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<!tf.string>
- %3 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
- tf_device.return %3 : tensor<?xi32>
- }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
- return %0 : tensor<?xi32>
+ %1 = "tf.UnsupportedOp"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
+ tf_device.return %2 : tensor<i32>
+ }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
+ return %0 : tensor<i32>
}
-// CHECK-LABEL: func @op_string_operand
-func @op_string_operand(%arg0: tensor<!tf.string>) -> tensor<?xi32> {
+// CHECK-LABEL: func @op_string_result
+func @op_string_result() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
- // CHECK: "tf.A"
+ // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK-NOT: _xla_outside_compilation
- // CHECK: "tf.B"
+ // CHECK: "tf.Const"
// CHECK-SAME: _xla_outside_compilation
- // CHECK: "tf.C"
+ // CHECK-SAME: tf.string
+ // CHECK: "tf.Identity"
// CHECK-NOT: _xla_outside_compilation
- %1 = "tf.A"() : () -> tensor<?xi32>
- %2 = "tf.B"(%arg0) : (tensor<!tf.string>) -> tensor<?xi32>
- %3 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
- tf_device.return %3 : tensor<?xi32>
- }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
- return %0 : tensor<?xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %2 = "tf.Const"() {value = dense<"x"> : tensor<!tf.string>} : () -> tensor<!tf.string>
+ %3 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
+ tf_device.return %3 : tensor<i32>
+ }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
+ return %0 : tensor<i32>
+}
+// CHECK-LABEL: func @op_string_operand
+func @op_string_operand(%arg0: tensor<!tf.string>) -> tensor<i32> {
+ %0 = "tf_device.cluster"() ( {
+ // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
+ // CHECK-NOT: _xla_outside_compilation
+ // CHECK: "tf.StringToNumber"
+ // CHECK-SAME: _xla_outside_compilation
+ // CHECK-SAME: tf.string
+ // CHECK: "tf.Identity"
+ // CHECK-NOT: _xla_outside_compilation
+ %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %2 = "tf.StringToNumber"(%arg0) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
+ %3 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
+ tf_device.return %3 : tensor<i32>
+ }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
+ return %0 : tensor<i32>
}
// CHECK-LABEL: func @op_string_operand_string_result
-func @op_string_operand_string_result(%arg0: tensor<!tf.string>) -> tensor<?xi32> {
+func @op_string_operand_string_result(%arg0: tensor<!tf.string>) -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
- // CHECK: "tf.A"
+ // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK-NOT: _xla_outside_compilation
- // CHECK: "tf.B"
+ // CHECK: "tf.Identity"
// CHECK-SAME: _xla_outside_compilation
- // CHECK: "tf.C"
+ // CHECK-SAME: tf.string
+ // CHECK: "tf.Identity"
// CHECK-NOT: _xla_outside_compilation
- %1 = "tf.A"() : () -> tensor<?xi32>
- %2 = "tf.B"(%arg0) : (tensor<!tf.string>) -> tensor<!tf.string>
- %3 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
- tf_device.return %3 : tensor<?xi32>
- }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
- return %0 : tensor<?xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %2 = "tf.Identity"(%arg0) : (tensor<!tf.string>) -> tensor<!tf.string>
+ %3 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
+ tf_device.return %3 : tensor<i32>
+ }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
+ return %0 : tensor<i32>
}
-
// Test that a tf.IfRegion op with a captured string operand is marked for outside compilation.
// CHECK-LABEL: func @if_region_captured_string
-func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) -> tensor<?xi32> {
+func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) -> tensor<f32> {
%0 = "tf_device.cluster"() ( {
- // CHECK: "tf.A"
+ // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf.IfRegion"
- // CHECK: "tf.D"
- // CHECK-SAME: _xla_outside_compilation
- // CHECK: _xla_outside_compilation
- // CHECK-SAME: is_stateless = true
- %1 = "tf.A"() : () -> tensor<?xi32>
+ // CHECK: "tf.StringToNumber"
+ // CHECK: _xla_outside_compilation = "auto", is_stateless = true
+ %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%2 = "tf.IfRegion"(%arg0) ( {
- %3 = "tf.D"(%arg1) : (tensor<!tf.string>) -> tensor<?xi32>
- "tf.Yield"(%3) : (tensor<?xi32>) -> ()
+ %3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
+ "tf.Yield"(%3) : (tensor<f32>) -> ()
}, {
- %4 = "tf.H"() : () -> tensor<?xi32>
- "tf.Yield"(%4) : (tensor<?xi32>) -> ()
- }) {is_stateless = true} : (tensor<i1>) -> (tensor<?xi32>)
- %5 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
- tf_device.return %5 : tensor<?xi32>
- }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
- return %0 : tensor<?xi32>
+ %4 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
+ "tf.Yield"(%4) : (tensor<f32>) -> ()
+ }) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
+ %5 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
+ tf_device.return %5 : tensor<f32>
+ }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<f32>
+ return %0 : tensor<f32>
}
-// Test that op with a string results/operands inside a tf.IfRegion branch is marked for outside compilation.
+// Test that ops with string results/operands inside a tf.IfRegion branch are marked for outside compilation.
// CHECK-LABEL: func @if_region_string_op
-func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
+func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
%0 = "tf_device.cluster"() ( {
- // CHECK: "tf.A"
+ // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf.IfRegion"
// CHECK-NOT: _xla_outside_compilation
- %1 = "tf.A"() : () -> tensor<?xi32>
- %2 = "tf.IfRegion"(%arg0)({
- // CHECK: "tf.D"
- // CHECK-NOT: _xla_outside_compilation
- %3 = "tf.D"(%arg1) : (tensor<?xi32>) -> tensor<?xi32>
- "tf.Yield"(%3) : (tensor<?xi32>) -> ()
- }, {
- // CHECK: "tf.F"
+ %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %2 = "tf.IfRegion"(%arg0) ( {
+ %3 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
+ "tf.Yield"(%3) : (tensor<f32>) -> ()
+ }, {
+ // CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor<!tf.string>}
+ // CHECK-NEXT: "tf.StringToNumber"
// CHECK-SAME: _xla_outside_compilation
- %4 = "tf.F"() : () -> tensor<!tf.string>
- // CHECK: "tf.G"
- // CHECK-SAME: _xla_outside_compilation
- %5 = "tf.G"(%4) : (tensor<!tf.string>) -> tensor<?xi32>
- %6 = "tf.H"() : () -> tensor<?xi32>
- "tf.Yield"(%6) : (tensor<?xi32>) -> ()
- }) {is_stateless = true} : (tensor<i1>) -> tensor<?xi32>
- // CHECK: "tf.C"
- // CHECK-NOT: _xla_outside_compilation
- %7 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
- tf_device.return %7 : tensor<?xi32>
- }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
- return %0 : tensor<?xi32>
+ %4 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf.string>} : () -> tensor<!tf.string>
+ %5 = "tf.StringToNumber"(%4) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
+ "tf.Yield"(%5) : (tensor<f32>) -> ()
+ // CHECK: {is_stateless
+ }) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
+ %6 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
+ tf_device.return %6: tensor<f32>
+ }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<f32>
+ return %0 : tensor<f32>
}
-// Test that op with a string results/operands inside a tf.IfRegion branch is marked for outside compilation.
+// Test that ops with string results/operands inside a nested tf.IfRegion branch are marked for outside compilation.
// CHECK-LABEL: func @nested_if_region_string_op
-func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
+func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
%0 = "tf_device.cluster"() ( {
- // CHECK: "tf.A"
+ // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf.IfRegion"
// CHECK-NOT: _xla_outside_compilation
- %1 = "tf.A"() : () -> tensor<?xi32>
- %2 = "tf.IfRegion"(%arg0)({
- // CHECK: "tf.D"
- // CHECK-NOT: _xla_outside_compilation
- %3 = "tf.D"(%arg1) : (tensor<?xi32>) -> tensor<?xi32>
- "tf.Yield"(%3) : (tensor<?xi32>) -> ()
- }, {
- %4 = "tf.E"() : () -> tensor<i1>
- %5 = "tf.IfRegion"(%4)({
- // CHECK: "tf.F"
- // CHECK-NOT: _xla_outside_compilation
- %6 = "tf.F"(%arg1) : (tensor<?xi32>) -> tensor<?xi32>
- "tf.Yield"(%6) : (tensor<?xi32>) -> ()
+ %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %2 = "tf.IfRegion"(%arg0) ( {
+ %3 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
+ "tf.Yield"(%3) : (tensor<f32>) -> ()
}, {
- // CHECK: "tf.G"
- // CHECK-SAME: _xla_outside_compilation
- %7 = "tf.G"() : () -> tensor<!tf.string>
- // CHECK: "tf.H"
- // CHECK-SAME: _xla_outside_compilation
- %8 = "tf.H"(%7) : (tensor<!tf.string>) -> tensor<?xi32>
- %9 = "tf.I"() : () -> tensor<?xi32>
- "tf.Yield"(%9) : (tensor<?xi32>) -> ()
- }) {is_stateless = true} : (tensor<i1>) -> tensor<?xi32>
- "tf.Yield"(%5) : (tensor<?xi32>) -> ()
- }) {is_stateless = true} : (tensor<i1>) -> tensor<?xi32>
- // CHECK: "tf.C"
- // CHECK-NOT: _xla_outside_compilation
- %10 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
- tf_device.return %10 : tensor<?xi32>
- }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
- return %0 : tensor<?xi32>
+ // CHECK: "tf.Const"() {value = dense<true> : tensor<i1>}
+ // CHECK-NOT: _xla_outside_compilation
+ %4 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+ %5 = "tf.IfRegion"(%4)({
+ // CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor<!tf.string>}
+ // CHECK-NEXT: "tf.StringToNumber"
+ // CHECK-SAME: _xla_outside_compilation
+ %6 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf.string>} : () -> tensor<!tf.string>
+ %7 = "tf.StringToNumber"(%6) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
+ "tf.Yield"(%7) : (tensor<f32>) -> ()
+ }, {
+ // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
+ // CHECK-NOT: _xla_outside_compilation
+ %8 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+ "tf.Yield"(%8) : (tensor<f32>) -> ()
+ // CHECK: {is_stateless
+ }){is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
+ "tf.Yield"(%5) : (tensor<f32>) -> ()
+ // CHECK: {is_stateless
+ }) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
+ %9 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
+ tf_device.return %9: tensor<f32>
+ }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<f32>
+ return %0 : tensor<f32>
}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc
index 8f1f3ec..71146cf 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc
@@ -25,6 +25,7 @@
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
namespace mlir {
namespace TFDevice {
@@ -43,6 +44,14 @@
void runOnOperation() override;
};
+// TODO(b/159128666): Check the control flow legalization passes instead once
+// added.
+void AddSupportedControlFlowOps(MLIRContext* context,
+ llvm::DenseSet<OperationName>* supported_ops) {
+ supported_ops->insert(OperationName("tf.IfRegion", context));
+ supported_ops->insert(OperationName("tf.Yield", context));
+}
+
bool HasStringOperand(Operation& op) {
for (auto operand : op.getOperands()) {
if (getElementTypeOrSelf(operand).isa<TF::StringType>()) return true;
@@ -57,12 +66,18 @@
return false;
}
+bool MatchesPattern(Operation& op,
+ const llvm::DenseSet<OperationName>& supported_ops) {
+ return (supported_ops.contains(op.getName()));
+}
+
// Checks if the op is supported inside of a device cluster.
-bool IsSupportedOp(Operation& op) {
- if (HasStringOperand(op) || HasStringResult(op)) {
- return false;
- }
- return true;
+bool IsSupportedOp(Operation& op,
+ const llvm::DenseSet<OperationName>& supported_ops) {
+ // TODO(b/161726307): Check the allowed ops list in LegalizeTfWithTf2XlaPass
+ // as well.
+ return !HasStringOperand(op) && !HasStringResult(op) &&
+ MatchesPattern(op, supported_ops);
}
bool HasCapturedStringOperand(TF::IfRegionOp* if_op) {
@@ -83,9 +98,10 @@
return string_operand;
}
-LogicalResult MarkUncompilableOps(Block* block) {
+LogicalResult MarkUncompilableOps(
+ Block* block, llvm::DenseSet<OperationName>& supported_ops) {
block->walk([&](Operation* op) {
- if (!IsSupportedOp(*op)) {
+ if (!IsSupportedOp(*op, supported_ops)) {
op->setAttr(kXlaOutsideCompilationAttr,
StringAttr::get("auto", op->getContext()));
}
@@ -101,9 +117,21 @@
void MarkOpsForOutsideCompilation::runOnOperation() {
auto module = getOperation();
+ OwningRewritePatternList patterns;
+ mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns);
+
+ // `supported_ops` contains the name of all of the ops that can potentially be
+ // lowered into HLO on the device. This doesn't always mean that the op can
+ // be lowered in the future passes but if the op is not in this set, it can't
+ // be lowered in a subsequent pass.
+ llvm::DenseSet<OperationName> supported_ops;
+ for (auto& pattern : patterns) {
+ supported_ops.insert(*pattern->getRootKind());
+ }
+ AddSupportedControlFlowOps(module.getContext(), &supported_ops);
auto result = module.walk([&](tf_device::ClusterOp cluster) {
- if (failed(MarkUncompilableOps(&cluster.GetBody())))
+ if (failed(MarkUncompilableOps(&cluster.GetBody(), supported_ops)))
return WalkResult::interrupt();
return WalkResult::advance();