Added GpuInfo argument to GPUObjectDescriptor::GetGPUResources.
Sometimes resource set different for different device/api.
PiperOrigin-RevId: 373642633
Change-Id: I757a885dc52664c2b914ab50f430ee406ece4d0c
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc b/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc
index 4db81a6..fefe314 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc
@@ -213,7 +213,7 @@
const std::map<std::string, std::string>& linkables, CLContext* context,
Arguments* args, std::string* code) {
RETURN_IF_ERROR(AllocateObjects(*args, context));
- RETURN_IF_ERROR(AddObjectArgs(args));
+ RETURN_IF_ERROR(AddObjectArgs(gpu_info, args));
RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, *args, linkables, code));
object_refs_ = std::move(args->object_refs_);
args->GetActiveArguments(kArgsPrefix, *code);
@@ -232,7 +232,7 @@
absl::Status CLArguments::Init(const GpuInfo& gpu_info, Arguments* args,
CLContext* context) {
RETURN_IF_ERROR(AllocateObjects(*args, context));
- RETURN_IF_ERROR(AddObjectArgs(args));
+ RETURN_IF_ERROR(AddObjectArgs(gpu_info, args));
object_refs_ = std::move(args->object_refs_);
const bool use_f32_for_halfs = gpu_info.IsPowerVR();
CopyArguments(*args, use_f32_for_halfs);
@@ -251,12 +251,13 @@
return absl::OkStatus();
}
-absl::Status CLArguments::AddObjectArgs(Arguments* args) {
+absl::Status CLArguments::AddObjectArgs(const GpuInfo& gpu_info,
+ Arguments* args) {
for (auto& t : args->objects_) {
- AddGPUResources(t.first, t.second->GetGPUResources(), args);
+ AddGPUResources(t.first, t.second->GetGPUResources(gpu_info), args);
}
for (auto& t : args->object_refs_) {
- AddGPUResources(t.first, t.second->GetGPUResources(), args);
+ AddGPUResources(t.first, t.second->GetGPUResources(gpu_info), args);
}
return absl::OkStatus();
}
@@ -347,7 +348,7 @@
return absl::NotFoundError(
absl::StrCat("No object with name - ", object_name));
}
- auto names = desc_ptr->GetGPUResources().GetNames();
+ auto names = desc_ptr->GetGPUResources(gpu_info).GetNames();
const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc_ptr);
if (tensor_desc && (selector == "Write" || selector == "Linking")) {
auto it = linkables.find(object_name);
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_arguments.h b/tensorflow/lite/delegates/gpu/cl/cl_arguments.h
index f1160a5..08987db 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_arguments.h
+++ b/tensorflow/lite/delegates/gpu/cl/cl_arguments.h
@@ -64,7 +64,7 @@
private:
absl::Status AllocateObjects(const Arguments& args, CLContext* context);
- absl::Status AddObjectArgs(Arguments* args);
+ absl::Status AddObjectArgs(const GpuInfo& gpu_info, Arguments* args);
absl::Status ResolveSelectorsPass(
const GpuInfo& gpu_info, const Arguments& args,
diff --git a/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc b/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc
index 9c97369..4549434 100644
--- a/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc
@@ -27,7 +27,7 @@
void BufferDescriptor::Release() { data.clear(); }
-GPUResources BufferDescriptor::GetGPUResources() const {
+GPUResources BufferDescriptor::GetGPUResources(const GpuInfo& gpu_info) const {
GPUResources resources;
GPUBufferDescriptor desc;
desc.data_type = element_type;
diff --git a/tensorflow/lite/delegates/gpu/common/task/buffer_desc.h b/tensorflow/lite/delegates/gpu/common/task/buffer_desc.h
index ea37a2c..ecca244 100644
--- a/tensorflow/lite/delegates/gpu/common/task/buffer_desc.h
+++ b/tensorflow/lite/delegates/gpu/common/task/buffer_desc.h
@@ -48,7 +48,7 @@
const std::vector<std::string>& template_args,
std::string* result) const override;
- GPUResources GetGPUResources() const override;
+ GPUResources GetGPUResources(const GpuInfo& gpu_info) const override;
absl::Status PerformReadSelector(const std::vector<std::string>& args,
std::string* result) const;
absl::Status PerformGetPtrSelector(
diff --git a/tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h b/tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h
index 6559602..753681d 100644
--- a/tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h
+++ b/tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h
@@ -180,7 +180,9 @@
std::string* result) const {
return absl::UnimplementedError("No implementation of perform selector");
}
- virtual GPUResources GetGPUResources() const { return GPUResources(); }
+ virtual GPUResources GetGPUResources(const GpuInfo& gpu_info) const {
+ return GPUResources();
+ }
virtual void Release() {}
diff --git a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc
index e587ea2..ef9db12 100644
--- a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc
@@ -94,7 +94,7 @@
return *this;
}
-GPUResources TensorDescriptor::GetGPUResources() const {
+GPUResources TensorDescriptor::GetGPUResources(const GpuInfo& gpu_info) const {
GPUResources resources;
resources.ints.push_back("slice_stride");
if (HasAxis(Axis::WIDTH)) {
diff --git a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.h b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.h
index f2bc321..1d1e7db 100644
--- a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.h
+++ b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.h
@@ -65,7 +65,7 @@
const std::vector<std::string>& template_args,
std::string* result) const override;
- GPUResources GetGPUResources() const override;
+ GPUResources GetGPUResources(const GpuInfo& gpu_info) const override;
void Release() override { data.clear(); }
diff --git a/tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.cc b/tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.cc
index 942896e..9ba3729 100644
--- a/tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.cc
@@ -26,7 +26,8 @@
void TensorLinearDescriptor::Release() { data.clear(); }
-GPUResources TensorLinearDescriptor::GetGPUResources() const {
+GPUResources TensorLinearDescriptor::GetGPUResources(
+ const GpuInfo& gpu_info) const {
GPUResources resources;
resources.ints.push_back("length");
if (storage_type == LinearStorageType::BUFFER) {
diff --git a/tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.h b/tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.h
index 52c4b23..81f5c8a 100644
--- a/tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.h
+++ b/tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.h
@@ -54,7 +54,7 @@
const std::vector<std::string>& template_args,
std::string* result) const override;
- GPUResources GetGPUResources() const override;
+ GPUResources GetGPUResources(const GpuInfo& gpu_info) const override;
absl::Status PerformReadSelector(const GpuInfo& gpu_info,
const std::vector<std::string>& args,
std::string* result) const;
diff --git a/tensorflow/lite/delegates/gpu/common/task/texture2d_desc.cc b/tensorflow/lite/delegates/gpu/common/task/texture2d_desc.cc
index de7eaf5..3624eb7 100644
--- a/tensorflow/lite/delegates/gpu/common/task/texture2d_desc.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/texture2d_desc.cc
@@ -22,7 +22,8 @@
void Texture2DDescriptor::Release() { data.clear(); }
-GPUResources Texture2DDescriptor::GetGPUResources() const {
+GPUResources Texture2DDescriptor::GetGPUResources(
+ const GpuInfo& gpu_info) const {
GPUResources resources;
GPUImage2DDescriptor desc;
desc.data_type = element_type;
diff --git a/tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h b/tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h
index 32ab44f..4178799 100644
--- a/tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h
+++ b/tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h
@@ -47,7 +47,7 @@
const std::vector<std::string>& template_args,
std::string* result) const override;
- GPUResources GetGPUResources() const override;
+ GPUResources GetGPUResources(const GpuInfo& gpu_info) const override;
absl::Status PerformReadSelector(const GpuInfo& gpu_info,
const std::vector<std::string>& args,
std::string* result) const;
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc
index 1517489..5f63720 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc
+++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc
@@ -180,7 +180,7 @@
const std::map<std::string, std::string>& linkables, MetalDevice* device,
Arguments* args, std::string* code) {
RETURN_IF_ERROR(AllocateObjects(*args, device->device()));
- RETURN_IF_ERROR(AddObjectArgs(args));
+ RETURN_IF_ERROR(AddObjectArgs(device->GetInfo(), args));
RETURN_IF_ERROR(
ResolveSelectorsPass(device->GetInfo(), *args, linkables, code));
object_refs_ = std::move(args->object_refs_);
@@ -466,12 +466,13 @@
return absl::OkStatus();
}
-absl::Status MetalArguments::AddObjectArgs(Arguments* args) {
+absl::Status MetalArguments::AddObjectArgs(const GpuInfo& gpu_info,
+ Arguments* args) {
for (auto& t : args->objects_) {
- AddGPUResources(t.first, t.second->GetGPUResources(), args);
+ AddGPUResources(t.first, t.second->GetGPUResources(gpu_info), args);
}
for (auto& t : args->object_refs_) {
- AddGPUResources(t.first, t.second->GetGPUResources(), args);
+ AddGPUResources(t.first, t.second->GetGPUResources(gpu_info), args);
}
return absl::OkStatus();
}
@@ -733,7 +734,7 @@
return absl::NotFoundError(
absl::StrCat("No object with name - ", object_name));
}
- auto names = desc_ptr->GetGPUResources().GetNames();
+ auto names = desc_ptr->GetGPUResources(gpu_info).GetNames();
const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc_ptr);
if (tensor_desc && (selector == "Write" || selector == "Linking")) {
auto it = linkables.find(object_name);
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.h b/tensorflow/lite/delegates/gpu/metal/metal_arguments.h
index f8bc12d..2dff3ae 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_arguments.h
+++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.h
@@ -72,7 +72,7 @@
std::string* code);
absl::Status AllocateObjects(const Arguments& args, id<MTLDevice> device);
- absl::Status AddObjectArgs(Arguments* args);
+ absl::Status AddObjectArgs(const GpuInfo& gpu_info, Arguments* args);
void AddGPUResources(const std::string& name, const GPUResources& resources,
Arguments* args);