Using GpuObjects instead of hardcoded opencl buffers in Winograd transformation kernels.
PiperOrigin-RevId: 420121399
Change-Id: Idf2cfc47e21af418c4c5a777992fccab74a49578
diff --git a/tensorflow/lite/delegates/gpu/common/task/BUILD b/tensorflow/lite/delegates/gpu/common/task/BUILD
index 888f06a..480689e 100644
--- a/tensorflow/lite/delegates/gpu/common/task/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/task/BUILD
@@ -10,15 +10,17 @@
srcs = ["arguments.cc"],
hdrs = ["arguments.h"],
deps = [
+ ":buffer_desc",
+ ":gpu_object_desc",
":serialization_base_cc_fbs",
+ ":tensor_desc",
"//tensorflow/lite/delegates/gpu/common:access_type",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
- "//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
- "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
diff --git a/tensorflow/lite/delegates/gpu/common/task/arguments.cc b/tensorflow/lite/delegates/gpu/common/task/arguments.cc
index c7a43b2..17efcc2 100644
--- a/tensorflow/lite/delegates/gpu/common/task/arguments.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/arguments.cc
@@ -22,8 +22,12 @@
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "absl/strings/substitute.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
namespace tflite {
@@ -141,6 +145,109 @@
return absl::OkStatus();
}
+std::string DataTypeToGlType(DataType data_type, int vec_size,
+ bool explicit_f16) {
+ if (data_type == DataType::FLOAT32) {
+ if (vec_size == 1) {
+ return "float";
+ } else {
+ return "vec" + std::to_string(vec_size);
+ }
+ } else if (data_type == DataType::FLOAT16) {
+ if (vec_size == 1) {
+ return explicit_f16 ? "float16_t" : "float";
+ } else {
+ if (explicit_f16) {
+ return "f16vec" + std::to_string(vec_size);
+ } else {
+ return "vec" + std::to_string(vec_size);
+ }
+ }
+ } else if (data_type == DataType::INT32) {
+ if (vec_size == 1) {
+ return "int";
+ } else {
+ return "ivec" + std::to_string(vec_size);
+ }
+ } else if (data_type == DataType::UINT32) {
+ if (vec_size == 1) {
+ return "uint";
+ } else {
+ return "uvec" + std::to_string(vec_size);
+ }
+ }
+ return "unsupported_type";
+}
+
+absl::Status BufferToKernelLanguage(const GpuInfo& gpu_info,
+ const std::string& buffer_name,
+ const BufferDescriptor* buffer_desc,
+ std::string* result) {
+ if (buffer_desc->element_size != 1) {
+ return absl::UnimplementedError("No support of vector types.");
+ }
+ const int elements_count =
+ buffer_desc->size /
+ (buffer_desc->element_size * SizeOf(buffer_desc->element_type));
+ if (gpu_info.IsGlsl()) {
+ const std::string gl_type =
+ DataTypeToGlType(buffer_desc->element_type, buffer_desc->element_size,
+ gpu_info.IsGlslSupportsExplicitFp16());
+ *result = "const ";
+ if (buffer_desc->element_type == DataType::FLOAT16 &&
+ !gpu_info.IsGlslSupportsExplicitFp16()) {
+ *result += "mediump ";
+ }
+ *result += gl_type + " " + buffer_name + "_buffer[] = " + gl_type + "[](\n";
+ } else if (gpu_info.IsApiMetal()) {
+ const std::string metal_type =
+ ToMetalDataType(buffer_desc->element_type, buffer_desc->element_size);
+ *result = "constant " + metal_type + " " + buffer_name + "_buffer[" +
+ std::to_string(elements_count) + "] = {\n";
+ } else if (gpu_info.IsApiOpenCl()) {
+ const std::string cl_type =
+ ToCLDataType(buffer_desc->element_type, buffer_desc->element_size);
+ *result = "__constant " + cl_type + " " + buffer_name + "_buffer[" +
+ std::to_string(elements_count) + "] = {\n";
+ } else {
+ return absl::UnimplementedError("Not supported API.");
+ }
+ if (buffer_desc->element_type == DataType::FLOAT16) {
+ std::string postfix = "f";
+ if (gpu_info.IsGlsl() && gpu_info.IsGlslSupportsExplicitFp16()) {
+ postfix = "hf";
+ }
+ const half* data_ptr =
+ reinterpret_cast<const half*>(buffer_desc->data.data());
+ for (int i = 0; i < elements_count; ++i) {
+ *result += " " +
+ absl::StrFormat("%.10f", static_cast<float>(data_ptr[i])) +
+ postfix;
+ if (i != elements_count - 1) {
+ *result += ",\n";
+ }
+ }
+ } else if (buffer_desc->element_type == DataType::FLOAT32) {
+ const float* data_ptr =
+ reinterpret_cast<const float*>(buffer_desc->data.data());
+ for (int i = 0; i < elements_count; ++i) {
+ *result += " " + absl::StrFormat("%.10f", data_ptr[i]) + "f";
+ if (i != elements_count - 1) {
+ *result += ",\n";
+ }
+ }
+ } else {
+ return absl::UnimplementedError("Not supported type.");
+ }
+ if (gpu_info.IsGlsl()) {
+ *result += ");\n";
+ } else {
+ *result += "};\n";
+ }
+
+ return absl::OkStatus();
+}
+
} // namespace
// Static
@@ -321,6 +428,7 @@
RETURN_IF_ERROR(AddObjectsScalarArgs(gpu_info));
RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, linkables, code));
GetActiveArguments(*code);
+ RETURN_IF_ERROR(ResolveKernelGlobalSpaceBuffers(gpu_info, code));
return absl::OkStatus();
}
@@ -461,5 +569,37 @@
}
}
+absl::Status Arguments::ResolveKernelGlobalSpaceBuffers(const GpuInfo& gpu_info,
+ std::string* code) {
+ for (auto it = objects_.begin(); it != objects_.end();) {
+ const auto* buffer_desc =
+ dynamic_cast<const BufferDescriptor*>(it->second.get());
+ if (!buffer_desc || buffer_desc->memory_type != MemoryType::CONSTANT) {
+ ++it;
+ continue;
+ }
+ bool is_kernel_global_space = false;
+ for (const auto& attribute : buffer_desc->attributes) {
+ if (attribute == "kernel_global_space") {
+ is_kernel_global_space = true;
+ break;
+ }
+ }
+ if (!is_kernel_global_space) {
+ ++it;
+ continue;
+ }
+ std::string declaration;
+ if (!BufferToKernelLanguage(gpu_info, it->first, buffer_desc, &declaration)
+ .ok()) {
+ ++it;
+ continue;
+ }
+ *code = declaration + *code;
+ objects_.erase(it++);
+ }
+ return absl::OkStatus();
+}
+
} // namespace gpu
} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/task/arguments.h b/tensorflow/lite/delegates/gpu/common/task/arguments.h
index 7c366d2..0c2b295 100644
--- a/tensorflow/lite/delegates/gpu/common/task/arguments.h
+++ b/tensorflow/lite/delegates/gpu/common/task/arguments.h
@@ -155,6 +155,9 @@
friend absl::Status Decode(const tflite::gpu::data::Arguments* fb_args,
Arguments* args);
+ absl::Status ResolveKernelGlobalSpaceBuffers(const GpuInfo& gpu_info,
+ std::string* code);
+
friend class cl::CLArguments;
friend class metal::MetalArguments;
diff --git a/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc b/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc
index 9603e12..22c5d38 100644
--- a/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc
@@ -70,6 +70,17 @@
if (element_type == DataType::FLOAT16 &&
!gpu_info.IsGlslSupportsExplicitFp16()) {
if (memory_type == MemoryType::CONSTANT) {
+ bool is_kernel_global_space = false;
+ for (const auto& attribute : attributes) {
+ if (attribute == "kernel_global_space") {
+ is_kernel_global_space = true;
+ break;
+ }
+ }
+ if (is_kernel_global_space) {
+ *result = absl::StrCat("buffer[", args[0], "]");
+ return absl::OkStatus();
+ }
const std::string arg0 = "(" + args[0] + ")";
*result =
absl::StrCat("vec4(unpackHalf2x16(buffer[", arg0, " / 2][", arg0,
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc b/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
index a963dbc..da8be05 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
@@ -15,6 +15,7 @@
#include "tensorflow/lite/delegates/gpu/common/tasks/winograd.h"
+#include <cstring>
#include <string>
#include <utility>
#include <vector>
@@ -29,22 +30,26 @@
namespace tflite {
namespace gpu {
namespace {
+void VectorToKernelBufferDesc(const std::vector<float>& data,
+ DataType data_type,
+ BufferDescriptor* buffer_desc) {
+ buffer_desc->element_type = data_type;
+ buffer_desc->element_size = 1;
+ buffer_desc->memory_type = MemoryType::CONSTANT;
+ buffer_desc->attributes.push_back("kernel_global_space");
+ buffer_desc->size = SizeOf(data_type) * data.size();
+ buffer_desc->data.resize(buffer_desc->size);
+ if (data_type == DataType::FLOAT32) {
+ memcpy(buffer_desc->data.data(), data.data(), buffer_desc->size);
+ } else {
+ half* hf_ptr = reinterpret_cast<half*>(buffer_desc->data.data());
+ for (int i = 0; i < data.size(); ++i) {
+ hf_ptr[i] = data[i];
+ }
+ }
+}
std::string GetKernelWinograd4x4To36(const OperationDef& op_def) {
std::string c;
- auto bt_mat = BtMatrixForWinograd4x4To6x6();
- c += "__constant FLT Bt[36] = {\n";
- for (int y = 0; y < 6; ++y) {
- c += "\t";
- for (int x = 0; x < 6; ++x) {
- c += absl::StrFormat("%.10f", bt_mat[y * 6 + x]) + "f";
- if (!(x == 5 && y == 5)) {
- c += ", ";
- }
- }
- c += "\n";
- }
- c += "};\n";
-
const auto src_desc = op_def.src_tensors[0];
c += R"(
MAIN_FUNCTION($0) {
@@ -102,17 +107,18 @@
c += " FLT4 src = args.src_tensor.Read(coord_x, coord_y, S)" +
multiplier + ";\n";
}
- c += " I[0][" + s_x + "] += Bt[" + std::to_string(y) + "] * src;\n";
- c += " I[1][" + s_x + "] += Bt[" + std::to_string(y + 6) +
- "] * src;\n";
- c += " I[2][" + s_x + "] += Bt[" + std::to_string(y + 12) +
- "] * src;\n";
- c += " I[3][" + s_x + "] += Bt[" + std::to_string(y + 18) +
- "] * src;\n";
- c += " I[4][" + s_x + "] += Bt[" + std::to_string(y + 24) +
- "] * src;\n";
- c += " I[5][" + s_x + "] += Bt[" + std::to_string(y + 30) +
- "] * src;\n";
+ c += " I[0][" + s_x + "] += args.Bt.Read(" + std::to_string(y) +
+ ") * src;\n";
+ c += " I[1][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 6) +
+ ") * src;\n";
+ c += " I[2][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 12) +
+ ") * src;\n";
+ c += " I[3][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 18) +
+ ") * src;\n";
+ c += " I[4][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 24) +
+ ") * src;\n";
+ c += " I[5][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 30) +
+ ") * src;\n";
c += " }\n";
}
c += " }\n";
@@ -125,22 +131,22 @@
int dst_x = GLOBAL_ID_1 * args.tiles_x + GLOBAL_ID_0;
args.dst_tensor.GetAddress(dst_adress, dst_x, 0, S);
for (int y = 0; y < 6; ++y) {
- FLT4 value = I[y][0] + Bt[2] * I[y][2] + Bt[4] * I[y][4];
+ FLT4 value = I[y][0] + args.Bt.Read(2) * I[y][2] + args.Bt.Read(4) * I[y][4];
args.dst_tensor.WriteLinear(value, dst_adress);
dst_adress += args.dst_tensor.Width();
- value = Bt[7] * I[y][1] + Bt[8] * I[y][2] + Bt[9] * I[y][3] + Bt[10] * I[y][4];
+ value = args.Bt.Read(7) * I[y][1] + args.Bt.Read(8) * I[y][2] + args.Bt.Read(9) * I[y][3] + args.Bt.Read(10) * I[y][4];
args.dst_tensor.WriteLinear(value, dst_adress);
dst_adress += args.dst_tensor.Width();
- value = Bt[13] * I[y][1] + Bt[14] * I[y][2] + Bt[15] * I[y][3] + Bt[16] * I[y][4];
+ value = args.Bt.Read(13) * I[y][1] + args.Bt.Read(14) * I[y][2] + args.Bt.Read(15) * I[y][3] + args.Bt.Read(16) * I[y][4];
args.dst_tensor.WriteLinear(value, dst_adress);
dst_adress += args.dst_tensor.Width();
- value = Bt[19] * I[y][1] + Bt[20] * I[y][2] + Bt[21] * I[y][3] + Bt[22] * I[y][4];
+ value = args.Bt.Read(19) * I[y][1] + args.Bt.Read(20) * I[y][2] + args.Bt.Read(21) * I[y][3] + args.Bt.Read(22) * I[y][4];
args.dst_tensor.WriteLinear(value, dst_adress);
dst_adress += args.dst_tensor.Width();
- value = Bt[25] * I[y][1] + Bt[26] * I[y][2] + Bt[27] * I[y][3] + Bt[28] * I[y][4];
+ value = args.Bt.Read(25) * I[y][1] + args.Bt.Read(26) * I[y][2] + args.Bt.Read(27) * I[y][3] + args.Bt.Read(28) * I[y][4];
args.dst_tensor.WriteLinear(value, dst_adress);
dst_adress += args.dst_tensor.Width();
- value = Bt[31] * I[y][1] + Bt[33] * I[y][3] + I[y][5];
+ value = args.Bt.Read(31) * I[y][1] + args.Bt.Read(33) * I[y][3] + I[y][5];
args.dst_tensor.WriteLinear(value, dst_adress);
dst_adress += args.dst_tensor.Width();
}
@@ -150,17 +156,17 @@
c += R"(
int dst_x = GLOBAL_ID_1 * args.tiles_x + GLOBAL_ID_0;
for (int y = 0; y < 6; ++y) {
- FLT4 value = I[y][0] + Bt[2] * I[y][2] + Bt[4] * I[y][4];
+ FLT4 value = I[y][0] + args.Bt.Read(2) * I[y][2] + args.Bt.Read(4) * I[y][4];
args.dst_tensor.Write(value, dst_x, y * 6 + 0, S);
- value = Bt[7] * I[y][1] + Bt[8] * I[y][2] + Bt[9] * I[y][3] + Bt[10] * I[y][4];
+ value = args.Bt.Read(7) * I[y][1] + args.Bt.Read(8) * I[y][2] + args.Bt.Read(9) * I[y][3] + args.Bt.Read(10) * I[y][4];
args.dst_tensor.Write(value, dst_x, y * 6 + 1, S);
- value = Bt[13] * I[y][1] + Bt[14] * I[y][2] + Bt[15] * I[y][3] + Bt[16] * I[y][4];
+ value = args.Bt.Read(13) * I[y][1] + args.Bt.Read(14) * I[y][2] + args.Bt.Read(15) * I[y][3] + args.Bt.Read(16) * I[y][4];
args.dst_tensor.Write(value, dst_x, y * 6 + 2, S);
- value = Bt[19] * I[y][1] + Bt[20] * I[y][2] + Bt[21] * I[y][3] + Bt[22] * I[y][4];
+ value = args.Bt.Read(19) * I[y][1] + args.Bt.Read(20) * I[y][2] + args.Bt.Read(21) * I[y][3] + args.Bt.Read(22) * I[y][4];
args.dst_tensor.Write(value, dst_x, y * 6 + 3, S);
- value = Bt[25] * I[y][1] + Bt[26] * I[y][2] + Bt[27] * I[y][3] + Bt[28] * I[y][4];
+ value = args.Bt.Read(25) * I[y][1] + args.Bt.Read(26) * I[y][2] + args.Bt.Read(27) * I[y][3] + args.Bt.Read(28) * I[y][4];
args.dst_tensor.Write(value, dst_x, y * 6 + 4, S);
- value = Bt[31] * I[y][1] + Bt[33] * I[y][3] + I[y][5];
+ value = args.Bt.Read(31) * I[y][1] + args.Bt.Read(33) * I[y][3] + I[y][5];
args.dst_tensor.Write(value, dst_x, y * 6 + 5, S);
}
}
@@ -171,20 +177,6 @@
std::string GetKernelWinograd36To4x4(const OperationDef& op_def) {
std::string c;
- auto at_mat = AtMatrixForWinograd4x4To6x6();
- c += "__constant FLT At[24] = {\n";
- for (int y = 0; y < 4; ++y) {
- c += "\t";
- for (int x = 0; x < 6; ++x) {
- c += absl::StrFormat("%.10f", at_mat[y * 6 + x]) + "f";
- if (!(x == 5 && y == 3)) {
- c += ", ";
- }
- }
- c += "\n";
- }
- c += "};\n";
-
const auto src_desc = op_def.src_tensors[0];
c += R"(
@@ -209,10 +201,10 @@
for (int y = 0; y < 6; ++y) {
for (int x = 0; x < 6; ++x, src_adress += args.src_tensor.Width()) {
FLT4 src = args.src_tensor.Read(src_adress);
- I[0][x] += src * At[y];
- I[1][x] += src * At[y + 6];
- I[2][x] += src * At[y + 12];
- I[3][x] += src * At[y + 18];
+ I[0][x] += src * args.At.Read(y);
+ I[1][x] += src * args.At.Read(y + 6);
+ I[2][x] += src * args.At.Read(y + 12);
+ I[3][x] += src * args.At.Read(y + 18);
}
}
)";
@@ -221,10 +213,10 @@
for (int y = 0; y < 6; ++y) {
for (int x = 0; x < 6; ++x) {
FLT4 src = args.src_tensor.Read(tile_id, y * 6 + x, Z);
- I[0][x] += src * At[y];
- I[1][x] += src * At[y + 6];
- I[2][x] += src * At[y + 12];
- I[3][x] += src * At[y + 18];
+ I[0][x] += src * args.At.Read(y);
+ I[1][x] += src * args.At.Read(y + 6);
+ I[2][x] += src * args.At.Read(y + 12);
+ I[3][x] += src * args.At.Read(y + 18);
}
}
)";
@@ -242,15 +234,15 @@
FLT4 t2 = I[y][1] - I[y][2];
FLT4 t3 = I[y][3] - I[y][4];
if (tile_x + 1 < args.dst_tensor.Width() && tile_y + y < args.dst_tensor.Height()) {
- FLT4 value = t2 * At[7] + t3 * At[9] + bias_val;
+ FLT4 value = t2 * args.At.Read(7) + t3 * args.At.Read(9) + bias_val;
args.dst_tensor.Write(value, tile_x + 1, tile_y + y, Z);
}
if (tile_x + 2 < args.dst_tensor.Width() && tile_y + y < args.dst_tensor.Height()) {
- FLT4 value = t0 * At[13] + t1 * At[15] + bias_val;
+ FLT4 value = t0 * args.At.Read(13) + t1 * args.At.Read(15) + bias_val;
args.dst_tensor.Write(value, tile_x + 2, tile_y + y, Z);
}
if (tile_x + 3 < args.dst_tensor.Width() && tile_y + y < args.dst_tensor.Height()) {
- FLT4 value = t2 * At[19] + t3 * At[21] + I[y][5] + bias_val;
+ FLT4 value = t2 * args.At.Read(19) + t3 * args.At.Read(21) + I[y][5] + bias_val;
args.dst_tensor.Write(value, tile_x + 3, tile_y + y, Z);
}
}
@@ -295,6 +287,12 @@
desc.args_.AddInt("tiles_x");
desc.args_.AddInt("tiles_y");
+ BufferDescriptor buffer_desc;
+ VectorToKernelBufferDesc(BtMatrixForWinograd4x4To6x6(),
+ definition.GetDataType(), &buffer_desc);
+ desc.args_.AddObject(
+ "Bt", absl::make_unique<BufferDescriptor>(std::move(buffer_desc)));
+
desc.work_group_size_ = int3(8, 4, 1);
return desc;
}
@@ -317,21 +315,6 @@
std::string Winograd4x4To36TileX6::GetWinograd4x4To36TileX6Code(
const OperationDef& op_def, const GpuInfo& gpu_info) {
std::string c;
-
- auto bt_mat = BtMatrixForWinograd4x4To6x6();
- c += "__constant FLT Bt[36] = {\n";
- for (int y = 0; y < 6; ++y) {
- c += "\t";
- for (int x = 0; x < 6; ++x) {
- c += absl::StrFormat("%.10f", bt_mat[y * 6 + x]) + "f";
- if (!(x == 5 && y == 5)) {
- c += ", ";
- }
- }
- c += "\n";
- }
- c += "};\n";
-
const auto& src_desc = op_def.src_tensors[0];
AddSrcTensor("src_tensor", op_def.src_tensors[0]);
AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
@@ -352,8 +335,8 @@
c += " int tile_y = (DST_X / args.tiles_x) * 4;\n";
c += " FLT4 I0, I1, I2, I3, I4, I5;\n";
c += " FLT bt_ar[6];\n";
- c += " FLT4 t0 = args.bt.Read(DST_Y * 2 + 0);\n";
- c += " FLT4 t1 = args.bt.Read(DST_Y * 2 + 1);\n";
+ c += " FLT4 t0 = args.bt_non_uniform.Read(DST_Y * 2 + 0);\n";
+ c += " FLT4 t1 = args.bt_non_uniform.Read(DST_Y * 2 + 1);\n";
c += " DST_Y *= 6;\n";
c += " bt_ar[0] = t0.x;\n";
c += " bt_ar[1] = t0.y;\n";
@@ -463,39 +446,36 @@
c += " }\n";
}
c += " {\n";
- c += " FLT4 r0 = TO_FLT4(I0 + Bt[2] * I2 + Bt[4] * I4);\n";
+ c += " FLT4 r0 = I0 + args.Bt.Read(2) * I2 + args.Bt.Read(4) * I4;\n";
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
- c += " FLT4 r0 = TO_FLT4(Bt[7] * I1 + Bt[8] * I2 + Bt[9] * I3 + Bt[10] * "
- "I4);\n";
+ c += " FLT4 r0 = args.Bt.Read(7) * I1 + args.Bt.Read(8) * I2 + "
+ "args.Bt.Read(9) * I3 + args.Bt.Read(10) * I4;\n";
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
- c += " FLT4 r0 = TO_FLT4(Bt[13] * I1 + Bt[14] * I2 + Bt[15] * I3 + Bt[16] "
- "* "
- "I4);\n";
+ c += " FLT4 r0 = args.Bt.Read(13) * I1 + args.Bt.Read(14) * I2 + "
+ "args.Bt.Read(15) * I3 + args.Bt.Read(16) * I4;\n";
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
- c += " FLT4 r0 = TO_FLT4(Bt[19] * I1 + Bt[20] * I2 + Bt[21] * I3 + Bt[22] "
- "* "
- "I4);\n";
+ c += " FLT4 r0 = args.Bt.Read(19) * I1 + args.Bt.Read(20) * I2 + "
+ "args.Bt.Read(21) * I3 + args.Bt.Read(22) * I4;\n";
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
- c += " FLT4 r0 = TO_FLT4(Bt[25] * I1 + Bt[26] * I2 + Bt[27] * I3 + Bt[28] "
- "* "
- "I4);\n";
+ c += " FLT4 r0 = args.Bt.Read(25) * I1 + args.Bt.Read(26) * I2 + "
+ "args.Bt.Read(27) * I3 + args.Bt.Read(28) * I4;\n";
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
- c += " FLT4 r0 = TO_FLT4(Bt[31] * I1 + Bt[33] * I3 + I5);\n";
+ c += " FLT4 r0 = args.Bt.Read(31) * I1 + args.Bt.Read(33) * I3 + I5;\n";
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n";
c += " }\n";
@@ -520,8 +500,13 @@
desc.storage_type = LinearStorageType::TEXTURE_2D;
desc.element_type = definition_.GetDataType();
desc.UploadLinearData(bt_aligned);
- args_.AddObject("bt",
+ args_.AddObject("bt_non_uniform",
absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+
+ BufferDescriptor buffer_desc;
+ VectorToKernelBufferDesc(bt_mat, definition_.GetDataType(), &buffer_desc);
+ args_.AddObject("Bt",
+ absl::make_unique<BufferDescriptor>(std::move(buffer_desc)));
}
int3 Winograd4x4To36TileX6::SelectBestWorkGroup(
@@ -599,6 +584,12 @@
desc.args_.AddObject("biases", absl::make_unique<TensorLinearDescriptor>(
std::move(bias_desc)));
+ BufferDescriptor buffer_desc;
+ VectorToKernelBufferDesc(AtMatrixForWinograd4x4To6x6(),
+ definition.GetDataType(), &buffer_desc);
+ desc.args_.AddObject(
+ "At", absl::make_unique<BufferDescriptor>(std::move(buffer_desc)));
+
desc.work_group_size_ = int3(32, 1, 1);
return desc;
}
@@ -622,20 +613,6 @@
AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
args_.AddInt("tiles_x");
- auto at_mat = AtMatrixForWinograd4x4To6x6();
- c += "__constant FLT At[24] = {\n";
- for (int y = 0; y < 4; ++y) {
- c += "\t";
- for (int x = 0; x < 6; ++x) {
- c += absl::StrFormat("%.10f", at_mat[y * 6 + x]) + "f";
- if (!(x == 5 && y == 3)) {
- c += ", ";
- }
- }
- c += "\n";
- }
- c += "};\n";
-
c += "MAIN_FUNCTION($0) {\n";
c += " int tile_id = GLOBAL_ID_0;\n";
c += " int DST_Y = GLOBAL_ID_1;\n";
@@ -649,8 +626,8 @@
c += " }\n";
c += " FLT4 I0, I1, I2, I3, I4, I5;\n";
c += " FLT at_ar[6];\n";
- c += " FLT4 t00 = args.at.Read(DST_Y * 2 + 0);\n";
- c += " FLT4 t01 = args.at.Read(DST_Y * 2 + 1);\n";
+ c += " FLT4 t00 = args.at_non_uniform.Read(DST_Y * 2 + 0);\n";
+ c += " FLT4 t01 = args.at_non_uniform.Read(DST_Y * 2 + 1);\n";
c += " at_ar[0] = t00.x;\n";
c += " at_ar[1] = t00.y;\n";
c += " at_ar[2] = t00.z;\n";
@@ -703,24 +680,27 @@
c += " FLT4 t1 = I3 + I4;\n";
c += " FLT4 bias_val = args.biases.Read(DST_Z);\n";
c += " {\n";
- c += " FLT4 r0 = TO_FLT4(I0 + t0 + t1) + bias_val;\n";
+ c += " FLT4 r0 = I0 + t0 + t1 + bias_val;\n";
c += " args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
c += " tile_x++;\n";
c += " }\n";
c += " FLT4 t2 = I1 - I2;\n";
c += " FLT4 t3 = I3 - I4;\n";
c += " if (tile_x < args.dst_tensor.Width()) {\n";
- c += " FLT4 r0 = TO_FLT4(t2 * At[7] + t3 * At[9]) + bias_val;\n";
+ c +=
+ " FLT4 r0 = t2 * args.At.Read(7) + t3 * args.At.Read(9) + bias_val;\n";
c += " args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
c += " tile_x++;\n";
c += " }\n";
c += " if (tile_x < args.dst_tensor.Width()) {\n";
- c += " FLT4 r0 = TO_FLT4(t0 * At[13] + t1 * At[15]) + bias_val;\n";
+ c += " FLT4 r0 = t0 * args.At.Read(13) + t1 * args.At.Read(15) + "
+ "bias_val;\n";
c += " args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
c += " tile_x++;\n";
c += " }\n";
c += " if (tile_x < args.dst_tensor.Width()) {\n";
- c += " FLT4 r0 = TO_FLT4(t2 * At[19] + t3 * At[21] + I5) + bias_val;\n";
+ c += " FLT4 r0 = t2 * args.At.Read(19) + t3 * args.At.Read(21) + I5 + "
+ "bias_val;\n";
c += " args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
c += " tile_x++;\n";
c += " }\n";
@@ -745,8 +725,13 @@
desc.storage_type = LinearStorageType::TEXTURE_2D;
desc.element_type = definition_.GetDataType();
desc.UploadLinearData(at_aligned);
- args_.AddObject("at",
+ args_.AddObject("at_non_uniform",
absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+
+ BufferDescriptor buffer_desc;
+ VectorToKernelBufferDesc(at_mat, definition_.GetDataType(), &buffer_desc);
+ args_.AddObject("At",
+ absl::make_unique<BufferDescriptor>(std::move(buffer_desc)));
}
int3 Winograd36To4x4Tile4x1::SelectBestWorkGroup(