Using the same precision in all versions of winograd transformation.

PiperOrigin-RevId: 419771522
Change-Id: I35936c2a4a3c7abc4a34771cdcf7e9e95eb98aa9
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc b/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
index 8af0501..a963dbc 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
@@ -318,22 +318,8 @@
     const OperationDef& op_def, const GpuInfo& gpu_info) {
   std::string c;
 
-  switch (op_def.precision) {
-    case CalculationsPrecision::F32:
-    case CalculationsPrecision::F32_F16:
-      c += "#define ACCUM_FLT float\n";
-      break;
-    case CalculationsPrecision::F16:
-      c += "#define ACCUM_FLT half\n";
-      break;
-  }
-
-  const DataType accum_type = op_def.precision == CalculationsPrecision::F16
-                                  ? DataType::FLOAT16
-                                  : DataType::FLOAT32;
-
   auto bt_mat = BtMatrixForWinograd4x4To6x6();
-  c += "__constant ACCUM_FLT Bt[36] = {\n";
+  c += "__constant FLT Bt[36] = {\n";
   for (int y = 0; y < 6; ++y) {
     c += "\t";
     for (int x = 0; x < 6; ++x) {
@@ -346,10 +332,8 @@
   }
   c += "};\n";
 
-  std::string cl_type = accum_type == DataType::FLOAT16 ? "half" : "float";
-  auto src_desc = op_def.src_tensors[0];
-  src_desc.SetStateVar("ACCUM_FLT", cl_type);
-  AddSrcTensor("src_tensor", src_desc);
+  const auto& src_desc = op_def.src_tensors[0];
+  AddSrcTensor("src_tensor", op_def.src_tensors[0]);
   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
   args_.AddInt("padding_x");
   args_.AddInt("padding_y");
@@ -366,10 +350,10 @@
   c += "  }\n";
   c += "  int tile_x = (DST_X % args.tiles_x) * 4;\n";
   c += "  int tile_y = (DST_X / args.tiles_x) * 4;\n";
-  c += "  ACCUM_FLT4 I0, I1, I2, I3, I4, I5;\n";
-  c += "  ACCUM_FLT bt_ar[6];\n";
-  c += "  ACCUM_FLT4 t0 = TO_ACCUM_TYPE(args.bt.Read(DST_Y * 2 + 0));\n";
-  c += "  ACCUM_FLT4 t1 = TO_ACCUM_TYPE(args.bt.Read(DST_Y * 2 + 1));\n";
+  c += "  FLT4 I0, I1, I2, I3, I4, I5;\n";
+  c += "  FLT bt_ar[6];\n";
+  c += "  FLT4 t0 = args.bt.Read(DST_Y * 2 + 0);\n";
+  c += "  FLT4 t1 = args.bt.Read(DST_Y * 2 + 1);\n";
   c += "  DST_Y *= 6;\n";
   c += "  bt_ar[0] = t0.x;\n";
   c += "  bt_ar[1] = t0.y;\n";
@@ -380,11 +364,9 @@
   auto read_src = [&](const std::string& src, const std::string& xs) {
     std::string read_statement;
     if (src_desc.IsLinear()) {
-      read_statement =
-          "args.src_tensor.Read<ACCUM_FLT>(src_a_" + xs + " + offset)";
+      read_statement = "args.src_tensor.Read(src_a_" + xs + " + offset)";
     } else {
-      read_statement =
-          "args.src_tensor.Read<ACCUM_FLT>(xc" + xs + ", yc, DST_Z)";
+      read_statement = "args.src_tensor.Read(xc" + xs + ", yc, DST_Z)";
     }
     std::string multiplier;
     if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) {
@@ -392,7 +374,7 @@
         multiplier = " * m" + xs + "_x";
       }
     }
-    c += "    ACCUM_FLT4 " + src + " = " + read_statement + multiplier + ";\n";
+    c += "    FLT4 " + src + " = " + read_statement + multiplier + ";\n";
   };
   for (int x = 0; x < 6; ++x) {
     const std::string xs = std::to_string(x);
@@ -400,7 +382,7 @@
     if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) {
       c += "  bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs +
            " < args.src_tensor.Width());\n";
-      c += "  ACCUM_FLT m" + xs + "_x = TO_ACCUM_FLT(inx" + xs + ");\n";
+      c += "  FLT m" + xs + "_x = INIT_FLT(inx" + xs + ");\n";
       c += "  xc" + xs + " = clamp(xc" + xs +
            ", 0, args.src_tensor.Width() - 1);\n";
     }
@@ -424,9 +406,9 @@
       c += "    bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
       c += "    yc = clamp(yc, 0, args.src_tensor.Height() - 1);\n";
       c += "    int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
-      c += "    ACCUM_FLT bt = bt_ar[0] * TO_ACCUM_FLT(iny);\n";
+      c += "    FLT bt = bt_ar[0] * INIT_FLT(iny);\n";
     } else {
-      c += "    ACCUM_FLT bt = bt_ar[0];\n";
+      c += "    FLT bt = bt_ar[0];\n";
     }
     for (int x = 0; x < 6; ++x) {
       const std::string xs = std::to_string(x);
@@ -443,9 +425,9 @@
         c += "    bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
         c += "    yc = clamp(yc, 0, args.src_tensor.Height() - 1);\n";
         c += "    int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
-        c += "    ACCUM_FLT bt = bt_ar[" + ys + "] * TO_ACCUM_FLT(iny);\n";
+        c += "    FLT bt = bt_ar[" + ys + "] * INIT_FLT(iny);\n";
       } else {
-        c += "    ACCUM_FLT bt = bt_ar[" + ys + "];\n";
+        c += "    FLT bt = bt_ar[" + ys + "];\n";
       }
       for (int x = 0; x < 6; ++x) {
         const std::string xs = std::to_string(x);
@@ -456,21 +438,21 @@
       c += "  }\n";
     }
   } else {
-    c += "  I0 = INIT_ACCUM_FLT4(0.0f);\n";
-    c += "  I1 = INIT_ACCUM_FLT4(0.0f);\n";
-    c += "  I2 = INIT_ACCUM_FLT4(0.0f);\n";
-    c += "  I3 = INIT_ACCUM_FLT4(0.0f);\n";
-    c += "  I4 = INIT_ACCUM_FLT4(0.0f);\n";
-    c += "  I5 = INIT_ACCUM_FLT4(0.0f);\n";
+    c += "  I0 = INIT_FLT4(0.0f);\n";
+    c += "  I1 = INIT_FLT4(0.0f);\n";
+    c += "  I2 = INIT_FLT4(0.0f);\n";
+    c += "  I3 = INIT_FLT4(0.0f);\n";
+    c += "  I4 = INIT_FLT4(0.0f);\n";
+    c += "  I5 = INIT_FLT4(0.0f);\n";
     c += "  for (int y = 0; y < 6; ++y) {\n";
     c += "    int yc = tile_y + args.padding_y + y;\n";
     if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) {
       c += "    bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
       c += "    yc = clamp(yc, 0, args.src_tensor.Height() - 1);\n";
       c += "    int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
-      c += "    ACCUM_FLT bt = bt_ar[y] * TO_ACCUM_FLT(iny);\n";
+      c += "    FLT bt = bt_ar[y] * INIT_FLT(iny);\n";
     } else {
-      c += "    ACCUM_FLT bt = bt_ar[y];\n";
+      c += "    FLT bt = bt_ar[y];\n";
     }
     for (int x = 0; x < 6; ++x) {
       const std::string xs = std::to_string(x);
@@ -636,29 +618,12 @@
     const OperationDef& op_def, const GpuInfo& gpu_info) {
   std::string c;
 
-  switch (op_def.precision) {
-    case CalculationsPrecision::F32:
-    case CalculationsPrecision::F32_F16:
-      c += "#define ACCUM_FLT float\n";
-      break;
-    case CalculationsPrecision::F16:
-      c += "#define ACCUM_FLT half\n";
-      break;
-  }
-
-  const DataType accum_type = op_def.precision == CalculationsPrecision::F16
-                                  ? DataType::FLOAT16
-                                  : DataType::FLOAT32;
-
-  std::string cl_type = accum_type == DataType::FLOAT16 ? "half" : "float";
-  auto src_desc = op_def.src_tensors[0];
-  src_desc.SetStateVar("ACCUM_FLT", cl_type);
-  AddSrcTensor("src_tensor", src_desc);
+  AddSrcTensor("src_tensor", op_def.src_tensors[0]);
   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
   args_.AddInt("tiles_x");
 
   auto at_mat = AtMatrixForWinograd4x4To6x6();
-  c += "__constant ACCUM_FLT At[24] = {\n";
+  c += "__constant FLT At[24] = {\n";
   for (int y = 0; y < 4; ++y) {
     c += "\t";
     for (int x = 0; x < 6; ++x) {
@@ -682,10 +647,10 @@
        "args.dst_tensor.Height() || DST_Z >= args.dst_tensor.Slices()) {\n";
   c += "    return; \n";
   c += "  }\n";
-  c += "  ACCUM_FLT4 I0, I1, I2, I3, I4, I5;\n";
-  c += "  ACCUM_FLT at_ar[6];\n";
-  c += "  ACCUM_FLT4 t00 = TO_ACCUM_TYPE(args.at.Read(DST_Y * 2 + 0));\n";
-  c += "  ACCUM_FLT4 t01 = TO_ACCUM_TYPE(args.at.Read(DST_Y * 2 + 1));\n";
+  c += "  FLT4 I0, I1, I2, I3, I4, I5;\n";
+  c += "  FLT at_ar[6];\n";
+  c += "  FLT4 t00 = args.at.Read(DST_Y * 2 + 0);\n";
+  c += "  FLT4 t01 = args.at.Read(DST_Y * 2 + 1);\n";
   c += "  at_ar[0] = t00.x;\n";
   c += "  at_ar[1] = t00.y;\n";
   c += "  at_ar[2] = t00.z;\n";
@@ -696,56 +661,54 @@
       !(op_def.precision == CalculationsPrecision::F32 && gpu_info.IsMali());
   if (manual_unroll) {
     c += "  {\n";
-    c += "    ACCUM_FLT at = at_ar[0];\n";
+    c += "    FLT at = at_ar[0];\n";
     for (int x = 0; x < 6; ++x) {
       const std::string yc = std::to_string(x);
       const std::string src = "src" + std::to_string(x);
-      c += "    ACCUM_FLT4 " + src +
-           " = args.src_tensor.Read<ACCUM_FLT>(tile_id, " + yc + ", DST_Z);\n";
+      c += "    FLT4 " + src + " = args.src_tensor.Read(tile_id, " + yc +
+           ", DST_Z);\n";
       c += "    I" + std::to_string(x) + " = at * " + src + ";\n";
     }
     c += "  }\n";
     for (int y = 1; y < 6; ++y) {
       c += "  {\n";
-      c += "    ACCUM_FLT at = at_ar[" + std::to_string(y) + "];\n";
+      c += "    FLT at = at_ar[" + std::to_string(y) + "];\n";
       for (int x = 0; x < 6; ++x) {
         const std::string yc = std::to_string(y * 6 + x);
         const std::string src = "src" + std::to_string(x);
-        c += "    ACCUM_FLT4 " + src +
-             " = args.src_tensor.Read<ACCUM_FLT>(tile_id, " + yc +
+        c += "    FLT4 " + src + " = args.src_tensor.Read(tile_id, " + yc +
              ", DST_Z);\n";
         c += "    I" + std::to_string(x) + " += at * " + src + ";\n";
       }
       c += "  }\n";
     }
   } else {
-    c += "  I0 = INIT_ACCUM_FLT4(0.0f);\n";
-    c += "  I1 = INIT_ACCUM_FLT4(0.0f);\n";
-    c += "  I2 = INIT_ACCUM_FLT4(0.0f);\n";
-    c += "  I3 = INIT_ACCUM_FLT4(0.0f);\n";
-    c += "  I4 = INIT_ACCUM_FLT4(0.0f);\n";
-    c += "  I5 = INIT_ACCUM_FLT4(0.0f);\n";
+    c += "  I0 = INIT_FLT4(0.0f);\n";
+    c += "  I1 = INIT_FLT4(0.0f);\n";
+    c += "  I2 = INIT_FLT4(0.0f);\n";
+    c += "  I3 = INIT_FLT4(0.0f);\n";
+    c += "  I4 = INIT_FLT4(0.0f);\n";
+    c += "  I5 = INIT_FLT4(0.0f);\n";
     c += "  for (int y = 0; y < 6; ++y) {\n";
-    c += "    ACCUM_FLT at = at_ar[y];\n";
+    c += "    FLT at = at_ar[y];\n";
     for (int x = 0; x < 6; ++x) {
       const std::string src = "src" + std::to_string(x);
-      c += "    ACCUM_FLT4 " + src +
-           " = args.src_tensor.Read<ACCUM_FLT>(tile_id, y * 6 + " +
+      c += "    FLT4 " + src + " = args.src_tensor.Read(tile_id, y * 6 + " +
            std::to_string(x) + ", DST_Z);\n";
       c += "    I" + std::to_string(x) + " += at * " + src + ";\n";
     }
     c += "  }\n";
   }
-  c += "  ACCUM_FLT4 t0 = I1 + I2;\n";
-  c += "  ACCUM_FLT4 t1 = I3 + I4;\n";
+  c += "  FLT4 t0 = I1 + I2;\n";
+  c += "  FLT4 t1 = I3 + I4;\n";
   c += "  FLT4 bias_val = args.biases.Read(DST_Z);\n";
   c += "  {\n";
   c += "    FLT4 r0 = TO_FLT4(I0 + t0 + t1) + bias_val;\n";
   c += "    args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
   c += "    tile_x++;\n";
   c += "  }\n";
-  c += "  ACCUM_FLT4 t2 = I1 - I2;\n";
-  c += "  ACCUM_FLT4 t3 = I3 - I4;\n";
+  c += "  FLT4 t2 = I1 - I2;\n";
+  c += "  FLT4 t3 = I3 - I4;\n";
   c += "  if (tile_x < args.dst_tensor.Width()) {\n";
   c += "    FLT4 r0 = TO_FLT4(t2 * At[7] + t3 * At[9]) + bias_val;\n";
   c += "    args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";