[tfrt:jit] Detect combiner in linalg.generic reductions during tiling.
PiperOrigin-RevId: 413636681
Change-Id: I0a0a33effede696ef46e7d0494f889136486786c
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD
index a714a70..3a0ddcd 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD
@@ -70,6 +70,7 @@
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_codegen_reduction.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_codegen_reduction.cc
index e53d873..26f8c5b 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_codegen_reduction.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_codegen_reduction.cc
@@ -15,14 +15,17 @@
#include <utility>
+#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/ArrayRef.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_passes.h"
@@ -38,7 +41,9 @@
using mlir::cast;
using mlir::dyn_cast;
using mlir::failure;
+using mlir::FailureOr;
using mlir::Identifier;
+using mlir::Location;
using mlir::LogicalResult;
using mlir::MLIRContext;
using mlir::OpBuilder;
@@ -46,6 +51,8 @@
using mlir::OpRewritePattern;
using mlir::PatternRewriter;
using mlir::RankedTensorType;
+using mlir::ShapedType;
+using mlir::SmallVector;
using mlir::success;
using mlir::Value;
using mlir::ValueRange;
@@ -65,6 +72,16 @@
using mlir::tensor::ExtractSliceOp;
using mlir::tensor::InsertSliceOp;
+// Detects the combiner in the body of LinalgOp if any. Currently, only
+// ops with a single combiner are supported.
+FailureOr<Operation *> DetectCombiner(LinalgOp linalg_op) {
+ SmallVector<Operation *, 4> combiners;
+ if (!matchReduction(linalg_op.getRegionOutputArgs(), 0, combiners) ||
+ combiners.size() != 1)
+ return failure();
+ return combiners.front();
+}
+
// Tiles a GenericOp that models a reduction and then fuses its inputs and
// outputs. Currently, only the FillOp that initializes the output is fused into
// the TiledLoopOp.
@@ -82,6 +99,7 @@
if (failed(filter.checkAndNotify(rewriter, linalg_op))) return failure();
if (linalg_op.getNumOutputs() != 1) return failure();
+ if (linalg_op.getNumLoops() != 2) return failure();
auto tiled_op = tileLinalgOp(rewriter, linalg_op, options);
if (failed(tiled_op)) return failure();
@@ -222,29 +240,36 @@
//
// linalg.yield %insert_output_slice, %update_cloned_output
// }
- void CombineReducedTileWithOutput(PatternRewriter &rewriter,
- LinalgOp tiled_op, Value partial_result,
- ExtractSliceOp extract_output_slice,
- InsertSliceOp insert_output_slice) const {
+ LogicalResult CombineReducedTileWithOutput(
+ PatternRewriter &rewriter, LinalgOp tiled_op, Value partial_result,
+ ExtractSliceOp extract_output_slice,
+ InsertSliceOp insert_output_slice) const {
rewriter.setInsertionPointAfter(tiled_op);
auto num_parallel_loops = tiled_op.getNumParallelLoops();
- mlir::SmallVector<mlir::StringRef, 3> parallel_iter_types(
+ SmallVector<mlir::StringRef, 3> parallel_iter_types(
num_parallel_loops, mlir::getParallelIteratorTypeName());
auto id_map = rewriter.getMultiDimIdentityMap(num_parallel_loops);
+ auto combiner_or = DetectCombiner(tiled_op);
+ if (failed(combiner_or)) return failure();
+ Operation *combiner = combiner_or.getValue();
+
auto accumulator = rewriter.create<GenericOp>(
tiled_op.getLoc(), partial_result.getType(),
makeArrayRef(partial_result),
makeArrayRef(extract_output_slice.result()),
- makeArrayRef({id_map, id_map}), parallel_iter_types);
+ makeArrayRef({id_map, id_map}), parallel_iter_types,
+ [&](OpBuilder &b, Location nested_loc, ValueRange args) {
+ BlockAndValueMapping bvm;
+ bvm.map(combiner->getOperands(), args);
+ Value result_val = b.clone(*combiner, bvm)->getResult(0);
+ b.create<YieldOp>(nested_loc, result_val);
+ });
- auto reduce_tile = mlir::cast<GenericOp>(tiled_op);
- BlockAndValueMapping bvm;
- rewriter.cloneRegionBefore(reduce_tile.region(), accumulator.region(),
- accumulator.region().end(), bvm);
rewriter.updateRootInPlace(insert_output_slice, [&]() {
insert_output_slice.sourceMutable().assign(accumulator.getResult(0));
});
+ return success();
}
// Unfortunaly, there is no way to modify the results of the loop inplace. So
@@ -297,8 +322,10 @@
CloneAndAppendInitTensorToTiledLoop(rewriter, fill, tiled_loop);
FuseFill(rewriter, tiled_op, fill, loop_output_bb_arg, cloned_output_bb_arg,
extract_output_slice, insert_output_slice);
- CombineReducedTileWithOutput(rewriter, tiled_op, tiled_op_result,
- extract_output_slice, insert_output_slice);
+ if (mlir::failed(CombineReducedTileWithOutput(
+ rewriter, tiled_op, tiled_op_result, extract_output_slice,
+ insert_output_slice)))
+ return failure();
// Update the results.
TiledLoopOp updated_loop =
@@ -365,13 +392,19 @@
LogicalResult matchAndRewrite(GenericOp linalg_op,
PatternRewriter &rewriter) const override {
if (failed(filter.checkAndNotify(rewriter, linalg_op))) return failure();
+ if (linalg_op.getNumOutputs() != 1) return failure();
+
+ // Check if all inputs have a 1D identity map.
if (linalg_op.getNumLoops() != 1) return failure();
+ auto indexing_maps = linalg_op.getIndexingMaps();
+ for (auto affine_map : makeArrayRef(indexing_maps).drop_back()) {
+ if (!affine_map.isIdentity()) return failure();
+ }
- // This condition has to be relaxed to support fused inputs.
- if (linalg_op.getNumInputs() != 1) return failure();
-
- mlir::Location loc = linalg_op.getLoc();
+ Location loc = linalg_op.getLoc();
Value input = linalg_op.getInputOperand(0)->get();
+ // All inputs have the same size because of identity maps for indexing.
+ SmallVector<Value> inputs = linalg_op.inputs();
Value input_size = rewriter.create<mlir::tensor::DimOp>(loc, input, 0);
auto fill_op = linalg_op.outputs().front().getDefiningOp<FillOp>();
@@ -391,42 +424,28 @@
GenericOp tiled_reduction;
auto tiled_loop_op = rewriter.create<TiledLoopOp>(
loc, makeArrayRef(zero), makeArrayRef(input_size),
- makeArrayRef(vector_size_value), makeArrayRef(input),
- makeArrayRef(new_fill),
+ makeArrayRef(vector_size_value), inputs, makeArrayRef(new_fill),
rewriter.getStrArrayAttr(mlir::getReductionIteratorTypeName()),
- [&](OpBuilder &b, mlir::Location nested_loc, ValueRange ivs,
+ [&](OpBuilder &b, Location nested_loc, ValueRange ivs,
ValueRange inputs, ValueRange outputs) {
- auto tile_sizes = mlir::linalg::computeTileSizes(
- b, nested_loc, ivs, vector_size_value, input_size);
-
- // Extract slice of input.
- Value slice = mlir::linalg::makeTiledShape(
- b, nested_loc, inputs[0], vector_size_value,
- rewriter.getMultiDimIdentityMap(1), ivs[0], input_size,
- tile_sizes);
-
- // Pad input tile.
- Value pad = PadTensorOp::createPadHighOp(
- RankedTensorType::get({vector_size}, element_type), slice,
- neutral_value, false, nested_loc, b);
-
- // Reshape input tile to tensor<1xVECTOR_SIZExELEM_TYPE>.
- llvm::SmallVector<mlir::ReassociationIndices> indices = {{0, 1}};
- Value expand_shape = b.create<TensorExpandShapeOp>(
- nested_loc, RankedTensorType::get({1, vector_size}, element_type),
- pad, indices);
-
- // Create `linalg.generic` to reduce
- // tensor<1xVECTOR_SIZExELEM_TYPE>->tensor<VECTOR_SIZExELEM_TYPE>.
- mlir::SmallVector<mlir::StringRef, 2> iter_types{
+ SmallVector<Value, 2> reshaped_tiled_inputs =
+ TileAndReshapeInputTensors(b, nested_loc, ivs, inputs,
+ neutral_value, input_size,
+ vector_size_value);
+ // Create `linalg.generic` to combine
+ // `tensor<1xVECTOR_SIZExELEM_TYPE>1 input with the
+ // `tensor<VECTOR_SIZExELEM_TYPE>` output.
+ SmallVector<mlir::StringRef, 2> iter_types{
mlir::getReductionIteratorTypeName(),
mlir::getParallelIteratorTypeName()};
+ SmallVector<mlir::AffineMap, 2> indexing_maps(
+ inputs.size(), rewriter.getMultiDimIdentityMap(2));
+ indexing_maps.push_back(
+ mlir::AffineMap::get(2, 0, b.getAffineDimExpr(1)));
tiled_reduction = b.create<GenericOp>(
- nested_loc, outputs[0].getType(), makeArrayRef({expand_shape}),
- makeArrayRef({outputs[0]}),
- makeArrayRef({b.getMultiDimIdentityMap(2),
- mlir::AffineMap::get(2, 0, b.getAffineDimExpr(1))}),
- iter_types, /*bodyBuild=*/nullptr);
+ nested_loc, outputs[0].getType(), reshaped_tiled_inputs,
+ makeArrayRef({outputs[0]}), indexing_maps, iter_types,
+ /*bodyBuild=*/nullptr);
mlir::Region ®ion = tiled_reduction.region();
OpBuilder::InsertionGuard g(rewriter);
rewriter.cloneRegionBefore(linalg_op.region(), region, region.end());
@@ -434,9 +453,10 @@
});
// Create `linalg.generic` to reduce
// tensor<VECTOR_SIZExELEM_TYPE>->tensor<ELEM_TYPE>.
- BlockAndValueMapping bvm;
- bvm.map(input, tiled_loop_op.getResult(0));
- auto final_reduction = rewriter.clone(*linalg_op.getOperation(), bvm);
+ auto final_reduction_or =
+ ReduceVectorIntoOutput(rewriter, linalg_op, tiled_loop_op.getResult(0));
+ if (failed(final_reduction_or)) return failure();
+ auto final_reduction = final_reduction_or.getValue();
rewriter.replaceOp(linalg_op, final_reduction->getResults());
tiled_loop_op->walk([&](GenericOp op) {
@@ -446,6 +466,69 @@
return success();
}
+ // Tiles, pads and reshapes every input argument of type tensor<?xELEM_TYPE>
+ // into tensor<1xVECTOR_SIZExELEM_TYPE>.
+ SmallVector<Value, 2> TileAndReshapeInputTensors(
+ OpBuilder &b, Location nested_loc, ValueRange ivs, ValueRange inputs,
+ Value neutral_value, Value input_size, Value vector_size_value) const {
+ SmallVector<Value, 2> reshaped_tiled_inputs;
+
+ SmallVector<mlir::ReassociationIndices> indices = {{0, 1}};
+ auto identity_1d_map = b.getMultiDimIdentityMap(1);
+ auto iv = ivs.front();
+
+ auto tile_sizes = mlir::linalg::computeTileSizes(
+ b, nested_loc, ivs, vector_size_value, input_size);
+ for (auto input : inputs) {
+ // Extract slice of input.
+ Value slice = mlir::linalg::makeTiledShape(
+ b, nested_loc, input, vector_size_value, identity_1d_map, iv,
+ input_size, tile_sizes);
+ auto element_type = slice.getType().cast<ShapedType>().getElementType();
+
+ // Pad input tile.
+ Value pad = PadTensorOp::createPadHighOp(
+ RankedTensorType::get({vector_size}, element_type), slice,
+ neutral_value, false, nested_loc, b);
+
+ // Reshape input tile to tensor<1xVECTOR_SIZExELEM_TYPE>.
+ Value expand_shape = b.create<TensorExpandShapeOp>(
+ nested_loc, RankedTensorType::get({1, vector_size}, element_type),
+ pad, indices);
+ reshaped_tiled_inputs.push_back(expand_shape);
+ }
+ return reshaped_tiled_inputs;
+ }
+
+ // Creates `linalg.generic` to reduce
+ // tensor<VECTOR_SIZExELEM_TYPE>->tensor<ELEM_TYPE>. To perform that we match
+ // the combiner in the original "untiled" linalg_op.
+ FailureOr<GenericOp> ReduceVectorIntoOutput(PatternRewriter &rewriter,
+ LinalgOp linalg_op,
+ Value partial_result) const {
+ SmallVector<mlir::StringRef, 3> reduction_iter_type(
+ 1, mlir::getReductionIteratorTypeName());
+ auto map = mlir::AffineMap::get(1, 0, llvm::None, rewriter.getContext());
+
+ auto combiner_or = DetectCombiner(linalg_op);
+ if (failed(combiner_or)) return failure();
+ Operation *combiner = combiner_or.getValue();
+
+ auto accumulator = rewriter.create<GenericOp>(
+ linalg_op.getLoc(), linalg_op->getResultTypes(),
+ makeArrayRef(partial_result),
+ makeArrayRef(linalg_op.getOutputOperand(0)->get()),
+ makeArrayRef({rewriter.getMultiDimIdentityMap(1), map}),
+ reduction_iter_type,
+ [&](OpBuilder &b, Location nested_loc, ValueRange args) {
+ BlockAndValueMapping bvm;
+ bvm.map(combiner->getOperands(), args);
+ Value result_val = b.clone(*combiner, bvm)->getResult(0);
+ b.create<YieldOp>(nested_loc, result_val);
+ });
+ return accumulator;
+ }
+
private:
LinalgTransformationFilter filter;
int64_t vector_size;
@@ -456,8 +539,7 @@
auto reduction = mlir::dyn_cast<GenericOp>(op);
if (!reduction) return false;
- if (reduction.getNumOutputs() != 1 || reduction.getNumLoops() > 2)
- return false;
+ if (reduction.getNumLoops() > 2) return false;
return reduction.getNumReductionLoops() == 1;
}
diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/reduction_codegen.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/reduction_codegen.mlir
index fb5d67f..e256c79 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/jit/reduction_codegen.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/jit/reduction_codegen.mlir
@@ -1,61 +1,69 @@
// RUN: tf-tfrt-opt -tf-cpurt-codegen-reduction %s --split-input-file |\
// RUN: FileCheck %s
-func @reduce_row_sum_2d(%input: tensor<?x?xf32>) -> tensor<?xf32> {
+func @reduce_row_sum_2d(%lhs: tensor<?x?xf32>,
+ %rhs: tensor<?x?xf32>) -> tensor<?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
- %0 = tensor.dim %input, %c0 : tensor<?x?xf32>
+ %0 = tensor.dim %lhs, %c0 : tensor<?x?xf32>
%init = linalg.init_tensor [%0] : tensor<?xf32>
%fill = linalg.fill(%cst, %init) : f32, tensor<?xf32> -> tensor<?xf32>
- %sum = linalg.generic {
+ %sum_of_prod = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
- ins(%input : tensor<?x?xf32>)
+ ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?xf32>) {
- ^bb0(%in: f32, %out: f32):
- %add = arith.addf %in, %out : f32
+ ^bb0(%l: f32, %r: f32, %o: f32):
+ %prod = arith.mulf %l, %r : f32
+ %add = arith.addf %prod, %o : f32
linalg.yield %add : f32
} -> tensor<?xf32>
- return %sum : tensor<?xf32>
+ return %sum_of_prod : tensor<?xf32>
}
// CHECK-LABEL: func @reduce_row_sum_2d(
-// CHECK-SAME: %[[INPUT:.*]]: tensor<?x?xf32>) -> tensor<?xf32>
+// CHECK-SAME: %[[LHS:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<?xf32>
// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[DIM_0:.*]] = tensor.dim %[[INPUT]], %[[C0]] : [[TY_2D:.*]]
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D:.*]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[DIM_0]]] : [[TY_1D:.*]]
// CHECK: %[[CLONE:.*]] = linalg.init_tensor [%[[DIM_0]]] : [[TY_1D:.*]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[C0_F32]], %[[INIT]])
-// CHECK: %[[DIM_0_:.*]] = tensor.dim %[[INPUT]], %[[C0]] : [[TY_2D]]
-// CHECK: %[[DIM_1:.*]] = tensor.dim %[[INPUT]], %[[C1]] : [[TY_2D]]
+// CHECK: %[[DIM_0_:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D]]
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] : [[TY_2D]]
// CHECK: linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
// CHECK-SAME: to (%[[DIM_0_]], %[[DIM_1]]) step (%[[C4]], %[[C4]])
-// CHECK-SAME: ins (%[[IN_:.*]] = %[[INPUT]]: [[TY_2D]])
+// CHECK-SAME: ins (%[[LHS_:.*]] = %[[LHS]]: [[TY_2D]],
+// CHECK-SAME: %[[RHS_:.*]] = %[[RHS]]: [[TY_2D]])
// CHECK-SAME: outs (%[[OUT_:.*]] = %[[FILL]]: [[TY_1D]],
// CHECK-SAME: %[[CLONE_:.*]] = %[[CLONE]]: [[TY_1D]])
-// CHECK: %[[IN_SUB:.*]] = tensor.extract_slice %[[IN_]][%[[I]], %[[J]]]
+// CHECK: %[[LHS_SUB:.*]] = tensor.extract_slice %[[LHS_]][%[[I]], %[[J]]]
+// CHECK: %[[RHS_SUB:.*]] = tensor.extract_slice %[[RHS_]][%[[I]], %[[J]]]
// CHECK: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]]]
// CHECK: %[[CLONE_SUB:.*]] = tensor.extract_slice %[[CLONE_]][%[[I]]]
// CHECK: %[[FILL_SUB:.*]] = linalg.fill(%[[C0_F32]], %[[CLONE_SUB]])
-// CHECK: %[[SUM_SUB:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[IN_SUB]] : [[TY_2D]])
+// CHECK: %[[SUM_OF_PROD_SUB:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[LHS_SUB]], %[[RHS_SUB]] : [[TY_2D]], [[TY_2D]])
// CHECK-SAME: outs(%[[FILL_SUB]] : [[TY_1D]])
+// CHECK: mulf
// CHECK: addf
// CHECK-NEXT: linalg.yield
// CHECK: %[[ACC:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[SUM_SUB]] : [[TY_1D]])
+// CHECK-SAME: ins(%[[SUM_OF_PROD_SUB]] : [[TY_1D]])
// CHECK-SAME: outs(%[[OUT_SUB]] : [[TY_1D]]) {
+// CHECK-NOT: mulf
// CHECK: addf
// CHECK-NEXT: linalg.yield
@@ -181,57 +189,70 @@
// -----
-func @reduce_sum_1d(%input: tensor<?xf32>) -> tensor<f32> {
+func @reduce_sum_1d(%lhs: tensor<?xf32>, %rhs: tensor<?xf32>) -> tensor<f32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
- %0 = tensor.dim %input, %c0 : tensor<?xf32>
+ %0 = tensor.dim %lhs, %c0 : tensor<?xf32>
%init = linalg.init_tensor [] : tensor<f32>
%fill = linalg.fill(%cst, %init) : f32, tensor<f32> -> tensor<f32>
%sum = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
affine_map<(d0) -> ()>],
iterator_types = ["reduction"]}
- ins(%input : tensor<?xf32>)
+ ins(%lhs, %rhs : tensor<?xf32>, tensor<?xf32>)
outs(%fill : tensor<f32>) {
- ^bb0(%in: f32, %out: f32):
- %add = arith.addf %in, %out : f32
+ ^bb0(%l: f32, %r: f32, %out: f32):
+ %prod = arith.mulf %l, %r : f32
+ %add = arith.addf %prod, %out : f32
linalg.yield %add : f32
} -> tensor<f32>
return %sum : tensor<f32>
}
// CHECK-LABEL: func @reduce_sum_1d(
-// CHECK-SAME: %[[INPUT:.*]]: tensor<?xf32>) -> tensor<f32> {
+// CHECK-SAME: %[[LHS:.*]]: tensor<?xf32>, %[[RHS:.*]]: tensor<?xf32>)
// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[INIT:.*]] = linalg.init_tensor [] : tensor<f32>
// CHECK: %[[FILL:.*]] = linalg.fill(%[[C0_F32]], %[[INIT]])
- // CHECK: %[[INPUT_SIZE:.*]] = tensor.dim %[[INPUT]], %[[C0]]
+ // CHECK: %[[INPUT_SIZE:.*]] = tensor.dim %[[LHS]], %[[C0]]
// CHECK: %[[TMP_INIT:.*]] = linalg.init_tensor [8] : tensor<8xf32>
// CHECK: %[[TMP_FILL:.*]] = linalg.fill(%[[C0_F32]], %[[TMP_INIT]])
// CHECK: %[[TMP_SUM:.*]] = linalg.tiled_loop (%[[I:.*]]) = (%[[C0]])
// CHECK-SAME: to (%[[INPUT_SIZE]]) step (%[[C8]])
-// CHECK-SAME: ins (%[[INPUT_:.*]] = %[[INPUT]]: tensor<?xf32>)
+// CHECK-SAME: ins (%[[LHS_:.*]] = %[[LHS]]: tensor<?xf32>,
+// CHECK-SAME: %[[RHS_:.*]] = %[[RHS]]: tensor<?xf32>)
// CHECK-SAME: outs (%[[TMP_INIT_:.*]] = %[[TMP_FILL]]: tensor<8xf32>)
- // CHECK: %[[IN_SUB:.*]] = tensor.extract_slice %[[INPUT_]][%[[I]]]
- // CHECK: %[[PAD:.*]] = linalg.pad_tensor %[[IN_SUB]]
- // CHECK: %[[RESHAPE:.*]] = linalg.tensor_expand_shape %[[PAD]]
+ // CHECK: %[[LHS_SUB:.*]] = tensor.extract_slice %[[LHS_]][%[[I]]]
+ // CHECK: %[[LHS_PAD:.*]] = linalg.pad_tensor %[[LHS_SUB]]
+ // CHECK: %[[LHS_RESHAPE:.*]] = linalg.tensor_expand_shape %[[LHS_PAD]]
// CHECK-SAME: {{\[\[}}0, 1]]
// CHECK-SAME: : tensor<8xf32> into tensor<1x8xf32>
- // CHECK: %[[SUM:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[RESHAPE]] : tensor<1x8xf32>)
+ // CHECK: %[[RHS_SUB:.*]] = tensor.extract_slice %[[RHS_]][%[[I]]]
+ // CHECK: %[[RHS_PAD:.*]] = linalg.pad_tensor %[[RHS_SUB]]
+ // CHECK: %[[RHS_RESHAPE:.*]] = linalg.tensor_expand_shape %[[RHS_PAD]]
+// CHECK-SAME: {{\[\[}}0, 1]]
+// CHECK-SAME: : tensor<8xf32> into tensor<1x8xf32>
+
+ // CHECK: %[[SUM_OF_PROD:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[LHS_RESHAPE]], %[[RHS_RESHAPE]]
+// CHECK-SAME: tensor<1x8xf32>, tensor<1x8xf32>)
// CHECK-SAME: outs(%[[TMP_INIT_]] : tensor<8xf32>) {
- // CHECK: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32):
- // CHECK: %[[ADD:.*]] = arith.addf %[[A]], %[[B]] : f32
+ // CHECK: ^bb0(%[[L:.*]]: f32, %[[R:.*]]: f32, %[[O:.*]]: f32):
+ // CHECK: %[[MUL:.*]] = arith.mulf %[[L]], %[[R]] : f32
+ // CHECK: %[[ADD:.*]] = arith.addf %[[MUL]], %[[O]] : f32
// CHECK: linalg.yield %[[ADD]] : f32
// CHECK: } -> tensor<8xf32>
- // CHECK: linalg.yield %[[SUM]] : tensor<8xf32>
+ // CHECK: linalg.yield %[[SUM_OF_PROD]] : tensor<8xf32>
// CHECK: }
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[TMP_SUM]] : tensor<8xf32>) outs(%[[FILL]] : tensor<f32>)
+// CHECK-NOT: mulf
+// CHECK: addf