Improved GetComputeUnitsCount for GpuInfo.
PiperOrigin-RevId: 454194098
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
index 1895c17..1d131c0 100644
--- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
@@ -506,6 +506,32 @@
return IsValhallGen1() || IsValhallGen2() || IsValhallGen3();
}
+int MaliInfo::GetApproximateComputeUnitsCount() const {
+ if (IsMidgard()) {
+ // Mali Midgard can have 1-16 cores
+ return 8;
+ } else if (IsBifrost()) {
+ // Mali Bifrost can have 1-32 cores
+ return 16;
+ } else if (IsValhall()) {
+ if (gpu_version == MaliGpu::kG57) {
+ return 6; // Mali-G57 can have 1-6 cores
+ } else if (gpu_version == MaliGpu::kG77) {
+ return 16; // Mali-G77 can have 7-16 cores
+ } else if (gpu_version == MaliGpu::kG68) {
+ return 6; // Mali-G68 can have 4-6 cores
+ } else if (gpu_version == MaliGpu::kG78) {
+ return 16; // Mali-G78 can have 7-24 cores
+ } else if (gpu_version == MaliGpu::kG310 || gpu_version == MaliGpu::kG510 ||
+ gpu_version == MaliGpu::kG610) {
+ return 6; // Mali-G310/G510/G610 can have up to 6 cores
+ } else if (gpu_version == MaliGpu::kG710) {
+ return 10; // Mali-G710 can have 7–16 cores
+ }
+ }
+ return 4;
+}
+
void GetGpuInfoFromDeviceDescription(const std::string& gpu_description,
GpuApi gpu_api, GpuInfo* gpu_info) {
gpu_info->gpu_api = gpu_api;
@@ -782,13 +808,21 @@
if (IsApple()) {
return apple_info.GetComputeUnitsCount();
}
- if (IsAMD() && IsApiVulkan()) {
- return amd_info.GetComputeUnitsCount();
+ if (IsAMD()) {
+ if (amd_info.GetComputeUnitsCount() != 0) {
+ return amd_info.GetComputeUnitsCount();
+ } else {
+ // approximate number
+ return 16;
+ }
}
if (IsAdreno()) {
return adreno_info.GetComputeUnitsCount();
}
- return 1;
+ if (IsMali()) {
+ mali_info.GetApproximateComputeUnitsCount();
+ }
+ return 4;
}
int GpuInfo::GetMaxWorkGroupSizeForX() const {
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h
index c9d17af..b9a5dbe 100644
--- a/tensorflow/lite/delegates/gpu/common/gpu_info.h
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h
@@ -100,8 +100,8 @@
struct AMDInfo {
AMDInfo() = default;
- int shader_engines;
- int compute_units_per_shader_engine;
+ int shader_engines = 0;
+ int compute_units_per_shader_engine = 0;
int GetComputeUnitsCount() const {
return shader_engines * compute_units_per_shader_engine;
}
@@ -250,6 +250,9 @@
bool IsValhallGen2() const;
bool IsValhallGen3() const;
bool IsValhall() const;
+
+ // returns approximate compute units count using GPU name
+ int GetApproximateComputeUnitsCount() const;
};
struct OpenGlInfo {