blob: 5cf3de77d354d75904274c40bd55c85a70324a32 [file] [log] [blame]
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @transpose_splat_constant
func.func @transpose_splat_constant() -> tensor<5x10xf32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<5x10xf32>
%cst = mhlo.constant dense<1.000000e+00> : tensor<10x5xf32>
%0 = "mhlo.transpose"(%cst) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<10x5xf32>) -> tensor<5x10xf32>
// CHECK-NEXT: return [[CST]]
func.return %0 : tensor<5x10xf32>
}
// -----
// CHECK-LABEL: func @remove_noop
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func.func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
%0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
// CHECK-NEXT: return [[ARG]]
func.return %0 : tensor<2x3x9x5xi32>
}
// -----
// CHECK-LABEL: func @keep_real_transpose
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func.func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
func.return %0 : tensor<3x2x5x9xi32>
}
// -----
// CHECK-LABEL: func @keep_same_shape_real_transpose
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func.func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> {
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
func.return %0 : tensor<4x4xi32>
}
// -----
// CHECK-LABEL: @eliminate_redundant_transpose
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func.func @eliminate_redundant_transpose(%arg : tensor<3x4x16x2xf32>) -> tensor<3x2x16x4xf32> {
%0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}: (tensor<3x4x16x2xf32>) -> tensor<3x2x4x16xf32>
%1 = "mhlo.transpose"(%0) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>}: (tensor<3x2x4x16xf32>) -> tensor<3x2x16x4xf32>
// CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.transpose"([[ARG]])
// CHECK-SAME: dense<[0, 3, 2, 1]
// CHECK-NEXT: return [[RET]]
func.return %1 : tensor<3x2x16x4xf32>
}
// -----
// CHECK-LABEL: @simplify_transpose_case1
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func.func @simplify_transpose_case1(%arg : tensor<10x1x512xf32>) -> tensor<1x10x512xf32> {
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 2]> : tensor<3xi64>}: (tensor<10x1x512xf32>) -> tensor<1x10x512xf32>
// CHECK-NEXT: "mhlo.reshape"([[ARG]])
func.return %0 : tensor<1x10x512xf32>
}
// -----
// CHECK-LABEL: @simplify_transpose_case2
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func.func @simplify_transpose_case2(%arg : tensor<10x1x512x1xf32>) -> tensor<1x1x10x512xf32> {
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 3, 0, 2]> : tensor<4xi64>}: (tensor<10x1x512x1xf32>) -> tensor<1x1x10x512xf32>
// CHECK-NEXT: "mhlo.reshape"([[ARG]])
func.return %0 : tensor<1x1x10x512xf32>
}
// -----
// CHECK-LABEL: @not_simplify_transpose_dynamic_shape
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func.func @not_simplify_transpose_dynamic_shape(%arg : tensor<10x?x512xf32>) -> tensor<?x10x512xf32> {
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 2]> : tensor<3xi64>}: (tensor<10x?x512xf32>) -> tensor<?x10x512xf32>
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
func.return %0 : tensor<?x10x512xf32>
}
// -----
// CHECK-LABEL: func @broadcast_transpose
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func.func @broadcast_transpose(%arg0 : tensor<64xf32>) -> tensor<5x64x31x95xf32> {
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<5x31x95x64xf32>
%1 = "mhlo.transpose"(%0) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32>
// CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.broadcast_in_dim"([[ARG]])
// CHECK-SAME: dense<1>
// CHECK-NEXT: return [[RET]]
func.return %1 : tensor<5x64x31x95xf32>
}
// -----
// CHECK-LABEL: func @broadcast_transpose_non_dim
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func.func @broadcast_transpose_non_dim(%arg0 : tensor<f32>) -> tensor<5x64x31x95xf32> {
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x31x95x64xf32>
%1 = "mhlo.transpose"(%0) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32>
// CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.broadcast_in_dim"([[ARG]])
// CHECK-SAME: dense<>
// CHECK-NEXT: return [[RET]]
func.return %1 : tensor<5x64x31x95xf32>
}
// -----
// CHECK-LABEL: func @broadcast_transpose_multi_dim
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func.func @broadcast_transpose_multi_dim(%arg0 : tensor<95x64xf32>) -> tensor<5x64x31x95xf32> {
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<95x64xf32>) -> tensor<5x31x95x64xf32>
%1 = "mhlo.transpose"(%0) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32>
// CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.broadcast_in_dim"([[ARG]])
// CHECK-SAME: dense<[3, 1]>
// CHECK-NEXT: return [[RET]]
func.return %1 : tensor<5x64x31x95xf32>
}