Metal tensor simplified. Data loading only through tensor_descriptor.
PiperOrigin-RevId: 463579502
diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.cc b/tensorflow/lite/delegates/gpu/metal/inference_context.cc
index d92d6c0..594588d 100644
--- a/tensorflow/lite/delegates/gpu/metal/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/metal/inference_context.cc
@@ -345,18 +345,25 @@
absl::Status InferenceContext::SetInputTensor(ValueId id,
const TensorFloat32& tensor) {
- return GetTensor(id)->WriteData(device_, tensor);
+ MetalSpatialTensor* gpu_tensor = GetTensor(id);
+ TensorDescriptor descriptor_with_data = gpu_tensor->GetDescriptor();
+ descriptor_with_data.UploadData(tensor);
+ return gpu_tensor->UploadDescriptorData(descriptor_with_data, device_);
}
absl::Status InferenceContext::GetOutputTensor(ValueId id,
TensorFloat32* result) {
- const auto& gpu_tensor = *GetTensor(id);
- const auto dst_shape = BHWC(gpu_tensor.Batch(), gpu_tensor.Height(),
- gpu_tensor.Width(), gpu_tensor.Channels());
+ const MetalSpatialTensor* gpu_tensor = GetTensor(id);
+ const auto dst_shape = BHWC(gpu_tensor->Batch(), gpu_tensor->Height(),
+ gpu_tensor->Width(), gpu_tensor->Channels());
result->id = id;
result->shape = dst_shape;
result->data.resize(dst_shape.DimensionsProduct());
- return gpu_tensor.ReadData(device_, result);
+
+ TensorDescriptor desc;
+ RETURN_IF_ERROR(gpu_tensor->ToDescriptor(&desc, device_));
+ desc.DownloadData(result);
+ return absl::OkStatus();
}
void InferenceContext::BindTensorsToOperations() {
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc
index 3ac8c2d..ceeb067 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc
+++ b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc
@@ -305,50 +305,6 @@
}
}
-absl::Status MetalSpatialTensor::IsValid(const BHWC& shape) const {
- if (shape.b != shape_.b) {
- return absl::InvalidArgumentError(
- "Shape batch does not match tensor batch");
- }
- if (shape.w != shape_.w) {
- return absl::InvalidArgumentError(
- "Shape width does not match tensor width");
- }
- if (shape.h != shape_.h) {
- return absl::InvalidArgumentError(
- "Shape height does not match tensor height");
- }
- if (shape.c != shape_.c) {
- return absl::InvalidArgumentError(
- "Shape channels does not match tensor channels");
- }
- return absl::OkStatus();
-}
-
-absl::Status MetalSpatialTensor::IsValid(const BHWDC& shape) const {
- if (shape.b != shape_.b) {
- return absl::InvalidArgumentError(
- "Shape batch does not match tensor batch");
- }
- if (shape.w != shape_.w) {
- return absl::InvalidArgumentError(
- "Shape width does not match tensor width");
- }
- if (shape.h != shape_.h) {
- return absl::InvalidArgumentError(
- "Shape height does not match tensor height");
- }
- if (shape.d != shape_.d) {
- return absl::InvalidArgumentError(
- "Shape depth does not match tensor depth");
- }
- if (shape.c != shape_.c) {
- return absl::InvalidArgumentError(
- "Shape channels does not match tensor channels");
- }
- return absl::OkStatus();
-}
-
uint64_t MetalSpatialTensor::GetMemorySizeInBytes() const {
const int flt_size = SizeOf(descriptor_.GetDataType());
const int flt4_size = 4 * flt_size;
@@ -366,24 +322,6 @@
}
}
-int MetalSpatialTensor::GetAlignedChannels() const {
- return descriptor_.GetStorageType() == TensorStorageType::SINGLE_TEXTURE_2D
- ? shape_.c
- : AlignByN(shape_.c, 4);
-}
-
-absl::Status MetalSpatialTensor::WriteData(
- id<MTLDevice> device,
- const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src) {
- return WriteDataBHWDC(device, src.data.data());
-}
-
-absl::Status MetalSpatialTensor::WriteData(
- id<MTLDevice> device,
- const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& src) {
- return WriteDataBHWDC(device, src.data.data());
-}
-
absl::Status MetalSpatialTensor::CreateFromDescriptor(
const TensorDescriptor& desc, id<MTLDevice> device) {
shape_ = desc.GetBHWDCShape();
@@ -400,6 +338,11 @@
return absl::OkStatus();
}
+absl::Status MetalSpatialTensor::UploadDescriptorData(
+ const TensorDescriptor& desc, id<MTLDevice> device) {
+ return WriteData(device, desc.GetData().data());
+}
+
absl::Status MetalSpatialTensor::ToDescriptor(TensorDescriptor* desc,
id<MTLDevice> device) const {
*desc = descriptor_;
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h
index 517ea01..21f585c 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h
+++ b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h
@@ -69,23 +69,10 @@
uint64_t GetMemorySizeInBytes() const;
- absl::Status WriteData(
- id<MTLDevice> device,
- const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src);
- absl::Status WriteData(
- id<MTLDevice> device,
- const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& src);
- template <DataType T>
- absl::Status WriteData(id<MTLDevice> device, const tflite::gpu::Tensor<BHWC, T>& src);
- template <DataType T>
- absl::Status WriteData(id<MTLDevice> device, const tflite::gpu::Tensor<BHWDC, T>& src);
- template <DataType T>
- absl::Status ReadData(id<MTLDevice> device, tflite::gpu::Tensor<BHWC, T>* dst) const;
- template <DataType T>
- absl::Status ReadData(id<MTLDevice> device, tflite::gpu::Tensor<BHWDC, T>* dst) const;
-
absl::Status CreateFromDescriptor(const TensorDescriptor& desc,
id<MTLDevice> device);
+ absl::Status UploadDescriptorData(const TensorDescriptor& desc,
+ id<MTLDevice> device);
absl::Status ToDescriptor(TensorDescriptor* desc, id<MTLDevice> device) const;
absl::Status SetBufferHandle(id<MTLBuffer> buffer);
@@ -101,17 +88,9 @@
int row_bytes_alignment, MetalSpatialTensor* result,
uint64_t buffer_offset);
- absl::Status IsValid(const BHWC& shape) const;
- absl::Status IsValid(const BHWDC& shape) const;
-
- template <typename T>
- absl::Status WriteDataBHWDC(id<MTLDevice> device, const T* in);
absl::Status WriteData(id<MTLDevice> device, const void* ptr);
- template <typename T>
- absl::Status ReadDataBHWDC(id<MTLDevice> device, T* out) const;
absl::Status ReadData(id<MTLDevice> device, void* ptr) const;
- int GetAlignedChannels() const;
int3 GetFullTensorRegion() const;
void Release();
@@ -144,69 +123,6 @@
TensorStorageType GetFastestStorageType(const GpuInfo& gpu_info);
-template <DataType T>
-absl::Status MetalSpatialTensor::WriteData(id<MTLDevice> device,
- const tflite::gpu::Tensor<BHWC, T>& src) {
- RETURN_IF_ERROR(IsValid(src.shape));
- return WriteDataBHWDC(device, src.data.data());
-}
-
-template <DataType T>
-absl::Status MetalSpatialTensor::WriteData(id<MTLDevice> device,
- const tflite::gpu::Tensor<BHWDC, T>& src) {
- RETURN_IF_ERROR(IsValid(src.shape));
- return WriteDataBHWDC(device, src.data.data());
-}
-
-template <DataType T>
-absl::Status MetalSpatialTensor::ReadData(id<MTLDevice> device,
- tflite::gpu::Tensor<BHWC, T>* dst) const {
- RETURN_IF_ERROR(IsValid(dst->shape));
- return ReadDataBHWDC(device, dst->data.data());
-}
-
-template <DataType T>
-absl::Status MetalSpatialTensor::ReadData(id<MTLDevice> device,
- tflite::gpu::Tensor<BHWDC, T>* dst) const {
- RETURN_IF_ERROR(IsValid(dst->shape));
- return ReadDataBHWDC(device, dst->data.data());
-}
-
-template <typename T>
-absl::Status MetalSpatialTensor::WriteDataBHWDC(id<MTLDevice> device, const T* in) {
- std::unique_ptr<uint8_t[]> data_copy;
- data_copy.reset(new uint8_t[GetMemorySizeInBytes()]);
- if (descriptor_.GetDataType() == DataType::FLOAT16) {
- // rearrangement and conversion from float32 to float16
- DataFromBHWDC(reinterpret_cast<const float*>(in), shape_, descriptor_,
- reinterpret_cast<half*>(data_copy.get()));
- } else {
- // rearrangement
- DataFromBHWDC(in, shape_, descriptor_, reinterpret_cast<T*>(data_copy.get()));
- }
-
- return WriteData(device, data_copy.get());
-}
-
-template <typename T>
-absl::Status MetalSpatialTensor::ReadDataBHWDC(id<MTLDevice> device, T* out) const {
- std::unique_ptr<uint8_t[]> data_copy;
- data_copy.reset(new uint8_t[GetMemorySizeInBytes()]);
-
- RETURN_IF_ERROR(ReadData(device, data_copy.get()));
-
- if (descriptor_.GetDataType() == DataType::FLOAT16) {
- // rearrangement and conversion from float32 to float16
- DataToBHWDC(reinterpret_cast<half*>(data_copy.get()), shape_, descriptor_,
- reinterpret_cast<float*>(out));
- } else {
- // rearrangement
- DataToBHWDC(reinterpret_cast<T*>(data_copy.get()), shape_, descriptor_, out);
- }
-
- return absl::OkStatus();
-}
-
} // namespace metal
} // namespace gpu
} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor_test.mm b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor_test.mm
index c9e0b1d..daa35d0 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor_test.mm
@@ -74,11 +74,12 @@
}
tflite::gpu::metal::MetalSpatialTensor tensor;
- tflite::gpu::TensorDescriptor descriptor_with_shape = descriptor;
- descriptor_with_shape.SetBHWCShape(shape);
- RETURN_IF_ERROR(CreateTensor(device, descriptor_with_shape, &tensor));
- RETURN_IF_ERROR(tensor.WriteData(device, tensor_cpu));
- RETURN_IF_ERROR(tensor.ReadData(device, &tensor_gpu));
+ tflite::gpu::TensorDescriptor descriptor_with_data = descriptor;
+ descriptor_with_data.UploadData(tensor_cpu);
+ RETURN_IF_ERROR(tensor.CreateFromDescriptor(descriptor_with_data, device));
+ tflite::gpu::TensorDescriptor output_descriptor;
+ RETURN_IF_ERROR(tensor.ToDescriptor(&output_descriptor, device));
+ output_descriptor.DownloadData(&tensor_gpu);
for (int i = 0; i < tensor_gpu.data.size(); ++i) {
if (tensor_gpu.data[i] != tensor_cpu.data[i]) {
@@ -154,11 +155,12 @@
}
tflite::gpu::metal::MetalSpatialTensor tensor;
- tflite::gpu::TensorDescriptor descriptor_with_shape = descriptor;
- descriptor_with_shape.SetBHWDCShape(shape);
- RETURN_IF_ERROR(CreateTensor(device, descriptor_with_shape, &tensor));
- RETURN_IF_ERROR(tensor.WriteData(device, tensor_cpu));
- RETURN_IF_ERROR(tensor.ReadData(device, &tensor_gpu));
+ tflite::gpu::TensorDescriptor descriptor_with_data = descriptor;
+ descriptor_with_data.UploadData(tensor_cpu);
+ RETURN_IF_ERROR(tensor.CreateFromDescriptor(descriptor_with_data, device));
+ tflite::gpu::TensorDescriptor output_descriptor;
+ RETURN_IF_ERROR(tensor.ToDescriptor(&output_descriptor, device));
+ output_descriptor.DownloadData(&tensor_gpu);
for (int i = 0; i < tensor_gpu.data.size(); ++i) {
if (tensor_gpu.data[i] != tensor_cpu.data[i]) {