ConcatXY reverted and rewritten to support batch implicitly.
PiperOrigin-RevId: 272784208
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc
index d940134..789ba83 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc
@@ -33,15 +33,17 @@
std::vector<TensorCodeGenerator> srcs(tensors_count);
for (int i = 0; i < tensors_count; ++i) {
const std::string tensor_name = "src_data_" + std::to_string(i);
- const std::string uniform_name = "src_size_" + std::to_string(i);
- srcs[i] =
- TensorCodeGenerator(tensor_name, uniform_name, op_def.src_tensors[i]);
+ const std::string width = "src_size_" + std::to_string(i) + ".x";
+ const std::string height = "src_size_" + std::to_string(i) + ".y";
+ srcs[i] = TensorCodeGenerator(tensor_name, {width, height, "dst_size.z"},
+ op_def.src_tensors[i]);
}
- TensorCodeGenerator dst("dst_data", "dst_size", op_def.dst_tensors[0]);
+ TensorCodeGenerator dst("dst_data",
+ {"dst_size.x", "dst_size.y", "dst_size.z"},
+ op_def.dst_tensors[0]);
std::string c = GetCommonDefines(op_def.precision);
- const std::string batch_id = op_def.batch_support ? "batch_id" : "";
c += "__kernel void main_function(\n";
for (const auto& src : srcs) {
c += src.GetDeclaration(AccessType::READ) + ",\n";
@@ -52,38 +54,22 @@
const std::string uniform_name = "src_size_" + std::to_string(i);
c += " int4 " + uniform_name + ",\n";
}
- for (int i = 0; i < tensors_count; ++i) {
- const std::string uniform_name = "dst_offset_" + std::to_string(i);
- c += " int2 " + uniform_name + ",\n";
- }
- if (op_def.batch_support) {
- c += " int BATCH_SIZE, \n";
- }
c += " int4 dst_size \n";
c += ") {\n";
c += " int X = get_global_id(0);\n";
c += " int Y = get_global_id(1);\n";
- if (op_def.batch_support) {
- c += " int batch_id = get_global_id(2) / dst_size.w;\n";
- c += " int Z = get_global_id(2) - batch_id * dst_size.w;\n";
- c += " if (Z >= dst_size.w || batch_id >= BATCH_SIZE) return;\n";
- } else {
- c += " int Z = get_global_id(2);\n";
- c += " if (Z >= dst_size.w) return;\n";
- }
+ c += " int Z = get_global_id(2);\n";
+ c += " if (Z >= dst_size.z) return;\n";
for (int i = 0; i < tensors_count; ++i) {
- const std::string offset_name = "dst_offset_" + std::to_string(i);
const std::string size_name = "src_size_" + std::to_string(i);
c += " if (X < " + size_name + ".x && Y < " + size_name + ".y) { \n";
- c +=
- " FLT4 result = " +
- srcs[i].Read4D("X", "Y", "Z", batch_id, TextureAddressMode::DONT_CARE) +
- ";\n";
- c += " int dst_x = X + " + offset_name + ".x;\n";
- c += " int dst_y = Y + " + offset_name + ".y;\n";
+ c += " FLT4 result = " +
+ srcs[i].Read3D("X", "Y", "Z", TextureAddressMode::DONT_CARE) + ";\n";
+ c += " int dst_x = X + " + size_name + ".z;\n";
+ c += " int dst_y = Y + " + size_name + ".w;\n";
const LinkingContext context{"result", "dst_x", "dst_y", "Z"};
c += PostProcess(linked_operations, context);
- c += " " + dst.Write4D("result", "dst_x", "dst_y", "Z", batch_id);
+ c += " " + dst.Write3D("result", "dst_x", "dst_y", "Z");
c += " } \n";
}
c += "}\n";
@@ -125,24 +111,17 @@
}
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
- int max_src_width = 0;
- int max_src_height = 0;
- for (int i = 0; i < tensors_count_; ++i) {
- RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[i]->GetSizeWithDepth()));
- max_src_width = std::max(max_src_width, src_[i]->Width());
- max_src_height = std::max(max_src_height, src_[i]->Height());
- }
int x_offset = 0;
int y_offset = 0;
for (int i = 0; i < tensors_count_; ++i) {
- RETURN_IF_ERROR(kernel_.SetBytesAuto(int2(x_offset, y_offset)));
- x_offset += attr_.axis == Axis::WIDTH ? src_[i]->Width() : 0;
- y_offset += attr_.axis == Axis::HEIGHT ? src_[i]->Height() : 0;
+ const int width = src_[i]->Width() * src_[i]->Batch();
+ const int height = src_[i]->Height();
+ RETURN_IF_ERROR(
+ kernel_.SetBytesAuto(int4(width, height, x_offset, y_offset)));
+ x_offset += attr_.axis == Axis::WIDTH ? width : 0;
+ y_offset += attr_.axis == Axis::HEIGHT ? height : 0;
}
- if (definition_.batch_support) {
- RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Batch()));
- }
- RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
+ RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDB()));
return OkStatus();
}
@@ -154,9 +133,9 @@
max_src_height = std::max(max_src_height, src_[i]->Height());
}
- const int grid_x = max_src_width;
+ const int grid_x = max_src_width * dst_[0]->Batch();
const int grid_y = max_src_height;
- const int grid_z = dst_[0]->Depth() * dst_[0]->Batch();
+ const int grid_z = dst_[0]->Depth();
return int3(grid_x, grid_y, grid_z);
}