Using GpuObjects instead of hardcoded opencl buffers in Winograd transformation kernels.

PiperOrigin-RevId: 420121399
Change-Id: Idf2cfc47e21af418c4c5a777992fccab74a49578
diff --git a/tensorflow/lite/delegates/gpu/common/task/BUILD b/tensorflow/lite/delegates/gpu/common/task/BUILD
index 888f06a..480689e 100644
--- a/tensorflow/lite/delegates/gpu/common/task/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/task/BUILD
@@ -10,15 +10,17 @@
     srcs = ["arguments.cc"],
     hdrs = ["arguments.h"],
     deps = [
+        ":buffer_desc",
+        ":gpu_object_desc",
         ":serialization_base_cc_fbs",
+        ":tensor_desc",
         "//tensorflow/lite/delegates/gpu/common:access_type",
         "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:util",
-        "//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
-        "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
     ],
 )
 
diff --git a/tensorflow/lite/delegates/gpu/common/task/arguments.cc b/tensorflow/lite/delegates/gpu/common/task/arguments.cc
index c7a43b2..17efcc2 100644
--- a/tensorflow/lite/delegates/gpu/common/task/arguments.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/arguments.cc
@@ -22,8 +22,12 @@
 #include "absl/strings/ascii.h"
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
 #include "absl/strings/substitute.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 
 namespace tflite {
@@ -141,6 +145,109 @@
   return absl::OkStatus();
 }
 
+std::string DataTypeToGlType(DataType data_type, int vec_size,
+                             bool explicit_f16) {
+  if (data_type == DataType::FLOAT32) {
+    if (vec_size == 1) {
+      return "float";
+    } else {
+      return "vec" + std::to_string(vec_size);
+    }
+  } else if (data_type == DataType::FLOAT16) {
+    if (vec_size == 1) {
+      return explicit_f16 ? "float16_t" : "float";
+    } else {
+      if (explicit_f16) {
+        return "f16vec" + std::to_string(vec_size);
+      } else {
+        return "vec" + std::to_string(vec_size);
+      }
+    }
+  } else if (data_type == DataType::INT32) {
+    if (vec_size == 1) {
+      return "int";
+    } else {
+      return "ivec" + std::to_string(vec_size);
+    }
+  } else if (data_type == DataType::UINT32) {
+    if (vec_size == 1) {
+      return "uint";
+    } else {
+      return "uvec" + std::to_string(vec_size);
+    }
+  }
+  return "unsupported_type";
+}
+
+absl::Status BufferToKernelLanguage(const GpuInfo& gpu_info,
+                                    const std::string& buffer_name,
+                                    const BufferDescriptor* buffer_desc,
+                                    std::string* result) {
+  if (buffer_desc->element_size != 1) {
+    return absl::UnimplementedError("No support of vector types.");
+  }
+  const int elements_count =
+      buffer_desc->size /
+      (buffer_desc->element_size * SizeOf(buffer_desc->element_type));
+  if (gpu_info.IsGlsl()) {
+    const std::string gl_type =
+        DataTypeToGlType(buffer_desc->element_type, buffer_desc->element_size,
+                         gpu_info.IsGlslSupportsExplicitFp16());
+    *result = "const ";
+    if (buffer_desc->element_type == DataType::FLOAT16 &&
+        !gpu_info.IsGlslSupportsExplicitFp16()) {
+      *result += "mediump ";
+    }
+    *result += gl_type + " " + buffer_name + "_buffer[] = " + gl_type + "[](\n";
+  } else if (gpu_info.IsApiMetal()) {
+    const std::string metal_type =
+        ToMetalDataType(buffer_desc->element_type, buffer_desc->element_size);
+    *result = "constant " + metal_type + " " + buffer_name + "_buffer[" +
+              std::to_string(elements_count) + "] = {\n";
+  } else if (gpu_info.IsApiOpenCl()) {
+    const std::string cl_type =
+        ToCLDataType(buffer_desc->element_type, buffer_desc->element_size);
+    *result = "__constant " + cl_type + " " + buffer_name + "_buffer[" +
+              std::to_string(elements_count) + "] = {\n";
+  } else {
+    return absl::UnimplementedError("Not supported API.");
+  }
+  if (buffer_desc->element_type == DataType::FLOAT16) {
+    std::string postfix = "f";
+    if (gpu_info.IsGlsl() && gpu_info.IsGlslSupportsExplicitFp16()) {
+      postfix = "hf";
+    }
+    const half* data_ptr =
+        reinterpret_cast<const half*>(buffer_desc->data.data());
+    for (int i = 0; i < elements_count; ++i) {
+      *result += "  " +
+                 absl::StrFormat("%.10f", static_cast<float>(data_ptr[i])) +
+                 postfix;
+      if (i != elements_count - 1) {
+        *result += ",\n";
+      }
+    }
+  } else if (buffer_desc->element_type == DataType::FLOAT32) {
+    const float* data_ptr =
+        reinterpret_cast<const float*>(buffer_desc->data.data());
+    for (int i = 0; i < elements_count; ++i) {
+      *result += "  " + absl::StrFormat("%.10f", data_ptr[i]) + "f";
+      if (i != elements_count - 1) {
+        *result += ",\n";
+      }
+    }
+  } else {
+    return absl::UnimplementedError("Not supported type.");
+  }
+  if (gpu_info.IsGlsl()) {
+    *result += ");\n";
+  } else {
+    *result += "};\n";
+  }
+
+  return absl::OkStatus();
+}
+
 }  // namespace
 
 // Static
@@ -321,6 +428,7 @@
   RETURN_IF_ERROR(AddObjectsScalarArgs(gpu_info));
   RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, linkables, code));
   GetActiveArguments(*code);
+  RETURN_IF_ERROR(ResolveKernelGlobalSpaceBuffers(gpu_info, code));
   return absl::OkStatus();
 }
 
@@ -461,5 +569,37 @@
   }
 }
 
+absl::Status Arguments::ResolveKernelGlobalSpaceBuffers(const GpuInfo& gpu_info,
+                                                        std::string* code) {
+  for (auto it = objects_.begin(); it != objects_.end();) {
+    const auto* buffer_desc =
+        dynamic_cast<const BufferDescriptor*>(it->second.get());
+    if (!buffer_desc || buffer_desc->memory_type != MemoryType::CONSTANT) {
+      ++it;
+      continue;
+    }
+    bool is_kernel_global_space = false;
+    for (const auto& attribute : buffer_desc->attributes) {
+      if (attribute == "kernel_global_space") {
+        is_kernel_global_space = true;
+        break;
+      }
+    }
+    if (!is_kernel_global_space) {
+      ++it;
+      continue;
+    }
+    std::string declaration;
+    if (!BufferToKernelLanguage(gpu_info, it->first, buffer_desc, &declaration)
+             .ok()) {
+      ++it;
+      continue;
+    }
+    *code = declaration + *code;
+    objects_.erase(it++);
+  }
+  return absl::OkStatus();
+}
+
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/task/arguments.h b/tensorflow/lite/delegates/gpu/common/task/arguments.h
index 7c366d2..0c2b295 100644
--- a/tensorflow/lite/delegates/gpu/common/task/arguments.h
+++ b/tensorflow/lite/delegates/gpu/common/task/arguments.h
@@ -155,6 +155,9 @@
   friend absl::Status Decode(const tflite::gpu::data::Arguments* fb_args,
                              Arguments* args);
 
+  absl::Status ResolveKernelGlobalSpaceBuffers(const GpuInfo& gpu_info,
+                                               std::string* code);
+
   friend class cl::CLArguments;
   friend class metal::MetalArguments;
 
diff --git a/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc b/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc
index 9603e12..22c5d38 100644
--- a/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc
@@ -70,6 +70,17 @@
     if (element_type == DataType::FLOAT16 &&
         !gpu_info.IsGlslSupportsExplicitFp16()) {
       if (memory_type == MemoryType::CONSTANT) {
+        bool is_kernel_global_space = false;
+        for (const auto& attribute : attributes) {
+          if (attribute == "kernel_global_space") {
+            is_kernel_global_space = true;
+            break;
+          }
+        }
+        if (is_kernel_global_space) {
+          *result = absl::StrCat("buffer[", args[0], "]");
+          return absl::OkStatus();
+        }
         const std::string arg0 = "(" + args[0] + ")";
         *result =
             absl::StrCat("vec4(unpackHalf2x16(buffer[", arg0, " / 2][", arg0,
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc b/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
index a963dbc..da8be05 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/lite/delegates/gpu/common/tasks/winograd.h"
 
+#include <cstring>
 #include <string>
 #include <utility>
 #include <vector>
@@ -29,22 +30,26 @@
 namespace tflite {
 namespace gpu {
 namespace {
+void VectorToKernelBufferDesc(const std::vector<float>& data,
+                              DataType data_type,
+                              BufferDescriptor* buffer_desc) {
+  buffer_desc->element_type = data_type;
+  buffer_desc->element_size = 1;
+  buffer_desc->memory_type = MemoryType::CONSTANT;
+  buffer_desc->attributes.push_back("kernel_global_space");
+  buffer_desc->size = SizeOf(data_type) * data.size();
+  buffer_desc->data.resize(buffer_desc->size);
+  if (data_type == DataType::FLOAT32) {
+    memcpy(buffer_desc->data.data(), data.data(), buffer_desc->size);
+  } else {
+    half* hf_ptr = reinterpret_cast<half*>(buffer_desc->data.data());
+    for (int i = 0; i < data.size(); ++i) {
+      hf_ptr[i] = data[i];
+    }
+  }
+}
 std::string GetKernelWinograd4x4To36(const OperationDef& op_def) {
   std::string c;
-  auto bt_mat = BtMatrixForWinograd4x4To6x6();
-  c += "__constant FLT Bt[36] = {\n";
-  for (int y = 0; y < 6; ++y) {
-    c += "\t";
-    for (int x = 0; x < 6; ++x) {
-      c += absl::StrFormat("%.10f", bt_mat[y * 6 + x]) + "f";
-      if (!(x == 5 && y == 5)) {
-        c += ", ";
-      }
-    }
-    c += "\n";
-  }
-  c += "};\n";
-
   const auto src_desc = op_def.src_tensors[0];
   c += R"(
 MAIN_FUNCTION($0) {
@@ -102,17 +107,18 @@
         c += "      FLT4 src = args.src_tensor.Read(coord_x, coord_y, S)" +
              multiplier + ";\n";
       }
-      c += "      I[0][" + s_x + "] += Bt[" + std::to_string(y) + "] * src;\n";
-      c += "      I[1][" + s_x + "] += Bt[" + std::to_string(y + 6) +
-           "] * src;\n";
-      c += "      I[2][" + s_x + "] += Bt[" + std::to_string(y + 12) +
-           "] * src;\n";
-      c += "      I[3][" + s_x + "] += Bt[" + std::to_string(y + 18) +
-           "] * src;\n";
-      c += "      I[4][" + s_x + "] += Bt[" + std::to_string(y + 24) +
-           "] * src;\n";
-      c += "      I[5][" + s_x + "] += Bt[" + std::to_string(y + 30) +
-           "] * src;\n";
+      c += "      I[0][" + s_x + "] += args.Bt.Read(" + std::to_string(y) +
+           ") * src;\n";
+      c += "      I[1][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 6) +
+           ") * src;\n";
+      c += "      I[2][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 12) +
+           ") * src;\n";
+      c += "      I[3][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 18) +
+           ") * src;\n";
+      c += "      I[4][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 24) +
+           ") * src;\n";
+      c += "      I[5][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 30) +
+           ") * src;\n";
       c += "    }\n";
     }
     c += "  }\n";
@@ -125,22 +131,22 @@
   int dst_x = GLOBAL_ID_1 * args.tiles_x + GLOBAL_ID_0;
   args.dst_tensor.GetAddress(dst_adress, dst_x, 0, S);
   for (int y = 0; y < 6; ++y) {
-    FLT4 value = I[y][0] + Bt[2] * I[y][2] + Bt[4] * I[y][4];
+    FLT4 value = I[y][0] + args.Bt.Read(2) * I[y][2] + args.Bt.Read(4) * I[y][4];
     args.dst_tensor.WriteLinear(value, dst_adress);
     dst_adress += args.dst_tensor.Width();
-    value = Bt[7] * I[y][1] + Bt[8] * I[y][2] + Bt[9] * I[y][3] + Bt[10] * I[y][4];
+    value = args.Bt.Read(7) * I[y][1] + args.Bt.Read(8) * I[y][2] + args.Bt.Read(9) * I[y][3] + args.Bt.Read(10) * I[y][4];
     args.dst_tensor.WriteLinear(value, dst_adress);
     dst_adress += args.dst_tensor.Width();
-    value = Bt[13] * I[y][1] + Bt[14] * I[y][2] + Bt[15] * I[y][3] + Bt[16] * I[y][4];
+    value = args.Bt.Read(13) * I[y][1] + args.Bt.Read(14) * I[y][2] + args.Bt.Read(15) * I[y][3] + args.Bt.Read(16) * I[y][4];
     args.dst_tensor.WriteLinear(value, dst_adress);
     dst_adress += args.dst_tensor.Width();
-    value = Bt[19] * I[y][1] + Bt[20] * I[y][2] + Bt[21] * I[y][3] + Bt[22] * I[y][4];
+    value = args.Bt.Read(19) * I[y][1] + args.Bt.Read(20) * I[y][2] + args.Bt.Read(21) * I[y][3] + args.Bt.Read(22) * I[y][4];
     args.dst_tensor.WriteLinear(value, dst_adress);
     dst_adress += args.dst_tensor.Width();
-    value = Bt[25] * I[y][1] + Bt[26] * I[y][2] + Bt[27] * I[y][3] + Bt[28] * I[y][4];
+    value = args.Bt.Read(25) * I[y][1] + args.Bt.Read(26) * I[y][2] + args.Bt.Read(27) * I[y][3] + args.Bt.Read(28) * I[y][4];
     args.dst_tensor.WriteLinear(value, dst_adress);
     dst_adress += args.dst_tensor.Width();
-    value = Bt[31] * I[y][1] + Bt[33] * I[y][3] + I[y][5];
+    value = args.Bt.Read(31) * I[y][1] + args.Bt.Read(33) * I[y][3] + I[y][5];
     args.dst_tensor.WriteLinear(value, dst_adress);
     dst_adress += args.dst_tensor.Width();
   }
@@ -150,17 +156,17 @@
     c += R"(
   int dst_x = GLOBAL_ID_1 * args.tiles_x + GLOBAL_ID_0;
   for (int y = 0; y < 6; ++y) {
-    FLT4 value = I[y][0] + Bt[2] * I[y][2] + Bt[4] * I[y][4];
+    FLT4 value = I[y][0] + args.Bt.Read(2) * I[y][2] + args.Bt.Read(4) * I[y][4];
     args.dst_tensor.Write(value, dst_x, y * 6 + 0, S);
-    value = Bt[7] * I[y][1] + Bt[8] * I[y][2] + Bt[9] * I[y][3] + Bt[10] * I[y][4];
+    value = args.Bt.Read(7) * I[y][1] + args.Bt.Read(8) * I[y][2] + args.Bt.Read(9) * I[y][3] + args.Bt.Read(10) * I[y][4];
     args.dst_tensor.Write(value, dst_x, y * 6 + 1, S);
-    value = Bt[13] * I[y][1] + Bt[14] * I[y][2] + Bt[15] * I[y][3] + Bt[16] * I[y][4];
+    value = args.Bt.Read(13) * I[y][1] + args.Bt.Read(14) * I[y][2] + args.Bt.Read(15) * I[y][3] + args.Bt.Read(16) * I[y][4];
     args.dst_tensor.Write(value, dst_x, y * 6 + 2, S);
-    value = Bt[19] * I[y][1] + Bt[20] * I[y][2] + Bt[21] * I[y][3] + Bt[22] * I[y][4];
+    value = args.Bt.Read(19) * I[y][1] + args.Bt.Read(20) * I[y][2] + args.Bt.Read(21) * I[y][3] + args.Bt.Read(22) * I[y][4];
     args.dst_tensor.Write(value, dst_x, y * 6 + 3, S);
-    value = Bt[25] * I[y][1] + Bt[26] * I[y][2] + Bt[27] * I[y][3] + Bt[28] * I[y][4];
+    value = args.Bt.Read(25) * I[y][1] + args.Bt.Read(26) * I[y][2] + args.Bt.Read(27) * I[y][3] + args.Bt.Read(28) * I[y][4];
     args.dst_tensor.Write(value, dst_x, y * 6 + 4, S);
-    value = Bt[31] * I[y][1] + Bt[33] * I[y][3] + I[y][5];
+    value = args.Bt.Read(31) * I[y][1] + args.Bt.Read(33) * I[y][3] + I[y][5];
     args.dst_tensor.Write(value, dst_x, y * 6 + 5, S);
   }
 }
@@ -171,20 +177,6 @@
 
 std::string GetKernelWinograd36To4x4(const OperationDef& op_def) {
   std::string c;
-  auto at_mat = AtMatrixForWinograd4x4To6x6();
-  c += "__constant FLT At[24] = {\n";
-  for (int y = 0; y < 4; ++y) {
-    c += "\t";
-    for (int x = 0; x < 6; ++x) {
-      c += absl::StrFormat("%.10f", at_mat[y * 6 + x]) + "f";
-      if (!(x == 5 && y == 3)) {
-        c += ", ";
-      }
-    }
-    c += "\n";
-  }
-  c += "};\n";
-
   const auto src_desc = op_def.src_tensors[0];
 
   c += R"(
@@ -209,10 +201,10 @@
   for (int y = 0; y < 6; ++y) {
     for (int x = 0; x < 6; ++x, src_adress += args.src_tensor.Width()) {
       FLT4 src = args.src_tensor.Read(src_adress);
-      I[0][x] += src * At[y];
-      I[1][x] += src * At[y + 6];
-      I[2][x] += src * At[y + 12];
-      I[3][x] += src * At[y + 18];
+      I[0][x] += src * args.At.Read(y);
+      I[1][x] += src * args.At.Read(y + 6);
+      I[2][x] += src * args.At.Read(y + 12);
+      I[3][x] += src * args.At.Read(y + 18);
     }
   }
 )";
@@ -221,10 +213,10 @@
   for (int y = 0; y < 6; ++y) {
     for (int x = 0; x < 6; ++x) {
       FLT4 src = args.src_tensor.Read(tile_id, y * 6 + x, Z);
-      I[0][x] += src * At[y];
-      I[1][x] += src * At[y + 6];
-      I[2][x] += src * At[y + 12];
-      I[3][x] += src * At[y + 18];
+      I[0][x] += src * args.At.Read(y);
+      I[1][x] += src * args.At.Read(y + 6);
+      I[2][x] += src * args.At.Read(y + 12);
+      I[3][x] += src * args.At.Read(y + 18);
     }
   }
 )";
@@ -242,15 +234,15 @@
     FLT4 t2 = I[y][1] - I[y][2];
     FLT4 t3 = I[y][3] - I[y][4];
     if (tile_x + 1 < args.dst_tensor.Width() && tile_y + y < args.dst_tensor.Height()) {
-      FLT4 value = t2 * At[7] + t3 * At[9] + bias_val;
+      FLT4 value = t2 * args.At.Read(7) + t3 * args.At.Read(9) + bias_val;
       args.dst_tensor.Write(value, tile_x + 1, tile_y + y, Z);
     }
     if (tile_x + 2 < args.dst_tensor.Width() && tile_y + y < args.dst_tensor.Height()) {
-      FLT4 value = t0 * At[13] + t1 * At[15] + bias_val;
+      FLT4 value = t0 * args.At.Read(13) + t1 * args.At.Read(15) + bias_val;
       args.dst_tensor.Write(value, tile_x + 2, tile_y + y, Z);
     }
     if (tile_x + 3 < args.dst_tensor.Width() && tile_y + y < args.dst_tensor.Height()) {
-      FLT4 value = t2 * At[19] + t3 * At[21] + I[y][5] + bias_val;
+      FLT4 value = t2 * args.At.Read(19) + t3 * args.At.Read(21) + I[y][5] + bias_val;
       args.dst_tensor.Write(value, tile_x + 3, tile_y + y, Z);
     }
   }
@@ -295,6 +287,12 @@
   desc.args_.AddInt("tiles_x");
   desc.args_.AddInt("tiles_y");
 
+  BufferDescriptor buffer_desc;
+  VectorToKernelBufferDesc(BtMatrixForWinograd4x4To6x6(),
+                           definition.GetDataType(), &buffer_desc);
+  desc.args_.AddObject(
+      "Bt", absl::make_unique<BufferDescriptor>(std::move(buffer_desc)));
+
   desc.work_group_size_ = int3(8, 4, 1);
   return desc;
 }
@@ -317,21 +315,6 @@
 std::string Winograd4x4To36TileX6::GetWinograd4x4To36TileX6Code(
     const OperationDef& op_def, const GpuInfo& gpu_info) {
   std::string c;
-
-  auto bt_mat = BtMatrixForWinograd4x4To6x6();
-  c += "__constant FLT Bt[36] = {\n";
-  for (int y = 0; y < 6; ++y) {
-    c += "\t";
-    for (int x = 0; x < 6; ++x) {
-      c += absl::StrFormat("%.10f", bt_mat[y * 6 + x]) + "f";
-      if (!(x == 5 && y == 5)) {
-        c += ", ";
-      }
-    }
-    c += "\n";
-  }
-  c += "};\n";
-
   const auto& src_desc = op_def.src_tensors[0];
   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
@@ -352,8 +335,8 @@
   c += "  int tile_y = (DST_X / args.tiles_x) * 4;\n";
   c += "  FLT4 I0, I1, I2, I3, I4, I5;\n";
   c += "  FLT bt_ar[6];\n";
-  c += "  FLT4 t0 = args.bt.Read(DST_Y * 2 + 0);\n";
-  c += "  FLT4 t1 = args.bt.Read(DST_Y * 2 + 1);\n";
+  c += "  FLT4 t0 = args.bt_non_uniform.Read(DST_Y * 2 + 0);\n";
+  c += "  FLT4 t1 = args.bt_non_uniform.Read(DST_Y * 2 + 1);\n";
   c += "  DST_Y *= 6;\n";
   c += "  bt_ar[0] = t0.x;\n";
   c += "  bt_ar[1] = t0.y;\n";
@@ -463,39 +446,36 @@
     c += "  }\n";
   }
   c += "  {\n";
-  c += "    FLT4 r0 = TO_FLT4(I0 + Bt[2] * I2 + Bt[4] * I4);\n";
+  c += "    FLT4 r0 = I0 + args.Bt.Read(2) * I2 + args.Bt.Read(4) * I4;\n";
   c += "    args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
   c += "    DST_Y++;\n";
   c += "  }\n";
   c += "  {\n";
-  c += "    FLT4 r0 = TO_FLT4(Bt[7] * I1 + Bt[8] * I2 + Bt[9] * I3 + Bt[10] * "
-       "I4);\n";
+  c += "    FLT4 r0 = args.Bt.Read(7) * I1 + args.Bt.Read(8) * I2 + "
+       "args.Bt.Read(9) * I3 + args.Bt.Read(10) * I4;\n";
   c += "    args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
   c += "    DST_Y++;\n";
   c += "  }\n";
   c += "  {\n";
-  c += "    FLT4 r0 = TO_FLT4(Bt[13] * I1 + Bt[14] * I2 + Bt[15] * I3 + Bt[16] "
-       "* "
-       "I4);\n";
+  c += "    FLT4 r0 = args.Bt.Read(13) * I1 + args.Bt.Read(14) * I2 + "
+       "args.Bt.Read(15) * I3 + args.Bt.Read(16) * I4;\n";
   c += "    args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
   c += "    DST_Y++;\n";
   c += "  }\n";
   c += "  {\n";
-  c += "    FLT4 r0 = TO_FLT4(Bt[19] * I1 + Bt[20] * I2 + Bt[21] * I3 + Bt[22] "
-       "* "
-       "I4);\n";
+  c += "    FLT4 r0 = args.Bt.Read(19) * I1 + args.Bt.Read(20) * I2 + "
+       "args.Bt.Read(21) * I3 + args.Bt.Read(22) * I4;\n";
   c += "    args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
   c += "    DST_Y++;\n";
   c += "  }\n";
   c += "  {\n";
-  c += "    FLT4 r0 = TO_FLT4(Bt[25] * I1 + Bt[26] * I2 + Bt[27] * I3 + Bt[28] "
-       "* "
-       "I4);\n";
+  c += "    FLT4 r0 = args.Bt.Read(25) * I1 + args.Bt.Read(26) * I2 + "
+       "args.Bt.Read(27) * I3 + args.Bt.Read(28) * I4;\n";
   c += "    args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
   c += "    DST_Y++;\n";
   c += "  }\n";
   c += "  {\n";
-  c += "    FLT4 r0 = TO_FLT4(Bt[31] * I1 + Bt[33] * I3 + I5);\n";
+  c += "    FLT4 r0 = args.Bt.Read(31) * I1 + args.Bt.Read(33) * I3 + I5;\n";
   c += "    args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
   c += "    DST_Y++;\n";
   c += "  }\n";
@@ -520,8 +500,13 @@
   desc.storage_type = LinearStorageType::TEXTURE_2D;
   desc.element_type = definition_.GetDataType();
   desc.UploadLinearData(bt_aligned);
-  args_.AddObject("bt",
+  args_.AddObject("bt_non_uniform",
                   absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+
+  BufferDescriptor buffer_desc;
+  VectorToKernelBufferDesc(bt_mat, definition_.GetDataType(), &buffer_desc);
+  args_.AddObject("Bt",
+                  absl::make_unique<BufferDescriptor>(std::move(buffer_desc)));
 }
 
 int3 Winograd4x4To36TileX6::SelectBestWorkGroup(
@@ -599,6 +584,12 @@
   desc.args_.AddObject("biases", absl::make_unique<TensorLinearDescriptor>(
                                      std::move(bias_desc)));
 
+  BufferDescriptor buffer_desc;
+  VectorToKernelBufferDesc(AtMatrixForWinograd4x4To6x6(),
+                           definition.GetDataType(), &buffer_desc);
+  desc.args_.AddObject(
+      "At", absl::make_unique<BufferDescriptor>(std::move(buffer_desc)));
+
   desc.work_group_size_ = int3(32, 1, 1);
   return desc;
 }
@@ -622,20 +613,6 @@
   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
   args_.AddInt("tiles_x");
 
-  auto at_mat = AtMatrixForWinograd4x4To6x6();
-  c += "__constant FLT At[24] = {\n";
-  for (int y = 0; y < 4; ++y) {
-    c += "\t";
-    for (int x = 0; x < 6; ++x) {
-      c += absl::StrFormat("%.10f", at_mat[y * 6 + x]) + "f";
-      if (!(x == 5 && y == 3)) {
-        c += ", ";
-      }
-    }
-    c += "\n";
-  }
-  c += "};\n";
-
   c += "MAIN_FUNCTION($0) {\n";
   c += "  int tile_id = GLOBAL_ID_0;\n";
   c += "  int DST_Y = GLOBAL_ID_1;\n";
@@ -649,8 +626,8 @@
   c += "  }\n";
   c += "  FLT4 I0, I1, I2, I3, I4, I5;\n";
   c += "  FLT at_ar[6];\n";
-  c += "  FLT4 t00 = args.at.Read(DST_Y * 2 + 0);\n";
-  c += "  FLT4 t01 = args.at.Read(DST_Y * 2 + 1);\n";
+  c += "  FLT4 t00 = args.at_non_uniform.Read(DST_Y * 2 + 0);\n";
+  c += "  FLT4 t01 = args.at_non_uniform.Read(DST_Y * 2 + 1);\n";
   c += "  at_ar[0] = t00.x;\n";
   c += "  at_ar[1] = t00.y;\n";
   c += "  at_ar[2] = t00.z;\n";
@@ -703,24 +680,27 @@
   c += "  FLT4 t1 = I3 + I4;\n";
   c += "  FLT4 bias_val = args.biases.Read(DST_Z);\n";
   c += "  {\n";
-  c += "    FLT4 r0 = TO_FLT4(I0 + t0 + t1) + bias_val;\n";
+  c += "    FLT4 r0 = I0 + t0 + t1 + bias_val;\n";
   c += "    args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
   c += "    tile_x++;\n";
   c += "  }\n";
   c += "  FLT4 t2 = I1 - I2;\n";
   c += "  FLT4 t3 = I3 - I4;\n";
   c += "  if (tile_x < args.dst_tensor.Width()) {\n";
-  c += "    FLT4 r0 = TO_FLT4(t2 * At[7] + t3 * At[9]) + bias_val;\n";
+  c +=
+      "    FLT4 r0 = t2 * args.At.Read(7) + t3 * args.At.Read(9) + bias_val;\n";
   c += "    args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
   c += "    tile_x++;\n";
   c += "  }\n";
   c += "  if (tile_x < args.dst_tensor.Width()) {\n";
-  c += "    FLT4 r0 = TO_FLT4(t0 * At[13] + t1 * At[15]) + bias_val;\n";
+  c += "    FLT4 r0 = t0 * args.At.Read(13) + t1 * args.At.Read(15) + "
+       "bias_val;\n";
   c += "    args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
   c += "    tile_x++;\n";
   c += "  }\n";
   c += "  if (tile_x < args.dst_tensor.Width()) {\n";
-  c += "    FLT4 r0 = TO_FLT4(t2 * At[19] + t3 * At[21] + I5) + bias_val;\n";
+  c += "    FLT4 r0 = t2 * args.At.Read(19) + t3 * args.At.Read(21) + I5 + "
+       "bias_val;\n";
   c += "    args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
   c += "    tile_x++;\n";
   c += "  }\n";
@@ -745,8 +725,13 @@
   desc.storage_type = LinearStorageType::TEXTURE_2D;
   desc.element_type = definition_.GetDataType();
   desc.UploadLinearData(at_aligned);
-  args_.AddObject("at",
+  args_.AddObject("at_non_uniform",
                   absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+
+  BufferDescriptor buffer_desc;
+  VectorToKernelBufferDesc(at_mat, definition_.GetDataType(), &buffer_desc);
+  args_.AddObject("At",
+                  absl::make_unique<BufferDescriptor>(std::move(buffer_desc)));
 }
 
 int3 Winograd36To4x4Tile4x1::SelectBestWorkGroup(