Switches MatrixDiagOp to MlirXlaOpKernel for lowering from TF to HLO in the tf2xla bridge.

PiperOrigin-RevId: 416379755
Change-Id: I3ff5713562c6e3d39a10d14d4259245dd2617e2a
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc
index f81e45c..2fcbf22 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc
@@ -13,6 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -329,7 +330,7 @@
   static constexpr int kNumV1Inputs = 1;
 };
 
-REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
+REGISTER_XLA_OP(Name("MatrixDiag"), MlirXlaOpKernel);
 REGISTER_XLA_OP(Name("MatrixDiagV2")
                     .CompileTimeConstantInput("k")
                     .CompileTimeConstantInput("num_rows")