Enable binary TensorFlow ops in "xla-legalize-tf-with-tf2xla" pass

This enables binary ops in binary_ops_test.py that have kernel defined in tf2xla/kernels/binary_ops.cc and doesn't already have legalizations. Some tests are disabled if the op is not supported or either using unsigned int or complex constants.

PiperOrigin-RevId: 307031001
Change-Id: I76745c8d9443f19a50dc79c5e78d0dce0d8fddd4
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index 8b76a5d..63a7026 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -77,13 +77,24 @@
   // building valid MLIR using MlirHloBuilder.
   // TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
   // all tf2xla kernels.
-  return isa<TF::AbsOp>(op) || isa<TF::Atan2Op>(op) ||
-         isa<TF::BiasAddGradOp>(op) || isa<TF::CastOp>(op) ||
-         isa<TF::ComplexAbsOp>(op) || isa<TF::GreaterOp>(op) ||
-         isa<TF::InvOp>(op) || isa<TF::InvertOp>(op) || isa<TF::LogOp>(op) ||
-         isa<TF::LogicalNotOp>(op) || isa<TF::NegOp>(op) ||
-         isa<TF::SelectV2Op>(op) || isa<TF::SinOp>(op) ||
+  return isa<TF::AbsOp>(op) || isa<TF::AddV2Op>(op) || isa<TF::Atan2Op>(op) ||
+         isa<TF::BatchMatMulV2Op>(op) || isa<TF::BiasAddOp>(op) ||
+         isa<TF::BiasAddGradOp>(op) || isa<TF::BitwiseAndOp>(op) ||
+         isa<TF::BitwiseOrOp>(op) || isa<TF::BitwiseXorOp>(op) ||
+         isa<TF::CastOp>(op) || isa<TF::ComplexAbsOp>(op) ||
+         isa<TF::DivNoNanOp>(op) || isa<TF::EqualOp>(op) ||
+         isa<TF::FloorDivOp>(op) || isa<TF::FloorModOp>(op) ||
+         isa<TF::GreaterOp>(op) || isa<TF::GreaterEqualOp>(op) ||
+         isa<TF::InvOp>(op) || isa<TF::InvertOp>(op) ||
+         isa<TF::LeftShiftOp>(op) || isa<TF::LessOp>(op) ||
+         isa<TF::LessEqualOp>(op) || isa<TF::LogicalAndOp>(op) ||
+         isa<TF::LogicalNotOp>(op) || isa<TF::LogicalOrOp>(op) ||
+         isa<TF::LogOp>(op) || isa<TF::MatMulOp>(op) || isa<TF::MulOp>(op) ||
+         isa<TF::NegOp>(op) || isa<TF::NotEqualOp>(op) || isa<TF::PowOp>(op) ||
+         isa<TF::RealDivOp>(op) || isa<TF::RightShiftOp>(op) ||
+         isa<TF::SinOp>(op) || isa<TF::SelectV2Op>(op) || isa<TF::SubOp>(op) ||
          isa<TF::SquareOp>(op) || isa<TF::TransposeOp>(op) ||
+         isa<TF::TruncateDivOp>(op) || isa<TF::TruncateModOp>(op) ||
          isa<TF::UnpackOp>(op);
 }
 
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 01e5165..1098abe 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -200,6 +200,7 @@
     name = "binary_ops_test",
     size = "medium",
     srcs = ["binary_ops_test.py"],
+    enable_mlir_bridge = True,
     python_version = "PY3",
     shard_count = 5,
     tags = [
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index f4df06c..d9721a3 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -26,6 +26,7 @@
 from tensorflow.compiler.tests import xla_test
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import bitwise_ops
 from tensorflow.python.ops import gen_math_ops
@@ -72,6 +73,8 @@
       self.assertAllCloseAccordingToType(
           result[i], expected[i], rtol=rtol, atol=atol)
 
+  @test_util.disable_mlir_bridge(
+      "F16 type is not supported in CreateDenseElementsAttrFromLiteral")
   def testFloatOps(self):
     for dtype in self.float_types:
       if dtype == dtypes.bfloat16.as_numpy_dtype:
@@ -296,6 +299,7 @@
         ]
         self._testBinary(bitwise_ops.right_shift, lhs, rhs, expected=expected)
 
+  @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints")
   def testAdd(self):
     for dtype in self.numeric_types:
       self._testBinary(
@@ -322,6 +326,7 @@
             expected=np.array([3.0269620882574744, 3.3149631512242195],
                               dtype=dtype))
 
+  @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints")
   def testMultiply(self):
     for dtype in self.numeric_types:
       self._testBinary(
@@ -385,6 +390,7 @@
           expected=np.array([[16], [81]], dtype=dtype),
           rtol=rtol)
 
+  @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints")
   def testNumericOps(self):
     for dtype in self.numeric_types:
       self._testBinary(
@@ -474,6 +480,7 @@
           expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36],
                             dtype=np.int64))
 
+  @test_util.disable_mlir_bridge("Enable tf.NextAfter Compilation")
   def testNextAfter(self):
     for dtype in self.numeric_types:
       if dtype in [np.float32, np.float64]:
@@ -501,6 +508,8 @@
             expected=expected,
             equality_test=NextAfterEqualityTest)
 
+  @test_util.disable_mlir_bridge(
+      "Complex types not supported in CreateDenseElementsAttrFromLiteral")
   def testComplexOps(self):
     for dtype in self.complex_types:
       ctypes = {np.complex64: np.float32, np.complex128: np.float64}
@@ -724,6 +733,8 @@
     for dtype in self.signed_int_types - {np.int8}:
       self._testRemainder(dtype)
 
+  @test_util.disable_mlir_bridge(
+      "F16 type is not supported in CreateDenseElementsAttrFromLiteral")
   def testFloatRemainder(self):
     for dtype in self.float_types:
       self._testRemainder(dtype)
@@ -923,6 +934,7 @@
       expected = np.array([op(l, r) for l, r in zip(lhs, rhs)], dtype=np.bool)
       self._testBinary(op, lhs, rhs, expected=expected)
 
+  @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints")
   def testBroadcasting(self):
     """Tests broadcasting behavior of an operator."""
 
@@ -1490,6 +1502,7 @@
           np.array([1, 0], dtype=np.int32),
           expected=np.array([[1 + 1j, 3 + 3j], [2 - 2j, 4 - 4j]], dtype=dtype))
 
+  @test_util.disable_mlir_bridge("Enable tf.Cross Compilation")
   def testCross(self):
     for dtype in self.float_types:
       self._testBinary(