Added batch support for elementwise operations.
PiperOrigin-RevId: 272974762
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/add.cc
index b2ca756..7a1a070 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/add.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/add.cc
@@ -26,13 +26,16 @@
namespace cl {
std::string Add::GetElementWiseCode(
- const TensorDescriptor& src_descriptor,
- const TensorDescriptor& dst_descriptor, CalculationsPrecision precision,
+ const OperationDef& op_def,
const std::vector<ElementwiseOperation*>& linked_operations) {
- TensorCodeGenerator src_tensor("src_data", "src_size", src_descriptor);
- TensorCodeGenerator dst_tensor("dst_data", "dst_size", dst_descriptor);
+ TensorCodeGenerator src_tensor("src_data",
+ {"src_size.x", "src_size.y", "src_size.z"},
+ op_def.src_tensors[0]);
+ TensorCodeGenerator dst_tensor("dst_data",
+ {"dst_size.x", "dst_size.y", "dst_size.z"},
+ op_def.dst_tensors[0]);
- std::string c = GetCommonDefines(precision);
+ std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
@@ -45,7 +48,7 @@
c += " int X = get_global_id(0);\n";
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
- c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.w) { \n";
+ c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) { \n";
c += " return; \n";
c += " } \n";
c += " FLT4 src = (FLT4)(0.0);\n";
@@ -106,8 +109,9 @@
absl::StrCat("src_data_", link_index_, "_", i);
const std::string size_name =
"src_size_" + std::to_string(link_index_) + "_" + std::to_string(i);
- TensorCodeGenerator src_tensor(tensor_name, size_name,
- definition_.src_tensors[i]);
+ TensorCodeGenerator src_tensor(
+ tensor_name, {size_name + ".x", size_name + ".y", size_name + ".z"},
+ definition_.src_tensors[i]);
if (src_depthes_[i] != dst_depth_) {
absl::StrAppend(&result, " if (", context.z_coord, " < ",
src_depthes_[i], ") {\n");
@@ -149,15 +153,13 @@
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[i]->GetMemoryPtr()));
}
for (int i = 1; i < src_depthes_.size(); ++i) {
- RETURN_IF_ERROR(kernel->SetBytesAuto(src_[i]->GetSizeWithDepth()));
+ RETURN_IF_ERROR(kernel->SetBytesAuto(src_[i]->GetWBatchedHDB()));
}
return OkStatus();
}
Status Add::Compile(const CreationContext& creation_context) {
- const auto code =
- GetElementWiseCode(definition_.src_tensors[0], definition_.dst_tensors[0],
- definition_.precision, linked_operations_);
+ const auto code = GetElementWiseCode(definition_, linked_operations_);
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/add.h b/tensorflow/lite/delegates/gpu/cl/kernels/add.h
index cad591b..ac6243c 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/add.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/add.h
@@ -51,8 +51,7 @@
private:
std::string GetElementWiseCode(
- const TensorDescriptor& src_descriptor,
- const TensorDescriptor& dst_descriptor, CalculationsPrecision precision,
+ const OperationDef& op_def,
const std::vector<ElementwiseOperation*>& linked_operations);
int link_index_;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc
index a3305f8..085c4e9 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc
@@ -25,14 +25,16 @@
namespace {
std::string GetElementWiseCode(
- const TensorDescriptor& src_descriptor,
- const TensorDescriptor& dst_descriptor, CalculationsPrecision precision,
- const ElementwiseOperation& op,
+ const OperationDef& op_def, const ElementwiseOperation& op,
const std::vector<ElementwiseOperation*>& linked_operations) {
- TensorCodeGenerator src_tensor("src_data", "src_size", src_descriptor);
- TensorCodeGenerator dst_tensor("dst_data", "dst_size", dst_descriptor);
+ TensorCodeGenerator src_tensor("src_data",
+ {"src_size.x", "src_size.y", "src_size.z"},
+ op_def.src_tensors[0]);
+ TensorCodeGenerator dst_tensor("dst_data",
+ {"dst_size.x", "dst_size.y", "dst_size.z"},
+ op_def.dst_tensors[0]);
- std::string c = GetCommonDefines(precision);
+ std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
@@ -45,7 +47,7 @@
c += " int X = get_global_id(0);\n";
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
- c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.w) { \n";
+ c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) { \n";
c += " return; \n";
c += " } \n";
c += " FLT4 src = " +
@@ -144,22 +146,20 @@
RETURN_IF_ERROR(BindArguments(&kernel_));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
- RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
- RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
+ RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDB()));
+ RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDB()));
return OkStatus();
}
int3 ElementwiseOperation::GetGridSize() const {
- const int grid_x = dst_[0]->Width();
+ const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
const int grid_y = dst_[0]->Height();
const int grid_z = dst_[0]->Depth();
return int3(grid_x, grid_y, grid_z);
}
Status ElementwiseOperation::Compile(const CreationContext& creation_context) {
- const auto code =
- GetElementWiseCode(definition_.src_tensors[0], definition_.dst_tensors[0],
- definition_.precision, *this, linked_operations_);
+ const auto code = GetElementWiseCode(definition_, *this, linked_operations_);
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);