[tfrt:jitrt:xla] Add data types conversion for types used in XLA
PiperOrigin-RevId: 449671126
diff --git a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
index f00d3bb..0747c7d 100644
--- a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
+++ b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
@@ -124,10 +124,44 @@
static PrimitiveType ToPrimitiveType(tfrt::DType dtype) {
switch (dtype) {
+ // Unsigned integer types.
+ case tfrt::DType::UI8:
+ return PrimitiveType::U8;
+ case tfrt::DType::UI16:
+ return PrimitiveType::U16;
+ case tfrt::DType::UI32:
+ return PrimitiveType::U32;
+ case tfrt::DType::UI64:
+ return PrimitiveType::U64;
+
+ // Signed integer types.
+ case tfrt::DType::I1:
+ return PrimitiveType::PRED;
+ case tfrt::DType::I8:
+ return PrimitiveType::S8;
+ case tfrt::DType::I16:
+ return PrimitiveType::S16;
+ case tfrt::DType::I32:
+ return PrimitiveType::S32;
+ case tfrt::DType::I64:
+ return PrimitiveType::S64;
+
+ // Floating point types.
+ case tfrt::DType::F16:
+ return PrimitiveType::F16;
case tfrt::DType::F32:
return PrimitiveType::F32;
case tfrt::DType::F64:
return PrimitiveType::F64;
+ case tfrt::DType::BF16:
+ return PrimitiveType::BF16;
+
+ // Complex types.
+ case tfrt::DType::Complex64:
+ return PrimitiveType::C64;
+ case tfrt::DType::Complex128:
+ return PrimitiveType::C128;
+
default:
LOG(FATAL) << "Unsupported data type: " << dtype;
}