Wingorad tasks converted to unified representation.
Added Metal tests.
PiperOrigin-RevId: 355904479
Change-Id: I68da32580635e466813b1bfb0b86082b6887a4b8
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc b/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
index f5c5ea8..2b423de 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
@@ -87,11 +87,10 @@
args_.AddInt("tiles_total");
args_.AddInt("tiles_x");
- c += "__kernel void main_function(\n";
- c += "$0) {\n";
- c += " int DST_X = get_global_id(0);\n";
- c += " int DST_Y = get_global_id(1);\n";
- c += " int DST_Z = get_global_id(2);\n";
+ c += "MAIN_FUNCTION($0) {\n";
+ c += " int DST_X = GLOBAL_ID_0;\n";
+ c += " int DST_Y = GLOBAL_ID_1;\n";
+ c += " int DST_Z = GLOBAL_ID_2;\n";
c += " if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= "
"args.dst_tensor.Slices()) {\n";
c += " return; \n";
@@ -127,7 +126,7 @@
for (int x = 0; x < 6; ++x) {
const std::string xs = std::to_string(x);
c += " int xc" + xs + " = tile_x + args.padding_x + " + xs + ";\n";
- c += " ACCUM_FLT m" + xs + "_x = (ACCUM_FLT)(xc" + xs + " >= 0 && xc" +
+ c += " ACCUM_FLT m" + xs + "_x = TO_ACCUM_FLT(xc" + xs + " >= 0 && xc" +
xs + " < args.src_tensor.Width());\n";
c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs +
" < args.src_tensor.Width());\n";
@@ -148,7 +147,7 @@
if (is_buffer || is_image_buffer) {
c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
- c += " ACCUM_FLT bt = bt_ar[0] * (ACCUM_FLT)(iny);\n";
+ c += " ACCUM_FLT bt = bt_ar[0] * TO_ACCUM_FLT(iny);\n";
} else {
c += " ACCUM_FLT bt = bt_ar[0];\n";
}
@@ -166,7 +165,7 @@
if (is_buffer || is_image_buffer) {
c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
- c += " ACCUM_FLT bt = bt_ar[" + ys + "] * (ACCUM_FLT)(iny);\n";
+ c += " ACCUM_FLT bt = bt_ar[" + ys + "] * TO_ACCUM_FLT(iny);\n";
} else {
c += " ACCUM_FLT bt = bt_ar[" + ys + "];\n";
}
@@ -271,6 +270,10 @@
void Winograd4x4To36TileX6::GetPossibleKernelWorkGroups(
TuningType tuning_type, const GpuInfo& gpu_info,
const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
+ if (gpu_info.IsIntel()) {
+ work_groups->push_back(int3(4, 6, 1));
+ return;
+ }
switch (tuning_type) {
case TuningType::kExhaustive:
GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
@@ -338,11 +341,10 @@
}
c += "};\n";
- c += "__kernel void main_function(\n";
- c += "$0) {\n";
- c += " int tile_id = get_global_id(0);\n";
- c += " int DST_Y = get_global_id(1);\n";
- c += " int DST_Z = get_global_id(2);\n";
+ c += "MAIN_FUNCTION($0) {\n";
+ c += " int tile_id = GLOBAL_ID_0;\n";
+ c += " int DST_Y = GLOBAL_ID_1;\n";
+ c += " int DST_Z = GLOBAL_ID_2;\n";
c += " int tile_x = (tile_id % args.tiles_x) * 4;\n";
c += " int tile_y = (tile_id / args.tiles_x) * 4 + DST_Y;\n";
@@ -458,6 +460,10 @@
void Winograd36To4x4Tile4x1::GetPossibleKernelWorkGroups(
TuningType tuning_type, const GpuInfo& gpu_info,
const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
+ if (gpu_info.IsIntel()) {
+ work_groups->push_back(int3(8, 4, 1));
+ return;
+ }
switch (tuning_type) {
case TuningType::kExhaustive:
GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.cc b/tensorflow/lite/delegates/gpu/metal/compute_task.cc
index fa22aeb..947a0f3 100644
--- a/tensorflow/lite/delegates/gpu/metal/compute_task.cc
+++ b/tensorflow/lite/delegates/gpu/metal/compute_task.cc
@@ -133,6 +133,7 @@
@"INIT_ACCUM_FLT4(value)" :
[NSString stringWithFormat:@"%@4(value)", accumulatorType],
@"TO_ACCUM_TYPE" : toAccumulatorType4,
+ @"TO_ACCUM_FLT" : accumulatorType,
@"TO_FLT4" : [NSString stringWithFormat:@"%@4", storageType],
@"SIMDGROUP_BARRIER" : barrier,
@"SIMD_LOCAL_MEM_BARRIER" : barrier,
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index 221eb15..2e95242 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -643,6 +643,7 @@
":test_util",
":winograd",
"//tensorflow/lite/delegates/gpu/common:winograd_util",
+ "//tensorflow/lite/delegates/gpu/common/tasks:winograd_test_util",
],
)
@@ -723,6 +724,7 @@
"//tensorflow/lite/delegates/gpu/common/tasks:space_to_depth_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:strided_slice_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:transpose_test_util",
+ "//tensorflow/lite/delegates/gpu/common/tasks:winograd_test_util",
"//tensorflow/lite/delegates/gpu/metal:common",
],
)
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm
index c8cace0..5eb482d 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm
@@ -22,10 +22,11 @@
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tasks/winograd_test_util.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
#include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
@interface WinogradTest : XCTestCase
@end
@@ -324,24 +325,34 @@
} // namespace gpu
} // namespace tflite
-- (void)testWinograd4x4To36 {
+- (void)testWinograd4x4To36Metal {
auto status = tflite::gpu::metal::Winograd4x4To36Test(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
-- (void)testWinograd4x4To36TileX6 {
+- (void)testWinograd4x4To36TileX6Metal {
auto status = tflite::gpu::metal::Winograd4x4To36TileX6Test(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
-- (void)testWinograd36To4x4 {
+- (void)testWinograd36To4x4Metal {
auto status = tflite::gpu::metal::Winograd36To4x4Test(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
-- (void)testWinograd36To4x4Tile4x1 {
+- (void)testWinograd36To4x4Tile4x1Metal {
auto status = tflite::gpu::metal::Winograd36To4x4Tile4x1Test(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
+- (void)testWinograd4x4To36TileX6 {
+ auto status = tflite::gpu::Winograd4x4To36TileX6Test(&exec_env_);
+ XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testWinograd36To4x4Tile4x1 {
+ auto status = tflite::gpu::Winograd36To4x4Tile4x1Test(&exec_env_);
+ XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
@end