Enable binary TensorFlow ops in "xla-legalize-tf-with-tf2xla" pass
This enables binary ops in binary_ops_test.py that have kernel defined in tf2xla/kernels/binary_ops.cc and doesn't already have legalizations. Some tests are disabled if the op is not supported or either using unsigned int or complex constants.
PiperOrigin-RevId: 307031001
Change-Id: I76745c8d9443f19a50dc79c5e78d0dce0d8fddd4
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index 8b76a5d..63a7026 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -77,13 +77,24 @@
// building valid MLIR using MlirHloBuilder.
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
// all tf2xla kernels.
- return isa<TF::AbsOp>(op) || isa<TF::Atan2Op>(op) ||
- isa<TF::BiasAddGradOp>(op) || isa<TF::CastOp>(op) ||
- isa<TF::ComplexAbsOp>(op) || isa<TF::GreaterOp>(op) ||
- isa<TF::InvOp>(op) || isa<TF::InvertOp>(op) || isa<TF::LogOp>(op) ||
- isa<TF::LogicalNotOp>(op) || isa<TF::NegOp>(op) ||
- isa<TF::SelectV2Op>(op) || isa<TF::SinOp>(op) ||
+ return isa<TF::AbsOp>(op) || isa<TF::AddV2Op>(op) || isa<TF::Atan2Op>(op) ||
+ isa<TF::BatchMatMulV2Op>(op) || isa<TF::BiasAddOp>(op) ||
+ isa<TF::BiasAddGradOp>(op) || isa<TF::BitwiseAndOp>(op) ||
+ isa<TF::BitwiseOrOp>(op) || isa<TF::BitwiseXorOp>(op) ||
+ isa<TF::CastOp>(op) || isa<TF::ComplexAbsOp>(op) ||
+ isa<TF::DivNoNanOp>(op) || isa<TF::EqualOp>(op) ||
+ isa<TF::FloorDivOp>(op) || isa<TF::FloorModOp>(op) ||
+ isa<TF::GreaterOp>(op) || isa<TF::GreaterEqualOp>(op) ||
+ isa<TF::InvOp>(op) || isa<TF::InvertOp>(op) ||
+ isa<TF::LeftShiftOp>(op) || isa<TF::LessOp>(op) ||
+ isa<TF::LessEqualOp>(op) || isa<TF::LogicalAndOp>(op) ||
+ isa<TF::LogicalNotOp>(op) || isa<TF::LogicalOrOp>(op) ||
+ isa<TF::LogOp>(op) || isa<TF::MatMulOp>(op) || isa<TF::MulOp>(op) ||
+ isa<TF::NegOp>(op) || isa<TF::NotEqualOp>(op) || isa<TF::PowOp>(op) ||
+ isa<TF::RealDivOp>(op) || isa<TF::RightShiftOp>(op) ||
+ isa<TF::SinOp>(op) || isa<TF::SelectV2Op>(op) || isa<TF::SubOp>(op) ||
isa<TF::SquareOp>(op) || isa<TF::TransposeOp>(op) ||
+ isa<TF::TruncateDivOp>(op) || isa<TF::TruncateModOp>(op) ||
isa<TF::UnpackOp>(op);
}
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 01e5165..1098abe 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -200,6 +200,7 @@
name = "binary_ops_test",
size = "medium",
srcs = ["binary_ops_test.py"],
+ enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 5,
tags = [
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index f4df06c..d9721a3 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -26,6 +26,7 @@
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import gen_math_ops
@@ -72,6 +73,8 @@
self.assertAllCloseAccordingToType(
result[i], expected[i], rtol=rtol, atol=atol)
+ @test_util.disable_mlir_bridge(
+ "F16 type is not supported in CreateDenseElementsAttrFromLiteral")
def testFloatOps(self):
for dtype in self.float_types:
if dtype == dtypes.bfloat16.as_numpy_dtype:
@@ -296,6 +299,7 @@
]
self._testBinary(bitwise_ops.right_shift, lhs, rhs, expected=expected)
+ @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints")
def testAdd(self):
for dtype in self.numeric_types:
self._testBinary(
@@ -322,6 +326,7 @@
expected=np.array([3.0269620882574744, 3.3149631512242195],
dtype=dtype))
+ @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints")
def testMultiply(self):
for dtype in self.numeric_types:
self._testBinary(
@@ -385,6 +390,7 @@
expected=np.array([[16], [81]], dtype=dtype),
rtol=rtol)
+ @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints")
def testNumericOps(self):
for dtype in self.numeric_types:
self._testBinary(
@@ -474,6 +480,7 @@
expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36],
dtype=np.int64))
+ @test_util.disable_mlir_bridge("Enable tf.NextAfter Compilation")
def testNextAfter(self):
for dtype in self.numeric_types:
if dtype in [np.float32, np.float64]:
@@ -501,6 +508,8 @@
expected=expected,
equality_test=NextAfterEqualityTest)
+ @test_util.disable_mlir_bridge(
+ "Complex types not supported in CreateDenseElementsAttrFromLiteral")
def testComplexOps(self):
for dtype in self.complex_types:
ctypes = {np.complex64: np.float32, np.complex128: np.float64}
@@ -724,6 +733,8 @@
for dtype in self.signed_int_types - {np.int8}:
self._testRemainder(dtype)
+ @test_util.disable_mlir_bridge(
+ "F16 type is not supported in CreateDenseElementsAttrFromLiteral")
def testFloatRemainder(self):
for dtype in self.float_types:
self._testRemainder(dtype)
@@ -923,6 +934,7 @@
expected = np.array([op(l, r) for l, r in zip(lhs, rhs)], dtype=np.bool)
self._testBinary(op, lhs, rhs, expected=expected)
+ @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints")
def testBroadcasting(self):
"""Tests broadcasting behavior of an operator."""
@@ -1490,6 +1502,7 @@
np.array([1, 0], dtype=np.int32),
expected=np.array([[1 + 1j, 3 + 3j], [2 - 2j, 4 - 4j]], dtype=dtype))
+ @test_util.disable_mlir_bridge("Enable tf.Cross Compilation")
def testCross(self):
for dtype in self.float_types:
self._testBinary(