| // RUN: tf-opt -split-input-file -verify-diagnostics -tf-unroll-batch-matmul %s | FileCheck %s |
| |
| //==== V1 tests ==== |
| |
| func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> { |
| %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> |
| return %0 : tensor<2x3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulTwoDim |
| // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>} |
| // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>} |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>} |
| |
| // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32> |
| // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32> |
| // CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> |
| // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> |
| // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulTwoDimAdjXY(%arg0: tensor<2x3x5x4xf32>, %arg1: tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32> { |
| %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<2x3x5x4xf32>, tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32> |
| return %0 : tensor<2x3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulTwoDimAdjXY |
| // CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32> |
| // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>} |
| // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>} |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>} |
| |
| // CHECK: %[[LHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<2x3x5x4xf32>, tensor<4xi32>) -> tensor<2x3x4x5xf32> |
| // CHECK: %[[RHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<2x3x6x5xf32>, tensor<4xi32>) -> tensor<2x3x5x6xf32> |
| |
| // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%[[LHS_TRANSPOSED]], %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32> |
| // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%[[RHS_TRANSPOSED]], %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32> |
| // CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> |
| // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> |
| // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulOneDim(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { |
| %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> |
| return %0 : tensor<3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulOneDim |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| |
| // CHECK: %[[LHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> |
| // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulSingleBatch(%arg0: tensor<1x4x5xf32>, %arg1: tensor<1x5x6xf32>) -> tensor<1x4x6xf32> { |
| %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<1x4x5xf32>, tensor<1x5x6xf32>) -> tensor<1x4x6xf32> |
| return %0 : tensor<1x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulSingleBatch |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64> |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} : () -> tensor<2xi64> |
| |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%arg0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) {axis = 0 : i64} : (tensor<4x6xf32>) -> tensor<1x4x6xf32> |
| // CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulUnbatchedLeft(%arg0: tensor<4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { |
| %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> |
| return %0 : tensor<3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulUnbatchedLeft |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| |
| // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> |
| // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulUnbatchedRight(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> { |
| %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32> |
| return %0 : tensor<3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulUnbatchedRight |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64> |
| |
| // CHECK: %[[LHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> |
| // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> { |
| %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| return %0 : tensor<4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulMatrix |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulMatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf32>) -> tensor<4x6xf32> { |
| %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> |
| return %0 : tensor<4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulMatrixAdjXY |
| // CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[1, 0]> : tensor<2xi32> |
| |
| // CHECK: %[[LHS_1:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<5x4xf32>, tensor<2xi32>) -> tensor<4x5xf32> |
| // CHECK: %[[RHS_1:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<6x5xf32>, tensor<2xi32>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> |
| } |
| |
| // ----- |
| // ==== V2 tests ==== |
| |
| func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> { |
| %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> |
| return %0 : tensor<2x3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV2TwoDim |
| // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>} |
| // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>} |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>} |
| |
| // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32> |
| // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32> |
| // CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> |
| // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> |
| // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV2TwoDimAdjXY(%arg0: tensor<2x3x5x4xf32>, %arg1: tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32> { |
| %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<2x3x5x4xf32>, tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32> |
| return %0 : tensor<2x3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV2TwoDimAdjXY |
| // CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32> |
| // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>} |
| // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>} |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>} |
| |
| // CHECK: %[[LHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<2x3x5x4xf32>, tensor<4xi32>) -> tensor<2x3x4x5xf32> |
| // CHECK: %[[RHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<2x3x6x5xf32>, tensor<4xi32>) -> tensor<2x3x5x6xf32> |
| |
| // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%[[LHS_TRANSPOSED]], %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32> |
| // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%[[RHS_TRANSPOSED]], %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32> |
| // CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> |
| // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> |
| // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV2Broadcast(%arg0: tensor<2x1x4x5xf32>, %arg1: tensor<1x3x5x6xf32>) -> tensor<2x3x4x6xf32> { |
| %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x1x4x5xf32>, tensor<1x3x5x6xf32>) -> tensor<2x3x4x6xf32> |
| return %0 : tensor<2x3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV2Broadcast |
| // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 4, 5]> : tensor<3xi64>} |
| // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>} |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>} |
| |
| // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x1x4x5xf32>, tensor<3xi64>) -> tensor<2x4x5xf32> |
| // CHECK: %[[LHS_SPLIT:.*]]:2 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<2x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<1x3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32> |
| // CHECK: %[[RHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> |
| // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> |
| // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV2OneDim(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { |
| %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> |
| return %0 : tensor<3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV2OneDim |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| |
| // CHECK: %[[LHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> |
| // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV2SingleBatch(%arg0: tensor<1x4x5xf32>, %arg1: tensor<1x5x6xf32>) -> tensor<1x4x6xf32> { |
| %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<1x4x5xf32>, tensor<1x5x6xf32>) -> tensor<1x4x6xf32> |
| return %0 : tensor<1x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV2SingleBatch |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64> |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} : () -> tensor<2xi64> |
| |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%arg0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) {axis = 0 : i64} : (tensor<4x6xf32>) -> tensor<1x4x6xf32> |
| // CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV2UnbatchedLeft(%arg0: tensor<4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { |
| %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> |
| return %0 : tensor<3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV2UnbatchedLeft |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| |
| // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> |
| // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV2UnbatchedRight(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> { |
| %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32> |
| return %0 : tensor<3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV2UnbatchedRight |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64> |
| |
| // CHECK: %[[LHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> |
| // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV2Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> { |
| %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| return %0 : tensor<4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV2Matrix |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV2MatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf32>) -> tensor<4x6xf32> { |
| %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> |
| return %0 : tensor<4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV2MatrixAdjXY |
| // CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[1, 0]> : tensor<2xi32> |
| |
| // CHECK: %[[LHS_1:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<5x4xf32>, tensor<2xi32>) -> tensor<4x5xf32> |
| // CHECK: %[[RHS_1:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<6x5xf32>, tensor<2xi32>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> |
| } |
| |
| // ----- |
| // ==== V3 tests ==== |
| |
| func @batchMatMulV3TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> { |
| %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> |
| return %0 : tensor<2x3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV3TwoDim |
| // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>} |
| // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>} |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>} |
| |
| // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32> |
| // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32> |
| // CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> |
| // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> |
| // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV3TwoDimAdjXY(%arg0: tensor<2x3x5x4xf32>, %arg1: tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32> { |
| %0 = "tf.BatchMatMulV3"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<2x3x5x4xf32>, tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32> |
| return %0 : tensor<2x3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV3TwoDimAdjXY |
| // CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32> |
| // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>} |
| // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>} |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>} |
| |
| // CHECK: %[[LHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<2x3x5x4xf32>, tensor<4xi32>) -> tensor<2x3x4x5xf32> |
| // CHECK: %[[RHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<2x3x6x5xf32>, tensor<4xi32>) -> tensor<2x3x5x6xf32> |
| |
| // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%[[LHS_TRANSPOSED]], %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32> |
| // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%[[RHS_TRANSPOSED]], %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32> |
| // CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> |
| // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> |
| // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV3Broadcast(%arg0: tensor<2x1x4x5xf32>, %arg1: tensor<1x3x5x6xf32>) -> tensor<2x3x4x6xf32> { |
| %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<2x1x4x5xf32>, tensor<1x3x5x6xf32>) -> tensor<2x3x4x6xf32> |
| return %0 : tensor<2x3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV3Broadcast |
| // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 4, 5]> : tensor<3xi64>} |
| // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>} |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>} |
| |
| // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x1x4x5xf32>, tensor<3xi64>) -> tensor<2x4x5xf32> |
| // CHECK: %[[LHS_SPLIT:.*]]:2 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<2x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<1x3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32> |
| // CHECK: %[[RHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> |
| // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> |
| // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV3OneDim(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { |
| %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> |
| return %0 : tensor<3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV3OneDim |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| |
| // CHECK: %[[LHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> |
| // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV3SingleBatch(%arg0: tensor<1x4x5xf32>, %arg1: tensor<1x5x6xf32>) -> tensor<1x4x6xf32> { |
| %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<1x4x5xf32>, tensor<1x5x6xf32>) -> tensor<1x4x6xf32> |
| return %0 : tensor<1x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV3SingleBatch |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64> |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} : () -> tensor<2xi64> |
| |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%arg0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) {axis = 0 : i64} : (tensor<4x6xf32>) -> tensor<1x4x6xf32> |
| // CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV3UnbatchedLeft(%arg0: tensor<4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { |
| %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> |
| return %0 : tensor<3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV3UnbatchedLeft |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} |
| // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} |
| |
| // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>) |
| // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> |
| // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV3UnbatchedRight(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> { |
| %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32> |
| return %0 : tensor<3x4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV3UnbatchedRight |
| // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> |
| // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64> |
| |
| // CHECK: %[[LHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>) |
| // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> |
| // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV3Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> { |
| %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| return %0 : tensor<4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV3Matrix |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV3MatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf32>) -> tensor<4x6xf32> { |
| %0 = "tf.BatchMatMulV3"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> |
| return %0 : tensor<4x6xf32> |
| |
| // CHECK-LABEL: batchMatMulV3MatrixAdjXY |
| // CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[1, 0]> : tensor<2xi32> |
| |
| // CHECK: %[[LHS_1:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<5x4xf32>, tensor<2xi32>) -> tensor<4x5xf32> |
| // CHECK: %[[RHS_1:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<6x5xf32>, tensor<2xi32>) -> tensor<5x6xf32> |
| |
| // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> |
| |
| // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> |
| } |
| |
| // ----- |
| |
| func @batchMatMulV3MatrixInt8(%arg0: tensor<4x5xi8>, %arg1: tensor<5x6xi8>) -> tensor<4x6xi32> { |
| %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xi8>, tensor<5x6xi8>) -> tensor<4x6xi32> |
| return %0 : tensor<4x6xi32> |
| |
| // CHECK-LABEL: batchMatMulV3MatrixInt8 |
| // CHECK: %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xi8>, tensor<5x6xi8>) -> tensor<4x6xi32> |
| // CHECK: return %0 : tensor<4x6xi32> |
| } |