Added kernels for elementwise operations with first input broadcast.
PiperOrigin-RevId: 469004696
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc
index 1b07e74..307ffb7 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc
@@ -222,6 +222,26 @@
ASSERT_TRUE(status.ok()) << status.error_message();
}
+TEST_F(OpenCLOperationTest, CosBroadcast) {
+ auto status = CosBroadcastTest(&exec_env_);
+ ASSERT_TRUE(status.ok()) << status.error_message();
+}
+
+TEST_F(OpenCLOperationTest, MaximumScalarBroadcastInput) {
+ auto status = MaximumScalarBroadcastInputTest(&exec_env_);
+ ASSERT_TRUE(status.ok()) << status.error_message();
+}
+
+TEST_F(OpenCLOperationTest, MulLinearBroadcastInput) {
+ auto status = MulLinearBroadcastInputTest(&exec_env_);
+ ASSERT_TRUE(status.ok()) << status.error_message();
+}
+
+TEST_F(OpenCLOperationTest, MulBroadcastBothInputs) {
+ auto status = MulBroadcastBothInputsTest(&exec_env_);
+ ASSERT_TRUE(status.ok()) << status.error_message();
+}
+
} // namespace
} // namespace cl
} // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc b/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc
index 6ad3f9d..26ec90d 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc
@@ -20,6 +20,7 @@
#include <utility>
#include "absl/strings/str_cat.h"
+#include "absl/strings/str_replace.h"
#include "absl/strings/substitute.h"
namespace tflite {
@@ -227,7 +228,7 @@
// Creates simple two input (first input is runtime tensor and second input is
// scalar argument) operation, for example sub, div, pow, etc.
-GPUOperation CreateElementwiseOneRuntimeOneScalar(
+ElementwiseDescriptor CreateElementwiseOneRuntimeOneScalar(
const OperationDef& definition, const OperationType& op_type,
float scalar_parameter, bool swap_inputs) {
ElementwiseDescriptor op_desc;
@@ -239,12 +240,12 @@
op_desc.code = "FLT4 second_val = INIT_FLT4(args.scalar);\n";
op_desc.code += GetTwoInputCode(op_type, "out_value", "in_value",
"second_val", swap_inputs);
- return CreateGpuOperation(definition, std::move(op_desc));
+ return op_desc;
}
// Creates simple two input(first input is runtime tensor and second input is
// constant linear tensor) operation, for example sub, div and etc.
-GPUOperation CreateElementwiseTwoInput(
+ElementwiseDescriptor CreateElementwiseTwoInput(
const GpuInfo& gpu_info, const OperationDef& definition,
const OperationType& op_type,
const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& constant_tensor,
@@ -265,12 +266,12 @@
}
op_desc.code += GetTwoInputCode(op_type, "out_value", "in_value",
"second_val", swap_inputs);
- return CreateGpuOperation(definition, std::move(op_desc));
+ return op_desc;
}
// Creates simple two input(first input is runtime tensor and second input is
// constant HWC tensor) operation, for example sub, div and etc.
-GPUOperation CreateElementwiseTwoInput(
+ElementwiseDescriptor CreateElementwiseTwoInput(
const GpuInfo& gpu_info, const OperationDef& definition,
const OperationType& op_type,
const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& constant_tensor,
@@ -298,24 +299,13 @@
op_desc.code += GetTwoInputCode(op_type, "out_value", "in_value",
"second_val", swap_inputs);
- return CreateGpuOperation(definition, std::move(op_desc));
+ return op_desc;
}
-} // namespace
-
-GPUOperation CreateElementwiseOneInput(const GpuInfo& gpu_info,
- const OperationDef& definition,
- const OperationType& op_type) {
- ElementwiseDescriptor op_desc;
- op_desc.code = GetOneInputCode(gpu_info, op_type, definition.precision,
- "in_value", "out_value");
- return CreateGpuOperation(definition, std::move(op_desc));
-}
-
-GPUOperation CreateElementwise(const GpuInfo& gpu_info,
- const OperationDef& definition,
- const OperationType& op_type,
- const ElementwiseAttributes& attr) {
+ElementwiseDescriptor CreateElementwiseDesc(const GpuInfo& gpu_info,
+ const OperationDef& definition,
+ const OperationType& op_type,
+ const ElementwiseAttributes& attr) {
const float* scalar = absl::get_if<float>(&attr.param);
const auto* linear_tensor =
absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(&attr.param);
@@ -333,10 +323,29 @@
return CreateElementwiseTwoInput(gpu_info, definition, op_type, *hwc_tensor,
attr.runtime_tensor_is_second);
} else {
- return GPUOperation(definition);
+ return ElementwiseDescriptor();
}
}
+} // namespace
+
+GPUOperation CreateElementwiseOneInput(const GpuInfo& gpu_info,
+ const OperationDef& definition,
+ const OperationType& op_type) {
+ ElementwiseDescriptor op_desc;
+ op_desc.code = GetOneInputCode(gpu_info, op_type, definition.precision,
+ "in_value", "out_value");
+ return CreateGpuOperation(definition, std::move(op_desc));
+}
+
+GPUOperation CreateElementwise(const GpuInfo& gpu_info,
+ const OperationDef& definition,
+ const OperationType& op_type,
+ const ElementwiseAttributes& attr) {
+ return CreateGpuOperation(
+ definition, CreateElementwiseDesc(gpu_info, definition, op_type, attr));
+}
+
GPUOperation CreateElementwiseTwoInput(const OperationDef& definition,
const OperationType& op_type,
const BHWC& shape) {
@@ -346,5 +355,121 @@
return CreateGpuOperation(definition, std::move(op_desc), shape);
}
+namespace {
+std::string GetKernelBodyCode(const TensorDescriptor& dst_desc) {
+ std::string c;
+ c += "MAIN_FUNCTION($$0) {\n";
+ if (dst_desc.HasAxis(Axis::BATCH)) {
+ c += " int linear_id = GLOBAL_ID_0;\n";
+ c += " int X = linear_id / args.dst_tensor.Batch();\n";
+ c += " int B = linear_id % args.dst_tensor.Batch();\n";
+ c += " args.dst_tensor.SetBatchRef(B);\n";
+ } else {
+ c += " int X = GLOBAL_ID_0;\n";
+ }
+ c += " int Y = GLOBAL_ID_1;\n";
+ c += " int S = GLOBAL_ID_2;\n";
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
+ "S >= args.dst_tensor.Slices()) return; \n";
+ c += " args.dst_tensor::type result;\n";
+ c += " $0\n";
+ c += " args.dst_tensor.Write(result, X, Y, S);\n";
+ c += "} \n";
+ return c;
+}
+std::string GetReadBroadcastedValueCode(const BHWC& src_shape,
+ const TensorDescriptor& src_desc,
+ const BHWC& dst_shape) {
+ const std::string x_coord = src_shape.w != dst_shape.w ? "0" : "X";
+ const std::string y_coord = src_shape.h != dst_shape.h ? "0" : "Y";
+ const std::string s_coord = src_shape.c != dst_shape.c ? "0" : "S";
+ std::string coords = absl::StrCat(x_coord, ", ", y_coord, ", ", s_coord);
+ if (src_desc.HasAxis(Axis::BATCH)) {
+ const std::string b_coord = src_shape.b != dst_shape.b ? "0" : "B";
+ coords += ", " + b_coord;
+ }
+ std::string read_value_code =
+ absl::StrCat("args.$0::type $1 = args.$0.Read(", coords, ");\n");
+ if (src_shape.c != dst_shape.c) {
+ read_value_code += " $1.y = $1.x;\n";
+ read_value_code += " $1.z = $1.x;\n";
+ read_value_code += " $1.w = $1.x;\n";
+ }
+ return read_value_code;
+}
+} // namespace
+
+GPUOperation CreateElementwiseOneInputWithBroadcast(
+ const GpuInfo& gpu_info, const OperationDef& definition,
+ const OperationType& op_type, const BHWC& input_shape,
+ const BHWC& output_shape) {
+ GPUOperation op(definition);
+ op.AddSrcTensor("src_tensor", definition.src_tensors[0]);
+ op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
+ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+ std::string c;
+ c += " " + absl::Substitute(
+ GetReadBroadcastedValueCode(
+ input_shape, definition.src_tensors[0], output_shape),
+ "src_tensor", "first_value");
+ c += " " + GetOneInputCode(gpu_info, op_type, definition.precision,
+ "first_value", "result");
+ op.code_ = absl::Substitute(GetKernelBodyCode(definition.dst_tensors[0]), c);
+ return op;
+}
+
+GPUOperation CreateElementwiseWithBroadcast(const GpuInfo& gpu_info,
+ const OperationDef& definition,
+ const OperationType& op_type,
+ const ElementwiseAttributes& attr,
+ const BHWC& input_shape,
+ const BHWC& output_shape) {
+ ElementwiseDescriptor op_desc =
+ CreateElementwiseDesc(gpu_info, definition, op_type, attr);
+
+ GPUOperation op(definition);
+ op.args_ = std::move(op_desc.args);
+ op.AddSrcTensor("src_tensor", definition.src_tensors[0]);
+ op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
+ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+ std::string c;
+ c += " " + absl::Substitute(
+ GetReadBroadcastedValueCode(
+ input_shape, definition.src_tensors[0], output_shape),
+ "src_tensor", "first_value");
+ c += " " + absl::StrReplaceAll(op_desc.code, {{"in_value", "first_value"},
+ {"out_value", "result"},
+ {"X_COORD", "X"},
+ {"Y_COORD", "Y"},
+ {"S_COORD", "S"},
+ {"B_COORD", "B"}});
+ op.code_ = absl::Substitute(GetKernelBodyCode(definition.dst_tensors[0]), c);
+ return op;
+}
+
+GPUOperation CreateElementwiseTwoInputWithBroadcast(
+ const OperationDef& definition, const OperationType& op_type,
+ const BHWC& first_input_shape, const BHWC& second_input_shape,
+ const BHWC& output_shape) {
+ GPUOperation op(definition);
+ op.AddSrcTensor("src0_tensor", definition.src_tensors[0]);
+ op.AddSrcTensor("src1_tensor", definition.src_tensors[1]);
+ op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
+ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+ std::string c;
+ c += " " + absl::Substitute(GetReadBroadcastedValueCode(
+ first_input_shape, definition.src_tensors[0],
+ output_shape),
+ "src0_tensor", "first_value");
+ c += " " + absl::Substitute(GetReadBroadcastedValueCode(
+ second_input_shape,
+ definition.src_tensors[1], output_shape),
+ "src1_tensor", "second_value");
+ c += " " +
+ GetTwoInputCode(op_type, "result", "first_value", "second_value", false);
+ op.code_ = absl::Substitute(GetKernelBodyCode(definition.dst_tensors[0]), c);
+ return op;
+}
+
} // namespace gpu
} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/elementwise.h b/tensorflow/lite/delegates/gpu/common/tasks/elementwise.h
index 5c41a8c..0e8e366 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/elementwise.h
+++ b/tensorflow/lite/delegates/gpu/common/tasks/elementwise.h
@@ -31,6 +31,14 @@
const OperationDef& definition,
const OperationType& op_type);
+// Creates simple one input operation without any parameters, for example
+// log, sin, cos, etc.
+// Can broadcast input.
+GPUOperation CreateElementwiseOneInputWithBroadcast(
+ const GpuInfo& gpu_info, const OperationDef& definition,
+ const OperationType& op_type, const BHWC& input_shape,
+ const BHWC& output_shape);
+
// Creates simple two input(first input is runtime tensor and second input is
// constant or linear/hwc tensor) operation, for example sub, div and etc.
GPUOperation CreateElementwise(const GpuInfo& gpu_info,
@@ -38,12 +46,30 @@
const OperationType& op_type,
const ElementwiseAttributes& attr);
+// Creates simple two input(first input is runtime tensor and second input is
+// constant or linear/hwc tensor) operation, for example sub, div and etc.
+// Can broadcast input.
+GPUOperation CreateElementwiseWithBroadcast(const GpuInfo& gpu_info,
+ const OperationDef& definition,
+ const OperationType& op_type,
+ const ElementwiseAttributes& attr,
+ const BHWC& input_shape,
+ const BHWC& output_shape);
+
// Creates simple two input(2 runtime tensors) operation, for example
// sub, div and etc.
GPUOperation CreateElementwiseTwoInput(const OperationDef& definition,
const OperationType& op_type,
const BHWC& shape);
+// Creates simple two input(2 runtime tensors) operation, for example
+// sub, div and etc.
+// Can broadcast first and second input simultaneously.
+GPUOperation CreateElementwiseTwoInputWithBroadcast(
+ const OperationDef& definition, const OperationType& op_type,
+ const BHWC& first_input_shape, const BHWC& second_input_shape,
+ const BHWC& output_shape);
+
} // namespace gpu
} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.cc b/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.cc
index 5af7fe5..07e57af 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.cc
@@ -1221,5 +1221,131 @@
return absl::OkStatus();
}
+absl::Status CosBroadcastTest(TestExecutionEnvironment* env) {
+ TensorFloat32 src_tensor;
+ src_tensor.shape = BHWC(1, 2, 1, 1);
+ src_tensor.data = {0.7f, -1.5f};
+
+ for (auto precision : env->GetSupportedPrecisions()) {
+ auto data_type = DeduceDataTypeFromPrecision(precision);
+ for (auto storage : env->GetSupportedStorages(data_type)) {
+ const float eps = precision == CalculationsPrecision::F32 ? 5e-5f : 1e-3f;
+ OperationDef op_def;
+ op_def.precision = precision;
+ op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+ op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+ TensorFloat32 dst_tensor;
+ BHWC output_shape(1, 2, 1, 2);
+ GPUOperation operation = CreateElementwiseOneInputWithBroadcast(
+ env->GetGpuInfo(), op_def, OperationType::COS, src_tensor.shape,
+ output_shape);
+ RETURN_IF_ERROR(env->ExecuteGPUOperation(
+ src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
+ output_shape, &dst_tensor));
+ RETURN_IF_ERROR(PointWiseNear(
+ {std::cos(0.7f), std::cos(0.7f), std::cos(-1.5f), std::cos(-1.5f)},
+ dst_tensor.data, eps));
+ }
+ }
+ return absl::OkStatus();
+}
+
+absl::Status MaximumScalarBroadcastInputTest(TestExecutionEnvironment* env) {
+ TensorFloat32 src_tensor_0;
+ src_tensor_0.shape = BHWC(1, 2, 1, 1);
+ src_tensor_0.data = {2.0f, -3.0f};
+
+ ElementwiseAttributes attr;
+ attr.param = -2.0f;
+
+ for (auto precision : env->GetSupportedPrecisions()) {
+ auto data_type = DeduceDataTypeFromPrecision(precision);
+ for (auto storage : env->GetSupportedStorages(data_type)) {
+ const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
+ OperationDef op_def;
+ op_def.precision = precision;
+ op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+ op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+ TensorFloat32 dst_tensor;
+ BHWC output_shape(1, 2, 1, 2);
+ GPUOperation operation = CreateElementwiseWithBroadcast(
+ env->GetGpuInfo(), op_def, OperationType::MAXIMUM, attr,
+ src_tensor_0.shape, output_shape);
+ RETURN_IF_ERROR(env->ExecuteGPUOperation(
+ src_tensor_0, std::make_unique<GPUOperation>(std::move(operation)),
+ output_shape, &dst_tensor));
+ RETURN_IF_ERROR(
+ PointWiseNear({2.0f, 2.0f, -2.0f, -2.0f}, dst_tensor.data, eps));
+ }
+ }
+ return absl::OkStatus();
+}
+
+absl::Status MulLinearBroadcastInputTest(TestExecutionEnvironment* env) {
+ TensorFloat32 src_tensor_0;
+ src_tensor_0.shape = BHWC(1, 2, 1, 1);
+ src_tensor_0.data = {2.0f, -3.0f};
+
+ ::tflite::gpu::Tensor<Linear, DataType::FLOAT32> linear_tensor;
+ linear_tensor.shape = Linear(2);
+ linear_tensor.data = {0.5f, 2.0f};
+ ElementwiseAttributes attr;
+ attr.param = linear_tensor;
+
+ for (auto precision : env->GetSupportedPrecisions()) {
+ auto data_type = DeduceDataTypeFromPrecision(precision);
+ for (auto storage : env->GetSupportedStorages(data_type)) {
+ const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
+ OperationDef op_def;
+ op_def.precision = precision;
+ op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+ op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+ TensorFloat32 dst_tensor;
+ BHWC output_shape(1, 2, 1, 2);
+ GPUOperation operation = CreateElementwiseWithBroadcast(
+ env->GetGpuInfo(), op_def, OperationType::MUL, attr,
+ src_tensor_0.shape, output_shape);
+ RETURN_IF_ERROR(env->ExecuteGPUOperation(
+ src_tensor_0, std::make_unique<GPUOperation>(std::move(operation)),
+ output_shape, &dst_tensor));
+ RETURN_IF_ERROR(
+ PointWiseNear({1.0f, 4.0f, -1.5f, -6.0f}, dst_tensor.data, eps));
+ }
+ }
+ return absl::OkStatus();
+}
+
+absl::Status MulBroadcastBothInputsTest(TestExecutionEnvironment* env) {
+ TensorFloat32 src_tensor_0, src_tensor_1;
+ src_tensor_0.shape = BHWC(1, 1, 2, 1);
+ src_tensor_1.shape = BHWC(1, 1, 1, 2);
+ src_tensor_0.data = {1.0f, 2.0f};
+ src_tensor_1.data = {3.0f, 4.0f};
+
+ for (auto precision : env->GetSupportedPrecisions()) {
+ auto data_type = DeduceDataTypeFromPrecision(precision);
+ for (auto storage : env->GetSupportedStorages(data_type)) {
+ const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
+ OperationDef op_def;
+ op_def.precision = precision;
+ op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+ op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+ op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+ TensorFloat32 dst_tensor;
+ BHWC output_shape(1, 1, 2, 2);
+ GPUOperation operation = CreateElementwiseTwoInputWithBroadcast(
+ op_def, OperationType::MUL, src_tensor_0.shape, src_tensor_1.shape,
+ output_shape);
+ RETURN_IF_ERROR(env->ExecuteGPUOperation(
+ {src_tensor_0, src_tensor_1},
+ std::make_unique<GPUOperation>(std::move(operation)), output_shape,
+ &dst_tensor));
+ RETURN_IF_ERROR(
+ PointWiseNear({3.0f, 4.0f, 6.0f, 8.0f}, dst_tensor.data, eps));
+ }
+ }
+ return absl::OkStatus();
+}
+
} // namespace gpu
} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.h b/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.h
index b3db4cc..c7df18d 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.h
+++ b/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.h
@@ -62,6 +62,10 @@
absl::Status GreaterEqualTest(TestExecutionEnvironment* env);
absl::Status EqualTest(TestExecutionEnvironment* env);
absl::Status NotEqualTest(TestExecutionEnvironment* env);
+absl::Status CosBroadcastTest(TestExecutionEnvironment* env);
+absl::Status MaximumScalarBroadcastInputTest(TestExecutionEnvironment* env);
+absl::Status MulLinearBroadcastInputTest(TestExecutionEnvironment* env);
+absl::Status MulBroadcastBothInputsTest(TestExecutionEnvironment* env);
} // namespace gpu
} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm
index 29057a1..f449964 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm
@@ -221,4 +221,24 @@
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
+- (void)testCosBroadcast {
+ auto status = CosBroadcastTest(&exec_env_);
+ XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testMaximumScalarBroadcastInput {
+ auto status = MaximumScalarBroadcastInputTest(&exec_env_);
+ XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testMulLinearBroadcastInput {
+ auto status = MulLinearBroadcastInputTest(&exec_env_);
+ XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testMulBroadcastBothInputs {
+ auto status = MulBroadcastBothInputsTest(&exec_env_);
+ XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
@end