Using the same precision in all versions of winograd transformation.
PiperOrigin-RevId: 419771522
Change-Id: I35936c2a4a3c7abc4a34771cdcf7e9e95eb98aa9
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc b/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
index 8af0501..a963dbc 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc
@@ -318,22 +318,8 @@
const OperationDef& op_def, const GpuInfo& gpu_info) {
std::string c;
- switch (op_def.precision) {
- case CalculationsPrecision::F32:
- case CalculationsPrecision::F32_F16:
- c += "#define ACCUM_FLT float\n";
- break;
- case CalculationsPrecision::F16:
- c += "#define ACCUM_FLT half\n";
- break;
- }
-
- const DataType accum_type = op_def.precision == CalculationsPrecision::F16
- ? DataType::FLOAT16
- : DataType::FLOAT32;
-
auto bt_mat = BtMatrixForWinograd4x4To6x6();
- c += "__constant ACCUM_FLT Bt[36] = {\n";
+ c += "__constant FLT Bt[36] = {\n";
for (int y = 0; y < 6; ++y) {
c += "\t";
for (int x = 0; x < 6; ++x) {
@@ -346,10 +332,8 @@
}
c += "};\n";
- std::string cl_type = accum_type == DataType::FLOAT16 ? "half" : "float";
- auto src_desc = op_def.src_tensors[0];
- src_desc.SetStateVar("ACCUM_FLT", cl_type);
- AddSrcTensor("src_tensor", src_desc);
+ const auto& src_desc = op_def.src_tensors[0];
+ AddSrcTensor("src_tensor", op_def.src_tensors[0]);
AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
args_.AddInt("padding_x");
args_.AddInt("padding_y");
@@ -366,10 +350,10 @@
c += " }\n";
c += " int tile_x = (DST_X % args.tiles_x) * 4;\n";
c += " int tile_y = (DST_X / args.tiles_x) * 4;\n";
- c += " ACCUM_FLT4 I0, I1, I2, I3, I4, I5;\n";
- c += " ACCUM_FLT bt_ar[6];\n";
- c += " ACCUM_FLT4 t0 = TO_ACCUM_TYPE(args.bt.Read(DST_Y * 2 + 0));\n";
- c += " ACCUM_FLT4 t1 = TO_ACCUM_TYPE(args.bt.Read(DST_Y * 2 + 1));\n";
+ c += " FLT4 I0, I1, I2, I3, I4, I5;\n";
+ c += " FLT bt_ar[6];\n";
+ c += " FLT4 t0 = args.bt.Read(DST_Y * 2 + 0);\n";
+ c += " FLT4 t1 = args.bt.Read(DST_Y * 2 + 1);\n";
c += " DST_Y *= 6;\n";
c += " bt_ar[0] = t0.x;\n";
c += " bt_ar[1] = t0.y;\n";
@@ -380,11 +364,9 @@
auto read_src = [&](const std::string& src, const std::string& xs) {
std::string read_statement;
if (src_desc.IsLinear()) {
- read_statement =
- "args.src_tensor.Read<ACCUM_FLT>(src_a_" + xs + " + offset)";
+ read_statement = "args.src_tensor.Read(src_a_" + xs + " + offset)";
} else {
- read_statement =
- "args.src_tensor.Read<ACCUM_FLT>(xc" + xs + ", yc, DST_Z)";
+ read_statement = "args.src_tensor.Read(xc" + xs + ", yc, DST_Z)";
}
std::string multiplier;
if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) {
@@ -392,7 +374,7 @@
multiplier = " * m" + xs + "_x";
}
}
- c += " ACCUM_FLT4 " + src + " = " + read_statement + multiplier + ";\n";
+ c += " FLT4 " + src + " = " + read_statement + multiplier + ";\n";
};
for (int x = 0; x < 6; ++x) {
const std::string xs = std::to_string(x);
@@ -400,7 +382,7 @@
if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) {
c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs +
" < args.src_tensor.Width());\n";
- c += " ACCUM_FLT m" + xs + "_x = TO_ACCUM_FLT(inx" + xs + ");\n";
+ c += " FLT m" + xs + "_x = INIT_FLT(inx" + xs + ");\n";
c += " xc" + xs + " = clamp(xc" + xs +
", 0, args.src_tensor.Width() - 1);\n";
}
@@ -424,9 +406,9 @@
c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
c += " yc = clamp(yc, 0, args.src_tensor.Height() - 1);\n";
c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
- c += " ACCUM_FLT bt = bt_ar[0] * TO_ACCUM_FLT(iny);\n";
+ c += " FLT bt = bt_ar[0] * INIT_FLT(iny);\n";
} else {
- c += " ACCUM_FLT bt = bt_ar[0];\n";
+ c += " FLT bt = bt_ar[0];\n";
}
for (int x = 0; x < 6; ++x) {
const std::string xs = std::to_string(x);
@@ -443,9 +425,9 @@
c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
c += " yc = clamp(yc, 0, args.src_tensor.Height() - 1);\n";
c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
- c += " ACCUM_FLT bt = bt_ar[" + ys + "] * TO_ACCUM_FLT(iny);\n";
+ c += " FLT bt = bt_ar[" + ys + "] * INIT_FLT(iny);\n";
} else {
- c += " ACCUM_FLT bt = bt_ar[" + ys + "];\n";
+ c += " FLT bt = bt_ar[" + ys + "];\n";
}
for (int x = 0; x < 6; ++x) {
const std::string xs = std::to_string(x);
@@ -456,21 +438,21 @@
c += " }\n";
}
} else {
- c += " I0 = INIT_ACCUM_FLT4(0.0f);\n";
- c += " I1 = INIT_ACCUM_FLT4(0.0f);\n";
- c += " I2 = INIT_ACCUM_FLT4(0.0f);\n";
- c += " I3 = INIT_ACCUM_FLT4(0.0f);\n";
- c += " I4 = INIT_ACCUM_FLT4(0.0f);\n";
- c += " I5 = INIT_ACCUM_FLT4(0.0f);\n";
+ c += " I0 = INIT_FLT4(0.0f);\n";
+ c += " I1 = INIT_FLT4(0.0f);\n";
+ c += " I2 = INIT_FLT4(0.0f);\n";
+ c += " I3 = INIT_FLT4(0.0f);\n";
+ c += " I4 = INIT_FLT4(0.0f);\n";
+ c += " I5 = INIT_FLT4(0.0f);\n";
c += " for (int y = 0; y < 6; ++y) {\n";
c += " int yc = tile_y + args.padding_y + y;\n";
if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) {
c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
c += " yc = clamp(yc, 0, args.src_tensor.Height() - 1);\n";
c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
- c += " ACCUM_FLT bt = bt_ar[y] * TO_ACCUM_FLT(iny);\n";
+ c += " FLT bt = bt_ar[y] * INIT_FLT(iny);\n";
} else {
- c += " ACCUM_FLT bt = bt_ar[y];\n";
+ c += " FLT bt = bt_ar[y];\n";
}
for (int x = 0; x < 6; ++x) {
const std::string xs = std::to_string(x);
@@ -636,29 +618,12 @@
const OperationDef& op_def, const GpuInfo& gpu_info) {
std::string c;
- switch (op_def.precision) {
- case CalculationsPrecision::F32:
- case CalculationsPrecision::F32_F16:
- c += "#define ACCUM_FLT float\n";
- break;
- case CalculationsPrecision::F16:
- c += "#define ACCUM_FLT half\n";
- break;
- }
-
- const DataType accum_type = op_def.precision == CalculationsPrecision::F16
- ? DataType::FLOAT16
- : DataType::FLOAT32;
-
- std::string cl_type = accum_type == DataType::FLOAT16 ? "half" : "float";
- auto src_desc = op_def.src_tensors[0];
- src_desc.SetStateVar("ACCUM_FLT", cl_type);
- AddSrcTensor("src_tensor", src_desc);
+ AddSrcTensor("src_tensor", op_def.src_tensors[0]);
AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
args_.AddInt("tiles_x");
auto at_mat = AtMatrixForWinograd4x4To6x6();
- c += "__constant ACCUM_FLT At[24] = {\n";
+ c += "__constant FLT At[24] = {\n";
for (int y = 0; y < 4; ++y) {
c += "\t";
for (int x = 0; x < 6; ++x) {
@@ -682,10 +647,10 @@
"args.dst_tensor.Height() || DST_Z >= args.dst_tensor.Slices()) {\n";
c += " return; \n";
c += " }\n";
- c += " ACCUM_FLT4 I0, I1, I2, I3, I4, I5;\n";
- c += " ACCUM_FLT at_ar[6];\n";
- c += " ACCUM_FLT4 t00 = TO_ACCUM_TYPE(args.at.Read(DST_Y * 2 + 0));\n";
- c += " ACCUM_FLT4 t01 = TO_ACCUM_TYPE(args.at.Read(DST_Y * 2 + 1));\n";
+ c += " FLT4 I0, I1, I2, I3, I4, I5;\n";
+ c += " FLT at_ar[6];\n";
+ c += " FLT4 t00 = args.at.Read(DST_Y * 2 + 0);\n";
+ c += " FLT4 t01 = args.at.Read(DST_Y * 2 + 1);\n";
c += " at_ar[0] = t00.x;\n";
c += " at_ar[1] = t00.y;\n";
c += " at_ar[2] = t00.z;\n";
@@ -696,56 +661,54 @@
!(op_def.precision == CalculationsPrecision::F32 && gpu_info.IsMali());
if (manual_unroll) {
c += " {\n";
- c += " ACCUM_FLT at = at_ar[0];\n";
+ c += " FLT at = at_ar[0];\n";
for (int x = 0; x < 6; ++x) {
const std::string yc = std::to_string(x);
const std::string src = "src" + std::to_string(x);
- c += " ACCUM_FLT4 " + src +
- " = args.src_tensor.Read<ACCUM_FLT>(tile_id, " + yc + ", DST_Z);\n";
+ c += " FLT4 " + src + " = args.src_tensor.Read(tile_id, " + yc +
+ ", DST_Z);\n";
c += " I" + std::to_string(x) + " = at * " + src + ";\n";
}
c += " }\n";
for (int y = 1; y < 6; ++y) {
c += " {\n";
- c += " ACCUM_FLT at = at_ar[" + std::to_string(y) + "];\n";
+ c += " FLT at = at_ar[" + std::to_string(y) + "];\n";
for (int x = 0; x < 6; ++x) {
const std::string yc = std::to_string(y * 6 + x);
const std::string src = "src" + std::to_string(x);
- c += " ACCUM_FLT4 " + src +
- " = args.src_tensor.Read<ACCUM_FLT>(tile_id, " + yc +
+ c += " FLT4 " + src + " = args.src_tensor.Read(tile_id, " + yc +
", DST_Z);\n";
c += " I" + std::to_string(x) + " += at * " + src + ";\n";
}
c += " }\n";
}
} else {
- c += " I0 = INIT_ACCUM_FLT4(0.0f);\n";
- c += " I1 = INIT_ACCUM_FLT4(0.0f);\n";
- c += " I2 = INIT_ACCUM_FLT4(0.0f);\n";
- c += " I3 = INIT_ACCUM_FLT4(0.0f);\n";
- c += " I4 = INIT_ACCUM_FLT4(0.0f);\n";
- c += " I5 = INIT_ACCUM_FLT4(0.0f);\n";
+ c += " I0 = INIT_FLT4(0.0f);\n";
+ c += " I1 = INIT_FLT4(0.0f);\n";
+ c += " I2 = INIT_FLT4(0.0f);\n";
+ c += " I3 = INIT_FLT4(0.0f);\n";
+ c += " I4 = INIT_FLT4(0.0f);\n";
+ c += " I5 = INIT_FLT4(0.0f);\n";
c += " for (int y = 0; y < 6; ++y) {\n";
- c += " ACCUM_FLT at = at_ar[y];\n";
+ c += " FLT at = at_ar[y];\n";
for (int x = 0; x < 6; ++x) {
const std::string src = "src" + std::to_string(x);
- c += " ACCUM_FLT4 " + src +
- " = args.src_tensor.Read<ACCUM_FLT>(tile_id, y * 6 + " +
+ c += " FLT4 " + src + " = args.src_tensor.Read(tile_id, y * 6 + " +
std::to_string(x) + ", DST_Z);\n";
c += " I" + std::to_string(x) + " += at * " + src + ";\n";
}
c += " }\n";
}
- c += " ACCUM_FLT4 t0 = I1 + I2;\n";
- c += " ACCUM_FLT4 t1 = I3 + I4;\n";
+ c += " FLT4 t0 = I1 + I2;\n";
+ c += " FLT4 t1 = I3 + I4;\n";
c += " FLT4 bias_val = args.biases.Read(DST_Z);\n";
c += " {\n";
c += " FLT4 r0 = TO_FLT4(I0 + t0 + t1) + bias_val;\n";
c += " args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
c += " tile_x++;\n";
c += " }\n";
- c += " ACCUM_FLT4 t2 = I1 - I2;\n";
- c += " ACCUM_FLT4 t3 = I3 - I4;\n";
+ c += " FLT4 t2 = I1 - I2;\n";
+ c += " FLT4 t3 = I3 - I4;\n";
c += " if (tile_x < args.dst_tensor.Width()) {\n";
c += " FLT4 r0 = TO_FLT4(t2 * At[7] + t3 * At[9]) + bias_val;\n";
c += " args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";