Added kernels for elementwise operations with first input broadcast.

PiperOrigin-RevId: 469004696
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc
index 1b07e74..307ffb7 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc
@@ -222,6 +222,26 @@
   ASSERT_TRUE(status.ok()) << status.error_message();
 }
 
+TEST_F(OpenCLOperationTest, CosBroadcast) {
+  auto status = CosBroadcastTest(&exec_env_);
+  ASSERT_TRUE(status.ok()) << status.error_message();
+}
+
+TEST_F(OpenCLOperationTest, MaximumScalarBroadcastInput) {
+  auto status = MaximumScalarBroadcastInputTest(&exec_env_);
+  ASSERT_TRUE(status.ok()) << status.error_message();
+}
+
+TEST_F(OpenCLOperationTest, MulLinearBroadcastInput) {
+  auto status = MulLinearBroadcastInputTest(&exec_env_);
+  ASSERT_TRUE(status.ok()) << status.error_message();
+}
+
+TEST_F(OpenCLOperationTest, MulBroadcastBothInputs) {
+  auto status = MulBroadcastBothInputsTest(&exec_env_);
+  ASSERT_TRUE(status.ok()) << status.error_message();
+}
+
 }  // namespace
 }  // namespace cl
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc b/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc
index 6ad3f9d..26ec90d 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc
@@ -20,6 +20,7 @@
 #include <utility>
 
 #include "absl/strings/str_cat.h"
+#include "absl/strings/str_replace.h"
 #include "absl/strings/substitute.h"
 
 namespace tflite {
@@ -227,7 +228,7 @@
 
 // Creates simple two input (first input is runtime tensor and second input is
 // scalar argument) operation, for example sub, div, pow, etc.
-GPUOperation CreateElementwiseOneRuntimeOneScalar(
+ElementwiseDescriptor CreateElementwiseOneRuntimeOneScalar(
     const OperationDef& definition, const OperationType& op_type,
     float scalar_parameter, bool swap_inputs) {
   ElementwiseDescriptor op_desc;
@@ -239,12 +240,12 @@
   op_desc.code = "FLT4 second_val = INIT_FLT4(args.scalar);\n";
   op_desc.code += GetTwoInputCode(op_type, "out_value", "in_value",
                                   "second_val", swap_inputs);
-  return CreateGpuOperation(definition, std::move(op_desc));
+  return op_desc;
 }
 
 // Creates simple two input(first input is runtime tensor and second input is
 // constant linear tensor) operation, for example sub, div and etc.
-GPUOperation CreateElementwiseTwoInput(
+ElementwiseDescriptor CreateElementwiseTwoInput(
     const GpuInfo& gpu_info, const OperationDef& definition,
     const OperationType& op_type,
     const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& constant_tensor,
@@ -265,12 +266,12 @@
   }
   op_desc.code += GetTwoInputCode(op_type, "out_value", "in_value",
                                   "second_val", swap_inputs);
-  return CreateGpuOperation(definition, std::move(op_desc));
+  return op_desc;
 }
 
 // Creates simple two input(first input is runtime tensor and second input is
 // constant HWC tensor) operation, for example sub, div and etc.
-GPUOperation CreateElementwiseTwoInput(
+ElementwiseDescriptor CreateElementwiseTwoInput(
     const GpuInfo& gpu_info, const OperationDef& definition,
     const OperationType& op_type,
     const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& constant_tensor,
@@ -298,24 +299,13 @@
   op_desc.code += GetTwoInputCode(op_type, "out_value", "in_value",
                                   "second_val", swap_inputs);
 
-  return CreateGpuOperation(definition, std::move(op_desc));
+  return op_desc;
 }
 
-}  // namespace
-
-GPUOperation CreateElementwiseOneInput(const GpuInfo& gpu_info,
-                                       const OperationDef& definition,
-                                       const OperationType& op_type) {
-  ElementwiseDescriptor op_desc;
-  op_desc.code = GetOneInputCode(gpu_info, op_type, definition.precision,
-                                 "in_value", "out_value");
-  return CreateGpuOperation(definition, std::move(op_desc));
-}
-
-GPUOperation CreateElementwise(const GpuInfo& gpu_info,
-                               const OperationDef& definition,
-                               const OperationType& op_type,
-                               const ElementwiseAttributes& attr) {
+ElementwiseDescriptor CreateElementwiseDesc(const GpuInfo& gpu_info,
+                                            const OperationDef& definition,
+                                            const OperationType& op_type,
+                                            const ElementwiseAttributes& attr) {
   const float* scalar = absl::get_if<float>(&attr.param);
   const auto* linear_tensor =
       absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(&attr.param);
@@ -333,10 +323,29 @@
     return CreateElementwiseTwoInput(gpu_info, definition, op_type, *hwc_tensor,
                                      attr.runtime_tensor_is_second);
   } else {
-    return GPUOperation(definition);
+    return ElementwiseDescriptor();
   }
 }
 
+}  // namespace
+
+GPUOperation CreateElementwiseOneInput(const GpuInfo& gpu_info,
+                                       const OperationDef& definition,
+                                       const OperationType& op_type) {
+  ElementwiseDescriptor op_desc;
+  op_desc.code = GetOneInputCode(gpu_info, op_type, definition.precision,
+                                 "in_value", "out_value");
+  return CreateGpuOperation(definition, std::move(op_desc));
+}
+
+GPUOperation CreateElementwise(const GpuInfo& gpu_info,
+                               const OperationDef& definition,
+                               const OperationType& op_type,
+                               const ElementwiseAttributes& attr) {
+  return CreateGpuOperation(
+      definition, CreateElementwiseDesc(gpu_info, definition, op_type, attr));
+}
+
 GPUOperation CreateElementwiseTwoInput(const OperationDef& definition,
                                        const OperationType& op_type,
                                        const BHWC& shape) {
@@ -346,5 +355,121 @@
   return CreateGpuOperation(definition, std::move(op_desc), shape);
 }
 
+namespace {
+std::string GetKernelBodyCode(const TensorDescriptor& dst_desc) {
+  std::string c;
+  c += "MAIN_FUNCTION($$0) {\n";
+  if (dst_desc.HasAxis(Axis::BATCH)) {
+    c += "  int linear_id = GLOBAL_ID_0;\n";
+    c += "  int X = linear_id / args.dst_tensor.Batch();\n";
+    c += "  int B = linear_id % args.dst_tensor.Batch();\n";
+    c += "  args.dst_tensor.SetBatchRef(B);\n";
+  } else {
+    c += "  int X = GLOBAL_ID_0;\n";
+  }
+  c += "  int Y = GLOBAL_ID_1;\n";
+  c += "  int S = GLOBAL_ID_2;\n";
+  c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
+       "S >= args.dst_tensor.Slices()) return; \n";
+  c += "  args.dst_tensor::type result;\n";
+  c += "  $0\n";
+  c += "  args.dst_tensor.Write(result, X, Y, S);\n";
+  c += "} \n";
+  return c;
+}
+std::string GetReadBroadcastedValueCode(const BHWC& src_shape,
+                                        const TensorDescriptor& src_desc,
+                                        const BHWC& dst_shape) {
+  const std::string x_coord = src_shape.w != dst_shape.w ? "0" : "X";
+  const std::string y_coord = src_shape.h != dst_shape.h ? "0" : "Y";
+  const std::string s_coord = src_shape.c != dst_shape.c ? "0" : "S";
+  std::string coords = absl::StrCat(x_coord, ", ", y_coord, ", ", s_coord);
+  if (src_desc.HasAxis(Axis::BATCH)) {
+    const std::string b_coord = src_shape.b != dst_shape.b ? "0" : "B";
+    coords += ", " + b_coord;
+  }
+  std::string read_value_code =
+      absl::StrCat("args.$0::type $1 = args.$0.Read(", coords, ");\n");
+  if (src_shape.c != dst_shape.c) {
+    read_value_code += "  $1.y = $1.x;\n";
+    read_value_code += "  $1.z = $1.x;\n";
+    read_value_code += "  $1.w = $1.x;\n";
+  }
+  return read_value_code;
+}
+}  // namespace
+
+GPUOperation CreateElementwiseOneInputWithBroadcast(
+    const GpuInfo& gpu_info, const OperationDef& definition,
+    const OperationType& op_type, const BHWC& input_shape,
+    const BHWC& output_shape) {
+  GPUOperation op(definition);
+  op.AddSrcTensor("src_tensor", definition.src_tensors[0]);
+  op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+  std::string c;
+  c += "  " + absl::Substitute(
+                  GetReadBroadcastedValueCode(
+                      input_shape, definition.src_tensors[0], output_shape),
+                  "src_tensor", "first_value");
+  c += "  " + GetOneInputCode(gpu_info, op_type, definition.precision,
+                              "first_value", "result");
+  op.code_ = absl::Substitute(GetKernelBodyCode(definition.dst_tensors[0]), c);
+  return op;
+}
+
+GPUOperation CreateElementwiseWithBroadcast(const GpuInfo& gpu_info,
+                                            const OperationDef& definition,
+                                            const OperationType& op_type,
+                                            const ElementwiseAttributes& attr,
+                                            const BHWC& input_shape,
+                                            const BHWC& output_shape) {
+  ElementwiseDescriptor op_desc =
+      CreateElementwiseDesc(gpu_info, definition, op_type, attr);
+
+  GPUOperation op(definition);
+  op.args_ = std::move(op_desc.args);
+  op.AddSrcTensor("src_tensor", definition.src_tensors[0]);
+  op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+  std::string c;
+  c += "  " + absl::Substitute(
+                  GetReadBroadcastedValueCode(
+                      input_shape, definition.src_tensors[0], output_shape),
+                  "src_tensor", "first_value");
+  c += "  " + absl::StrReplaceAll(op_desc.code, {{"in_value", "first_value"},
+                                                 {"out_value", "result"},
+                                                 {"X_COORD", "X"},
+                                                 {"Y_COORD", "Y"},
+                                                 {"S_COORD", "S"},
+                                                 {"B_COORD", "B"}});
+  op.code_ = absl::Substitute(GetKernelBodyCode(definition.dst_tensors[0]), c);
+  return op;
+}
+
+GPUOperation CreateElementwiseTwoInputWithBroadcast(
+    const OperationDef& definition, const OperationType& op_type,
+    const BHWC& first_input_shape, const BHWC& second_input_shape,
+    const BHWC& output_shape) {
+  GPUOperation op(definition);
+  op.AddSrcTensor("src0_tensor", definition.src_tensors[0]);
+  op.AddSrcTensor("src1_tensor", definition.src_tensors[1]);
+  op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+  std::string c;
+  c += "  " + absl::Substitute(GetReadBroadcastedValueCode(
+                                   first_input_shape, definition.src_tensors[0],
+                                   output_shape),
+                               "src0_tensor", "first_value");
+  c += "  " + absl::Substitute(GetReadBroadcastedValueCode(
+                                   second_input_shape,
+                                   definition.src_tensors[1], output_shape),
+                               "src1_tensor", "second_value");
+  c += "  " +
+       GetTwoInputCode(op_type, "result", "first_value", "second_value", false);
+  op.code_ = absl::Substitute(GetKernelBodyCode(definition.dst_tensors[0]), c);
+  return op;
+}
+
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/elementwise.h b/tensorflow/lite/delegates/gpu/common/tasks/elementwise.h
index 5c41a8c..0e8e366 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/elementwise.h
+++ b/tensorflow/lite/delegates/gpu/common/tasks/elementwise.h
@@ -31,6 +31,14 @@
                                        const OperationDef& definition,
                                        const OperationType& op_type);
 
+// Creates simple one input operation without any parameters, for example
+// log, sin, cos, etc.
+// Can broadcast input.
+GPUOperation CreateElementwiseOneInputWithBroadcast(
+    const GpuInfo& gpu_info, const OperationDef& definition,
+    const OperationType& op_type, const BHWC& input_shape,
+    const BHWC& output_shape);
+
 // Creates simple two input(first input is runtime tensor and second input is
 // constant or linear/hwc tensor) operation, for example sub, div and etc.
 GPUOperation CreateElementwise(const GpuInfo& gpu_info,
@@ -38,12 +46,30 @@
                                const OperationType& op_type,
                                const ElementwiseAttributes& attr);
 
+// Creates simple two input(first input is runtime tensor and second input is
+// constant or linear/hwc tensor) operation, for example sub, div and etc.
+// Can broadcast input.
+GPUOperation CreateElementwiseWithBroadcast(const GpuInfo& gpu_info,
+                                            const OperationDef& definition,
+                                            const OperationType& op_type,
+                                            const ElementwiseAttributes& attr,
+                                            const BHWC& input_shape,
+                                            const BHWC& output_shape);
+
 // Creates simple two input(2 runtime tensors) operation, for example
 // sub, div and etc.
 GPUOperation CreateElementwiseTwoInput(const OperationDef& definition,
                                        const OperationType& op_type,
                                        const BHWC& shape);
 
+// Creates simple two input(2 runtime tensors) operation, for example
+// sub, div and etc.
+// Can broadcast first and second input simultaneously.
+GPUOperation CreateElementwiseTwoInputWithBroadcast(
+    const OperationDef& definition, const OperationType& op_type,
+    const BHWC& first_input_shape, const BHWC& second_input_shape,
+    const BHWC& output_shape);
+
 }  // namespace gpu
 }  // namespace tflite
 
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.cc b/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.cc
index 5af7fe5..07e57af 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.cc
@@ -1221,5 +1221,131 @@
   return absl::OkStatus();
 }
 
+absl::Status CosBroadcastTest(TestExecutionEnvironment* env) {
+  TensorFloat32 src_tensor;
+  src_tensor.shape = BHWC(1, 2, 1, 1);
+  src_tensor.data = {0.7f, -1.5f};
+
+  for (auto precision : env->GetSupportedPrecisions()) {
+    auto data_type = DeduceDataTypeFromPrecision(precision);
+    for (auto storage : env->GetSupportedStorages(data_type)) {
+      const float eps = precision == CalculationsPrecision::F32 ? 5e-5f : 1e-3f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      BHWC output_shape(1, 2, 1, 2);
+      GPUOperation operation = CreateElementwiseOneInputWithBroadcast(
+          env->GetGpuInfo(), op_def, OperationType::COS, src_tensor.shape,
+          output_shape);
+      RETURN_IF_ERROR(env->ExecuteGPUOperation(
+          src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
+          output_shape, &dst_tensor));
+      RETURN_IF_ERROR(PointWiseNear(
+          {std::cos(0.7f), std::cos(0.7f), std::cos(-1.5f), std::cos(-1.5f)},
+          dst_tensor.data, eps));
+    }
+  }
+  return absl::OkStatus();
+}
+
+absl::Status MaximumScalarBroadcastInputTest(TestExecutionEnvironment* env) {
+  TensorFloat32 src_tensor_0;
+  src_tensor_0.shape = BHWC(1, 2, 1, 1);
+  src_tensor_0.data = {2.0f, -3.0f};
+
+  ElementwiseAttributes attr;
+  attr.param = -2.0f;
+
+  for (auto precision : env->GetSupportedPrecisions()) {
+    auto data_type = DeduceDataTypeFromPrecision(precision);
+    for (auto storage : env->GetSupportedStorages(data_type)) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      BHWC output_shape(1, 2, 1, 2);
+      GPUOperation operation = CreateElementwiseWithBroadcast(
+          env->GetGpuInfo(), op_def, OperationType::MAXIMUM, attr,
+          src_tensor_0.shape, output_shape);
+      RETURN_IF_ERROR(env->ExecuteGPUOperation(
+          src_tensor_0, std::make_unique<GPUOperation>(std::move(operation)),
+          output_shape, &dst_tensor));
+      RETURN_IF_ERROR(
+          PointWiseNear({2.0f, 2.0f, -2.0f, -2.0f}, dst_tensor.data, eps));
+    }
+  }
+  return absl::OkStatus();
+}
+
+absl::Status MulLinearBroadcastInputTest(TestExecutionEnvironment* env) {
+  TensorFloat32 src_tensor_0;
+  src_tensor_0.shape = BHWC(1, 2, 1, 1);
+  src_tensor_0.data = {2.0f, -3.0f};
+
+  ::tflite::gpu::Tensor<Linear, DataType::FLOAT32> linear_tensor;
+  linear_tensor.shape = Linear(2);
+  linear_tensor.data = {0.5f, 2.0f};
+  ElementwiseAttributes attr;
+  attr.param = linear_tensor;
+
+  for (auto precision : env->GetSupportedPrecisions()) {
+    auto data_type = DeduceDataTypeFromPrecision(precision);
+    for (auto storage : env->GetSupportedStorages(data_type)) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      BHWC output_shape(1, 2, 1, 2);
+      GPUOperation operation = CreateElementwiseWithBroadcast(
+          env->GetGpuInfo(), op_def, OperationType::MUL, attr,
+          src_tensor_0.shape, output_shape);
+      RETURN_IF_ERROR(env->ExecuteGPUOperation(
+          src_tensor_0, std::make_unique<GPUOperation>(std::move(operation)),
+          output_shape, &dst_tensor));
+      RETURN_IF_ERROR(
+          PointWiseNear({1.0f, 4.0f, -1.5f, -6.0f}, dst_tensor.data, eps));
+    }
+  }
+  return absl::OkStatus();
+}
+
+absl::Status MulBroadcastBothInputsTest(TestExecutionEnvironment* env) {
+  TensorFloat32 src_tensor_0, src_tensor_1;
+  src_tensor_0.shape = BHWC(1, 1, 2, 1);
+  src_tensor_1.shape = BHWC(1, 1, 1, 2);
+  src_tensor_0.data = {1.0f, 2.0f};
+  src_tensor_1.data = {3.0f, 4.0f};
+
+  for (auto precision : env->GetSupportedPrecisions()) {
+    auto data_type = DeduceDataTypeFromPrecision(precision);
+    for (auto storage : env->GetSupportedStorages(data_type)) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      BHWC output_shape(1, 1, 2, 2);
+      GPUOperation operation = CreateElementwiseTwoInputWithBroadcast(
+          op_def, OperationType::MUL, src_tensor_0.shape, src_tensor_1.shape,
+          output_shape);
+      RETURN_IF_ERROR(env->ExecuteGPUOperation(
+          {src_tensor_0, src_tensor_1},
+          std::make_unique<GPUOperation>(std::move(operation)), output_shape,
+          &dst_tensor));
+      RETURN_IF_ERROR(
+          PointWiseNear({3.0f, 4.0f, 6.0f, 8.0f}, dst_tensor.data, eps));
+    }
+  }
+  return absl::OkStatus();
+}
+
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.h b/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.h
index b3db4cc..c7df18d 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.h
+++ b/tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.h
@@ -62,6 +62,10 @@
 absl::Status GreaterEqualTest(TestExecutionEnvironment* env);
 absl::Status EqualTest(TestExecutionEnvironment* env);
 absl::Status NotEqualTest(TestExecutionEnvironment* env);
+absl::Status CosBroadcastTest(TestExecutionEnvironment* env);
+absl::Status MaximumScalarBroadcastInputTest(TestExecutionEnvironment* env);
+absl::Status MulLinearBroadcastInputTest(TestExecutionEnvironment* env);
+absl::Status MulBroadcastBothInputsTest(TestExecutionEnvironment* env);
 
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm
index 29057a1..f449964 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm
@@ -221,4 +221,24 @@
   XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
 }
 
+- (void)testCosBroadcast {
+  auto status = CosBroadcastTest(&exec_env_);
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testMaximumScalarBroadcastInput {
+  auto status = MaximumScalarBroadcastInputTest(&exec_env_);
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testMulLinearBroadcastInput {
+  auto status = MulLinearBroadcastInputTest(&exec_env_);
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testMulBroadcastBothInputs {
+  auto status = MulBroadcastBothInputsTest(&exec_env_);
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
 @end