Added possibility to use Textures2D as Tensors.
Need for runtime weights represented as textures 2d.
PiperOrigin-RevId: 356673672
Change-Id: I16afb3adca5a25fa46225f08433fe61a52c83138
diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD
index 1d3f3b9..53bc7cf 100644
--- a/tensorflow/lite/delegates/gpu/cl/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/BUILD
@@ -515,6 +515,7 @@
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common/task:gpu_tensor",
"//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
+ "//tensorflow/lite/delegates/gpu/common/task:texture2d_desc",
"@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc
index 2c28157..381ccca 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.cc
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc
@@ -24,6 +24,7 @@
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
+#include "tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h"
namespace tflite {
namespace gpu {
@@ -332,6 +333,17 @@
resources->buffers.push_back({"buffer", memory_});
return absl::OkStatus();
}
+ const auto* texture2d_desc =
+ dynamic_cast<const Texture2DDescriptor*>(obj_ptr);
+ if (texture2d_desc) {
+ if (descriptor_.storage_type != TensorStorageType::TEXTURE_2D) {
+ return absl::InvalidArgumentError(
+ "Tensor can be used with Texture2DDescriptor only wtih "
+ "TensorStorageType::TEXTURE_2D.");
+ }
+ resources->images2d.push_back({"image2d", memory_});
+ return absl::OkStatus();
+ }
const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(obj_ptr);
if (!tensor_desc) {
return absl::InvalidArgumentError("Expected TensorDescriptor on input.");
diff --git a/tensorflow/lite/delegates/gpu/common/task/BUILD b/tensorflow/lite/delegates/gpu/common/task/BUILD
index 07af9eb..965dd02 100644
--- a/tensorflow/lite/delegates/gpu/common/task/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/task/BUILD
@@ -69,6 +69,7 @@
"//tensorflow/lite/delegates/gpu/common/task:compiler_options",
"//tensorflow/lite/delegates/gpu/common/task:gpu_tensor",
"//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
+ "//tensorflow/lite/delegates/gpu/common/task:texture2d_desc",
"//tensorflow/lite/delegates/gpu/common/task:tuning_type",
"//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
"@com_google_absl//absl/strings",
diff --git a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc
index 5c620cc..9400269 100644
--- a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc
@@ -169,6 +169,13 @@
args_.AddObjectRef(buffer_name, AccessType::READ, std::move(desc_new));
}
+void GPUOperation::AddSrcTexture2D(const std::string& texture_name,
+ const Texture2DDescriptor& desc) {
+ src_tensors_names_.push_back(texture_name);
+ auto desc_new = absl::make_unique<Texture2DDescriptor>(desc);
+ args_.AddObjectRef(texture_name, AccessType::READ, std::move(desc_new));
+}
+
void GPUOperation::AddDstTensor(const std::string& tensor_name,
const TensorDescriptor& desc) {
dst_tensors_names_.push_back(tensor_name);
diff --git a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h
index 8e20f52..47b5c0c 100644
--- a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h
+++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h
@@ -30,6 +30,7 @@
#include "tensorflow/lite/delegates/gpu/common/task/gpu_tensor.h"
#include "tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h"
#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
+#include "tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h"
#include "tensorflow/lite/delegates/gpu/common/task/tuning_type.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
@@ -126,6 +127,8 @@
const TensorDescriptor& desc);
void AddSrcBuffer(const std::string& buffer_name,
const BufferDescriptor& desc);
+ void AddSrcTexture2D(const std::string& texture_name,
+ const Texture2DDescriptor& desc);
void AddDstTensor(const std::string& tensor_name,
const TensorDescriptor& desc);
diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD
index 8066560..94b5d81 100644
--- a/tensorflow/lite/delegates/gpu/metal/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/BUILD
@@ -242,6 +242,7 @@
"//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
"//tensorflow/lite/delegates/gpu/common/task:gpu_tensor",
"//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
+ "//tensorflow/lite/delegates/gpu/common/task:texture2d_desc",
],
)
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc
index 955768f..d0ddd4a 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc
+++ b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc
@@ -18,6 +18,7 @@
#include <memory>
#include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
+#include "tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h"
#include "tensorflow/lite/delegates/gpu/metal/common.h"
namespace tflite {
@@ -237,6 +238,17 @@
resources->buffers.push_back({"buffer", memory_});
return absl::OkStatus();
}
+ const auto* texture2d_desc =
+ dynamic_cast<const Texture2DDescriptor*>(obj_ptr);
+ if (texture2d_desc) {
+ if (descriptor_.storage_type != TensorStorageType::TEXTURE_2D) {
+ return absl::InvalidArgumentError(
+ "Tensor can be used with Texture2DDescriptor only wtih "
+ "TensorStorageType::TEXTURE_2D.");
+ }
+ resources->images2d.push_back({"image2d", texture_mem_});
+ return absl::OkStatus();
+ }
const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(obj_ptr);
if (!tensor_desc) {
return absl::InvalidArgumentError("Expected TensorDescriptor on input.");