Added CreateTensorShared which accept only tensor descriptor instead of tensor descriptor and shape.
PiperOrigin-RevId: 463059695
diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
index 2a4812d..b15c1495 100644
--- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
@@ -588,9 +588,8 @@
*context, shared_buffers_[buffer_index].GetMemoryPtr(), tensor_desc,
width_pixel_alignment, &shared_buffer_tensors_[tensor_index]));
} else {
- RETURN_IF_ERROR(CreateSharedTensor(
- *context, shared_buffers_[buffer_index].GetMemoryPtr(),
- tensor_desc.GetBHWCShape(), tensor_desc,
+ RETURN_IF_ERROR(CreateTensorShared(
+ *context, shared_buffers_[buffer_index].GetMemoryPtr(), tensor_desc,
&shared_buffer_tensors_[tensor_index]));
}
created_tensors[tensor_index] = true;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
index b83b8d1..aa6a890 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
@@ -184,11 +184,15 @@
RETURN_IF_ERROR(GetOpenCLMemory(output_obj, &out_memory));
Tensor src_tensor;
- RETURN_IF_ERROR(CreateSharedTensor(*context_, in_memory, shape_,
- src_tensor_descriptor_, &src_tensor));
+ TensorDescriptor descriptor_with_shape = src_tensor_descriptor_;
+ descriptor_with_shape.SetBHWCShape(shape_);
+ RETURN_IF_ERROR(CreateTensorShared(*context_, in_memory,
+ descriptor_with_shape, &src_tensor));
Tensor dst_tensor;
- RETURN_IF_ERROR(CreateSharedTensor(*context_, out_memory, shape_,
- dst_tensor_descriptor_, &dst_tensor));
+ descriptor_with_shape = dst_tensor_descriptor_;
+ descriptor_with_shape.SetBHWCShape(shape_);
+ RETURN_IF_ERROR(CreateTensorShared(*context_, out_memory,
+ descriptor_with_shape, &dst_tensor));
RETURN_IF_ERROR(cl_args_.SetObjectRef("src_tensor", &src_tensor));
RETURN_IF_ERROR(cl_args_.SetObjectRef("dst_tensor", &dst_tensor));
RETURN_IF_ERROR(cl_args_.Bind(kernel_.kernel()));
@@ -281,8 +285,10 @@
cl_mem in_memory;
RETURN_IF_ERROR(GetOpenCLMemory(input_obj, &in_memory));
Tensor tensor;
- RETURN_IF_ERROR(CreateSharedTensor(*context_, in_memory, shape_,
- tensor_descriptor_, &tensor));
+ TensorDescriptor descriptor_with_shape = tensor_descriptor_;
+ descriptor_with_shape.SetBHWCShape(shape_);
+ RETURN_IF_ERROR(CreateTensorShared(*context_, in_memory,
+ descriptor_with_shape, &tensor));
return DispatchKernel(output->memobj, &tensor);
}
};
@@ -376,8 +382,10 @@
cl_mem out_memory;
RETURN_IF_ERROR(GetOpenCLMemory(output_obj, &out_memory));
Tensor tensor;
- RETURN_IF_ERROR(CreateSharedTensor(*context_, out_memory, shape_,
- tensor_descriptor_, &tensor));
+ TensorDescriptor descriptor_with_shape = tensor_descriptor_;
+ descriptor_with_shape.SetBHWCShape(shape_);
+ RETURN_IF_ERROR(CreateTensorShared(*context_, out_memory,
+ descriptor_with_shape, &tensor));
return DispatchKernel(input->memobj, &tensor);
}
};
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc
index ef2bc03..502df43 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.cc
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc
@@ -255,24 +255,6 @@
}
return absl::OkStatus();
}
-
-absl::Status CreateTensorShared(const CLContext& context, const BHWDC& shape,
- const TensorDescriptor& descriptor,
- cl_mem memory, Tensor* result) {
- const bool memory_owner = false;
- if (descriptor.GetStorageType() == TensorStorageType::IMAGE_BUFFER) {
- cl_mem image_memory;
- RETURN_IF_ERROR(CreateImageBufferFromBuffer(
- context, memory, descriptor.GetDataType(),
- shape.b * shape.w * shape.h * shape.d * DivideRoundUp(shape.c, 4),
- &image_memory));
- *result = Tensor(memory, memory_owner, image_memory, shape, descriptor);
- } else {
- *result = Tensor(memory, memory_owner, shape, descriptor);
- }
- return absl::OkStatus();
-}
-
} // namespace
Tensor::Tensor(cl_mem memory, bool memory_owner, const BHWC& shape,
@@ -625,19 +607,31 @@
return absl::OkStatus();
}
+absl::Status CreateTensorShared(const CLContext& context, cl_mem memory,
+ const TensorDescriptor& descriptor,
+ Tensor* result) {
+ const BHWDC& shape = descriptor.GetBHWDCShape();
+ const bool memory_owner = false;
+ if (descriptor.GetStorageType() == TensorStorageType::IMAGE_BUFFER) {
+ cl_mem image_memory;
+ RETURN_IF_ERROR(CreateImageBufferFromBuffer(
+ context, memory, descriptor.GetDataType(),
+ shape.b * shape.w * shape.h * shape.d * DivideRoundUp(shape.c, 4),
+ &image_memory));
+ *result = Tensor(memory, memory_owner, image_memory, shape, descriptor);
+ } else {
+ *result = Tensor(memory, memory_owner, shape, descriptor);
+ }
+ return absl::OkStatus();
+}
+
absl::Status CreateSharedTensor(const CLContext& context, cl_mem memory,
const BHWC& shape,
const TensorDescriptor& descriptor,
Tensor* result) {
- const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c);
- return CreateTensorShared(context, shape5D, descriptor, memory, result);
-}
-
-absl::Status CreateSharedTensor(const CLContext& context, cl_mem memory,
- const BHWDC& shape,
- const TensorDescriptor& descriptor,
- Tensor* result) {
- return CreateTensorShared(context, shape, descriptor, memory, result);
+ TensorDescriptor descriptor_with_shape = descriptor;
+ descriptor_with_shape.SetBHWCShape(shape);
+ return CreateTensorShared(context, memory, descriptor, result);
}
absl::Status CreateTensorSharedImage2DBuffer(const CLContext& context,
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.h b/tensorflow/lite/delegates/gpu/cl/tensor.h
index d23907e..8bbe61e 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.h
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.h
@@ -146,13 +146,12 @@
absl::Status CreateTensor(const CLContext& context,
const TensorDescriptor& descriptor, Tensor* result);
-absl::Status CreateSharedTensor(const CLContext& context, cl_mem memory,
- const BHWC& shape,
+absl::Status CreateTensorShared(const CLContext& context, cl_mem memory,
const TensorDescriptor& descriptor,
Tensor* result);
absl::Status CreateSharedTensor(const CLContext& context, cl_mem memory,
- const BHWDC& shape,
+ const BHWC& shape,
const TensorDescriptor& descriptor,
Tensor* result);