Metal tensor simplified. Data loading only through tensor_descriptor.

PiperOrigin-RevId: 463579502
diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.cc b/tensorflow/lite/delegates/gpu/metal/inference_context.cc
index d92d6c0..594588d 100644
--- a/tensorflow/lite/delegates/gpu/metal/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/metal/inference_context.cc
@@ -345,18 +345,25 @@
 
 absl::Status InferenceContext::SetInputTensor(ValueId id,
                                               const TensorFloat32& tensor) {
-  return GetTensor(id)->WriteData(device_, tensor);
+  MetalSpatialTensor* gpu_tensor = GetTensor(id);
+  TensorDescriptor descriptor_with_data = gpu_tensor->GetDescriptor();
+  descriptor_with_data.UploadData(tensor);
+  return gpu_tensor->UploadDescriptorData(descriptor_with_data, device_);
 }
 
 absl::Status InferenceContext::GetOutputTensor(ValueId id,
                                                TensorFloat32* result) {
-  const auto& gpu_tensor = *GetTensor(id);
-  const auto dst_shape = BHWC(gpu_tensor.Batch(), gpu_tensor.Height(),
-                              gpu_tensor.Width(), gpu_tensor.Channels());
+  const MetalSpatialTensor* gpu_tensor = GetTensor(id);
+  const auto dst_shape = BHWC(gpu_tensor->Batch(), gpu_tensor->Height(),
+                              gpu_tensor->Width(), gpu_tensor->Channels());
   result->id = id;
   result->shape = dst_shape;
   result->data.resize(dst_shape.DimensionsProduct());
-  return gpu_tensor.ReadData(device_, result);
+
+  TensorDescriptor desc;
+  RETURN_IF_ERROR(gpu_tensor->ToDescriptor(&desc, device_));
+  desc.DownloadData(result);
+  return absl::OkStatus();
 }
 
 void InferenceContext::BindTensorsToOperations() {
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc
index 3ac8c2d..ceeb067 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc
+++ b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc
@@ -305,50 +305,6 @@
   }
 }
 
-absl::Status MetalSpatialTensor::IsValid(const BHWC& shape) const {
-  if (shape.b != shape_.b) {
-    return absl::InvalidArgumentError(
-        "Shape batch does not match tensor batch");
-  }
-  if (shape.w != shape_.w) {
-    return absl::InvalidArgumentError(
-        "Shape width does not match tensor width");
-  }
-  if (shape.h != shape_.h) {
-    return absl::InvalidArgumentError(
-        "Shape height does not match tensor height");
-  }
-  if (shape.c != shape_.c) {
-    return absl::InvalidArgumentError(
-        "Shape channels does not match tensor channels");
-  }
-  return absl::OkStatus();
-}
-
-absl::Status MetalSpatialTensor::IsValid(const BHWDC& shape) const {
-  if (shape.b != shape_.b) {
-    return absl::InvalidArgumentError(
-        "Shape batch does not match tensor batch");
-  }
-  if (shape.w != shape_.w) {
-    return absl::InvalidArgumentError(
-        "Shape width does not match tensor width");
-  }
-  if (shape.h != shape_.h) {
-    return absl::InvalidArgumentError(
-        "Shape height does not match tensor height");
-  }
-  if (shape.d != shape_.d) {
-    return absl::InvalidArgumentError(
-        "Shape depth does not match tensor depth");
-  }
-  if (shape.c != shape_.c) {
-    return absl::InvalidArgumentError(
-        "Shape channels does not match tensor channels");
-  }
-  return absl::OkStatus();
-}
-
 uint64_t MetalSpatialTensor::GetMemorySizeInBytes() const {
   const int flt_size = SizeOf(descriptor_.GetDataType());
   const int flt4_size = 4 * flt_size;
@@ -366,24 +322,6 @@
   }
 }
 
-int MetalSpatialTensor::GetAlignedChannels() const {
-  return descriptor_.GetStorageType() == TensorStorageType::SINGLE_TEXTURE_2D
-             ? shape_.c
-             : AlignByN(shape_.c, 4);
-}
-
-absl::Status MetalSpatialTensor::WriteData(
-    id<MTLDevice> device,
-    const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src) {
-  return WriteDataBHWDC(device, src.data.data());
-}
-
-absl::Status MetalSpatialTensor::WriteData(
-    id<MTLDevice> device,
-    const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& src) {
-  return WriteDataBHWDC(device, src.data.data());
-}
-
 absl::Status MetalSpatialTensor::CreateFromDescriptor(
     const TensorDescriptor& desc, id<MTLDevice> device) {
   shape_ = desc.GetBHWDCShape();
@@ -400,6 +338,11 @@
   return absl::OkStatus();
 }
 
+absl::Status MetalSpatialTensor::UploadDescriptorData(
+    const TensorDescriptor& desc, id<MTLDevice> device) {
+  return WriteData(device, desc.GetData().data());
+}
+
 absl::Status MetalSpatialTensor::ToDescriptor(TensorDescriptor* desc,
                                               id<MTLDevice> device) const {
   *desc = descriptor_;
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h
index 517ea01..21f585c 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h
+++ b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h
@@ -69,23 +69,10 @@
 
   uint64_t GetMemorySizeInBytes() const;
 
-  absl::Status WriteData(
-      id<MTLDevice> device,
-      const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src);
-  absl::Status WriteData(
-      id<MTLDevice> device,
-      const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& src);
-  template <DataType T>
-  absl::Status WriteData(id<MTLDevice> device, const tflite::gpu::Tensor<BHWC, T>& src);
-  template <DataType T>
-  absl::Status WriteData(id<MTLDevice> device, const tflite::gpu::Tensor<BHWDC, T>& src);
-  template <DataType T>
-  absl::Status ReadData(id<MTLDevice> device, tflite::gpu::Tensor<BHWC, T>* dst) const;
-  template <DataType T>
-  absl::Status ReadData(id<MTLDevice> device, tflite::gpu::Tensor<BHWDC, T>* dst) const;
-
   absl::Status CreateFromDescriptor(const TensorDescriptor& desc,
                                     id<MTLDevice> device);
+  absl::Status UploadDescriptorData(const TensorDescriptor& desc,
+                                    id<MTLDevice> device);
   absl::Status ToDescriptor(TensorDescriptor* desc, id<MTLDevice> device) const;
 
   absl::Status SetBufferHandle(id<MTLBuffer> buffer);
@@ -101,17 +88,9 @@
       int row_bytes_alignment, MetalSpatialTensor* result,
       uint64_t buffer_offset);
 
-  absl::Status IsValid(const BHWC& shape) const;
-  absl::Status IsValid(const BHWDC& shape) const;
-
-  template <typename T>
-  absl::Status WriteDataBHWDC(id<MTLDevice> device, const T* in);
   absl::Status WriteData(id<MTLDevice> device, const void* ptr);
-  template <typename T>
-  absl::Status ReadDataBHWDC(id<MTLDevice> device, T* out) const;
   absl::Status ReadData(id<MTLDevice> device, void* ptr) const;
 
-  int GetAlignedChannels() const;
   int3 GetFullTensorRegion() const;
   void Release();
 
@@ -144,69 +123,6 @@
 
 TensorStorageType GetFastestStorageType(const GpuInfo& gpu_info);
 
-template <DataType T>
-absl::Status MetalSpatialTensor::WriteData(id<MTLDevice> device,
-                                           const tflite::gpu::Tensor<BHWC, T>& src) {
-  RETURN_IF_ERROR(IsValid(src.shape));
-  return WriteDataBHWDC(device, src.data.data());
-}
-
-template <DataType T>
-absl::Status MetalSpatialTensor::WriteData(id<MTLDevice> device,
-                                           const tflite::gpu::Tensor<BHWDC, T>& src) {
-  RETURN_IF_ERROR(IsValid(src.shape));
-  return WriteDataBHWDC(device, src.data.data());
-}
-
-template <DataType T>
-absl::Status MetalSpatialTensor::ReadData(id<MTLDevice> device,
-                                          tflite::gpu::Tensor<BHWC, T>* dst) const {
-  RETURN_IF_ERROR(IsValid(dst->shape));
-  return ReadDataBHWDC(device, dst->data.data());
-}
-
-template <DataType T>
-absl::Status MetalSpatialTensor::ReadData(id<MTLDevice> device,
-                                          tflite::gpu::Tensor<BHWDC, T>* dst) const {
-  RETURN_IF_ERROR(IsValid(dst->shape));
-  return ReadDataBHWDC(device, dst->data.data());
-}
-
-template <typename T>
-absl::Status MetalSpatialTensor::WriteDataBHWDC(id<MTLDevice> device, const T* in) {
-  std::unique_ptr<uint8_t[]> data_copy;
-  data_copy.reset(new uint8_t[GetMemorySizeInBytes()]);
-  if (descriptor_.GetDataType() == DataType::FLOAT16) {
-    // rearrangement and conversion from float32 to float16
-    DataFromBHWDC(reinterpret_cast<const float*>(in), shape_, descriptor_,
-                  reinterpret_cast<half*>(data_copy.get()));
-  } else {
-    // rearrangement
-    DataFromBHWDC(in, shape_, descriptor_, reinterpret_cast<T*>(data_copy.get()));
-  }
-
-  return WriteData(device, data_copy.get());
-}
-
-template <typename T>
-absl::Status MetalSpatialTensor::ReadDataBHWDC(id<MTLDevice> device, T* out) const {
-  std::unique_ptr<uint8_t[]> data_copy;
-  data_copy.reset(new uint8_t[GetMemorySizeInBytes()]);
-
-  RETURN_IF_ERROR(ReadData(device, data_copy.get()));
-
-  if (descriptor_.GetDataType() == DataType::FLOAT16) {
-    // rearrangement and conversion from float32 to float16
-    DataToBHWDC(reinterpret_cast<half*>(data_copy.get()), shape_, descriptor_,
-                reinterpret_cast<float*>(out));
-  } else {
-    // rearrangement
-    DataToBHWDC(reinterpret_cast<T*>(data_copy.get()), shape_, descriptor_, out);
-  }
-
-  return absl::OkStatus();
-}
-
 }  // namespace metal
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor_test.mm b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor_test.mm
index c9e0b1d..daa35d0 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor_test.mm
@@ -74,11 +74,12 @@
   }
 
   tflite::gpu::metal::MetalSpatialTensor tensor;
-  tflite::gpu::TensorDescriptor descriptor_with_shape = descriptor;
-  descriptor_with_shape.SetBHWCShape(shape);
-  RETURN_IF_ERROR(CreateTensor(device, descriptor_with_shape, &tensor));
-  RETURN_IF_ERROR(tensor.WriteData(device, tensor_cpu));
-  RETURN_IF_ERROR(tensor.ReadData(device, &tensor_gpu));
+  tflite::gpu::TensorDescriptor descriptor_with_data = descriptor;
+  descriptor_with_data.UploadData(tensor_cpu);
+  RETURN_IF_ERROR(tensor.CreateFromDescriptor(descriptor_with_data, device));
+  tflite::gpu::TensorDescriptor output_descriptor;
+  RETURN_IF_ERROR(tensor.ToDescriptor(&output_descriptor, device));
+  output_descriptor.DownloadData(&tensor_gpu);
 
   for (int i = 0; i < tensor_gpu.data.size(); ++i) {
     if (tensor_gpu.data[i] != tensor_cpu.data[i]) {
@@ -154,11 +155,12 @@
   }
 
   tflite::gpu::metal::MetalSpatialTensor tensor;
-  tflite::gpu::TensorDescriptor descriptor_with_shape = descriptor;
-  descriptor_with_shape.SetBHWDCShape(shape);
-  RETURN_IF_ERROR(CreateTensor(device, descriptor_with_shape, &tensor));
-  RETURN_IF_ERROR(tensor.WriteData(device, tensor_cpu));
-  RETURN_IF_ERROR(tensor.ReadData(device, &tensor_gpu));
+  tflite::gpu::TensorDescriptor descriptor_with_data = descriptor;
+  descriptor_with_data.UploadData(tensor_cpu);
+  RETURN_IF_ERROR(tensor.CreateFromDescriptor(descriptor_with_data, device));
+  tflite::gpu::TensorDescriptor output_descriptor;
+  RETURN_IF_ERROR(tensor.ToDescriptor(&output_descriptor, device));
+  output_descriptor.DownloadData(&tensor_gpu);
 
   for (int i = 0; i < tensor_gpu.data.size(); ++i) {
     if (tensor_gpu.data[i] != tensor_cpu.data[i]) {