Enabled support of IMAGE_BUFFER for NVidia.
PiperOrigin-RevId: 293064178
Change-Id: I97b3eda664a5299e59f7253086ab6442102209aa
diff --git a/tensorflow/lite/delegates/gpu/cl/environment.cc b/tensorflow/lite/delegates/gpu/cl/environment.cc
index 71a7592..06cd04a 100644
--- a/tensorflow/lite/delegates/gpu/cl/environment.cc
+++ b/tensorflow/lite/delegates/gpu/cl/environment.cc
@@ -205,7 +205,7 @@
case TensorStorageType::TEXTURE_ARRAY:
return !device_.IsAMD() && device_.SupportsTextureArray();
case TensorStorageType::IMAGE_BUFFER:
- return (device_.IsAdreno() || device_.IsAMD()) &&
+ return (device_.IsAdreno() || device_.IsAMD() || device_.IsNvidia()) &&
device_.SupportsImageBuffer();
case TensorStorageType::TEXTURE_3D:
return !device_.IsAMD() && device_.SupportsImage3D();
@@ -224,10 +224,13 @@
} else {
return TensorStorageType::TEXTURE_2D;
}
- } else if (gpu.IsPowerVR() || gpu.IsNvidia()) {
+ } else if (gpu.IsPowerVR()) {
return TensorStorageType::TEXTURE_2D;
} else if (gpu.IsMali()) {
return TensorStorageType::BUFFER;
+ } else if (gpu.IsNvidia()) {
+ return gpu.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER
+ : TensorStorageType::BUFFER;
} else if (gpu.IsAMD()) {
return gpu.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER
: TensorStorageType::BUFFER;