[tfrt:jit] Detect combiner in linalg.generic reductions during tiling.

PiperOrigin-RevId: 413636681
Change-Id: I0a0a33effede696ef46e7d0494f889136486786c
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD
index a714a70..3a0ddcd 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD
@@ -70,6 +70,7 @@
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
         "//tensorflow/compiler/mlir/xla:xla_legalize_tf",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:ArithmeticDialect",
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:IR",
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_codegen_reduction.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_codegen_reduction.cc
index e53d873..26f8c5b 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_codegen_reduction.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_codegen_reduction.cc
@@ -15,14 +15,17 @@
 
 #include <utility>
 
+#include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_passes.h"
 
@@ -38,7 +41,9 @@
 using mlir::cast;
 using mlir::dyn_cast;
 using mlir::failure;
+using mlir::FailureOr;
 using mlir::Identifier;
+using mlir::Location;
 using mlir::LogicalResult;
 using mlir::MLIRContext;
 using mlir::OpBuilder;
@@ -46,6 +51,8 @@
 using mlir::OpRewritePattern;
 using mlir::PatternRewriter;
 using mlir::RankedTensorType;
+using mlir::ShapedType;
+using mlir::SmallVector;
 using mlir::success;
 using mlir::Value;
 using mlir::ValueRange;
@@ -65,6 +72,16 @@
 using mlir::tensor::ExtractSliceOp;
 using mlir::tensor::InsertSliceOp;
 
+// Detects the combiner in the body of LinalgOp if any. Currently, only
+// ops with a single combiner are supported.
+FailureOr<Operation *> DetectCombiner(LinalgOp linalg_op) {
+  SmallVector<Operation *, 4> combiners;
+  if (!matchReduction(linalg_op.getRegionOutputArgs(), 0, combiners) ||
+      combiners.size() != 1)
+    return failure();
+  return combiners.front();
+}
+
 // Tiles a GenericOp that models a reduction and then fuses its inputs and
 // outputs. Currently, only the FillOp that initializes the output is fused into
 // the TiledLoopOp.
@@ -82,6 +99,7 @@
     if (failed(filter.checkAndNotify(rewriter, linalg_op))) return failure();
 
     if (linalg_op.getNumOutputs() != 1) return failure();
+    if (linalg_op.getNumLoops() != 2) return failure();
 
     auto tiled_op = tileLinalgOp(rewriter, linalg_op, options);
     if (failed(tiled_op)) return failure();
@@ -222,29 +240,36 @@
   //
   //   linalg.yield %insert_output_slice, %update_cloned_output
   // }
-  void CombineReducedTileWithOutput(PatternRewriter &rewriter,
-                                    LinalgOp tiled_op, Value partial_result,
-                                    ExtractSliceOp extract_output_slice,
-                                    InsertSliceOp insert_output_slice) const {
+  LogicalResult CombineReducedTileWithOutput(
+      PatternRewriter &rewriter, LinalgOp tiled_op, Value partial_result,
+      ExtractSliceOp extract_output_slice,
+      InsertSliceOp insert_output_slice) const {
     rewriter.setInsertionPointAfter(tiled_op);
     auto num_parallel_loops = tiled_op.getNumParallelLoops();
-    mlir::SmallVector<mlir::StringRef, 3> parallel_iter_types(
+    SmallVector<mlir::StringRef, 3> parallel_iter_types(
         num_parallel_loops, mlir::getParallelIteratorTypeName());
     auto id_map = rewriter.getMultiDimIdentityMap(num_parallel_loops);
 
+    auto combiner_or = DetectCombiner(tiled_op);
+    if (failed(combiner_or)) return failure();
+    Operation *combiner = combiner_or.getValue();
+
     auto accumulator = rewriter.create<GenericOp>(
         tiled_op.getLoc(), partial_result.getType(),
         makeArrayRef(partial_result),
         makeArrayRef(extract_output_slice.result()),
-        makeArrayRef({id_map, id_map}), parallel_iter_types);
+        makeArrayRef({id_map, id_map}), parallel_iter_types,
+        [&](OpBuilder &b, Location nested_loc, ValueRange args) {
+          BlockAndValueMapping bvm;
+          bvm.map(combiner->getOperands(), args);
+          Value result_val = b.clone(*combiner, bvm)->getResult(0);
+          b.create<YieldOp>(nested_loc, result_val);
+        });
 
-    auto reduce_tile = mlir::cast<GenericOp>(tiled_op);
-    BlockAndValueMapping bvm;
-    rewriter.cloneRegionBefore(reduce_tile.region(), accumulator.region(),
-                               accumulator.region().end(), bvm);
     rewriter.updateRootInPlace(insert_output_slice, [&]() {
       insert_output_slice.sourceMutable().assign(accumulator.getResult(0));
     });
+    return success();
   }
 
   // Unfortunaly, there is no way to modify the results of the loop inplace. So
@@ -297,8 +322,10 @@
         CloneAndAppendInitTensorToTiledLoop(rewriter, fill, tiled_loop);
     FuseFill(rewriter, tiled_op, fill, loop_output_bb_arg, cloned_output_bb_arg,
              extract_output_slice, insert_output_slice);
-    CombineReducedTileWithOutput(rewriter, tiled_op, tiled_op_result,
-                                 extract_output_slice, insert_output_slice);
+    if (mlir::failed(CombineReducedTileWithOutput(
+            rewriter, tiled_op, tiled_op_result, extract_output_slice,
+            insert_output_slice)))
+      return failure();
 
     // Update the results.
     TiledLoopOp updated_loop =
@@ -365,13 +392,19 @@
   LogicalResult matchAndRewrite(GenericOp linalg_op,
                                 PatternRewriter &rewriter) const override {
     if (failed(filter.checkAndNotify(rewriter, linalg_op))) return failure();
+    if (linalg_op.getNumOutputs() != 1) return failure();
+
+    // Check if all inputs have a 1D identity map.
     if (linalg_op.getNumLoops() != 1) return failure();
+    auto indexing_maps = linalg_op.getIndexingMaps();
+    for (auto affine_map : makeArrayRef(indexing_maps).drop_back()) {
+      if (!affine_map.isIdentity()) return failure();
+    }
 
-    // This condition has to be relaxed to support fused inputs.
-    if (linalg_op.getNumInputs() != 1) return failure();
-
-    mlir::Location loc = linalg_op.getLoc();
+    Location loc = linalg_op.getLoc();
     Value input = linalg_op.getInputOperand(0)->get();
+    // All inputs have the same size because of identity maps for indexing.
+    SmallVector<Value> inputs = linalg_op.inputs();
     Value input_size = rewriter.create<mlir::tensor::DimOp>(loc, input, 0);
 
     auto fill_op = linalg_op.outputs().front().getDefiningOp<FillOp>();
@@ -391,42 +424,28 @@
     GenericOp tiled_reduction;
     auto tiled_loop_op = rewriter.create<TiledLoopOp>(
         loc, makeArrayRef(zero), makeArrayRef(input_size),
-        makeArrayRef(vector_size_value), makeArrayRef(input),
-        makeArrayRef(new_fill),
+        makeArrayRef(vector_size_value), inputs, makeArrayRef(new_fill),
         rewriter.getStrArrayAttr(mlir::getReductionIteratorTypeName()),
-        [&](OpBuilder &b, mlir::Location nested_loc, ValueRange ivs,
+        [&](OpBuilder &b, Location nested_loc, ValueRange ivs,
             ValueRange inputs, ValueRange outputs) {
-          auto tile_sizes = mlir::linalg::computeTileSizes(
-              b, nested_loc, ivs, vector_size_value, input_size);
-
-          // Extract slice of input.
-          Value slice = mlir::linalg::makeTiledShape(
-              b, nested_loc, inputs[0], vector_size_value,
-              rewriter.getMultiDimIdentityMap(1), ivs[0], input_size,
-              tile_sizes);
-
-          // Pad input tile.
-          Value pad = PadTensorOp::createPadHighOp(
-              RankedTensorType::get({vector_size}, element_type), slice,
-              neutral_value, false, nested_loc, b);
-
-          // Reshape input tile to tensor<1xVECTOR_SIZExELEM_TYPE>.
-          llvm::SmallVector<mlir::ReassociationIndices> indices = {{0, 1}};
-          Value expand_shape = b.create<TensorExpandShapeOp>(
-              nested_loc, RankedTensorType::get({1, vector_size}, element_type),
-              pad, indices);
-
-          // Create `linalg.generic` to reduce
-          // tensor<1xVECTOR_SIZExELEM_TYPE>->tensor<VECTOR_SIZExELEM_TYPE>.
-          mlir::SmallVector<mlir::StringRef, 2> iter_types{
+          SmallVector<Value, 2> reshaped_tiled_inputs =
+              TileAndReshapeInputTensors(b, nested_loc, ivs, inputs,
+                                         neutral_value, input_size,
+                                         vector_size_value);
+          // Create `linalg.generic` to combine
+          // `tensor<1xVECTOR_SIZExELEM_TYPE>1 input with the
+          // `tensor<VECTOR_SIZExELEM_TYPE>` output.
+          SmallVector<mlir::StringRef, 2> iter_types{
               mlir::getReductionIteratorTypeName(),
               mlir::getParallelIteratorTypeName()};
+          SmallVector<mlir::AffineMap, 2> indexing_maps(
+              inputs.size(), rewriter.getMultiDimIdentityMap(2));
+          indexing_maps.push_back(
+              mlir::AffineMap::get(2, 0, b.getAffineDimExpr(1)));
           tiled_reduction = b.create<GenericOp>(
-              nested_loc, outputs[0].getType(), makeArrayRef({expand_shape}),
-              makeArrayRef({outputs[0]}),
-              makeArrayRef({b.getMultiDimIdentityMap(2),
-                            mlir::AffineMap::get(2, 0, b.getAffineDimExpr(1))}),
-              iter_types, /*bodyBuild=*/nullptr);
+              nested_loc, outputs[0].getType(), reshaped_tiled_inputs,
+              makeArrayRef({outputs[0]}), indexing_maps, iter_types,
+              /*bodyBuild=*/nullptr);
           mlir::Region &region = tiled_reduction.region();
           OpBuilder::InsertionGuard g(rewriter);
           rewriter.cloneRegionBefore(linalg_op.region(), region, region.end());
@@ -434,9 +453,10 @@
         });
     // Create `linalg.generic` to reduce
     // tensor<VECTOR_SIZExELEM_TYPE>->tensor<ELEM_TYPE>.
-    BlockAndValueMapping bvm;
-    bvm.map(input, tiled_loop_op.getResult(0));
-    auto final_reduction = rewriter.clone(*linalg_op.getOperation(), bvm);
+    auto final_reduction_or =
+        ReduceVectorIntoOutput(rewriter, linalg_op, tiled_loop_op.getResult(0));
+    if (failed(final_reduction_or)) return failure();
+    auto final_reduction = final_reduction_or.getValue();
     rewriter.replaceOp(linalg_op, final_reduction->getResults());
 
     tiled_loop_op->walk([&](GenericOp op) {
@@ -446,6 +466,69 @@
     return success();
   }
 
+  // Tiles, pads and reshapes every input argument of type tensor<?xELEM_TYPE>
+  // into tensor<1xVECTOR_SIZExELEM_TYPE>.
+  SmallVector<Value, 2> TileAndReshapeInputTensors(
+      OpBuilder &b, Location nested_loc, ValueRange ivs, ValueRange inputs,
+      Value neutral_value, Value input_size, Value vector_size_value) const {
+    SmallVector<Value, 2> reshaped_tiled_inputs;
+
+    SmallVector<mlir::ReassociationIndices> indices = {{0, 1}};
+    auto identity_1d_map = b.getMultiDimIdentityMap(1);
+    auto iv = ivs.front();
+
+    auto tile_sizes = mlir::linalg::computeTileSizes(
+        b, nested_loc, ivs, vector_size_value, input_size);
+    for (auto input : inputs) {
+      // Extract slice of input.
+      Value slice = mlir::linalg::makeTiledShape(
+          b, nested_loc, input, vector_size_value, identity_1d_map, iv,
+          input_size, tile_sizes);
+      auto element_type = slice.getType().cast<ShapedType>().getElementType();
+
+      // Pad input tile.
+      Value pad = PadTensorOp::createPadHighOp(
+          RankedTensorType::get({vector_size}, element_type), slice,
+          neutral_value, false, nested_loc, b);
+
+      // Reshape input tile to tensor<1xVECTOR_SIZExELEM_TYPE>.
+      Value expand_shape = b.create<TensorExpandShapeOp>(
+          nested_loc, RankedTensorType::get({1, vector_size}, element_type),
+          pad, indices);
+      reshaped_tiled_inputs.push_back(expand_shape);
+    }
+    return reshaped_tiled_inputs;
+  }
+
+  // Creates `linalg.generic` to reduce
+  // tensor<VECTOR_SIZExELEM_TYPE>->tensor<ELEM_TYPE>. To perform that we match
+  // the combiner in the original "untiled" linalg_op.
+  FailureOr<GenericOp> ReduceVectorIntoOutput(PatternRewriter &rewriter,
+                                              LinalgOp linalg_op,
+                                              Value partial_result) const {
+    SmallVector<mlir::StringRef, 3> reduction_iter_type(
+        1, mlir::getReductionIteratorTypeName());
+    auto map = mlir::AffineMap::get(1, 0, llvm::None, rewriter.getContext());
+
+    auto combiner_or = DetectCombiner(linalg_op);
+    if (failed(combiner_or)) return failure();
+    Operation *combiner = combiner_or.getValue();
+
+    auto accumulator = rewriter.create<GenericOp>(
+        linalg_op.getLoc(), linalg_op->getResultTypes(),
+        makeArrayRef(partial_result),
+        makeArrayRef(linalg_op.getOutputOperand(0)->get()),
+        makeArrayRef({rewriter.getMultiDimIdentityMap(1), map}),
+        reduction_iter_type,
+        [&](OpBuilder &b, Location nested_loc, ValueRange args) {
+          BlockAndValueMapping bvm;
+          bvm.map(combiner->getOperands(), args);
+          Value result_val = b.clone(*combiner, bvm)->getResult(0);
+          b.create<YieldOp>(nested_loc, result_val);
+        });
+    return accumulator;
+  }
+
  private:
   LinalgTransformationFilter filter;
   int64_t vector_size;
@@ -456,8 +539,7 @@
   auto reduction = mlir::dyn_cast<GenericOp>(op);
   if (!reduction) return false;
 
-  if (reduction.getNumOutputs() != 1 || reduction.getNumLoops() > 2)
-    return false;
+  if (reduction.getNumLoops() > 2) return false;
   return reduction.getNumReductionLoops() == 1;
 }
 
diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/reduction_codegen.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/reduction_codegen.mlir
index fb5d67f..e256c79 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/jit/reduction_codegen.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/jit/reduction_codegen.mlir
@@ -1,61 +1,69 @@
 // RUN: tf-tfrt-opt -tf-cpurt-codegen-reduction %s --split-input-file |\
 // RUN: FileCheck %s
 
-func @reduce_row_sum_2d(%input: tensor<?x?xf32>) -> tensor<?xf32> {
+func @reduce_row_sum_2d(%lhs: tensor<?x?xf32>,
+                        %rhs: tensor<?x?xf32>) -> tensor<?xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index
-  %0 = tensor.dim %input, %c0 : tensor<?x?xf32>
+  %0 = tensor.dim %lhs, %c0 : tensor<?x?xf32>
 
   %init = linalg.init_tensor [%0] : tensor<?xf32>
   %fill = linalg.fill(%cst, %init) : f32, tensor<?xf32> -> tensor<?xf32>
-  %sum = linalg.generic {
+  %sum_of_prod = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>,
                      affine_map<(d0, d1) -> (d0)>],
     iterator_types = ["parallel", "reduction"]}
-    ins(%input : tensor<?x?xf32>)
+    ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>)
     outs(%fill : tensor<?xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    %add = arith.addf %in, %out : f32
+  ^bb0(%l: f32, %r: f32, %o: f32):
+    %prod = arith.mulf %l, %r : f32
+    %add = arith.addf %prod, %o : f32
     linalg.yield %add : f32
   } -> tensor<?xf32>
-  return %sum : tensor<?xf32>
+  return %sum_of_prod : tensor<?xf32>
 }
 // CHECK-LABEL: func @reduce_row_sum_2d(
-// CHECK-SAME:    %[[INPUT:.*]]: tensor<?x?xf32>) -> tensor<?xf32>
+// CHECK-SAME:    %[[LHS:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<?xf32>
 
 // CHECK-DAG:  %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 
-// CHECK:      %[[DIM_0:.*]] = tensor.dim %[[INPUT]], %[[C0]] : [[TY_2D:.*]]
+// CHECK:      %[[DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D:.*]]
 // CHECK:      %[[INIT:.*]] = linalg.init_tensor [%[[DIM_0]]] : [[TY_1D:.*]]
 // CHECK:      %[[CLONE:.*]] = linalg.init_tensor [%[[DIM_0]]] : [[TY_1D:.*]]
 // CHECK:      %[[FILL:.*]] = linalg.fill(%[[C0_F32]], %[[INIT]])
-// CHECK:      %[[DIM_0_:.*]] = tensor.dim %[[INPUT]], %[[C0]] : [[TY_2D]]
-// CHECK:      %[[DIM_1:.*]] = tensor.dim %[[INPUT]], %[[C1]] : [[TY_2D]]
+// CHECK:      %[[DIM_0_:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D]]
+// CHECK:      %[[DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] : [[TY_2D]]
 
 // CHECK:      linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
 // CHECK-SAME:   to (%[[DIM_0_]], %[[DIM_1]]) step (%[[C4]], %[[C4]])
-// CHECK-SAME:   ins (%[[IN_:.*]] = %[[INPUT]]: [[TY_2D]])
+// CHECK-SAME:   ins (%[[LHS_:.*]] = %[[LHS]]: [[TY_2D]],
+// CHECK-SAME:        %[[RHS_:.*]] = %[[RHS]]: [[TY_2D]])
 // CHECK-SAME:   outs (%[[OUT_:.*]] = %[[FILL]]: [[TY_1D]],
 // CHECK-SAME:         %[[CLONE_:.*]] = %[[CLONE]]: [[TY_1D]])
 
-// CHECK:      %[[IN_SUB:.*]] = tensor.extract_slice %[[IN_]][%[[I]], %[[J]]]
+// CHECK:      %[[LHS_SUB:.*]] = tensor.extract_slice %[[LHS_]][%[[I]], %[[J]]]
+// CHECK:      %[[RHS_SUB:.*]] = tensor.extract_slice %[[RHS_]][%[[I]], %[[J]]]
 // CHECK:      %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]]]
 // CHECK:      %[[CLONE_SUB:.*]] = tensor.extract_slice %[[CLONE_]][%[[I]]]
 
 // CHECK:      %[[FILL_SUB:.*]] = linalg.fill(%[[C0_F32]], %[[CLONE_SUB]])
 
-// CHECK:      %[[SUM_SUB:.*]] = linalg.generic
-// CHECK-SAME:   ins(%[[IN_SUB]] : [[TY_2D]])
+// CHECK:      %[[SUM_OF_PROD_SUB:.*]] = linalg.generic
+// CHECK-SAME:   ins(%[[LHS_SUB]], %[[RHS_SUB]] : [[TY_2D]], [[TY_2D]])
 // CHECK-SAME:   outs(%[[FILL_SUB]] : [[TY_1D]])
+// CHECK:          mulf
 // CHECK:          addf
 // CHECK-NEXT:     linalg.yield
 
 // CHECK:      %[[ACC:.*]] = linalg.generic
-// CHECK-SAME:   ins(%[[SUM_SUB]] : [[TY_1D]])
+// CHECK-SAME:   ins(%[[SUM_OF_PROD_SUB]] : [[TY_1D]])
 // CHECK-SAME:   outs(%[[OUT_SUB]] : [[TY_1D]]) {
+// CHECK-NOT:      mulf
 // CHECK:          addf
 // CHECK-NEXT:     linalg.yield
 
@@ -181,57 +189,70 @@
 
 // -----
 
-func @reduce_sum_1d(%input: tensor<?xf32>) -> tensor<f32> {
+func @reduce_sum_1d(%lhs: tensor<?xf32>, %rhs: tensor<?xf32>) -> tensor<f32> {
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index
-  %0 = tensor.dim %input, %c0 : tensor<?xf32>
+  %0 = tensor.dim %lhs, %c0 : tensor<?xf32>
 
   %init = linalg.init_tensor [] : tensor<f32>
   %fill = linalg.fill(%cst, %init) : f32, tensor<f32> -> tensor<f32>
   %sum = linalg.generic {
     indexing_maps = [affine_map<(d0) -> (d0)>,
+                     affine_map<(d0) -> (d0)>,
                      affine_map<(d0) -> ()>],
     iterator_types = ["reduction"]}
-    ins(%input : tensor<?xf32>)
+    ins(%lhs, %rhs : tensor<?xf32>, tensor<?xf32>)
     outs(%fill : tensor<f32>) {
-  ^bb0(%in: f32, %out: f32):
-    %add = arith.addf %in, %out : f32
+  ^bb0(%l: f32, %r: f32, %out: f32):
+    %prod = arith.mulf %l, %r : f32
+    %add = arith.addf %prod, %out : f32
     linalg.yield %add : f32
   } -> tensor<f32>
   return %sum : tensor<f32>
 }
 
 // CHECK-LABEL: func @reduce_sum_1d(
-// CHECK-SAME:    %[[INPUT:.*]]: tensor<?xf32>) -> tensor<f32> {
+// CHECK-SAME:    %[[LHS:.*]]: tensor<?xf32>, %[[RHS:.*]]: tensor<?xf32>)
      // CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
      // CHECK: %[[C0:.*]] = arith.constant 0 : index
      // CHECK: %[[C8:.*]] = arith.constant 8 : index
 
      // CHECK: %[[INIT:.*]] = linalg.init_tensor [] : tensor<f32>
      // CHECK: %[[FILL:.*]] = linalg.fill(%[[C0_F32]], %[[INIT]])
-     // CHECK: %[[INPUT_SIZE:.*]] = tensor.dim %[[INPUT]], %[[C0]]
+     // CHECK: %[[INPUT_SIZE:.*]] = tensor.dim %[[LHS]], %[[C0]]
 
      // CHECK: %[[TMP_INIT:.*]] = linalg.init_tensor [8] : tensor<8xf32>
      // CHECK: %[[TMP_FILL:.*]] = linalg.fill(%[[C0_F32]], %[[TMP_INIT]])
      // CHECK: %[[TMP_SUM:.*]] = linalg.tiled_loop (%[[I:.*]]) = (%[[C0]])
 // CHECK-SAME:   to (%[[INPUT_SIZE]]) step (%[[C8]])
-// CHECK-SAME:   ins (%[[INPUT_:.*]] = %[[INPUT]]: tensor<?xf32>)
+// CHECK-SAME:   ins (%[[LHS_:.*]] = %[[LHS]]: tensor<?xf32>,
+// CHECK-SAME:        %[[RHS_:.*]] = %[[RHS]]: tensor<?xf32>)
 // CHECK-SAME:   outs (%[[TMP_INIT_:.*]] = %[[TMP_FILL]]: tensor<8xf32>)
 
-     // CHECK: %[[IN_SUB:.*]] = tensor.extract_slice %[[INPUT_]][%[[I]]]
-     // CHECK: %[[PAD:.*]] = linalg.pad_tensor %[[IN_SUB]]
-     // CHECK: %[[RESHAPE:.*]] = linalg.tensor_expand_shape %[[PAD]]
+     // CHECK: %[[LHS_SUB:.*]] = tensor.extract_slice %[[LHS_]][%[[I]]]
+     // CHECK: %[[LHS_PAD:.*]] = linalg.pad_tensor %[[LHS_SUB]]
+     // CHECK: %[[LHS_RESHAPE:.*]] = linalg.tensor_expand_shape %[[LHS_PAD]]
 // CHECK-SAME:   {{\[\[}}0, 1]]
 // CHECK-SAME:   : tensor<8xf32> into tensor<1x8xf32>
 
-     // CHECK: %[[SUM:.*]] = linalg.generic
-// CHECK-SAME:   ins(%[[RESHAPE]] : tensor<1x8xf32>)
+     // CHECK: %[[RHS_SUB:.*]] = tensor.extract_slice %[[RHS_]][%[[I]]]
+     // CHECK: %[[RHS_PAD:.*]] = linalg.pad_tensor %[[RHS_SUB]]
+     // CHECK: %[[RHS_RESHAPE:.*]] = linalg.tensor_expand_shape %[[RHS_PAD]]
+// CHECK-SAME:   {{\[\[}}0, 1]]
+// CHECK-SAME:   : tensor<8xf32> into tensor<1x8xf32>
+
+     // CHECK: %[[SUM_OF_PROD:.*]] = linalg.generic
+// CHECK-SAME:   ins(%[[LHS_RESHAPE]], %[[RHS_RESHAPE]]
+// CHECK-SAME:       tensor<1x8xf32>, tensor<1x8xf32>)
 // CHECK-SAME:   outs(%[[TMP_INIT_]] : tensor<8xf32>) {
-     // CHECK:   ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32):
-     // CHECK:     %[[ADD:.*]] = arith.addf %[[A]], %[[B]] : f32
+     // CHECK:   ^bb0(%[[L:.*]]: f32, %[[R:.*]]: f32, %[[O:.*]]: f32):
+     // CHECK:     %[[MUL:.*]] = arith.mulf %[[L]], %[[R]] : f32
+     // CHECK:     %[[ADD:.*]] = arith.addf %[[MUL]], %[[O]] : f32
      // CHECK:       linalg.yield %[[ADD]] : f32
      // CHECK:     } -> tensor<8xf32>
-     // CHECK:   linalg.yield %[[SUM]] : tensor<8xf32>
+     // CHECK:   linalg.yield %[[SUM_OF_PROD]] : tensor<8xf32>
      // CHECK: }
      // CHECK: linalg.generic
 // CHECK-SAME: ins(%[[TMP_SUM]] : tensor<8xf32>) outs(%[[FILL]] : tensor<f32>)
+//  CHECK-NOT:  mulf
+//      CHECK:  addf