Pooling updated to use properly src tensor description.
PiperOrigin-RevId: 455103825
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc
index b480ae1..0cc7d6d 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc
@@ -611,7 +611,7 @@
case OperationType::POOLING_2D: {
auto attr =
absl::any_cast<Pooling2DAttributes>(node.operation.attributes);
- *gpu_op = SelectPooling(attr, op_def);
+ *gpu_op = SelectPooling(attr, gpu_info, op_def);
return absl::OkStatus();
}
case OperationType::PRELU: {
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc
index 9ac2ee4..2f3c769 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc
@@ -68,8 +68,9 @@
}
std::unique_ptr<GPUOperation> SelectPooling(const Pooling2DAttributes& attr,
+ const GpuInfo& gpu_info,
const OperationDef& op_def) {
- return std::make_unique<GPUOperation>(CreatePooling(op_def, attr));
+ return std::make_unique<GPUOperation>(CreatePooling(op_def, gpu_info, attr));
}
std::unique_ptr<GPUOperation> SelectMaxUnpooling(
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h
index 9439e21..8705582 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h
+++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h
@@ -38,6 +38,7 @@
const OperationDef& op_def);
std::unique_ptr<GPUOperation> SelectPooling(const Pooling2DAttributes& attr,
+ const GpuInfo& gpu_info,
const OperationDef& op_def);
std::unique_ptr<GPUOperation> SelectMaxUnpooling(
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/BUILD b/tensorflow/lite/delegates/gpu/common/tasks/BUILD
index 406d1c5..3f096fd 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/tasks/BUILD
@@ -778,13 +778,13 @@
srcs = ["pooling.cc"],
hdrs = ["pooling.h"],
deps = [
+ "//tensorflow/lite/delegates/gpu/common:gpu_info",
"//tensorflow/lite/delegates/gpu/common:operations",
+ "//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:status",
- "//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
"//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
"//tensorflow/lite/delegates/gpu/common/task:util",
- "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
],
)
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/pooling.cc b/tensorflow/lite/delegates/gpu/common/tasks/pooling.cc
index 766f5a1..299bc48 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/pooling.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/pooling.cc
@@ -17,14 +17,17 @@
#include <string>
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
#include "tensorflow/lite/delegates/gpu/common/task/util.h"
-#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
namespace tflite {
namespace gpu {
namespace {
std::string GetAveragePoolingKernelCode(const OperationDef& op_def,
+ const GpuInfo& gpu_info,
bool stride_correction,
GPUOperation* op) {
auto src_desc = op_def.src_tensors[0];
@@ -67,10 +70,6 @@
dst_coord += ", " + dst_coords[i];
}
- const bool manual_clamp =
- op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER ||
- op_def.src_tensors[0].storage_type == TensorStorageType::IMAGE_BUFFER;
-
std::string c;
c += "MAIN_FUNCTION($0) {\n";
c += " int X = GLOBAL_ID_0;\n";
@@ -119,12 +118,13 @@
}
c += " bool outside = outside_y || x_c < 0 || x_c >= "
"args.src_tensor.Width();\n";
- if (manual_clamp) {
+ if (op_def.src_tensors[0].SupportsZeroClamp(Axis::WIDTH, gpu_info) &&
+ op_def.src_tensors[0].SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
+ c += " r += args.src_tensor.Read<float>(" + src_coord + ");\n";
+ } else {
c += " r += !outside ? args.src_tensor.Read<float>(" + src_coord +
") : "
"INIT_FLOAT4(0.0f);\n";
- } else {
- c += " r += args.src_tensor.Read<float>(" + src_coord + ");\n";
}
c += " window_size += !outside ? 1.0 : 0.0;\n";
c += " }\n";
@@ -283,6 +283,7 @@
} // namespace
GPUOperation CreatePooling(const OperationDef& definition,
+ const GpuInfo& gpu_info,
const Pooling2DAttributes& attr) {
GPUOperation op(definition);
op.args_.AddInt("kernel_size_x", attr.kernel.w);
@@ -294,7 +295,8 @@
const bool stride_correction =
definition.IsBatchSupported() && attr.strides.w != 1;
if (attr.type == PoolingType::AVERAGE) {
- op.code_ = GetAveragePoolingKernelCode(definition, stride_correction, &op);
+ op.code_ = GetAveragePoolingKernelCode(definition, gpu_info,
+ stride_correction, &op);
} else if (attr.type == PoolingType::MAX) {
op.code_ = GetMaxPoolingKernelCode(definition, stride_correction,
attr.output_indices, &op);
@@ -304,6 +306,7 @@
}
GPUOperation CreatePooling(const OperationDef& definition,
+ const GpuInfo& gpu_info,
const Pooling3DAttributes& attr) {
GPUOperation op(definition);
op.args_.AddInt("kernel_size_x", attr.kernel.w);
@@ -318,7 +321,8 @@
const bool stride_correction =
definition.IsBatchSupported() && attr.strides.w != 1;
if (attr.type == PoolingType::AVERAGE) {
- op.code_ = GetAveragePoolingKernelCode(definition, stride_correction, &op);
+ op.code_ = GetAveragePoolingKernelCode(definition, gpu_info,
+ stride_correction, &op);
} else if (attr.type == PoolingType::MAX) {
op.code_ = GetMaxPoolingKernelCode(definition, stride_correction,
attr.output_indices, &op);
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/pooling.h b/tensorflow/lite/delegates/gpu/common/tasks/pooling.h
index deaf6f3..0f094b4 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/pooling.h
+++ b/tensorflow/lite/delegates/gpu/common/tasks/pooling.h
@@ -19,16 +19,16 @@
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
-#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
-#include "tensorflow/lite/delegates/gpu/common/types.h"
namespace tflite {
namespace gpu {
GPUOperation CreatePooling(const OperationDef& definition,
+ const GpuInfo& gpu_info,
const Pooling2DAttributes& attr);
GPUOperation CreatePooling(const OperationDef& definition,
+ const GpuInfo& gpu_info,
const Pooling3DAttributes& attr);
} // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/pooling_test_util.cc b/tensorflow/lite/delegates/gpu/common/tasks/pooling_test_util.cc
index 98b7d22..1869b91 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/pooling_test_util.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/pooling_test_util.cc
@@ -15,6 +15,7 @@
#include "tensorflow/lite/delegates/gpu/common/tasks/pooling_test_util.h"
+#include <memory>
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/operations.h"
@@ -46,9 +47,9 @@
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
- GPUOperation operation = CreatePooling(op_def, attr);
+ GPUOperation operation = CreatePooling(op_def, env->GetGpuInfo(), attr);
RETURN_IF_ERROR(env->ExecuteGPUOperation(
- src_tensor, absl::make_unique<GPUOperation>(std::move(operation)),
+ src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
BHWC(1, 1, 1, 2), &dst_tensor));
RETURN_IF_ERROR(PointWiseNear({3.0f, 4.0f}, dst_tensor.data, eps));
}
@@ -77,9 +78,9 @@
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
- GPUOperation operation = CreatePooling(op_def, attr);
+ GPUOperation operation = CreatePooling(op_def, env->GetGpuInfo(), attr);
RETURN_IF_ERROR(env->ExecuteGPUOperation(
- src_tensor, absl::make_unique<GPUOperation>(std::move(operation)),
+ src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
BHWC(1, 2, 2, 1), &dst_tensor));
RETURN_IF_ERROR(
PointWiseNear({1.5f, 2.0f, 2.5f, 3.0f}, dst_tensor.data, eps));
@@ -109,9 +110,9 @@
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
- GPUOperation operation = CreatePooling(op_def, attr);
+ GPUOperation operation = CreatePooling(op_def, env->GetGpuInfo(), attr);
RETURN_IF_ERROR(env->ExecuteGPUOperation(
- src_tensor, absl::make_unique<GPUOperation>(std::move(operation)),
+ src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
BHWC(1, 1, 1, 2), &dst_tensor));
RETURN_IF_ERROR(PointWiseNear({8.0f, 7.0f}, dst_tensor.data, eps));
}
@@ -143,9 +144,9 @@
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
TensorFloat32 dst_tensor_ind;
- GPUOperation operation = CreatePooling(op_def, attr);
+ GPUOperation operation = CreatePooling(op_def, env->GetGpuInfo(), attr);
RETURN_IF_ERROR(env->ExecuteGPUOperation(
- {src_tensor}, absl::make_unique<GPUOperation>(std::move(operation)),
+ {src_tensor}, std::make_unique<GPUOperation>(std::move(operation)),
{BHWC(1, 1, 1, 2), BHWC(1, 1, 1, 2)},
{&dst_tensor, &dst_tensor_ind}));
RETURN_IF_ERROR(PointWiseNear({8.0f, 7.0f}, dst_tensor.data, eps));
@@ -172,10 +173,10 @@
dst_0.SetBHWCShape(BHWC(1, 1, 1, 2));
dst_1.SetBHWCShape(BHWC(1, 1, 1, 2));
- GPUOperation operation = CreatePooling(op_def, attr);
+ GPUOperation operation = CreatePooling(op_def, env->GetGpuInfo(), attr);
RETURN_IF_ERROR(env->ExecuteGPUOperation(
{&src_0}, {&dst_0, &dst_1},
- absl::make_unique<GPUOperation>(std::move(operation))));
+ std::make_unique<GPUOperation>(std::move(operation))));
TensorFloat32 dst_tensor;
dst_0.DownloadData(&dst_tensor);