Add XlaReducePrecision TensorFlow op.
See https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
This is somewhat useful for testing fp8. Unfortunately the usefulness is currently quite limited due to the fact certain fp8 representations, such as the ones described in this paper https://arxiv.org/abs/2206.02915, have different representations of Inf/NaN compared to the standard IEEE floating-point dtypes.
PiperOrigin-RevId: 468765094
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index a8855bc..42b43cd 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -2240,6 +2240,7 @@
"XlaPad",
"XlaRecv",
"XlaReduce",
+ "XlaReducePrecision",
"XlaReduceWindow",
"XlaRemoveDynamicDimensionSize",
"XlaReplicaId",
diff --git a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
index 888f490..9fc9583 100644
--- a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
+++ b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
@@ -70,6 +70,8 @@
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaVariadicReduceV2").Device(DEVICE), \
XlaCompileOnDemandOp); \
+ REGISTER_KERNEL_BUILDER(Name("XlaReducePrecision").Device(DEVICE), \
+ XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaReduceWindow") \
.HostMemory("window_dimensions") \
.HostMemory("window_strides") \
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index b7de9e1..8b57b79 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -20237,6 +20237,27 @@
let hasCanonicalizer = 1;
}
+def TF_XlaReducePrecisionOp : TF_Op<"XlaReducePrecision", [NoSideEffect]> {
+ let summary = "Wraps the XLA ReducePrecision operator";
+
+ let description = [{
+documented at https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
+ }];
+
+ let arguments = (ins
+ Arg<TF_FloatTensor, [{array of floating-point type.}]>:$operand,
+
+ I64Attr:$exponent_bits,
+ I64Attr:$mantissa_bits
+ );
+
+ let results = (outs
+ TF_FloatTensor:$output
+ );
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
def TF_XlaReduceScatterOp : TF_Op<"XlaReduceScatter", [NoSideEffect]> {
let summary = "Wraps the XLA ReduceScatter operator";
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 5e6452e..f847b7e 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -7223,6 +7223,31 @@
return success();
}
};
+
+// Convert tf.XlaReducePrecision to mhlo.ReducePrecision
+class ConvertXlaReducePrecisionOp
+ : public OpRewritePattern<TF::XlaReducePrecisionOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TF::XlaReducePrecisionOp op,
+ PatternRewriter &rewriter) const override {
+ IntegerType int32_type = rewriter.getIntegerType(32);
+ APInt exponent_bits = op.exponent_bitsAttr().getValue();
+ // Truncating to 32-bits is safe, since pasing any number above the dtype
+ // size (which is at most 64, for float64) is equivalent to passing the
+ // dtype size.
+ IntegerAttr new_exponent_attr =
+ IntegerAttr::get(int32_type, exponent_bits.truncSSat(32));
+ APInt mantissa_bits = op.mantissa_bitsAttr().getValue();
+ IntegerAttr new_mantissa_attr =
+ IntegerAttr::get(int32_type, mantissa_bits.truncSSat(32));
+ rewriter.replaceOpWithNewOp<mhlo::ReducePrecisionOp>(
+ op, op.getType(), op.operand(), new_exponent_attr, new_mantissa_attr);
+ return success();
+ }
+};
+
} // end namespace
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
@@ -7308,6 +7333,7 @@
ConvertXlaShardingOp,
ConvertXlaDynamicUpdateSliceOp,
ConvertXlaConvV2Op,
+ ConvertXlaReducePrecisionOp,
ConvertXlaReduceScatterOp,
ConvertXlaReduceWindowOp,
ConvertXlaRngBitGeneratorOp,
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index ebeb95c..bb075fd 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -49,8 +49,8 @@
output = op(*placeholders)
result = session.run(output, feeds)
if not equality_fn:
- equality_fn = self.assertAllClose
- equality_fn(result, expected, rtol=1e-3)
+ equality_fn = lambda x, y: self.assertAllClose(x, y, rtol=1e-3)
+ equality_fn(result, expected)
def testAdd(self):
if xla_test.test.is_built_with_rocm():
@@ -439,7 +439,8 @@
np.exp(1) * 8 / 7 + 1e5 / 3
],
dtype=dtype))
- error_term_equality = functools.partial(self.assertAllClose, atol=.005)
+ error_term_equality = functools.partial(
+ self.assertAllClose, rtol=1e-3, atol=.005)
self._assertOpOutputMatchesExpected(
kahan_sum_reduction(dims=[0], output_idx=1),
args=(xs[shuffle_indices],),
@@ -658,6 +659,28 @@
self._assertOpOutputMatchesExpected(
xla.optimization_barrier, args=args, expected=args)
+ def test_reduce_precision(self):
+ arg = np.array([1 + 2**-2 + 2**-4, 128, 256], dtype=np.float32)
+ expected = np.array([1 + 2**-2, 128, float('Inf')], dtype=np.float32)
+ exponent_bits = 4
+ mantissa_bits = 2
+ self._assertOpOutputMatchesExpected(
+ lambda x: xla.reduce_precision(x, exponent_bits, mantissa_bits),
+ args=(arg,),
+ expected=expected,
+ equality_fn=self.assertAllEqual)
+
+ arg = np.array([4], dtype=np.float32)
+ expected = np.array([4], dtype=np.float32)
+ # Test passing numbers that cannot fit in a 32-bit integer.
+ exponent_bits = 2**33
+ mantissa_bits = 2**33
+ self._assertOpOutputMatchesExpected(
+ lambda x: xla.reduce_precision(x, exponent_bits, mantissa_bits),
+ args=(arg,),
+ expected=expected,
+ equality_fn=self.assertAllEqual)
+
class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 7000a20..a3caa66 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -154,6 +154,7 @@
"xla_optimization_barrier_op.cc",
"xla_pad_op.cc",
"xla_reduce_op.cc",
+ "xla_reduce_precision_op.cc",
"xla_select_and_scatter_op.cc",
"xla_self_adjoint_eig_op.cc",
"xla_svd_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_precision_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_precision_op.cc
new file mode 100644
index 0000000..045fdf1
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_precision_op.cc
@@ -0,0 +1,25 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+namespace {
+
+REGISTER_XLA_OP(Name("XlaReducePrecision"), MlirXlaOpKernel);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index cad8db8..9541379 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -1228,6 +1228,22 @@
input: A Tuple of Arrays of any type.
)doc");
+REGISTER_OP("XlaReducePrecision")
+ .Input("operand: T")
+ .Output("output: T")
+ .Attr("T: {bfloat16, half, float, double}")
+ .Attr("exponent_bits: int")
+ .Attr("mantissa_bits: int")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Wraps the XLA ReducePrecision operator
+ documented at https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
+
+operand: array of floating-point type.
+exponent_bits: number of exponent bits in lower-precision format
+mantissa_bits: number of mantissa bits in lower-precision format
+)doc");
+
REGISTER_OP("XlaCustomCall")
.Input("args: T")
.Output("output: dtype")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 902081c..aa07730 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -593,3 +593,7 @@
def optimization_barrier(*args):
return gen_xla_ops.xla_optimization_barrier(args)
+
+
+def reduce_precision(operand, exponent_bits, mantissa_bits):
+ return gen_xla_ops.xla_reduce_precision(operand, exponent_bits, mantissa_bits)