ConcatZ reverted and rewritten to support batch implicitly.
PiperOrigin-RevId: 272790908
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc
index 6712379..ca03ecc 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc
@@ -36,22 +36,27 @@
return true;
}
+std::string GetSrcDepthSizeVar(int src_index) {
+ return "src_size_" + std::to_string(src_index) + "_depth";
+}
+
std::string GetConcatKernelCode(
const OperationDef& op_def, const std::vector<int>& channels,
const std::vector<ElementwiseOperation*>& linked_operations) {
std::vector<TensorCodeGenerator> srcs(channels.size());
for (int i = 0; i < channels.size(); ++i) {
const std::string tensor_name = "src_data_" + std::to_string(i);
- const std::string uniform_name = "src_size_" + std::to_string(i);
- srcs[i] =
- TensorCodeGenerator(tensor_name, uniform_name, op_def.src_tensors[i]);
+ srcs[i] = TensorCodeGenerator(
+ tensor_name, {"dst_size.x", "dst_size.y", GetSrcDepthSizeVar(i)},
+ op_def.src_tensors[i]);
}
- TensorCodeGenerator dst("dst_data", "dst_size", op_def.dst_tensors[0]);
+ TensorCodeGenerator dst("dst_data",
+ {"dst_size.x", "dst_size.y", "dst_size.z"},
+ op_def.dst_tensors[0]);
std::string c = GetCommonDefines(op_def.precision);
const std::string postfix[] = {".x", ".y", ".z", ".w"};
- const std::string batch_id = op_def.batch_support ? "batch_id" : "";
c += "__kernel void main_function(\n";
for (const auto& src : srcs) {
c += src.GetDeclaration(AccessType::READ) + ",\n";
@@ -59,21 +64,13 @@
c += dst.GetDeclaration(AccessType::WRITE);
c += GetArgsDeclaration(linked_operations);
for (int i = 0; i < channels.size(); ++i) {
- const std::string uniform_name = "src_size_" + std::to_string(i);
- c += " int4 " + uniform_name + ",\n";
- }
- if (op_def.batch_support) {
- c += " int BATCH_SIZE, \n";
+ c += " int " + GetSrcDepthSizeVar(i) + ",\n";
}
c += " int4 dst_size\n";
c += ") {\n";
c += " int X = get_global_id(0);\n";
c += " int Y = get_global_id(1);\n";
- c += " if (X >= dst_size.x || Y >= dst_size.y) return;\n";
- if (op_def.batch_support) {
- c += " int batch_id = get_global_id(2);\n";
- c += " if (batch_id >= BATCH_SIZE) return;\n";
- }
+ c += " if (X >= dst_size.x || Y >= dst_size.y) return; \n";
if (IsAllChannelsX4(channels)) {
// When all channels % 4 == 0 we can read/assign/write FLT4 elements easily.
@@ -81,37 +78,35 @@
// generation.
c += " int Z = 0;\n";
for (int i = 0; i < channels.size(); ++i) {
- const std::string uniform_name = "src_size_" + std::to_string(i);
const int depth = IntegralDivideRoundUp(channels[i], 4);
if (depth % 2 == 0) {
// We can read more at once inside of loop in case depth % 2 == 0
// it should be better for reading latency hiding
- c += " for (int i = 0; i < " + uniform_name + ".w; i += 2) {\n";
+ c += " for (int i = 0; i < " + GetSrcDepthSizeVar(i) + "; i += 2) {\n";
c += " FLT4 result0 = " +
- srcs[i].Read4D("X", "Y", "i", batch_id,
- TextureAddressMode::DONT_CARE) +
+ srcs[i].Read3D("X", "Y", "i", TextureAddressMode::DONT_CARE) +
";\n";
c += " FLT4 result1 = " +
- srcs[i].Read4D("X", "Y", "i + 1", batch_id,
- TextureAddressMode::DONT_CARE) +
+ srcs[i].Read3D("X", "Y", "i + 1", TextureAddressMode::DONT_CARE) +
";\n";
+ c += " " + dst.GetAddress("dst_adr0", "X", "Y", "Z") + "\n";
+ c += " " + dst.GetAddress("dst_adr1", "X", "Y", "Z + 1") + "\n";
const LinkingContext context_0{"result0", "X", "Y", "Z"};
const LinkingContext context_1{"result1", "X", "Y", "Z + 1"};
c += PostProcess(linked_operations, context_0);
c += PostProcess(linked_operations, context_1);
- c += " " + dst.Write4D("result0", "X", "Y", "Z", batch_id);
- c += " " + dst.Write4D("result1", "X", "Y", "Z + 1", batch_id);
+ c += " " + dst.Write3D("result0", "X", "Y", "Z");
+ c += " " + dst.Write3D("result1", "X", "Y", "Z + 1");
c += " Z += 2;\n";
c += " }\n";
} else {
- c += " for (int i = 0; i < " + uniform_name + ".w; ++i) {\n";
+ c += " for (int i = 0; i < " + GetSrcDepthSizeVar(i) + "; ++i) {\n";
c += " FLT4 result = " +
- srcs[i].Read4D("X", "Y", "i", batch_id,
- TextureAddressMode::DONT_CARE) +
+ srcs[i].Read3D("X", "Y", "i", TextureAddressMode::DONT_CARE) +
";\n";
const LinkingContext context{"result", "X", "Y", "Z"};
c += PostProcess(linked_operations, context);
- c += " " + dst.Write4D("result", "X", "Y", "Z", batch_id);
+ c += " " + dst.Write3D("result", "X", "Y", "Z");
c += " Z++;\n";
c += " }\n";
}
@@ -126,8 +121,8 @@
for (int d = 0; d < depth; ++d) {
const int channels_in_group = std::min(4, channels[i] - d * 4);
const std::string temp_name = "t" + std::to_string(read_index);
- c += " FLT4 " + temp_name + " = " +
- srcs[i].Read4D("X", "Y", std::to_string(d), batch_id,
+ c += " FLT4 " + temp_name + " = ";
+ c += srcs[i].Read3D("X", "Y", std::to_string(d),
TextureAddressMode::DONT_CARE) +
";\n";
for (int ch = 0; ch < channels_in_group; ++ch) {
@@ -139,8 +134,7 @@
c += " {\n";
const LinkingContext context{"result", "X", "Y", std::to_string(z)};
c += PostProcess(linked_operations, context);
- c += " " +
- dst.Write4D("result", "X", "Y", std::to_string(z), batch_id);
+ c += " " + dst.Write3D("result", "X", "Y", std::to_string(z));
c += " }\n";
z++;
}
@@ -152,7 +146,7 @@
c += " {\n";
const LinkingContext context{"result", "X", "Y", std::to_string(z)};
c += PostProcess(linked_operations, context);
- c += " " + dst.Write4D("result", "X", "Y", "Z", std::to_string(z));
+ c += " " + dst.Write3D("result", "X", "Y", std::to_string(z));
c += " }\n";
}
}
@@ -199,21 +193,16 @@
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
for (int i = 0; i < channels_.size(); ++i) {
- int4 size(src_[i]->Width(), src_[i]->Height(), channels_[i],
- IntegralDivideRoundUp(channels_[i], 4));
- RETURN_IF_ERROR(kernel_.SetBytesAuto(size));
+ RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[i]->Depth()));
}
- if (definition_.batch_support) {
- RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Batch()));
- }
- RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
+ RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDB()));
return OkStatus();
}
int3 ConcatZ::GetGridSize() const {
- const int grid_x = dst_[0]->Width();
+ const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
const int grid_y = dst_[0]->Height();
- const int grid_z = dst_[0]->Batch();
+ const int grid_z = 1;
return int3(grid_x, grid_y, grid_z);
}