[tfrt:jitrt:xla] Add data types conversion for types used in XLA

PiperOrigin-RevId: 449671126
diff --git a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
index f00d3bb..0747c7d 100644
--- a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
+++ b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
@@ -124,10 +124,44 @@
 
 static PrimitiveType ToPrimitiveType(tfrt::DType dtype) {
   switch (dtype) {
+    // Unsigned integer types.
+    case tfrt::DType::UI8:
+      return PrimitiveType::U8;
+    case tfrt::DType::UI16:
+      return PrimitiveType::U16;
+    case tfrt::DType::UI32:
+      return PrimitiveType::U32;
+    case tfrt::DType::UI64:
+      return PrimitiveType::U64;
+
+    // Signed integer types.
+    case tfrt::DType::I1:
+      return PrimitiveType::PRED;
+    case tfrt::DType::I8:
+      return PrimitiveType::S8;
+    case tfrt::DType::I16:
+      return PrimitiveType::S16;
+    case tfrt::DType::I32:
+      return PrimitiveType::S32;
+    case tfrt::DType::I64:
+      return PrimitiveType::S64;
+
+    // Floating point types.
+    case tfrt::DType::F16:
+      return PrimitiveType::F16;
     case tfrt::DType::F32:
       return PrimitiveType::F32;
     case tfrt::DType::F64:
       return PrimitiveType::F64;
+    case tfrt::DType::BF16:
+      return PrimitiveType::BF16;
+
+    // Complex types.
+    case tfrt::DType::Complex64:
+      return PrimitiveType::C64;
+    case tfrt::DType::Complex128:
+      return PrimitiveType::C128;
+
     default:
       LOG(FATAL) << "Unsupported data type: " << dtype;
   }