blob: 7533511097cbfc5e56d93ff5569b5e772691a87b [file] [log] [blame]
// RUN: mlir-hlo-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]}
func.func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
%temp_result = memref.alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
outs(%temp_result : memref<6x6xf32>) {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = arith.addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
}
linalg.generic #pointwise_2d_trait
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
outs(%result : memref<6x6xf32>) {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = arith.mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
}
memref.dealloc %temp_result : memref<6x6xf32>
func.return
}
// CHECK-LABEL: func @fusion
// CHECK: %[[C1:.*]] = arith.constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: addf
// CHECK: linalg.generic
// CHECK: mulf
// TILED-LABEL: func @fusion
// TILED-DAG: %[[C2:.*]] = arith.constant 2
// TILED-DAG: %[[C3:.*]] = arith.constant 3
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED: scf.for {{.*}} step %[[C3]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: addf
// TILED: linalg.generic
// TILED: mulf
// PLOOP-LABEL: func @fusion
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: addf
// PLOOP: linalg.generic
// PLOOP: mulf
// -----
func.func @fusion_of_three(%arg0: memref<100x10xf32>,
%arg1: memref<100xf32>,
%arg2: memref<100x10xf32>) {
%0 = memref.alloc() : memref<100x10xf32>
linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg1 : memref<100xf32>)
outs(%0 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32):
linalg.yield %arg3 : f32
}
%1 = memref.alloc() : memref<100x10xf32>
linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>)
outs(%1 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%2 = arith.subf %arg3, %arg4 : f32
linalg.yield %2 : f32
}
memref.dealloc %0 : memref<100x10xf32>
linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%1 : memref<100x10xf32>)
outs(%arg2 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32):
%2 = math.exp %arg3 : f32
linalg.yield %2 : f32
}
memref.dealloc %1 : memref<100x10xf32>
func.return
}
// CHECK-LABEL: func @fusion
// CHECK: %[[C1:.*]] = arith.constant 1 :
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: linalg.generic
// CHECK: subf
// CHECK: linalg.generic
// CHECK: exp
// TILED-LABEL: func @fusion_of_three
// TILED-DAG: %[[C2:.*]] = arith.constant 2
// TILED-DAG: %[[C3:.*]] = arith.constant 3
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED: scf.for {{.*}} step %[[C3]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: linalg.generic
// TILED: subf
// TILED: linalg.generic
// TILED: exp
// PLOOP-LABEL: func @fusion_of_three
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: linalg.generic
// PLOOP: subf
// PLOOP: linalg.generic
// PLOOP: exp
// -----
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel", "parallel",
"parallel"]}
func.func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
%temp_result = memref.alloc() : memref<6x6x6x6xf32>
linalg.generic #pointwise_4d_trait
ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
outs(%temp_result : memref<6x6x6x6xf32>) {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = arith.addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
}
linalg.generic #pointwise_4d_trait
ins(%temp_result, %multiplier : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
outs(%result : memref<6x6x6x6xf32>) {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = arith.mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
}
memref.dealloc %temp_result : memref<6x6x6x6xf32>
func.return
}
// CHECK-LABEL: func @fusion_4d
// CHECK: %[[C1:.*]] = arith.constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: addf
// CHECK: linalg.generic
// CHECK: mulf
// TILED-LABEL: func @fusion_4d
// TILED-DAG: %[[C2:.*]] = arith.constant 2
// TILED-DAG: %[[C3:.*]] = arith.constant 3
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED: scf.for {{.*}} step %[[C3]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: addf
// TILED: linalg.generic
// TILED: mulf
// PLOOP-LABEL: func @fusion_4d
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: addf
// PLOOP: linalg.generic
// PLOOP: mulf
// -----
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]}
func.func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
%temp_result = memref.alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
outs(%temp_result : memref<6x6xf32>) {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = arith.addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
}
%result = memref.alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
outs(%result : memref<6x6xf32>) {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = arith.mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
}
memref.dealloc %temp_result : memref<6x6xf32>
func.return %result : memref<6x6xf32>
}
// CHECK-LABEL: func @fusion
// CHECK: %[[C1:.*]] = arith.constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: addf
// CHECK: linalg.generic
// CHECK: mulf
// TILED-LABEL: func @fusion
// TILED-DAG: %[[C2:.*]] = arith.constant 2
// TILED-DAG: %[[C3:.*]] = arith.constant 3
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED: scf.for {{.*}} step %[[C3]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: addf
// TILED: linalg.generic
// TILED: mulf
// PLOOP-LABEL: func @fusion
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: addf
// PLOOP: linalg.generic
// PLOOP: mulf
// -----
func.func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
-> memref<*xf32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%1 = memref.alloc(%arg2) : memref<?xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
^bb0(%arg3: f32, %arg4: f32):
%13 = math.abs %arg3 : f32
linalg.yield %13 : f32
}
%2 = memref.reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
func.return %2 : memref<*xf32>
}
// CHECK-LABEL: func @view_result
// CHECK: %[[C1:.*]] = arith.constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: math.abs
// CHECK: memref.reshape
// TILED-LABEL: func @view_result
// TILED-DAG: %[[C2:.*]] = arith.constant 2
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: math.abs
// TILED: memref.reshape
// PLOOP-LABEL: func @view_result
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: math.abs
// PLOOP: memref.reshape
// -----
// Confirm that tiling information is passed through RegionBranchOpInterfaces.
// This test also uses memref.reshape, just to have a value to return through
// the if statement.
func.func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
-> memref<*xf32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%1 = memref.alloc(%arg2) : memref<?xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
^bb0(%arg3: f32, %arg4: f32):
%13 = math.abs %arg3 : f32
linalg.yield %13 : f32
}
%true = arith.constant 1 : i1
%3 = scf.if %true -> memref<*xf32> {
%2 = memref.reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
scf.yield %2 : memref<*xf32>
} else {
%2 = memref.reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
scf.yield %2 : memref<*xf32>
}
func.return %3 : memref<*xf32>
}
// CHECK-LABEL: func @branching_result
// CHECK: %[[C1:.*]] = arith.constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: math.abs
// CHECK: scf.if
// CHECK: memref.reshape
// CHECK: scf.yield
// CHECK: else
// CHECK: memref.reshape
// CHECK: scf.yield
// TILED-LABEL: func @branching_result
// TILED-DAG: %[[C2:.*]] = arith.constant 2
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: math.abs
// TILED: scf.if
// TILED: memref.reshape
// TILED: scf.yield
// TILED: else
// TILED: memref.reshape
// TILED: scf.yield
// PLOOP-LABEL: func @branching_result
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: math.abs
// PLOOP: scf.if
// PLOOP: memref.reshape
// PLOOP: scf.yield
// PLOOP: else
// PLOOP: memref.reshape
// PLOOP: scf.yield
// -----
// Confirm that tiling information is passed through tensor_load, tensor.cast
// and memref_to_tensor operations.
func.func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
-> memref<?xf32> {
%c1 = arith.constant 1 : index
%1 = memref.alloc() : memref<32xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0 : memref<32xf32>) outs(%1 : memref<32xf32>) {
^bb0(%arg3: f32, %arg4: f32):
%13 = math.abs %arg3 : f32
linalg.yield %13 : f32
}
%2 = bufferization.to_tensor %1 : memref<32xf32>
%3 = tensor.cast %2 : tensor<32xf32> to tensor<?xf32>
%4 = bufferization.to_memref %3 : memref<?xf32>
func.return %4 : memref<?xf32>
}
// CHECK-LABEL: func @tensor_ops
// CHECK: %[[C1:.*]] = arith.constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: math.abs
// CHECK: bufferization.to_tensor
// CHECK: tensor.cast
// CHECK: bufferization.to_memref
// TILED-LABEL: func @tensor_ops
// TILED-DAG: %[[C2:.*]] = arith.constant 2
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: math.abs
// TILED: bufferization.to_tensor
// TILED: tensor.cast
// TILED: bufferization.to_memref
// PLOOP-LABEL: func @tensor_ops
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: math.abs
// PLOOP: bufferization.to_tensor
// PLOOP: tensor.cast
// PLOOP: bufferization.to_memref