blob: c271fbe69dd8df82f3b8a10fad384a97969baeb6 [file] [log] [blame]
// RUN: mlir-hlo-opt %s --split-input-file --gml-compose-set-ops | \
// RUN: FileCheck %s
// CHECK-LABEL: @tile_of_tile
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>, %[[I:.*]]: index, %[[J:.*]]: index, %[[K:.*]]: index, %[[M:.*]]: index, %[[N:.*]]: index, %[[A:.*]]: index, %[[B:.*]]: index
func.func @tile_of_tile(%arg : tensor<?x?xf32>, %i : index, %j : index,
%k : index, %m : index, %n : index, %a : index, %b : index)
-> tensor<4x?xf32> {
// CHECK-DAG: %[[SPACE:.*]] = gml_st.space [1024, %[[M]]] : !gml_st.tile<1024x?>
// CHECK-DAG: %[[AK:.*]] = arith.muli %[[A]], %[[K]]
// CHECK-DAG: %[[J_PLUS_AK:.*]] = arith.addi %[[J]], %[[AK]]
// CHECK-DAG: %[[AB:.*]] = arith.muli %[[A]], %[[B]]
// CHECK-DAG: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [%[[I]], %[[J_PLUS_AK]]] [4, %[[N]]] [2, %[[AB]]] : !gml_st.tile<1024x?> to !gml_st.tile<4x?>
// CHECK-DAG: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]] : tensor<?x?xf32>[!gml_st.tile<4x?>]
// CHECK: return %[[RES]]
%space = gml_st.space [1024, %m] : !gml_st.tile<1024x?>
%tile = gml_st.tile %space [%i, %j] [4, 128] [2, %a]
: !gml_st.tile<1024x?> to !gml_st.tile<4x128>
%tile_of_tile = gml_st.tile %tile [0, %k] [4, %n] [1, %b]
: !gml_st.tile<4x128> to !gml_st.tile<4x?>
%result = gml_st.materialize %arg[%tile_of_tile]
: tensor<?x?xf32>[!gml_st.tile<4x?>]
func.return %result : tensor<4x?xf32>
}
// -----
// CHECK-LABEL: @tile_of_tile_of_tile_all_constant
// CHECK-SAME: %[[ARG:.*]]: tensor<4096x2048xf32>
func.func @tile_of_tile_of_tile_all_constant(%arg : tensor<4096x2048xf32>)
-> tensor<128x64xf32> {
// CHECK: %[[SPACE:.*]] = gml_st.space [4096, 2048] : !gml_st.tile<4096x2048>
// CHECK: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [18, 64] [128, 64] [4, 0] : !gml_st.tile<4096x2048> to !gml_st.tile<128x64>
// CHECK: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]] : tensor<4096x2048xf32>[!gml_st.tile<128x64>]
// CHECK: return %[[RES]]
%s = gml_st.space [4096, 2048] : !gml_st.tile<4096x2048>
%t = gml_st.tile %s [0, 32] [2048, 256] [1, 2]
: !gml_st.tile<4096x2048> to !gml_st.tile<2048x256>
%tt = gml_st.tile %t [2, 16] [256, 128] [4, 0]
: !gml_st.tile<2048x256> to !gml_st.tile<256x128>
%ttt = gml_st.tile %tt [4, 8] [128, 64] [1, 1]
: !gml_st.tile<256x128> to !gml_st.tile<128x64>
%res = gml_st.materialize %arg[%ttt]
: tensor<4096x2048xf32>[!gml_st.tile<128x64>]
func.return %res : tensor<128x64xf32>
}
// -----
// CHECK-LABEL: @tile_chain_w_zeroes_and_ones
// CHECK-SAME: %[[ARG:.*]]: tensor<8192x4096x2048xf32>, %[[I:.*]]: index, %[[J:.*]]: index, %[[K:.*]]: index, %[[M:.*]]: index, %[[N:.*]]: index, %[[O:.*]]: index, %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index
func.func @tile_chain_w_zeroes_and_ones(%arg : tensor<8192x4096x2048xf32>,
%i : index, %j : index, %k : index, %m : index, %n : index, %o : index,
%a : index, %b : index, %c : index) -> tensor<?x?x?xf32> {
// CHECK-DAG: %[[C2:.*]] = arith.constant 2
// CHECK-DAG: %[[C16:.*]] = arith.constant 16
// CHECK-DAG: %[[C32:.*]] = arith.constant 32
// CHECK-DAG: %[[SPACE:.*]] = gml_st.space [8192, 4096, 2048] : !gml_st.tile<8192x4096x2048>
// CHECK-DAG: %[[TWO_K:.*]] = arith.muli %[[K]], %[[C2]]
// CHECK-DAG: %[[SIXTEEN_PLUS_J:.*]] = arith.addi %[[J]], %[[C16]]
// CHECK-DAG: %[[TWO_K_PLUS_32:.*]] = arith.addi %[[TWO_K]], %[[C32]]
// CHECK-DAG: %[[C_TIMES_C2:.*]] = arith.muli %[[C]], %[[C2]]
// CHECK-DAG: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [0, %[[SIXTEEN_PLUS_J]], %[[TWO_K_PLUS_32]]] [%[[M]], %[[N]], %[[O]]] [0, %[[B]], %[[C_TIMES_C2]]] : !gml_st.tile<8192x4096x2048> to !gml_st.tile<?x?x?>
// CHECK-DAG: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]] : tensor<8192x4096x2048xf32>[!gml_st.tile<?x?x?>]
// CHECK: return %[[RES]]
%space = gml_st.space [8192, 4096, 2048] : !gml_st.tile<8192x4096x2048>
%tile = gml_st.tile %space [0, 16, 32] [2048, 1024, 512] [0, 1, 2]
: !gml_st.tile<8192x4096x2048> to !gml_st.tile<2048x1024x512>
%tile_of_tile = gml_st.tile %tile [%i, %j, %k] [%m, %n, %o] [%a, %b, %c]
: !gml_st.tile<2048x1024x512> to !gml_st.tile<?x?x?>
%result = gml_st.materialize %arg[%tile_of_tile]
: tensor<8192x4096x2048xf32>[!gml_st.tile<?x?x?>]
func.return %result : tensor<?x?x?xf32>
}
// -----
// CHECK-LABEL: @tile_of_tile_arith_shortcuts_add
// CHECK-SAME: %[[ARG:.*]]: tensor<32x32x32xf32>, %[[I:.*]]: index, %[[J:.*]]: index
func.func @tile_of_tile_arith_shortcuts_add(%arg : tensor<32x32x32xf32>,
%i : index, %j : index) -> tensor<8x8x8xf32> {
// CHECK-DAG: %[[SPACE:.*]] = gml_st.space
// CHECK-DAG: %[[I_PLUS_J:.*]] = arith.addi %[[I]], %[[J]]
// CHECK-DAG: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [%[[J]], %[[I]], %[[I_PLUS_J]]] [8, 8, 8] [1, 1, 1]
// CHECK-DAG: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]]
// CHECK: return %[[RES]]
%space = gml_st.space [32, 32, 32] : !gml_st.tile<32x32x32>
%tile = gml_st.tile %space [0, %i, %i] [16, 16, 16] [1, 1, 1]
: !gml_st.tile<32x32x32> to !gml_st.tile<16x16x16>
%tile_of_tile = gml_st.tile %tile [%j, 0, %j] [8, 8, 8] [1, 1, 1]
: !gml_st.tile<16x16x16> to !gml_st.tile<8x8x8>
%result = gml_st.materialize %arg[%tile_of_tile]
: tensor<32x32x32xf32>[!gml_st.tile<8x8x8>]
func.return %result : tensor<8x8x8xf32>
}
// -----
// CHECK-LABEL: @tile_of_tile_arith_shortcuts_mul
// CHECK-SAME: %[[ARG:.*]]: tensor<32x32x32x32x32xf32>, %[[A:.*]]: index, %[[B:.*]]: index
func.func @tile_of_tile_arith_shortcuts_mul(%arg : tensor<32x32x32x32x32xf32>,
%a : index, %b : index) -> tensor<8x8x8x8x8xf32> {
// CHECK-DAG: %[[SPACE:.*]] = gml_st.space
// CHECK-DAG: %[[AB:.*]] = arith.muli %[[A]], %[[B]]
// CHECK-DAG: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [0, 0, 0, 0, 0] [8, 8, 8, 8, 8] [0, %[[B]], 0, %[[A]], %[[AB]]]
// CHECK-DAG: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]]
// CHECK: return %[[RES]]
%space = gml_st.space [32, 32, 32, 32, 32] : !gml_st.tile<32x32x32x32x32>
%tile = gml_st.tile %space
[0, 0, 0, 0, 0] [16, 16, 16, 16, 16] [0, 1, %a, %a, %a]
: !gml_st.tile<32x32x32x32x32> to !gml_st.tile<16x16x16x16x16>
%tile_of_tile = gml_st.tile %tile
[0, 0, 0, 0, 0] [8, 8, 8, 8, 8] [%b, %b, 0, 1, %b]
: !gml_st.tile<16x16x16x16x16> to !gml_st.tile<8x8x8x8x8>
%result = gml_st.materialize %arg[%tile_of_tile]
: tensor<32x32x32x32x32xf32>[!gml_st.tile<8x8x8x8x8>]
func.return %result : tensor<8x8x8x8x8xf32>
}
// -----
// CHECK-LABEL: @point_of_tile
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>, %[[I:.*]]: index, %[[J:.*]]: index, %[[K:.*]]: index, %[[M:.*]]: index, %[[A:.*]]: index
func.func @point_of_tile(%arg : tensor<?x?xf32>, %i : index, %j : index,
%k : index, %m : index, %a : index) -> f32 {
// CHECK-DAG: %[[SPACE:.*]] = gml_st.space [1024, %[[M]]] : !gml_st.tile<1024x?>
// CHECK-DAG: %[[AK:.*]] = arith.muli %[[A]], %[[K]]
// CHECK-DAG: %[[J_PLUS_AK:.*]] = arith.addi %[[J]], %[[AK]]
// CHECK-DAG: %[[POINT:.*]] = gml_st.point %[[SPACE]] [%[[I]], %[[J_PLUS_AK]]] : !gml_st.tile<1024x?> to !gml_st.point
// CHECK-DAG: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[POINT]]] : tensor<?x?xf32>[!gml_st.point]
// CHECK: return %[[RES]]
%space = gml_st.space [1024, %m] : !gml_st.tile<1024x?>
%tile = gml_st.tile %space [%i, %j] [4, 128] [2, %a]
: !gml_st.tile<1024x?> to !gml_st.tile<4x128>
%point_of_tile = gml_st.point %tile [0, %k]
: !gml_st.tile<4x128> to !gml_st.point
%result = gml_st.materialize %arg[%point_of_tile]
: tensor<?x?xf32>[!gml_st.point]
func.return %result : f32
}
// -----
// CHECK-LABEL: @point_of_tile_of_tile_all_constant
// CHECK-SAME: %[[ARG:.*]]: tensor<4096x2048xf32>
func.func @point_of_tile_of_tile_all_constant(%arg : tensor<4096x2048xf32>)
-> f32 {
// CHECK: %[[SPACE:.*]] = gml_st.space [4096, 2048] : !gml_st.tile<4096x2048>
// CHECK: %[[POINT:.*]] = gml_st.point %[[SPACE]] [18, 64] : !gml_st.tile<4096x2048> to !gml_st.point
// CHECK: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[POINT]]] : tensor<4096x2048xf32>[!gml_st.point]
// CHECK: return %[[RES]]
%s = gml_st.space [4096, 2048] : !gml_st.tile<4096x2048>
%t = gml_st.tile %s [0, 32] [2048, 256] [1, 2]
: !gml_st.tile<4096x2048> to !gml_st.tile<2048x256>
%tt = gml_st.tile %t [2, 16] [256, 128] [4, 0]
: !gml_st.tile<2048x256> to !gml_st.tile<256x128>
%ptt = gml_st.point %tt [4, 8] : !gml_st.tile<256x128> to !gml_st.point
%res = gml_st.materialize %arg[%ptt]
: tensor<4096x2048xf32>[!gml_st.point]
func.return %res : f32
}
// -----
// CHECK-LABEL: @point_chain_w_zeroes_and_ones
// CHECK-SAME: %[[ARG:.*]]: tensor<8192x4096x2048xf32>, %[[I:.*]]: index, %[[J:.*]]: index, %[[K:.*]]: index
func.func @point_chain_w_zeroes_and_ones(%arg : tensor<8192x4096x2048xf32>,
%i : index, %j : index, %k : index) -> f32 {
// CHECK-DAG: %[[C2:.*]] = arith.constant 2
// CHECK-DAG: %[[C16:.*]] = arith.constant 16
// CHECK-DAG: %[[C32:.*]] = arith.constant 32
// CHECK-DAG: %[[SPACE:.*]] = gml_st.space [8192, 4096, 2048] : !gml_st.tile<8192x4096x2048>
// CHECK-DAG: %[[TWO_K:.*]] = arith.muli %[[K]], %[[C2]]
// CHECK-DAG: %[[SIXTEEN_PLUS_J:.*]] = arith.addi %[[J]], %[[C16]]
// CHECK-DAG: %[[TWO_K_PLUS_32:.*]] = arith.addi %[[TWO_K]], %[[C32]]
// CHECK-DAG: %[[POINT:.*]] = gml_st.point %[[SPACE]] [0, %[[SIXTEEN_PLUS_J]], %[[TWO_K_PLUS_32]]] : !gml_st.tile<8192x4096x2048> to !gml_st.point
// CHECK-DAG: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[POINT]]] : tensor<8192x4096x2048xf32>[!gml_st.point]
// CHECK: return %[[RES]]
%space = gml_st.space [8192, 4096, 2048] : !gml_st.tile<8192x4096x2048>
%tile = gml_st.tile %space [0, 16, 32] [2048, 1024, 512] [0, 1, 2]
: !gml_st.tile<8192x4096x2048> to !gml_st.tile<2048x1024x512>
%point_of_tile = gml_st.point %tile [%i, %j, %k]
: !gml_st.tile<2048x1024x512> to !gml_st.point
%result = gml_st.materialize %arg[%point_of_tile]
: tensor<8192x4096x2048xf32>[!gml_st.point]
func.return %result : f32
}
// -----
// CHECK-LABEL: @point_of_transpose_dims_of_tile_all_constant
// CHECK-SAME: %[[ARG:.*]]: tensor<2048x4096xf32>
func.func @point_of_transpose_dims_of_tile_all_constant(%arg : tensor<2048x4096xf32>)
-> f32 {
// CHECK: %[[SPACE:.*]] = gml_st.space [2048, 4096] : !gml_st.tile<2048x4096>
// CHECK: %[[POINT:.*]] = gml_st.point %[[SPACE]] [40, 8] : !gml_st.tile<2048x4096> to !gml_st.point
// CHECK: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[POINT]]] : tensor<2048x4096xf32>[!gml_st.point]
// CHECK: return %[[RES]]
%s = gml_st.space [4096, 2048] : !gml_st.tile<4096x2048>
%t = gml_st.tile %s [0, 32] [128, 256] [1, 2]
: !gml_st.tile<4096x2048> to !gml_st.tile<128x256>
%tt = gml_st.transpose_dims %t, [1, 0]
: !gml_st.tile<128x256> to !gml_st.tile<256x128>
%ptt = gml_st.point %tt [4, 8] : !gml_st.tile<256x128> to !gml_st.point
%res = gml_st.materialize %arg[%ptt]
: tensor<2048x4096xf32>[!gml_st.point]
func.return %res : f32
}
// -----
// CHECK-LABEL: @transpose_dims_of_transpose_dims_of_tile
// CHECK-SAME: %[[ARG:.*]]: tensor<10x?x5xf32>, %[[SIZE:.*]]: index
func.func @transpose_dims_of_transpose_dims_of_tile(
%arg : tensor<10x?x5xf32>, %size: index) -> tensor<4x?x5xf32> {
// CHECK: %[[SPACE:.*]] = gml_st.space [10, %[[SIZE]], 5] : !gml_st.tile<10x?x5>
// CHECK: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [3, 0, 0] [4, %[[SIZE]], 5] [2, %[[SIZE]], 1] : !gml_st.tile<10x?x5> to !gml_st.tile<4x?x5>
// CHECK: %[[RES:.*]] = gml_st.materialize %arg0[%[[TILE]]] : tensor<10x?x5xf32>[!gml_st.tile<4x?x5>]
// CHECK: return %[[RES]] : tensor<4x?x5xf32>
%s = gml_st.space [%size, 5, 10] : !gml_st.tile<?x5x10>
%t = gml_st.tile %s [0, 0, 3] [%size, 5, 4] [%size, 1, 2]
: !gml_st.tile<?x5x10> to !gml_st.tile<?x5x4>
%tt = gml_st.transpose_dims %t, [1, 0, 2]
: !gml_st.tile<?x5x4> to !gml_st.tile<5x?x4>
%tt2 = gml_st.transpose_dims %tt, [2, 1, 0]
: !gml_st.tile<5x?x4> to !gml_st.tile<4x?x5>
%res = gml_st.materialize %arg[%tt2]
: tensor<10x?x5xf32>[!gml_st.tile<4x?x5>]
func.return %res : tensor<4x?x5xf32>
}
// -----
// CHECK-LABEL: @transpose_dims_of_space
// CHECK-SAME: %[[ARG:.*]]: tensor<5x10x?xf32>, %[[SIZE:.*]]: index
func.func @transpose_dims_of_space(
%arg : tensor<5x10x?xf32>, %size: index) -> tensor<5x10x?xf32> {
// CHECK: %[[SPACE:.*]] = gml_st.space [5, 10, %[[SIZE]]] : !gml_st.tile<5x10x?>
// CHECK: %[[RES:.*]] = gml_st.materialize %arg0[%[[SPACE]]] : tensor<5x10x?xf32>[!gml_st.tile<5x10x?>]
// CHECK: return %[[RES]] : tensor<5x10x?xf32>
%s = gml_st.space [%size, 5, 10] : !gml_st.tile<?x5x10>
%tt = gml_st.transpose_dims %s, [1, 2, 0]
: !gml_st.tile<?x5x10> to !gml_st.tile<5x10x?>
%res = gml_st.materialize %arg[%tt]
: tensor<5x10x?xf32>[!gml_st.tile<5x10x?>]
func.return %res : tensor<5x10x?xf32>
}
// -----
// CHECK-LABEL: @drop_dims_of_space
// CHECK-SAME: %[[ARG:.*]]: tensor<?x10xf32>, %[[SIZE:.*]]: index
func.func @drop_dims_of_space(
%arg : tensor<?x10xf32>, %size: index) -> tensor<?x10xf32> {
// CHECK: %[[SPACE:.*]] = gml_st.space [%[[SIZE]], 10] : !gml_st.tile<?x10>
// CHECK: %[[RES:.*]] = gml_st.materialize %arg0[%[[SPACE]]] : tensor<?x10xf32>[!gml_st.tile<?x10>]
// CHECK: return %[[RES]] : tensor<?x10xf32>
%s = gml_st.space [%size, 5, 10] : !gml_st.tile<?x5x10>
%tt = gml_st.drop_dims %s, [0, 2]
: !gml_st.tile<?x5x10> to !gml_st.tile<?x10>
%res = gml_st.materialize %arg[%tt]
: tensor<?x10xf32>[!gml_st.tile<?x10>]
func.return %res : tensor<?x10xf32>
}