Add XlaReducePrecision TensorFlow op.

See https://www.tensorflow.org/xla/operation_semantics#reduceprecision.

This is somewhat useful for testing fp8. Unfortunately the usefulness is currently quite limited due to the fact certain fp8 representations, such as the ones described in this paper https://arxiv.org/abs/2206.02915, have different representations of Inf/NaN compared to the standard IEEE floating-point dtypes.

PiperOrigin-RevId: 468765094
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index a8855bc..42b43cd 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -2240,6 +2240,7 @@
       "XlaPad",
       "XlaRecv",
       "XlaReduce",
+      "XlaReducePrecision",
       "XlaReduceWindow",
       "XlaRemoveDynamicDimensionSize",
       "XlaReplicaId",
diff --git a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
index 888f490..9fc9583 100644
--- a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
+++ b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
@@ -70,6 +70,8 @@
                           XlaCompileOnDemandOp);                               \
   REGISTER_KERNEL_BUILDER(Name("XlaVariadicReduceV2").Device(DEVICE),          \
                           XlaCompileOnDemandOp);                               \
+  REGISTER_KERNEL_BUILDER(Name("XlaReducePrecision").Device(DEVICE),           \
+                          XlaCompileOnDemandOp);                               \
   REGISTER_KERNEL_BUILDER(Name("XlaReduceWindow")                              \
                               .HostMemory("window_dimensions")                 \
                               .HostMemory("window_strides")                    \
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index b7de9e1..8b57b79 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -20237,6 +20237,27 @@
   let hasCanonicalizer = 1;
 }
 
+def TF_XlaReducePrecisionOp : TF_Op<"XlaReducePrecision", [NoSideEffect]> {
+  let summary = "Wraps the XLA ReducePrecision operator";
+
+  let description = [{
+documented at https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
+  }];
+
+  let arguments = (ins
+    Arg<TF_FloatTensor, [{array of floating-point type.}]>:$operand,
+
+    I64Attr:$exponent_bits,
+    I64Attr:$mantissa_bits
+  );
+
+  let results = (outs
+    TF_FloatTensor:$output
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
 def TF_XlaReduceScatterOp : TF_Op<"XlaReduceScatter", [NoSideEffect]> {
   let summary = "Wraps the XLA ReduceScatter operator";
 
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 5e6452e..f847b7e 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -7223,6 +7223,31 @@
     return success();
   }
 };
+
+// Convert tf.XlaReducePrecision to mhlo.ReducePrecision
+class ConvertXlaReducePrecisionOp
+    : public OpRewritePattern<TF::XlaReducePrecisionOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TF::XlaReducePrecisionOp op,
+                                PatternRewriter &rewriter) const override {
+    IntegerType int32_type = rewriter.getIntegerType(32);
+    APInt exponent_bits = op.exponent_bitsAttr().getValue();
+    // Truncating to 32-bits is safe, since pasing any number above the dtype
+    // size (which is at most 64, for float64) is equivalent to passing the
+    // dtype size.
+    IntegerAttr new_exponent_attr =
+        IntegerAttr::get(int32_type, exponent_bits.truncSSat(32));
+    APInt mantissa_bits = op.mantissa_bitsAttr().getValue();
+    IntegerAttr new_mantissa_attr =
+        IntegerAttr::get(int32_type, mantissa_bits.truncSSat(32));
+    rewriter.replaceOpWithNewOp<mhlo::ReducePrecisionOp>(
+        op, op.getType(), op.operand(), new_exponent_attr, new_mantissa_attr);
+    return success();
+  }
+};
+
 }  // end namespace
 
 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
@@ -7308,6 +7333,7 @@
     ConvertXlaShardingOp,
     ConvertXlaDynamicUpdateSliceOp,
     ConvertXlaConvV2Op,
+    ConvertXlaReducePrecisionOp,
     ConvertXlaReduceScatterOp,
     ConvertXlaReduceWindowOp,
     ConvertXlaRngBitGeneratorOp,
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index ebeb95c..bb075fd 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -49,8 +49,8 @@
         output = op(*placeholders)
       result = session.run(output, feeds)
       if not equality_fn:
-        equality_fn = self.assertAllClose
-      equality_fn(result, expected, rtol=1e-3)
+        equality_fn = lambda x, y: self.assertAllClose(x, y, rtol=1e-3)
+      equality_fn(result, expected)
 
   def testAdd(self):
     if xla_test.test.is_built_with_rocm():
@@ -439,7 +439,8 @@
               np.exp(1) * 8 / 7 + 1e5 / 3
           ],
                             dtype=dtype))
-      error_term_equality = functools.partial(self.assertAllClose, atol=.005)
+      error_term_equality = functools.partial(
+          self.assertAllClose, rtol=1e-3, atol=.005)
       self._assertOpOutputMatchesExpected(
           kahan_sum_reduction(dims=[0], output_idx=1),
           args=(xs[shuffle_indices],),
@@ -658,6 +659,28 @@
     self._assertOpOutputMatchesExpected(
         xla.optimization_barrier, args=args, expected=args)
 
+  def test_reduce_precision(self):
+    arg = np.array([1 + 2**-2 + 2**-4, 128, 256], dtype=np.float32)
+    expected = np.array([1 + 2**-2, 128, float('Inf')], dtype=np.float32)
+    exponent_bits = 4
+    mantissa_bits = 2
+    self._assertOpOutputMatchesExpected(
+        lambda x: xla.reduce_precision(x, exponent_bits, mantissa_bits),
+        args=(arg,),
+        expected=expected,
+        equality_fn=self.assertAllEqual)
+
+    arg = np.array([4], dtype=np.float32)
+    expected = np.array([4], dtype=np.float32)
+    # Test passing numbers that cannot fit in a 32-bit integer.
+    exponent_bits = 2**33
+    mantissa_bits = 2**33
+    self._assertOpOutputMatchesExpected(
+        lambda x: xla.reduce_precision(x, exponent_bits, mantissa_bits),
+        args=(arg,),
+        expected=expected,
+        equality_fn=self.assertAllEqual)
+
 
 class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
 
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 7000a20..a3caa66 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -154,6 +154,7 @@
         "xla_optimization_barrier_op.cc",
         "xla_pad_op.cc",
         "xla_reduce_op.cc",
+        "xla_reduce_precision_op.cc",
         "xla_select_and_scatter_op.cc",
         "xla_self_adjoint_eig_op.cc",
         "xla_svd_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_precision_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_precision_op.cc
new file mode 100644
index 0000000..045fdf1
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_precision_op.cc
@@ -0,0 +1,25 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+namespace {
+
+REGISTER_XLA_OP(Name("XlaReducePrecision"), MlirXlaOpKernel);
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index cad8db8..9541379 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -1228,6 +1228,22 @@
 input: A Tuple of Arrays of any type.
 )doc");
 
+REGISTER_OP("XlaReducePrecision")
+    .Input("operand: T")
+    .Output("output: T")
+    .Attr("T: {bfloat16, half, float, double}")
+    .Attr("exponent_bits: int")
+    .Attr("mantissa_bits: int")
+    .SetShapeFn(shape_inference::UnchangedShape)
+    .Doc(R"doc(
+Wraps the XLA ReducePrecision operator
+  documented at https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
+
+operand: array of floating-point type.
+exponent_bits: number of exponent bits in lower-precision format
+mantissa_bits: number of mantissa bits in lower-precision format
+)doc");
+
 REGISTER_OP("XlaCustomCall")
     .Input("args: T")
     .Output("output: dtype")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 902081c..aa07730 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -593,3 +593,7 @@
 
 def optimization_barrier(*args):
   return gen_xla_ops.xla_optimization_barrier(args)
+
+
+def reduce_precision(operand, exponent_bits, mantissa_bits):
+  return gen_xla_ops.xla_reduce_precision(operand, exponent_bits, mantissa_bits)