Uses prefer_tf2xla fallback option to cleanly separate op legalization between MLIR and tf2xla. The list of ops MLIR-legalized ops is introduced. All other ops use tf2xla to legalize.
PiperOrigin-RevId: 375725177
Change-Id: Ie56b66d3cfe97c10ac8f36b5fec5b52a7b779168
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir
new file mode 100644
index 0000000..0aa0377
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir
@@ -0,0 +1,29 @@
+// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion device-type=XLA_CPU_JIT legalize-chlo=false use-tf2xla-fallback=true prefer-tf2xla=true" %s | FileCheck %s
+// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion device-type=XLA_CPU_JIT legalize-chlo=false prefer-tf2xla=true" %s | FileCheck --check-prefix NOFALLBACK %s
+
+module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
+
+// CHECK-LABEL: @abs
+func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: bessel_i0e
+func @bessel_i0e(%arg0: tensor<3xf16>, %arg1: tensor<3xf32>, %arg2: tensor<3xf64>) -> (tensor<3xf16>, tensor<3xf32>, tensor<3xf64>) {
+ // CHECK-NOT: tf.BesselI0e
+ %0 = "tf.BesselI0e"(%arg0) : (tensor<3xf16>) -> (tensor<3xf16>)
+ %1 = "tf.BesselI0e"(%arg1) : (tensor<3xf32>) -> (tensor<3xf32>)
+ %2 = "tf.BesselI0e"(%arg2) : (tensor<3xf64>) -> (tensor<3xf64>)
+ return %0, %1, %2 : tensor<3xf16>, tensor<3xf32>, tensor<3xf64>
+}
+
+// NOFALLBACK-LABEL: @xla_svd
+func @xla_svd(%arg0: tensor<1x1xf32>) -> (tensor<1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) {
+ // NOFALLBACK: XlaSvd
+ %s, %u, %v = "tf.XlaSvd"(%arg0) {max_iter = 1, epsilon = 1.0E-09 : f32, precision_config = ""} : (tensor<1x1xf32>) -> (tensor<1xf32>, tensor<1x1xf32>, tensor<1x1xf32>)
+ return %s, %u, %v : tensor<1xf32>, tensor<1x1xf32>, tensor<1x1xf32>
+}
+
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 793eea0..1bbe8da 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -6390,6 +6390,34 @@
}
}
+// Patterns whose root op is in the set `include_ops` are moved from the set
+// `from` to the returned set. This is used to partition patterns by op so they
+// can be cleanly migrated from the old bridge to the MLIR bridge.
+OwningRewritePatternList PatternsIncludeOps(
+ OwningRewritePatternList &from,
+ const llvm::DenseSet<mlir::TypeID> &include_ops);
+
+OwningRewritePatternList PatternsIncludeOps(
+ OwningRewritePatternList &from,
+ const llvm::DenseSet<mlir::TypeID> &include_ops) {
+ OwningRewritePatternList to(from.getContext());
+ // Filter NativePatterns.
+ for (auto &pattern : from.getNativePatterns()) {
+ Optional<OperationName> pat_op_name = pattern->getRootKind();
+ // If the pattern does not have a specific operation, always include it,
+ // If the pattern is in include_ops then include it.
+ bool include =
+ !pat_op_name ||
+ include_ops.count(pat_op_name->getAbstractOperation()->typeID);
+ if (include) to.add(std::move(pattern));
+ }
+
+ // Don't filter PDLPatterns.
+ to.add(std::move(from.getPDLPatterns()));
+
+ return to;
+}
+
} // end namespace
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
@@ -6399,7 +6427,7 @@
llvm::Optional<StringRef> tf2xla_fallback_device_type,
bool prefer_tf2xla) {
MLIRContext *context = op->getContext();
- OwningRewritePatternList patterns(context);
+ OwningRewritePatternList legalize_lower_patterns(context);
// Note that the `OperationConverter` orders patterns lexicographically by:
// 1) Ascending legalization depth (i.e., minimum number of patterns necessary
// to arrive at conversion target). This requires relevant patterns to
@@ -6411,15 +6439,35 @@
// 4) Order of patterns in `OwningRewritePatternList`.
// Add TF->HLO legalization patterns.
- PopulateLegalizeTfPatterns(context, &patterns);
+ PopulateLegalizeTfPatterns(context, &legalize_lower_patterns);
// Add TF->TF lowering patterns.
- TF::PopulateTFLoweringBeforeHLOPatterns(context, &patterns);
+ TF::PopulateTFLoweringBeforeHLOPatterns(context, &legalize_lower_patterns);
- // Add TF->HLO legalization patterns via TF2XLA fallback.
- if (tf2xla_fallback_device_type.hasValue()) {
+ if (tf2xla_fallback_device_type && prefer_tf2xla) {
+ VLOG(1) << "TF to XLA legalization patterns are partitioned by op into "
+ "either native MLIR legalization, or TF2XLA fallback "
+ "legalzation, with a preference toward TF2XLA.";
+ } else if (tf2xla_fallback_device_type) {
+ VLOG(1) << "TF to XLA legalization patterns include all native patterns "
+ "and TF2XLA fallback patterns.";
+ } else {
+ VLOG(1) << "TF to XLA legalization patterns are native patterns only.";
+ }
+
+ // Set patterns to legalize_lower_patters, where in the prefer_tf2xla case
+ // only patterns whose ops are in the set MlirLegalizedUnderPreferTf2XlaSet
+ // are kept.
+ OwningRewritePatternList patterns =
+ (tf2xla_fallback_device_type && prefer_tf2xla)
+ ? PatternsIncludeOps(legalize_lower_patterns,
+ MlirLegalizedUnderPreferTf2XlaSet())
+ : std::move(legalize_lower_patterns);
+
+ if (tf2xla_fallback_device_type) {
+ // Add TF->HLO legalization patterns via TF2XLA fallback.
PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(),
- patterns, context);
+ patterns, context, prefer_tf2xla);
}
// Populate with CHLO->HLO lowerings to account for TF ops legalized to
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index c260195..b8f2120 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -280,6 +280,30 @@
return ops.count(abstractOp->typeID);
}
+// List ops that should use MLIR legalization in the case of prefer_tf2xla. All
+// other ops not in this list should use the old bridge's XlaOpKernel
+// legalization.
+const llvm::DenseSet<mlir::TypeID>& MlirLegalizedUnderPreferTf2XlaSet() {
+ // The static variable is a pointer in order to avoid destruction upon thread
+ // termination.
+
+ // clang-format off
+ static const llvm::DenseSet<mlir::TypeID>* ops =
+ new llvm::DenseSet<mlir::TypeID>{
+ // Ops that are legalized in the old bridge using MlirXlaOpKernel
+ TypeID::get<TF::AbsOp>(),
+ };
+ // clang-format on
+ return *ops;
+}
+
+bool IsOpMlirLegalizedUnderPreferTf2Xla(Operation* op) {
+ auto mlir_ops = MlirLegalizedUnderPreferTf2XlaSet();
+ auto* abstractOp = op->getAbstractOperation();
+ if (!abstractOp) return false;
+ return mlir_ops.count(abstractOp->typeID);
+}
+
namespace {
template <typename T, size_t N>
@@ -571,33 +595,42 @@
class Tf2XlaRewritePattern : public RewritePattern {
public:
explicit Tf2XlaRewritePattern(MLIRContext* ctx,
- const std::string& device_type)
+ const std::string& device_type,
+ bool prefer_tf2xla)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx),
- device_type_(device_type) {}
+ device_type_(device_type),
+ prefer_tf2xla_(prefer_tf2xla) {}
LogicalResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override {
- if (!IsOpAllowedTf2XlaFallback(op)) return failure();
+ if (prefer_tf2xla_) {
+ if (IsOpMlirLegalizedUnderPreferTf2Xla(op)) return failure();
+ } else {
+ if (!IsOpAllowedTf2XlaFallback(op)) return failure();
+ }
return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_);
}
private:
std::string device_type_;
+ bool prefer_tf2xla_;
};
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
public:
LegalizeTF() = default;
- explicit LegalizeTF(llvm::StringRef device_type) {
+ explicit LegalizeTF(llvm::StringRef device_type, bool prefer_tf2xla) {
device_type_ = device_type.str();
+ prefer_tf2xla_ = prefer_tf2xla;
}
LegalizeTF(const LegalizeTF&) {}
void runOnFunction() override {
OwningRewritePatternList patterns(&getContext());
- patterns.insert<Tf2XlaRewritePattern>(&getContext(), device_type_);
+ patterns.insert<Tf2XlaRewritePattern>(&getContext(), device_type_,
+ prefer_tf2xla_);
if (failed(
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
signalPassFailure();
@@ -609,6 +642,12 @@
Option<std::string> device_type_{
*this, "device-type",
llvm::cl::desc("XLA device type for execution of TensorFlow ops.")};
+ Option<bool> prefer_tf2xla_{
+ *this,
+ "prefer-tf2xla",
+ llvm::cl::desc("Enable legalization when it is not in the list of "
+ "MLIR-legalized ops."),
+ };
};
static PassRegistration<LegalizeTF> pass(
@@ -619,13 +658,14 @@
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
OwningRewritePatternList& patterns,
- MLIRContext* ctx) {
- patterns.insert<Tf2XlaRewritePattern>(ctx, device_type.str());
+ MLIRContext* ctx,
+ bool prefer_tf2xla) {
+ patterns.insert<Tf2XlaRewritePattern>(ctx, device_type.str(), prefer_tf2xla);
}
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
- llvm::StringRef device_type) {
- return std::make_unique<LegalizeTF>(device_type);
+ llvm::StringRef device_type, bool prefer_tf2xla) {
+ return std::make_unique<LegalizeTF>(device_type, prefer_tf2xla);
}
} // end namespace mhlo
diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h
index 5df4565..b4fadcd 100644
--- a/tensorflow/compiler/mlir/xla/transforms/passes.h
+++ b/tensorflow/compiler/mlir/xla/transforms/passes.h
@@ -50,22 +50,29 @@
/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the
/// specified device type.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
- llvm::StringRef device_type);
+ llvm::StringRef device_type, bool prefer_tf2xla = false);
/// Replaces types that do not exist in MHLO with equivalent types that do
/// exist.
std::unique_ptr<OperationPass<void>> CreateLegalizeTfTypesPass();
/// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list.
+/// `prefer_tf2xla` means an op will be included iff it is not in
+/// `MlirLegalizedUnderPreferTf2XlaSet`. `!prefer_tf2xla` mean an op will be
+/// included iff it is in `IsOpAllowedTf2XlaFallback`.
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
OwningRewritePatternList& patterns,
- MLIRContext* ctx);
+ MLIRContext* ctx,
+ bool prefer_tf2xla = false);
/// Adds the TF to TF lowerings and TF to XLA rewrite patterns to the pattern
/// list.
void PopulateLegalizeTfPatterns(MLIRContext* context,
OwningRewritePatternList* patterns);
+/// Ops that should be legalized using MLIR given the prefer_tf2xla option.
+const llvm::DenseSet<mlir::TypeID>& MlirLegalizedUnderPreferTf2XlaSet();
+
/// Checks whether the op is supported by the Tf2Xla fallback for legalization.
bool IsOpAllowedTf2XlaFallback(Operation* op);