Use MlirXlaOpKernel for 7 ops. These ops now use MlirXlaOpKernel to lower from TF to HLO in the tf2xla bridge. This is a subset of ops that work with MlirXlaOpKernel. To be cautious, I'm moving a small number first.

PiperOrigin-RevId: 369944953
Change-Id: If20004db5d8498bdbf9c4c33a1119f2f7609ed1b
diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
index 33bdf9a..1542493 100644
--- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
@@ -15,6 +15,7 @@
 
 #include <numeric>
 
+#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -72,7 +73,7 @@
   TensorFormat data_format_;
 };
 
-REGISTER_XLA_OP(Name("BiasAdd"), BiasOp);
+REGISTER_XLA_OP(Name("BiasAdd"), MlirXlaOpKernel);
 REGISTER_XLA_OP(Name("BiasAddV1"), BiasOp);
 
 class BiasAddGradOp : public XlaOpKernel {
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 39f4bee..ee3c711 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -17,6 +17,7 @@
 
 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
 #include "tensorflow/compiler/tf2xla/lib/broadcast.h"
+#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -60,7 +61,7 @@
 XLA_MAKE_BINARY(AddV2, xla::Add(lhs, rhs, extend_dimensions));
 XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions));
 XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
+REGISTER_XLA_OP(Name("Div"), MlirXlaOpKernel);
 
 XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
 XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
@@ -142,14 +143,7 @@
 XLA_MAKE_BINARY(FloorDiv,
                 FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));
 
-xla::XlaOp XlogyImpl(xla::XlaOp x, xla::XlaOp y,
-                     const BCast& broadcast_helper) {
-  std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
-  auto zero = xla::ZerosLike(x);
-  auto is_zero = xla::Eq(x, zero);
-  return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y)));
-}
-XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper));
+REGISTER_XLA_OP(Name("Xlogy"), MlirXlaOpKernel);
 
 xla::XlaOp Xlog1pyImpl(xla::XlaOp x, xla::XlaOp y,
                        const BCast& broadcast_helper) {
@@ -298,12 +292,7 @@
 
 XLA_MAKE_BINARY(Polygamma, PolygammaImpl(lhs, rhs, broadcast_helper));
 
-xla::XlaOp ZetaImpl(xla::XlaOp x, xla::XlaOp q, const BCast& broadcast_helper) {
-  std::tie(x, q) = XlaBinaryOp::Broadcast(x, q, broadcast_helper);
-  return xla::Zeta(x, q);
-}
-
-XLA_MAKE_BINARY(Zeta, ZetaImpl(lhs, rhs, broadcast_helper));
+REGISTER_XLA_OP(Name("Zeta"), MlirXlaOpKernel);
 
 #undef XLA_MAKE_BINARY
 
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index b722ecf..4b73e52 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -15,6 +15,7 @@
 
 // XLA specific pooling ops.
 
+#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -380,7 +381,7 @@
                 errors::InvalidArgument("Invalid data format"));
   }
 };
-REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
+REGISTER_XLA_OP(Name("MaxPoolGrad"), MlirXlaOpKernel);
 REGISTER_XLA_OP(Name("MaxPoolGradV2")
                     .CompileTimeConstantInput("ksize")
                     .CompileTimeConstantInput("strides"),
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index 6fe6b16..2567ce4 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -16,6 +16,7 @@
 // Native XLA implementations of simple unary Ops
 
 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -84,7 +85,7 @@
 XLAJIT_MAKE_UNARY(Rint, xla::RoundToEven(x));
 XLAJIT_MAKE_UNARY(Round, xla::RoundToEven(x));
 
-XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x));
+REGISTER_XLA_OP(Name("Rsqrt"), MlirXlaOpKernel);
 
 XLAJIT_MAKE_UNARY(Sigmoid, xla::Logistic(x));
 
@@ -128,7 +129,7 @@
 XLAJIT_MAKE_UNARY(Ndtri, xla::ScalarLike(x, std::sqrt(2.0)) *
                              xla::ErfInv(xla::ScalarLike(x, 2.0) * x -
                                          xla::ScalarLike(x, 1.0)));
-XLAJIT_MAKE_UNARY(Lgamma, xla::Lgamma(x));
+REGISTER_XLA_OP(Name("Lgamma"), MlirXlaOpKernel);
 XLAJIT_MAKE_UNARY(Digamma, xla::Digamma(x));
 XLAJIT_MAKE_UNARY(BesselI0e, xla::BesselI0e(x));
 XLAJIT_MAKE_UNARY(BesselI1e, xla::BesselI1e(x));