Clarified TensorDescriptor::CanCreateTensorWithShape to handle more data types.
PiperOrigin-RevId: 432410812
diff --git a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc
index 9fc01ab..3aadbae 100644
--- a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc
@@ -1220,12 +1220,7 @@
absl::Status TensorDescriptor::CanCreateTensorWithShape(
const GpuInfo& gpu_info, const BHWDC& shape) const {
const int slices = DivideRoundUp(shape.c, 4);
- const uint64_t flt_size = data_type == DataType::FLOAT32 ? 4 : 2;
- const uint64_t channels = storage_type == TensorStorageType::SINGLE_TEXTURE_2D
- ? shape.c
- : slices * 4;
- const uint64_t allocation_size =
- flt_size * channels * shape.b * shape.w * shape.h * shape.d;
+ const uint64_t allocation_size = GetSizeInBytesForShape(shape);
const std::string common_desc = "Shape - " + ToString(shape) +
", data type - " + ToString(data_type) + ".";
if (allocation_size > gpu_info.GetMaxMemoryAllocationSize()) {
@@ -1236,12 +1231,9 @@
}
switch (storage_type) {
case TensorStorageType::BUFFER: {
- const uint64_t flt4_size = 4 * (data_type == DataType::FLOAT32 ? 4 : 2);
- const uint64_t buffer_size =
- flt4_size * shape.b * shape.w * shape.h * shape.d * slices;
- if (buffer_size > gpu_info.GetMaxBufferSize()) {
+ if (allocation_size > gpu_info.GetMaxBufferSize()) {
return absl::ResourceExhaustedError(absl::StrCat(
- "Buffer with size - ", buffer_size,
+ "Buffer with size - ", allocation_size,
" bytes can not be created. Max buffer size for this GPU - ",
gpu_info.GetMaxBufferSize(), " bytes. ", common_desc));
} else {
@@ -1249,18 +1241,16 @@
}
}
case TensorStorageType::IMAGE_BUFFER: {
- const uint64_t flt4_size = 4 * (data_type == DataType::FLOAT32 ? 4 : 2);
- const uint64_t buffer_size =
- flt4_size * shape.b * shape.w * shape.h * shape.d * slices;
- const uint64_t image_width = buffer_size / flt4_size;
+ const uint64_t element_size = 4 * SizeOf(data_type);
+ const uint64_t image_width = allocation_size / element_size;
if (image_width > gpu_info.GetMaxImageBufferWidth()) {
return absl::ResourceExhaustedError(absl::StrCat(
"Image buffer with width - ", image_width,
" can not be created. Max image buffer width for this GPU - ",
gpu_info.GetMaxImageBufferWidth(), ". ", common_desc));
- } else if (buffer_size > gpu_info.GetMaxBufferSize()) {
+ } else if (allocation_size > gpu_info.GetMaxBufferSize()) {
return absl::ResourceExhaustedError(absl::StrCat(
- "Buffer with size - ", buffer_size,
+ "Buffer with size - ", allocation_size,
" bytes can not be created. Max buffer size for this GPU - ",
gpu_info.GetMaxBufferSize(), " bytes. ", common_desc));
} else {