Added CreateTensorShared which accept only tensor descriptor instead of tensor descriptor and shape.

PiperOrigin-RevId: 463059695
diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
index 2a4812d..b15c1495 100644
--- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
@@ -588,9 +588,8 @@
             *context, shared_buffers_[buffer_index].GetMemoryPtr(), tensor_desc,
             width_pixel_alignment, &shared_buffer_tensors_[tensor_index]));
       } else {
-        RETURN_IF_ERROR(CreateSharedTensor(
-            *context, shared_buffers_[buffer_index].GetMemoryPtr(),
-            tensor_desc.GetBHWCShape(), tensor_desc,
+        RETURN_IF_ERROR(CreateTensorShared(
+            *context, shared_buffers_[buffer_index].GetMemoryPtr(), tensor_desc,
             &shared_buffer_tensors_[tensor_index]));
       }
       created_tensors[tensor_index] = true;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
index b83b8d1..aa6a890 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
@@ -184,11 +184,15 @@
     RETURN_IF_ERROR(GetOpenCLMemory(output_obj, &out_memory));
 
     Tensor src_tensor;
-    RETURN_IF_ERROR(CreateSharedTensor(*context_, in_memory, shape_,
-                                       src_tensor_descriptor_, &src_tensor));
+    TensorDescriptor descriptor_with_shape = src_tensor_descriptor_;
+    descriptor_with_shape.SetBHWCShape(shape_);
+    RETURN_IF_ERROR(CreateTensorShared(*context_, in_memory,
+                                       descriptor_with_shape, &src_tensor));
     Tensor dst_tensor;
-    RETURN_IF_ERROR(CreateSharedTensor(*context_, out_memory, shape_,
-                                       dst_tensor_descriptor_, &dst_tensor));
+    descriptor_with_shape = dst_tensor_descriptor_;
+    descriptor_with_shape.SetBHWCShape(shape_);
+    RETURN_IF_ERROR(CreateTensorShared(*context_, out_memory,
+                                       descriptor_with_shape, &dst_tensor));
     RETURN_IF_ERROR(cl_args_.SetObjectRef("src_tensor", &src_tensor));
     RETURN_IF_ERROR(cl_args_.SetObjectRef("dst_tensor", &dst_tensor));
     RETURN_IF_ERROR(cl_args_.Bind(kernel_.kernel()));
@@ -281,8 +285,10 @@
     cl_mem in_memory;
     RETURN_IF_ERROR(GetOpenCLMemory(input_obj, &in_memory));
     Tensor tensor;
-    RETURN_IF_ERROR(CreateSharedTensor(*context_, in_memory, shape_,
-                                       tensor_descriptor_, &tensor));
+    TensorDescriptor descriptor_with_shape = tensor_descriptor_;
+    descriptor_with_shape.SetBHWCShape(shape_);
+    RETURN_IF_ERROR(CreateTensorShared(*context_, in_memory,
+                                       descriptor_with_shape, &tensor));
     return DispatchKernel(output->memobj, &tensor);
   }
 };
@@ -376,8 +382,10 @@
     cl_mem out_memory;
     RETURN_IF_ERROR(GetOpenCLMemory(output_obj, &out_memory));
     Tensor tensor;
-    RETURN_IF_ERROR(CreateSharedTensor(*context_, out_memory, shape_,
-                                       tensor_descriptor_, &tensor));
+    TensorDescriptor descriptor_with_shape = tensor_descriptor_;
+    descriptor_with_shape.SetBHWCShape(shape_);
+    RETURN_IF_ERROR(CreateTensorShared(*context_, out_memory,
+                                       descriptor_with_shape, &tensor));
     return DispatchKernel(input->memobj, &tensor);
   }
 };
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc
index ef2bc03..502df43 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.cc
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc
@@ -255,24 +255,6 @@
   }
   return absl::OkStatus();
 }
-
-absl::Status CreateTensorShared(const CLContext& context, const BHWDC& shape,
-                                const TensorDescriptor& descriptor,
-                                cl_mem memory, Tensor* result) {
-  const bool memory_owner = false;
-  if (descriptor.GetStorageType() == TensorStorageType::IMAGE_BUFFER) {
-    cl_mem image_memory;
-    RETURN_IF_ERROR(CreateImageBufferFromBuffer(
-        context, memory, descriptor.GetDataType(),
-        shape.b * shape.w * shape.h * shape.d * DivideRoundUp(shape.c, 4),
-        &image_memory));
-    *result = Tensor(memory, memory_owner, image_memory, shape, descriptor);
-  } else {
-    *result = Tensor(memory, memory_owner, shape, descriptor);
-  }
-  return absl::OkStatus();
-}
-
 }  // namespace
 
 Tensor::Tensor(cl_mem memory, bool memory_owner, const BHWC& shape,
@@ -625,19 +607,31 @@
   return absl::OkStatus();
 }
 
+absl::Status CreateTensorShared(const CLContext& context, cl_mem memory,
+                                const TensorDescriptor& descriptor,
+                                Tensor* result) {
+  const BHWDC& shape = descriptor.GetBHWDCShape();
+  const bool memory_owner = false;
+  if (descriptor.GetStorageType() == TensorStorageType::IMAGE_BUFFER) {
+    cl_mem image_memory;
+    RETURN_IF_ERROR(CreateImageBufferFromBuffer(
+        context, memory, descriptor.GetDataType(),
+        shape.b * shape.w * shape.h * shape.d * DivideRoundUp(shape.c, 4),
+        &image_memory));
+    *result = Tensor(memory, memory_owner, image_memory, shape, descriptor);
+  } else {
+    *result = Tensor(memory, memory_owner, shape, descriptor);
+  }
+  return absl::OkStatus();
+}
+
 absl::Status CreateSharedTensor(const CLContext& context, cl_mem memory,
                                 const BHWC& shape,
                                 const TensorDescriptor& descriptor,
                                 Tensor* result) {
-  const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c);
-  return CreateTensorShared(context, shape5D, descriptor, memory, result);
-}
-
-absl::Status CreateSharedTensor(const CLContext& context, cl_mem memory,
-                                const BHWDC& shape,
-                                const TensorDescriptor& descriptor,
-                                Tensor* result) {
-  return CreateTensorShared(context, shape, descriptor, memory, result);
+  TensorDescriptor descriptor_with_shape = descriptor;
+  descriptor_with_shape.SetBHWCShape(shape);
+  return CreateTensorShared(context, memory, descriptor, result);
 }
 
 absl::Status CreateTensorSharedImage2DBuffer(const CLContext& context,
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.h b/tensorflow/lite/delegates/gpu/cl/tensor.h
index d23907e..8bbe61e 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.h
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.h
@@ -146,13 +146,12 @@
 absl::Status CreateTensor(const CLContext& context,
                           const TensorDescriptor& descriptor, Tensor* result);
 
-absl::Status CreateSharedTensor(const CLContext& context, cl_mem memory,
-                                const BHWC& shape,
+absl::Status CreateTensorShared(const CLContext& context, cl_mem memory,
                                 const TensorDescriptor& descriptor,
                                 Tensor* result);
 
 absl::Status CreateSharedTensor(const CLContext& context, cl_mem memory,
-                                const BHWDC& shape,
+                                const BHWC& shape,
                                 const TensorDescriptor& descriptor,
                                 Tensor* result);