Wingorad tasks converted to unified representation.
Added Metal tests.

PiperOrigin-RevId: 355904479
Change-Id: I68da32580635e466813b1bfb0b86082b6887a4b8
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc b/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
index f5c5ea8..2b423de 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
@@ -87,11 +87,10 @@
   args_.AddInt("tiles_total");
   args_.AddInt("tiles_x");
 
-  c += "__kernel void main_function(\n";
-  c += "$0) {\n";
-  c += "  int DST_X = get_global_id(0);\n";
-  c += "  int DST_Y = get_global_id(1);\n";
-  c += "  int DST_Z = get_global_id(2);\n";
+  c += "MAIN_FUNCTION($0) {\n";
+  c += "  int DST_X = GLOBAL_ID_0;\n";
+  c += "  int DST_Y = GLOBAL_ID_1;\n";
+  c += "  int DST_Z = GLOBAL_ID_2;\n";
   c += "  if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= "
        "args.dst_tensor.Slices()) {\n";
   c += "    return; \n";
@@ -127,7 +126,7 @@
     for (int x = 0; x < 6; ++x) {
       const std::string xs = std::to_string(x);
       c += "  int xc" + xs + " = tile_x + args.padding_x + " + xs + ";\n";
-      c += "  ACCUM_FLT m" + xs + "_x = (ACCUM_FLT)(xc" + xs + " >= 0 && xc" +
+      c += "  ACCUM_FLT m" + xs + "_x = TO_ACCUM_FLT(xc" + xs + " >= 0 && xc" +
            xs + " < args.src_tensor.Width());\n";
       c += "  bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs +
            " < args.src_tensor.Width());\n";
@@ -148,7 +147,7 @@
   if (is_buffer || is_image_buffer) {
     c += "    bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
     c += "    int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
-    c += "    ACCUM_FLT bt = bt_ar[0] * (ACCUM_FLT)(iny);\n";
+    c += "    ACCUM_FLT bt = bt_ar[0] * TO_ACCUM_FLT(iny);\n";
   } else {
     c += "    ACCUM_FLT bt = bt_ar[0];\n";
   }
@@ -166,7 +165,7 @@
     if (is_buffer || is_image_buffer) {
       c += "    bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
       c += "    int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
-      c += "    ACCUM_FLT bt = bt_ar[" + ys + "] * (ACCUM_FLT)(iny);\n";
+      c += "    ACCUM_FLT bt = bt_ar[" + ys + "] * TO_ACCUM_FLT(iny);\n";
     } else {
       c += "    ACCUM_FLT bt = bt_ar[" + ys + "];\n";
     }
@@ -271,6 +270,10 @@
 void Winograd4x4To36TileX6::GetPossibleKernelWorkGroups(
     TuningType tuning_type, const GpuInfo& gpu_info,
     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
+  if (gpu_info.IsIntel()) {
+    work_groups->push_back(int3(4, 6, 1));
+    return;
+  }
   switch (tuning_type) {
     case TuningType::kExhaustive:
       GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
@@ -338,11 +341,10 @@
   }
   c += "};\n";
 
-  c += "__kernel void main_function(\n";
-  c += "$0) {\n";
-  c += "  int tile_id = get_global_id(0);\n";
-  c += "  int DST_Y = get_global_id(1);\n";
-  c += "  int DST_Z = get_global_id(2);\n";
+  c += "MAIN_FUNCTION($0) {\n";
+  c += "  int tile_id = GLOBAL_ID_0;\n";
+  c += "  int DST_Y = GLOBAL_ID_1;\n";
+  c += "  int DST_Z = GLOBAL_ID_2;\n";
   c += "  int tile_x = (tile_id % args.tiles_x) * 4;\n";
   c += "  int tile_y = (tile_id / args.tiles_x) * 4 + DST_Y;\n";
 
@@ -458,6 +460,10 @@
 void Winograd36To4x4Tile4x1::GetPossibleKernelWorkGroups(
     TuningType tuning_type, const GpuInfo& gpu_info,
     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
+  if (gpu_info.IsIntel()) {
+    work_groups->push_back(int3(8, 4, 1));
+    return;
+  }
   switch (tuning_type) {
     case TuningType::kExhaustive:
       GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.cc b/tensorflow/lite/delegates/gpu/metal/compute_task.cc
index fa22aeb..947a0f3 100644
--- a/tensorflow/lite/delegates/gpu/metal/compute_task.cc
+++ b/tensorflow/lite/delegates/gpu/metal/compute_task.cc
@@ -133,6 +133,7 @@
     @"INIT_ACCUM_FLT4(value)" :
         [NSString stringWithFormat:@"%@4(value)", accumulatorType],
     @"TO_ACCUM_TYPE" : toAccumulatorType4,
+    @"TO_ACCUM_FLT" : accumulatorType,
     @"TO_FLT4" : [NSString stringWithFormat:@"%@4", storageType],
     @"SIMDGROUP_BARRIER" : barrier,
     @"SIMD_LOCAL_MEM_BARRIER" : barrier,
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index 221eb15..2e95242 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -643,6 +643,7 @@
         ":test_util",
         ":winograd",
         "//tensorflow/lite/delegates/gpu/common:winograd_util",
+        "//tensorflow/lite/delegates/gpu/common/tasks:winograd_test_util",
     ],
 )
 
@@ -723,6 +724,7 @@
         "//tensorflow/lite/delegates/gpu/common/tasks:space_to_depth_test_util",
         "//tensorflow/lite/delegates/gpu/common/tasks:strided_slice_test_util",
         "//tensorflow/lite/delegates/gpu/common/tasks:transpose_test_util",
+        "//tensorflow/lite/delegates/gpu/common/tasks:winograd_test_util",
         "//tensorflow/lite/delegates/gpu/metal:common",
     ],
 )
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm
index c8cace0..5eb482d 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm
@@ -22,10 +22,11 @@
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tasks/winograd_test_util.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
 #include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
 
 @interface WinogradTest : XCTestCase
 @end
@@ -324,24 +325,34 @@
 }  // namespace gpu
 }  // namespace tflite
 
-- (void)testWinograd4x4To36 {
+- (void)testWinograd4x4To36Metal {
   auto status = tflite::gpu::metal::Winograd4x4To36Test(&exec_env_);
   XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
 }
 
-- (void)testWinograd4x4To36TileX6 {
+- (void)testWinograd4x4To36TileX6Metal {
   auto status = tflite::gpu::metal::Winograd4x4To36TileX6Test(&exec_env_);
   XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
 }
 
-- (void)testWinograd36To4x4 {
+- (void)testWinograd36To4x4Metal {
   auto status = tflite::gpu::metal::Winograd36To4x4Test(&exec_env_);
   XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
 }
 
-- (void)testWinograd36To4x4Tile4x1 {
+- (void)testWinograd36To4x4Tile4x1Metal {
   auto status = tflite::gpu::metal::Winograd36To4x4Tile4x1Test(&exec_env_);
   XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
 }
 
+- (void)testWinograd4x4To36TileX6 {
+  auto status = tflite::gpu::Winograd4x4To36TileX6Test(&exec_env_);
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testWinograd36To4x4Tile4x1 {
+  auto status = tflite::gpu::Winograd36To4x4Tile4x1Test(&exec_env_);
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
 @end