Improved GetComputeUnitsCount for GpuInfo.

PiperOrigin-RevId: 454194098
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
index 1895c17..1d131c0 100644
--- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
@@ -506,6 +506,32 @@
   return IsValhallGen1() || IsValhallGen2() || IsValhallGen3();
 }
 
+int MaliInfo::GetApproximateComputeUnitsCount() const {
+  if (IsMidgard()) {
+    // Mali Midgard can have 1-16 cores
+    return 8;
+  } else if (IsBifrost()) {
+    // Mali Bifrost can have 1-32 cores
+    return 16;
+  } else if (IsValhall()) {
+    if (gpu_version == MaliGpu::kG57) {
+      return 6;  // Mali-G57 can have 1-6 cores
+    } else if (gpu_version == MaliGpu::kG77) {
+      return 16;  // Mali-G77 can have 7-16 cores
+    } else if (gpu_version == MaliGpu::kG68) {
+      return 6;  // Mali-G68 can have 4-6 cores
+    } else if (gpu_version == MaliGpu::kG78) {
+      return 16;  // Mali-G78 can have 7-24 cores
+    } else if (gpu_version == MaliGpu::kG310 || gpu_version == MaliGpu::kG510 ||
+               gpu_version == MaliGpu::kG610) {
+      return 6;  // Mali-G310/G510/G610 can have up to 6 cores
+    } else if (gpu_version == MaliGpu::kG710) {
+      return 10;  // Mali-G710 can have 7–16 cores
+    }
+  }
+  return 4;
+}
+
 void GetGpuInfoFromDeviceDescription(const std::string& gpu_description,
                                      GpuApi gpu_api, GpuInfo* gpu_info) {
   gpu_info->gpu_api = gpu_api;
@@ -782,13 +808,21 @@
   if (IsApple()) {
     return apple_info.GetComputeUnitsCount();
   }
-  if (IsAMD() && IsApiVulkan()) {
-    return amd_info.GetComputeUnitsCount();
+  if (IsAMD()) {
+    if (amd_info.GetComputeUnitsCount() != 0) {
+      return amd_info.GetComputeUnitsCount();
+    } else {
+      // approximate number
+      return 16;
+    }
   }
   if (IsAdreno()) {
     return adreno_info.GetComputeUnitsCount();
   }
-  return 1;
+  if (IsMali()) {
+    mali_info.GetApproximateComputeUnitsCount();
+  }
+  return 4;
 }
 
 int GpuInfo::GetMaxWorkGroupSizeForX() const {
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h
index c9d17af..b9a5dbe 100644
--- a/tensorflow/lite/delegates/gpu/common/gpu_info.h
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h
@@ -100,8 +100,8 @@
 
 struct AMDInfo {
   AMDInfo() = default;
-  int shader_engines;
-  int compute_units_per_shader_engine;
+  int shader_engines = 0;
+  int compute_units_per_shader_engine = 0;
   int GetComputeUnitsCount() const {
     return shader_engines * compute_units_per_shader_engine;
   }
@@ -250,6 +250,9 @@
   bool IsValhallGen2() const;
   bool IsValhallGen3() const;
   bool IsValhall() const;
+
+  // returns approximate compute units count using GPU name
+  int GetApproximateComputeUnitsCount() const;
 };
 
 struct OpenGlInfo {