PR #43167: [INTEL MKL] Added missed bfloat16 CPU support for op math.rsqrt
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/43167
Copybara import of the project:
--
17ca1c7430c3f31b43a032ac9bf6b4ec18b9f9ae by Xiaoming (Jason) Cui <xiaoming.cui@intel.com>:
[INTEL MKL] Added missed bfloat16 CPU support for op math.rsqrt
PiperOrigin-RevId: 332304599
Change-Id: I9c66d29ed7cf1010388d70f30f84fef97de6dde6
diff --git a/tensorflow/core/kernels/cwise_op_rsqrt.cc b/tensorflow/core/kernels/cwise_op_rsqrt.cc
index e051e4d..21e3bf4 100644
--- a/tensorflow/core/kernels/cwise_op_rsqrt.cc
+++ b/tensorflow/core/kernels/cwise_op_rsqrt.cc
@@ -16,15 +16,15 @@
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER6(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, bfloat16,
- double, complex64, complex128);
+REGISTER5(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double,
+ complex64, complex128);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER3(UnaryOp, GPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double);
#endif
-REGISTER6(SimpleBinaryOp, CPU, "RsqrtGrad", functor::rsqrt_grad, float,
- Eigen::half, bfloat16, double, complex64, complex128);
+REGISTER5(SimpleBinaryOp, CPU, "RsqrtGrad", functor::rsqrt_grad, float,
+ Eigen::half, double, complex64, complex128);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER3(SimpleBinaryOp, GPU, "RsqrtGrad", functor::rsqrt_grad, float,
Eigen::half, double);
diff --git a/tensorflow/python/kernel_tests/cwise_ops_unary_test.py b/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
index c0f61f0..9d46ed3 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
@@ -405,7 +405,6 @@
self._compareCpu(z, compute_f32(np.log), math_ops.log)
self._compareCpu(z, compute_f32(np.log1p), math_ops.log1p)
self._compareCpu(y, np.sign, math_ops.sign)
- self._compareCpu(z, self._rsqrt, math_ops.rsqrt)
self._compareBoth(x, compute_f32(np.sin), math_ops.sin)
self._compareBoth(x, compute_f32(np.cos), math_ops.cos)
self._compareBoth(x, compute_f32(np.tan), math_ops.tan)