blob: 1f4ca704a95efaebd6f50bc1a3ccb35dd60b44a8 [file] [log] [blame]
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-unroll-batch-matmul %s | FileCheck %s
//==== V1 tests ====
func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulTwoDim
// CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
// CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
// CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
// CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
// CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
// CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
}
// -----
func @batchMatMulTwoDimAdjXY(%arg0: tensor<2x3x5x4xf32>, %arg1: tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<2x3x5x4xf32>, tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulTwoDimAdjXY
// CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32>
// CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
// CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
// CHECK: %[[LHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<2x3x5x4xf32>, tensor<4xi32>) -> tensor<2x3x4x5xf32>
// CHECK: %[[RHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<2x3x6x5xf32>, tensor<4xi32>) -> tensor<2x3x5x6xf32>
// CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%[[LHS_TRANSPOSED]], %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%[[RHS_TRANSPOSED]], %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
// CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
// CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
// CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
}
// -----
func @batchMatMulOneDim(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulOneDim
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[LHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
}
// -----
func @batchMatMulSingleBatch(%arg0: tensor<1x4x5xf32>, %arg1: tensor<1x5x6xf32>) -> tensor<1x4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<1x4x5xf32>, tensor<1x5x6xf32>) -> tensor<1x4x6xf32>
return %0 : tensor<1x4x6xf32>
// CHECK-LABEL: batchMatMulSingleBatch
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%arg0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) {axis = 0 : i64} : (tensor<4x6xf32>) -> tensor<1x4x6xf32>
// CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32>
}
// -----
func @batchMatMulUnbatchedLeft(%arg0: tensor<4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulUnbatchedLeft
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
}
// -----
func @batchMatMulUnbatchedRight(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulUnbatchedRight
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: %[[LHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
}
// -----
func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
return %0 : tensor<4x6xf32>
// CHECK-LABEL: batchMatMulMatrix
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
}
// -----
func @batchMatMulMatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf32>) -> tensor<4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
return %0 : tensor<4x6xf32>
// CHECK-LABEL: batchMatMulMatrixAdjXY
// CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[LHS_1:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<5x4xf32>, tensor<2xi32>) -> tensor<4x5xf32>
// CHECK: %[[RHS_1:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<6x5xf32>, tensor<2xi32>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
}
// -----
// ==== V2 tests ====
func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulV2TwoDim
// CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
// CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
// CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
// CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
// CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
// CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
}
// -----
func @batchMatMulV2TwoDimAdjXY(%arg0: tensor<2x3x5x4xf32>, %arg1: tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<2x3x5x4xf32>, tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulV2TwoDimAdjXY
// CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32>
// CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
// CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
// CHECK: %[[LHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<2x3x5x4xf32>, tensor<4xi32>) -> tensor<2x3x4x5xf32>
// CHECK: %[[RHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<2x3x6x5xf32>, tensor<4xi32>) -> tensor<2x3x5x6xf32>
// CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%[[LHS_TRANSPOSED]], %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%[[RHS_TRANSPOSED]], %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
// CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
// CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
// CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
}
// -----
func @batchMatMulV2Broadcast(%arg0: tensor<2x1x4x5xf32>, %arg1: tensor<1x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x1x4x5xf32>, tensor<1x3x5x6xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulV2Broadcast
// CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 4, 5]> : tensor<3xi64>}
// CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
// CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x1x4x5xf32>, tensor<3xi64>) -> tensor<2x4x5xf32>
// CHECK: %[[LHS_SPLIT:.*]]:2 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<2x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<1x3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
// CHECK: %[[RHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
// CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
// CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
}
// -----
func @batchMatMulV2OneDim(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulV2OneDim
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[LHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
}
// -----
func @batchMatMulV2SingleBatch(%arg0: tensor<1x4x5xf32>, %arg1: tensor<1x5x6xf32>) -> tensor<1x4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<1x4x5xf32>, tensor<1x5x6xf32>) -> tensor<1x4x6xf32>
return %0 : tensor<1x4x6xf32>
// CHECK-LABEL: batchMatMulV2SingleBatch
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%arg0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) {axis = 0 : i64} : (tensor<4x6xf32>) -> tensor<1x4x6xf32>
// CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32>
}
// -----
func @batchMatMulV2UnbatchedLeft(%arg0: tensor<4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulV2UnbatchedLeft
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
}
// -----
func @batchMatMulV2UnbatchedRight(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulV2UnbatchedRight
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: %[[LHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
}
// -----
func @batchMatMulV2Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
return %0 : tensor<4x6xf32>
// CHECK-LABEL: batchMatMulV2Matrix
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
}
// -----
func @batchMatMulV2MatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf32>) -> tensor<4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
return %0 : tensor<4x6xf32>
// CHECK-LABEL: batchMatMulV2MatrixAdjXY
// CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[LHS_1:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<5x4xf32>, tensor<2xi32>) -> tensor<4x5xf32>
// CHECK: %[[RHS_1:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<6x5xf32>, tensor<2xi32>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
}
// -----
// ==== V3 tests ====
func @batchMatMulV3TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulV3TwoDim
// CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
// CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
// CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
// CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
// CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
// CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
}
// -----
func @batchMatMulV3TwoDimAdjXY(%arg0: tensor<2x3x5x4xf32>, %arg1: tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMulV3"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<2x3x5x4xf32>, tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulV3TwoDimAdjXY
// CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32>
// CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
// CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
// CHECK: %[[LHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<2x3x5x4xf32>, tensor<4xi32>) -> tensor<2x3x4x5xf32>
// CHECK: %[[RHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<2x3x6x5xf32>, tensor<4xi32>) -> tensor<2x3x5x6xf32>
// CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%[[LHS_TRANSPOSED]], %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%[[RHS_TRANSPOSED]], %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
// CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
// CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
// CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
}
// -----
func @batchMatMulV3Broadcast(%arg0: tensor<2x1x4x5xf32>, %arg1: tensor<1x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<2x1x4x5xf32>, tensor<1x3x5x6xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulV3Broadcast
// CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 4, 5]> : tensor<3xi64>}
// CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
// CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x1x4x5xf32>, tensor<3xi64>) -> tensor<2x4x5xf32>
// CHECK: %[[LHS_SPLIT:.*]]:2 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<2x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<1x3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
// CHECK: %[[RHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
// CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
// CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
}
// -----
func @batchMatMulV3OneDim(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulV3OneDim
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[LHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
}
// -----
func @batchMatMulV3SingleBatch(%arg0: tensor<1x4x5xf32>, %arg1: tensor<1x5x6xf32>) -> tensor<1x4x6xf32> {
%0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<1x4x5xf32>, tensor<1x5x6xf32>) -> tensor<1x4x6xf32>
return %0 : tensor<1x4x6xf32>
// CHECK-LABEL: batchMatMulV3SingleBatch
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%arg0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) {axis = 0 : i64} : (tensor<4x6xf32>) -> tensor<1x4x6xf32>
// CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32>
}
// -----
func @batchMatMulV3UnbatchedLeft(%arg0: tensor<4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulV3UnbatchedLeft
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
// CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
}
// -----
func @batchMatMulV3UnbatchedRight(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulV3UnbatchedRight
// CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: %[[LHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
// CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
}
// -----
func @batchMatMulV3Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
%0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
return %0 : tensor<4x6xf32>
// CHECK-LABEL: batchMatMulV3Matrix
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
}
// -----
func @batchMatMulV3MatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf32>) -> tensor<4x6xf32> {
%0 = "tf.BatchMatMulV3"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
return %0 : tensor<4x6xf32>
// CHECK-LABEL: batchMatMulV3MatrixAdjXY
// CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[LHS_1:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<5x4xf32>, tensor<2xi32>) -> tensor<4x5xf32>
// CHECK: %[[RHS_1:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<6x5xf32>, tensor<2xi32>) -> tensor<5x6xf32>
// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
}
// -----
func @batchMatMulV3MatrixInt8(%arg0: tensor<4x5xi8>, %arg1: tensor<5x6xi8>) -> tensor<4x6xi32> {
%0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xi8>, tensor<5x6xi8>) -> tensor<4x6xi32>
return %0 : tensor<4x6xi32>
// CHECK-LABEL: batchMatMulV3MatrixInt8
// CHECK: %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xi8>, tensor<5x6xi8>) -> tensor<4x6xi32>
// CHECK: return %0 : tensor<4x6xi32>
}