Extended gpu/common/gpu_info to support AppleInfo.
Using gpu/common/gpu_info for Metal backend.

PiperOrigin-RevId: 341696861
Change-Id: I0c71c0691c76af4f7282d9aa886b41d61a3ccf96
diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD
index 3dfab18..069230e 100644
--- a/tensorflow/lite/delegates/gpu/BUILD
+++ b/tensorflow/lite/delegates/gpu/BUILD
@@ -88,6 +88,7 @@
         "//tensorflow/lite:minimal_logging",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/delegates/gpu/common:convert",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_builder",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
@@ -99,7 +100,6 @@
         "//tensorflow/lite/delegates/gpu/metal:api",
         "//tensorflow/lite/delegates/gpu/metal:buffer_convert",
         "//tensorflow/lite/delegates/gpu/metal:compiled_model",
-        "//tensorflow/lite/delegates/gpu/metal:device_info",
         "//tensorflow/lite/delegates/gpu/metal:inference_context",
         "@com_google_absl//absl/types:span",
     ],
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
index 40a4e4b..5a816fd 100644
--- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
@@ -24,26 +24,28 @@
 namespace gpu {
 namespace {
 
-GpuVendor GetGpuVendor(const std::string& renderer) {
-  if (renderer.find("mali") != renderer.npos) {
-    return GpuVendor::kMali;
-  }
-  if (renderer.find("adreno") != renderer.npos) {
-    return GpuVendor::kQualcomm;
-  }
-  if (renderer.find("powervr") != renderer.npos) {
-    return GpuVendor::kPowerVR;
-  }
-  if (renderer.find("intel") != renderer.npos) {
-    return GpuVendor::kIntel;
-  }
-  if (renderer.find("nvidia") != renderer.npos) {
-    return GpuVendor::kNvidia;
+GpuVendor GetGpuVendor(const std::string& gpu_description) {
+  const std::map<std::string, GpuVendor> kMapping = {
+      {"adreno", GpuVendor::kQualcomm},
+      {"apple", GpuVendor::kApple},
+      {"qualcomm", GpuVendor::kQualcomm},
+      {"mali", GpuVendor::kMali},
+      {"powervr", GpuVendor::kPowerVR},
+      {"advanced micro devices", GpuVendor::kAMD},
+      {"intel", GpuVendor::kIntel},
+      {"nvidia", GpuVendor::kNvidia},
+      {"amd", GpuVendor::kAMD},
+      {"power", GpuVendor::kPowerVR},
+  };
+  for (const auto& v : kMapping) {
+    if (gpu_description.find(v.first) != std::string::npos) {
+      return v.second;
+    }
   }
   return GpuVendor::kUnknown;
 }
 
-AdrenoGpu GetAdrenoGpuVersion(const std::string& device_name) {
+AdrenoGpu GetAdrenoGpuVersion(const std::string& gpu_description) {
   const std::map<std::string, AdrenoGpu> kMapping = {
       // Adreno 6xx series
       {"685", AdrenoGpu::kAdreno685},
@@ -93,7 +95,7 @@
   };
 
   for (const auto& v : kMapping) {
-    if (device_name.find(v.first) != std::string::npos) {
+    if (gpu_description.find(v.first) != std::string::npos) {
       return v.second;
     }
   }
@@ -212,6 +214,70 @@
   }
 }
 
+AppleInfo::AppleInfo(const std::string& gpu_description) {
+  const std::map<std::string, AppleGpu> kMapping = {
+      {"apple a7 gpu", AppleGpu::kA7},     {"apple a8 gpu", AppleGpu::kA8},
+      {"apple a8x gpu", AppleGpu::kA8X},   {"apple a9 gpu", AppleGpu::kA9},
+      {"apple a9x gpu", AppleGpu::kA9X},   {"apple a10 gpu", AppleGpu::kA10},
+      {"apple a10x gpu", AppleGpu::kA10X}, {"apple a11 gpu", AppleGpu::kA11},
+      {"apple a12 gpu", AppleGpu::kA12},   {"apple a12x gpu", AppleGpu::kA12X},
+      {"apple a12z gpu", AppleGpu::kA12Z}, {"apple a13 gpu", AppleGpu::kA13},
+      {"apple a14 gpu", AppleGpu::kA14},
+  };
+  auto it = kMapping.find(gpu_description);
+  if (it != kMapping.end()) {
+    gpu_type = it->second;
+  } else {
+    gpu_type = AppleGpu::kUnknown;
+  }
+}
+
+bool AppleInfo::IsLocalMemoryPreferredOverGlobal() const {
+  return gpu_type == AppleGpu::kA7 || gpu_type == AppleGpu::kA8 ||
+         gpu_type == AppleGpu::kA8X;
+}
+
+bool AppleInfo::IsBionic() const {
+  return gpu_type == AppleGpu::kA11 || gpu_type == AppleGpu::kA12 ||
+         gpu_type == AppleGpu::kA12X || gpu_type == AppleGpu::kA12Z ||
+         gpu_type == AppleGpu::kA13 || gpu_type == AppleGpu::kA14;
+}
+
+bool AppleInfo::IsRoundToNearestSupported() const { return IsBionic(); }
+
+int AppleInfo::GetComputeUnitsCount() const {
+  switch (gpu_type) {
+    case AppleGpu::kA7:
+      return 4;
+    case AppleGpu::kA8:
+      return 4;
+    case AppleGpu::kA8X:
+      return 8;
+    case AppleGpu::kA9:
+      return 6;
+    case AppleGpu::kA9X:
+      return 12;
+    case AppleGpu::kA10:
+      return 6;
+    case AppleGpu::kA10X:
+      return 12;
+    case AppleGpu::kA11:
+      return 3;
+    case AppleGpu::kA12:
+      return 4;
+    case AppleGpu::kA12X:
+      return 7;
+    case AppleGpu::kA12Z:
+      return 8;
+    case AppleGpu::kA13:
+      return 4;
+    case AppleGpu::kA14:
+      return 4;
+    case AppleGpu::kUnknown:
+      return 1;
+  }
+}
+
 void GetGpuInfoFromDeviceDescription(const std::string& gpu_description,
                                      GpuInfo* gpu_info) {
   std::string lowered = gpu_description;
@@ -219,6 +285,9 @@
   gpu_info->vendor = GetGpuVendor(lowered);
   if (gpu_info->IsAdreno()) {
     gpu_info->adreno_info = AdrenoInfo(lowered);
+  } else if (gpu_info->IsApple()) {
+    gpu_info->apple_info = AppleInfo(lowered);
+    gpu_info->supported_subgroup_sizes = {32};
   }
 }
 
@@ -236,5 +305,26 @@
 
 bool GpuInfo::IsIntel() const { return vendor == GpuVendor::kIntel; }
 
+bool GpuInfo::IsRoundToNearestSupported() const {
+  if (IsApple()) {
+    return apple_info.IsRoundToNearestSupported();
+  } else {
+    return true;
+  }
+}
+
+bool GpuInfo::IsWaveSizeEqualTo32() const {
+  return supported_subgroup_sizes.size() == 1 &&
+         supported_subgroup_sizes[0] == 32;
+}
+
+int GpuInfo::GetComputeUnitsCount() const {
+  if (IsApple()) {
+    return apple_info.GetComputeUnitsCount();
+  } else {
+    return 1;
+  }
+}
+
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h
index 053021b..ec32823 100644
--- a/tensorflow/lite/delegates/gpu/common/gpu_info.h
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h
@@ -116,6 +116,38 @@
   bool support_one_layer_texture_array = true;
 };
 
+enum class AppleGpu {
+  kUnknown,
+  kA7,
+  kA8,
+  kA8X,
+  kA9,
+  kA9X,
+  kA10,
+  kA10X,
+  kA11,
+  kA12,
+  kA12X,
+  kA12Z,
+  kA13,
+  kA14,
+};
+
+struct AppleInfo {
+  AppleInfo() = default;
+  explicit AppleInfo(const std::string& gpu_description);
+  AppleGpu gpu_type;
+
+  bool IsLocalMemoryPreferredOverGlobal() const;
+
+  bool IsBionic() const;
+
+  // floating point rounding mode
+  bool IsRoundToNearestSupported() const;
+
+  int GetComputeUnitsCount() const;
+};
+
 struct GpuInfo {
   bool IsAdreno() const;
   bool IsApple() const;
@@ -125,6 +157,14 @@
   bool IsAMD() const;
   bool IsIntel() const;
 
+  // floating point rounding mode
+  bool IsRoundToNearestSupported() const;
+
+  // returns true if device have fixed wave size equal to 32
+  bool IsWaveSizeEqualTo32() const;
+
+  int GetComputeUnitsCount() const;
+
   GpuVendor vendor = GpuVendor::kUnknown;
 
   std::string renderer_name;
@@ -141,7 +181,10 @@
   int max_image_units = 0;
   int max_array_texture_layers = 0;
 
+  std::vector<int> supported_subgroup_sizes;
+
   AdrenoInfo adreno_info;
+  AppleInfo apple_info;
 };
 
 inline bool IsOpenGl31OrAbove(const GpuInfo& gpu_info) {
@@ -149,8 +192,10 @@
          gpu_info.major_version > 3;
 }
 
-// Currently it initializes vendor and AdrenoInfo if
-// vendor is kQualcomm
+// Currently it initializes:
+// vendor
+// AdrenoInfo if vendor is kQualcomm
+// AppleInfo if vendor is kApple
 void GetGpuInfoFromDeviceDescription(const std::string& gpu_description,
                                      GpuInfo* gpu_info);
 
diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD
index 8d00eee..81b8434 100644
--- a/tensorflow/lite/delegates/gpu/metal/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/BUILD
@@ -26,8 +26,8 @@
     deps = [
         ":compiled_model",
         ":compute_task_descriptor",
-        ":device_info",
         ":runtime_options",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
@@ -188,12 +188,6 @@
     ],
 )
 
-cc_library(
-    name = "device_info",
-    srcs = ["device_info.cc"],
-    hdrs = ["device_info.h"],
-)
-
 objc_library(
     name = "gpu_object",
     hdrs = ["gpu_object.h"],
@@ -309,10 +303,10 @@
     ],
     sdk_frameworks = ["XCTest"],
     deps = [
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/metal:buffer",
         "//tensorflow/lite/delegates/gpu/metal:common",
-        "//tensorflow/lite/delegates/gpu/metal:device_info",
         "//tensorflow/lite/delegates/gpu/metal:inference_context",
         "//tensorflow/lite/delegates/gpu/metal:runtime_options",
         "//tensorflow/lite/delegates/gpu/metal/kernels:test_util",
diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc
index acb9ec7..561b082 100644
--- a/tensorflow/lite/delegates/gpu/metal/api.cc
+++ b/tensorflow/lite/delegates/gpu/metal/api.cc
@@ -18,6 +18,7 @@
 #include <vector>
 
 #include "absl/strings/substitute.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
@@ -25,7 +26,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/add.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/concat.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h"
diff --git a/tensorflow/lite/delegates/gpu/metal/api.h b/tensorflow/lite/delegates/gpu/metal/api.h
index 09eb651..f7cdfa4 100644
--- a/tensorflow/lite/delegates/gpu/metal/api.h
+++ b/tensorflow/lite/delegates/gpu/metal/api.h
@@ -16,10 +16,10 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_API_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_API_H_
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/metal/common_test.mm b/tensorflow/lite/delegates/gpu/metal/common_test.mm
index 9fecc59..3e2db54 100644
--- a/tensorflow/lite/delegates/gpu/metal/common_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/common_test.mm
@@ -22,7 +22,7 @@
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 
 using ::tflite::gpu::metal::GetBestSupportedMetalDevice;
 using ::tflite::gpu::metal::CreateComputeProgram;
diff --git a/tensorflow/lite/delegates/gpu/metal/device_info.cc b/tensorflow/lite/delegates/gpu/metal/device_info.cc
deleted file mode 100644
index 250ca9b..0000000
--- a/tensorflow/lite/delegates/gpu/metal/device_info.cc
+++ /dev/null
@@ -1,150 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
-
-#include <map>
-#include <string>
-
-namespace tflite {
-namespace gpu {
-namespace metal {
-namespace {
-GpuVendor GetVendorFromString(const std::string& device_name) {
-  const std::map<std::string, GpuVendor> kMapping = {
-      {"Apple", GpuVendor::kApple},
-      {"Intel", GpuVendor::kIntel},
-      {"AMD", GpuVendor::kAMD},
-  };
-  for (const auto& v : kMapping) {
-    if (device_name.find(v.first) != std::string::npos) {
-      return v.second;
-    }
-  }
-  return GpuVendor::kUnknown;
-}
-}  // namespace
-
-AppleGPUInfo::AppleGPUInfo(const std::string& device_name) {
-  const std::map<std::string, AppleGPU> kMapping = {
-      {"Apple A7 GPU", AppleGPU::kA7},     {"Apple A8 GPU", AppleGPU::kA8},
-      {"Apple A8X GPU", AppleGPU::kA8X},   {"Apple A9 GPU", AppleGPU::kA9},
-      {"Apple A9X GPU", AppleGPU::kA9X},   {"Apple A10 GPU", AppleGPU::kA10},
-      {"Apple A10X GPU", AppleGPU::kA10X}, {"Apple A11 GPU", AppleGPU::kA11},
-      {"Apple A12 GPU", AppleGPU::kA12},   {"Apple A12X GPU", AppleGPU::kA12X},
-      {"Apple A12Z GPU", AppleGPU::kA12Z}, {"Apple A13 GPU", AppleGPU::kA13},
-      {"Apple A14 GPU", AppleGPU::kA14},
-  };
-  auto it = kMapping.find(device_name);
-  if (it != kMapping.end()) {
-    gpu_type = it->second;
-  } else {
-    gpu_type = AppleGPU::kUnknown;
-  }
-}
-
-bool AppleGPUInfo::IsLocalMemoryPreferredOverGlobal() const {
-  return gpu_type == AppleGPU::kA7 ||
-         gpu_type == AppleGPU::kA8 ||
-         gpu_type == AppleGPU::kA8X;
-}
-
-bool AppleGPUInfo::IsBionic() const {
-  return gpu_type == AppleGPU::kA11 || gpu_type == AppleGPU::kA12 ||
-         gpu_type == AppleGPU::kA12X || gpu_type == AppleGPU::kA12Z ||
-         gpu_type == AppleGPU::kA13 || gpu_type == AppleGPU::kA14;
-}
-
-bool AppleGPUInfo::IsRoundToNearestSupported() const {
-  return IsBionic();
-}
-
-bool AppleGPUInfo::IsWaveSizeEqualTo32() const {
-  return true;
-}
-
-int AppleGPUInfo::GetComputeUnitsCount() const {
-  switch (gpu_type) {
-    case AppleGPU::kA7:
-      return 4;
-    case AppleGPU::kA8:
-      return 4;
-    case AppleGPU::kA8X:
-      return 8;
-    case AppleGPU::kA9:
-      return 6;
-    case AppleGPU::kA9X:
-      return 12;
-    case AppleGPU::kA10:
-      return 6;
-    case AppleGPU::kA10X:
-      return 12;
-    case AppleGPU::kA11:
-      return 3;
-    case AppleGPU::kA12:
-      return 4;
-    case AppleGPU::kA12X:
-      return 7;
-    case AppleGPU::kA12Z:
-      return 8;
-    case AppleGPU::kA13:
-      return 4;
-    case AppleGPU::kA14:
-      return 4;
-    case AppleGPU::kUnknown:
-      return 1;
-  }
-}
-
-GpuInfo::GpuInfo(const std::string& device_name)
-    : vendor(GetVendorFromString(device_name)) {
-  if (vendor == GpuVendor::kApple) {
-    apple_info = AppleGPUInfo(device_name);
-  }
-}
-
-bool GpuInfo::IsIntel() const { return vendor == GpuVendor::kIntel; }
-
-bool GpuInfo::IsApple() const { return vendor == GpuVendor::kApple; }
-
-bool GpuInfo::IsAMD() const { return vendor == GpuVendor::kAMD; }
-
-bool GpuInfo::IsRoundToNearestSupported() const {
-  if (vendor == GpuVendor::kApple) {
-    return apple_info.IsRoundToNearestSupported();
-  } else {
-    return true;
-  }
-}
-
-bool GpuInfo::IsWaveSizeEqualTo32() const {
-  if (vendor == GpuVendor::kApple) {
-    return apple_info.IsWaveSizeEqualTo32();
-  } else {
-    return false;
-  }
-}
-
-int GpuInfo::GetComputeUnitsCount() const {
-  if (vendor == GpuVendor::kApple) {
-    return apple_info.GetComputeUnitsCount();
-  } else {
-    return 1;
-  }
-}
-
-}  // namespace metal
-}  // namespace gpu
-}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/device_info.h b/tensorflow/lite/delegates/gpu/metal/device_info.h
deleted file mode 100644
index f77d695..0000000
--- a/tensorflow/lite/delegates/gpu/metal/device_info.h
+++ /dev/null
@@ -1,97 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_DEVICE_INFO_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_DEVICE_INFO_H_
-
-#include <string>
-
-namespace tflite {
-namespace gpu {
-namespace metal {
-
-// The VendorID returned by the GPU driver.
-enum class GpuVendor {
-  kApple,
-  kQualcomm,
-  kMali,
-  kPowerVR,
-  kNvidia,
-  kAMD,
-  kIntel,
-  kUnknown
-};
-
-enum class AppleGPU {
-  kUnknown,
-  kA7,
-  kA8,
-  kA8X,
-  kA9,
-  kA9X,
-  kA10,
-  kA10X,
-  kA11,
-  kA12,
-  kA12X,
-  kA12Z,
-  kA13,
-  kA14,
-};
-
-struct AppleGPUInfo {
-  AppleGPUInfo() = default;
-  explicit AppleGPUInfo(const std::string& device_name);
-  AppleGPU gpu_type;
-
-  bool IsLocalMemoryPreferredOverGlobal() const;
-
-  bool IsBionic() const;
-
-  // floating point rounding mode
-  bool IsRoundToNearestSupported() const;
-
-  // returns true if device have fixed wave size equal to 32
-  bool IsWaveSizeEqualTo32() const;
-
-  int GetComputeUnitsCount() const;
-};
-
-struct GpuInfo {
-  GpuInfo() = default;
-  explicit GpuInfo(const std::string& device_name);
-
-  GpuVendor vendor = GpuVendor::kUnknown;
-
-  AppleGPUInfo apple_info;
-
-  bool IsIntel() const;
-  bool IsApple() const;
-  bool IsAMD() const;
-
-  // floating point rounding mode
-  bool IsRoundToNearestSupported() const;
-
-  // returns true if device have fixed wave size equal to 32
-  bool IsWaveSizeEqualTo32() const;
-
-  int GetComputeUnitsCount() const;
-};
-
-}  // namespace metal
-}  // namespace gpu
-}  // namespace tflite
-
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_DEVICE_INFO_H_
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index 4033784..ae35555 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -123,6 +123,7 @@
     hdrs = ["conv.h"],
     deps = [
         "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
@@ -130,7 +131,6 @@
         "//tensorflow/lite/delegates/gpu/common:util",
         "//tensorflow/lite/delegates/gpu/common:winograd_util",
         "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
-        "//tensorflow/lite/delegates/gpu/metal:device_info",
         "//tensorflow/lite/delegates/gpu/metal:runtime_options",
         "@com_google_absl//absl/strings",
     ],
@@ -228,11 +228,11 @@
     hdrs = ["elementwise.h"],
     deps = [
         "//tensorflow/lite/delegates/gpu/common:convert",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:util",
         "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
-        "//tensorflow/lite/delegates/gpu/metal:device_info",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -265,13 +265,13 @@
     srcs = ["fully_connected.cc"],
     hdrs = ["fully_connected.h"],
     deps = [
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:util",
         "//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
         "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
-        "//tensorflow/lite/delegates/gpu/metal:device_info",
         "//tensorflow/lite/delegates/gpu/metal:runtime_options",
         "@com_google_absl//absl/strings",
     ],
@@ -691,13 +691,13 @@
     srcs = ["softmax.cc"],
     hdrs = ["softmax.h"],
     deps = [
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:util",
         "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
-        "//tensorflow/lite/delegates/gpu/metal:device_info",
         "//tensorflow/lite/delegates/gpu/metal:runtime_options",
     ],
 )
@@ -766,12 +766,12 @@
     srcs = ["transpose_conv.cc"],
     hdrs = ["transpose_conv.h"],
     deps = [
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:util",
         "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
-        "//tensorflow/lite/delegates/gpu/metal:device_info",
         "//tensorflow/lite/delegates/gpu/metal:runtime_options",
         "@com_google_absl//absl/strings",
     ],
@@ -821,6 +821,7 @@
     ],
     deps = [
         "//tensorflow/lite/delegates/gpu/common:convert",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
@@ -831,7 +832,6 @@
         "//tensorflow/lite/delegates/gpu/metal:api",
         "//tensorflow/lite/delegates/gpu/metal:common",
         "//tensorflow/lite/delegates/gpu/metal:compiled_model",
-        "//tensorflow/lite/delegates/gpu/metal:device_info",
         "//tensorflow/lite/delegates/gpu/metal:inference_context",
         "//tensorflow/lite/delegates/gpu/metal:runtime_options",
         "@FP16",
@@ -907,11 +907,11 @@
     sdk_frameworks = ["XCTest"],
     deps = [
         ":test_util",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:util",
         "//tensorflow/lite/delegates/gpu/metal:common",
-        "//tensorflow/lite/delegates/gpu/metal:device_info",
         "//tensorflow/lite/delegates/gpu/metal:inference_context",
         "//tensorflow/lite/delegates/gpu/metal:runtime_options",
     ],
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc
index 136609c..967004e 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc
@@ -25,6 +25,7 @@
 
 #include "absl/strings/substitute.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
@@ -32,7 +33,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 namespace tflite {
@@ -729,7 +729,7 @@
          attr.padding.appended.h == 0;
 }
 
-int GetMaximumPossibleWavesCount(const AppleGPUInfo& apple_info,
+int GetMaximumPossibleWavesCount(const AppleInfo& apple_info,
                                  const BHWC& dst_shape) {
   if (apple_info.IsLocalMemoryPreferredOverGlobal()) {
     return GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, {1, 1, 1});
@@ -738,7 +738,7 @@
   }
 }
 
-int GetRecommendedBlockSize(const AppleGPUInfo& apple_info,
+int GetRecommendedBlockSize(const AppleInfo& apple_info,
                             const BHWC& dst_shape) {
   const int max_waves = GetMaximumPossibleWavesCount(apple_info, dst_shape);
   const int cu_count = apple_info.GetComputeUnitsCount();
@@ -753,7 +753,7 @@
   }
 }
 
-ConvParams GetConvParamsForA7A8(const AppleGPUInfo& apple_info,
+ConvParams GetConvParamsForA7A8(const AppleInfo& apple_info,
                                 const Convolution2DAttributes& attr,
                                 const BHWC& dst_shape) {
   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
@@ -830,7 +830,7 @@
   return params;
 }
 
-ConvParams GetConvParamsForA9AndHigher(const AppleGPUInfo& apple_info,
+ConvParams GetConvParamsForA9AndHigher(const AppleInfo& apple_info,
                                        const Convolution2DAttributes& attr,
                                        const BHWC& dst_shape) {
   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.h b/tensorflow/lite/delegates/gpu/metal/kernels/conv.h
index 4187547..b2f6371 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/conv.h
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv.h
@@ -18,10 +18,10 @@
 
 #include <vector>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm
index 7842412..90fe191 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm
@@ -297,7 +297,8 @@
   outputs_v0[1].data.resize(dst_shape.DimensionsProduct());
 
   std::string device_name = std::string([[device name] UTF8String]);
-  tflite::gpu::metal::GpuInfo gpu_info(device_name);
+  tflite::gpu::GpuInfo gpu_info;
+  tflite::gpu::GetGpuInfoFromDeviceDescription(device_name, &gpu_info);
   auto tasks_v0 = ConvolutionGeneric(0, 0, 1, dst_shape, attr, gpu_info, options);
 
   auto status = RunGraph(tasks_v0, device, inputs_v0, &outputs_v0);
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc
index 79aee94..2be17c0 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc
@@ -24,13 +24,13 @@
 #include <vector>
 
 #include "absl/strings/substitute.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h
index 769ffa3..9f07b8a 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h
@@ -18,10 +18,10 @@
 
 #include <vector>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc
index 3dc5bca..7b03fd1 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc
@@ -20,12 +20,12 @@
 #include <utility>
 #include <vector>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h
index 81f45d9..f27c23f 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h
@@ -18,10 +18,10 @@
 
 #include <vector>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm
index d7c0507..0b84340 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm
@@ -34,7 +34,7 @@
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 
 namespace tflite {
 namespace gpu {
@@ -80,7 +80,8 @@
 
   id<MTLDevice> device = MTLCreateSystemDefaultDevice();
   std::string device_name = std::string([[device name] UTF8String]);
-  GpuInfo gpu_info(device_name);
+  GpuInfo gpu_info;
+  GetGpuInfoFromDeviceDescription(device_name, &gpu_info);
   RuntimeOptions options;
   options.storage_precision = RuntimeOptions::Precision::FP32;
   options.accumulator_precision = RuntimeOptions::Precision::FP32;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc
index fcf06c4..66fbc3f 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc
@@ -22,12 +22,12 @@
 #include <vector>
 
 #include "absl/strings/substitute.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h
index 5a4410d..56b9c3f 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h
@@ -18,10 +18,10 @@
 
 #include <vector>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.mm
index b4a8b91..0f9bbb3 100644
--- a/tensorflow/lite/delegates/gpu/metal_delegate.mm
+++ b/tensorflow/lite/delegates/gpu/metal_delegate.mm
@@ -43,7 +43,7 @@
 #include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
 #include "tensorflow/lite/delegates/gpu/metal/common.h"
 #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
-#include "tensorflow/lite/delegates/gpu/metal/device_info.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
@@ -338,7 +338,8 @@
     }
 
     std::string device_name = std::string([[metal_device_ name] UTF8String]);
-    GpuInfo gpu_info(device_name);
+    GpuInfo gpu_info;
+    GetGpuInfoFromDeviceDescription(device_name, &gpu_info);
     size_t storage_type_size;
     RuntimeOptions runtime_options;
     if (options_.allow_precision_loss) {