Allow _EagerConst op in fallback legalization in tf2xla preferred mode
Also, remove C++ pattern in favor of TableGen pattern for the op.
PiperOrigin-RevId: 388864240
Change-Id: I4c1e8ce9cc781260cadb1102580c200adf7f8ab4
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 30ba1ae..cdfd555 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -2323,17 +2323,6 @@
}
};
-// Bypass _EagerConst
-class ConvertEagerConstOp : public OpRewritePattern<TF::_EagerConstOp> {
- public:
- using OpRewritePattern<TF::_EagerConstOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(TF::_EagerConstOp op,
- PatternRewriter &rewriter) const override {
- rewriter.replaceOp(op, op.getOperand());
- return success();
- }
-};
-
template <typename OpTy>
class ConvertFFTOp : public OpRewritePattern<OpTy> {
public:
@@ -7486,7 +7475,6 @@
ConvertFusedBatchNormV3Op,
ConvertInfeedDequeueTupleOp,
ConvertIdentityNOp,
- ConvertEagerConstOp,
ConvertInplaceUpdateOp,
ConvertLinSpaceOp,
ConvertMaxOp,
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index ac40f7d..3459422 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -345,7 +345,7 @@
// Identity op patterns.
//===----------------------------------------------------------------------===//
-foreach src = [TF_IdentityOp, TF_StopGradientOp] in
+foreach src = [TF_IdentityOp, TF_StopGradientOp, TF__EagerConstOp] in
def : Pat<(src $op), (replaceWithValue $op)>;
// TODO(b/32223192): Support CheckNumerics in HLO.
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index 47296b5..ef80340 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -315,6 +315,7 @@
TypeID::get<TF::CumsumOp>(),
TypeID::get<TF::DepthwiseConv2dNativeOp>(),
TypeID::get<TF::DynamicStitchOp>(),
+ TypeID::get<TF::_EagerConstOp>(),
TypeID::get<TF::EmptyOp>(),
TypeID::get<TF::ExpandDimsOp>(),
TypeID::get<TF::FillOp>(),