Add missing test cases in unroll-batch-matmul.mlir:
- adj_x/adj_y set to true
- adj_x/adj_y set to true for unbatched matrices
- unbatched LHS and batched RHS
- broadcasting (supported by V2/V3)
- added missing tests for V2 and V3 versions as well

PiperOrigin-RevId: 401346926
Change-Id: I491d3b27837d7828e7adb5bbe65d809d6fdc2201
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir
index 97ca4bc..1f4ca70 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir
@@ -46,6 +46,54 @@
 
 // -----
 
+func @batchMatMulTwoDimAdjXY(%arg0: tensor<2x3x5x4xf32>, %arg1: tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32> {
+  %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<2x3x5x4xf32>, tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32>
+  return %0 : tensor<2x3x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulTwoDimAdjXY
+  // CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32>
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+
+  // CHECK: %[[LHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<2x3x5x4xf32>, tensor<4xi32>) -> tensor<2x3x4x5xf32>
+  // CHECK: %[[RHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<2x3x6x5xf32>, tensor<4xi32>) -> tensor<2x3x5x6xf32>
+
+  // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%[[LHS_TRANSPOSED]], %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
+  // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
+  // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+  // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%[[RHS_TRANSPOSED]], %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
+  // CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
+  // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
+  // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
+}
+
+// -----
+
 func @batchMatMulOneDim(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
   %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
   return %0 : tensor<3x4x6xf32>
@@ -95,6 +143,29 @@
 
 // -----
 
+func @batchMatMulUnbatchedLeft(%arg0: tensor<4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
+  %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
+  return %0 : tensor<3x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulUnbatchedLeft
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+
+  // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
+  // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
+}
+
+// -----
+
 func @batchMatMulUnbatchedRight(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> {
   %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
   return %0 : tensor<3x4x6xf32>
@@ -128,6 +199,23 @@
 }
 
 // -----
+
+func @batchMatMulMatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf32>) -> tensor<4x6xf32> {
+  %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  return %0 : tensor<4x6xf32>
+
+  // CHECK-LABEL: batchMatMulMatrixAdjXY
+  // CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[1, 0]> : tensor<2xi32>
+
+  // CHECK: %[[LHS_1:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<5x4xf32>, tensor<2xi32>) -> tensor<4x5xf32>
+  // CHECK: %[[RHS_1:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<6x5xf32>, tensor<2xi32>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
+}
+
+// -----
 // ==== V2 tests ====
 
 func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
@@ -174,6 +262,91 @@
 
 // -----
 
+func @batchMatMulV2TwoDimAdjXY(%arg0: tensor<2x3x5x4xf32>, %arg1: tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32> {
+  %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<2x3x5x4xf32>, tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32>
+  return %0 : tensor<2x3x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV2TwoDimAdjXY
+  // CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32>
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+
+  // CHECK: %[[LHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<2x3x5x4xf32>, tensor<4xi32>) -> tensor<2x3x4x5xf32>
+  // CHECK: %[[RHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<2x3x6x5xf32>, tensor<4xi32>) -> tensor<2x3x5x6xf32>
+
+  // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%[[LHS_TRANSPOSED]], %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
+  // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
+  // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+  // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%[[RHS_TRANSPOSED]], %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
+  // CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
+  // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
+  // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
+}
+
+// -----
+
+func @batchMatMulV2Broadcast(%arg0: tensor<2x1x4x5xf32>, %arg1: tensor<1x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
+  %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x1x4x5xf32>, tensor<1x3x5x6xf32>) -> tensor<2x3x4x6xf32>
+  return %0 : tensor<2x3x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV2Broadcast
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 4, 5]> : tensor<3xi64>}
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+
+  // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x1x4x5xf32>, tensor<3xi64>) -> tensor<2x4x5xf32>
+  // CHECK: %[[LHS_SPLIT:.*]]:2 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<2x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>)
+  // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+  // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<1x3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
+  // CHECK: %[[RHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
+  // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
+  // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
+}
+
+// -----
+
 func @batchMatMulV2OneDim(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
   %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
   return %0 : tensor<3x4x6xf32>
@@ -203,6 +376,72 @@
 
 // -----
 
+func @batchMatMulV2SingleBatch(%arg0: tensor<1x4x5xf32>, %arg1: tensor<1x5x6xf32>) -> tensor<1x4x6xf32> {
+  %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<1x4x5xf32>, tensor<1x5x6xf32>) -> tensor<1x4x6xf32>
+  return %0 : tensor<1x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV2SingleBatch
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} : () -> tensor<2xi64>
+
+  // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%arg0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+  // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) {axis = 0 : i64} : (tensor<4x6xf32>) -> tensor<1x4x6xf32>
+  // CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32>
+}
+
+// -----
+
+func @batchMatMulV2UnbatchedLeft(%arg0: tensor<4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
+  %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
+  return %0 : tensor<3x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV2UnbatchedLeft
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+
+  // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
+  // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
+}
+
+// -----
+
+func @batchMatMulV2UnbatchedRight(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> {
+  %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
+  return %0 : tensor<3x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV2UnbatchedRight
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
+
+  // CHECK: %[[LHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
+  // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
+}
+
+// -----
+
 func @batchMatMulV2Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
   %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
   return %0 : tensor<4x6xf32>
@@ -213,8 +452,249 @@
 }
 
 // -----
+
+func @batchMatMulV2MatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf32>) -> tensor<4x6xf32> {
+  %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  return %0 : tensor<4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV2MatrixAdjXY
+  // CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[1, 0]> : tensor<2xi32>
+
+  // CHECK: %[[LHS_1:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<5x4xf32>, tensor<2xi32>) -> tensor<4x5xf32>
+  // CHECK: %[[RHS_1:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<6x5xf32>, tensor<2xi32>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
+}
+
+// -----
 // ==== V3 tests ====
 
+func @batchMatMulV3TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
+  %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
+  return %0 : tensor<2x3x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV3TwoDim
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+
+  // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
+  // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
+  // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+  // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
+  // CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
+  // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
+  // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
+}
+
+// -----
+
+func @batchMatMulV3TwoDimAdjXY(%arg0: tensor<2x3x5x4xf32>, %arg1: tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32> {
+  %0 = "tf.BatchMatMulV3"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<2x3x5x4xf32>, tensor<2x3x6x5xf32>) -> tensor<2x3x4x6xf32>
+  return %0 : tensor<2x3x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV3TwoDimAdjXY
+  // CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32>
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+
+  // CHECK: %[[LHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<2x3x5x4xf32>, tensor<4xi32>) -> tensor<2x3x4x5xf32>
+  // CHECK: %[[RHS_TRANSPOSED:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<2x3x6x5xf32>, tensor<4xi32>) -> tensor<2x3x5x6xf32>
+
+  // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%[[LHS_TRANSPOSED]], %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
+  // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
+  // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_4:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#3, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_5:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#4, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_6:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#5, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+  // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%[[RHS_TRANSPOSED]], %[[RHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
+  // CHECK: %[[RHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
+  // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_4:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#3, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
+  // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
+}
+
+// -----
+
+func @batchMatMulV3Broadcast(%arg0: tensor<2x1x4x5xf32>, %arg1: tensor<1x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
+  %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<2x1x4x5xf32>, tensor<1x3x5x6xf32>) -> tensor<2x3x4x6xf32>
+  return %0 : tensor<2x3x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV3Broadcast
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 4, 5]> : tensor<3xi64>}
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+
+  // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x1x4x5xf32>, tensor<3xi64>) -> tensor<2x4x5xf32>
+  // CHECK: %[[LHS_SPLIT:.*]]:2 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<2x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>)
+  // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+  // CHECK: %[[RHS_RESHAPED:.*]] = "tf.Reshape"(%arg1, %[[RHS_RESHAPED_SHAPE]]) : (tensor<1x3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
+  // CHECK: %[[RHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %[[RHS_RESHAPED]]) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
+  // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
+  // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
+}
+
+// -----
+
+func @batchMatMulV3OneDim(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
+  %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
+  return %0 : tensor<3x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV3OneDim
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+
+  // CHECK: %[[LHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
+  // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+  // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
+  // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
+}
+
+// -----
+
+func @batchMatMulV3SingleBatch(%arg0: tensor<1x4x5xf32>, %arg1: tensor<1x5x6xf32>) -> tensor<1x4x6xf32> {
+  %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<1x4x5xf32>, tensor<1x5x6xf32>) -> tensor<1x4x6xf32>
+  return %0 : tensor<1x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV3SingleBatch
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} : () -> tensor<2xi64>
+
+  // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%arg0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+  // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) {axis = 0 : i64} : (tensor<4x6xf32>) -> tensor<1x4x6xf32>
+  // CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32>
+}
+
+// -----
+
+func @batchMatMulV3UnbatchedLeft(%arg0: tensor<4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
+  %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
+  return %0 : tensor<3x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV3UnbatchedLeft
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+
+  // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
+  // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+  // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
+}
+
+// -----
+
+func @batchMatMulV3UnbatchedRight(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> {
+  %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
+  return %0 : tensor<3x4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV3UnbatchedRight
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
+
+  // CHECK: %[[LHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
+  // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+  // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
+}
+
+// -----
+
 func @batchMatMulV3Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
   %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
   return %0 : tensor<4x6xf32>
@@ -226,6 +706,23 @@
 
 // -----
 
+func @batchMatMulV3MatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf32>) -> tensor<4x6xf32> {
+  %0 = "tf.BatchMatMulV3"(%arg0, %arg1) {adj_x = true, adj_y = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  return %0 : tensor<4x6xf32>
+
+  // CHECK-LABEL: batchMatMulV3MatrixAdjXY
+  // CHECK-DAG: %[[PERMUTATION:.*]] = constant dense<[1, 0]> : tensor<2xi32>
+
+  // CHECK: %[[LHS_1:.*]] = "tf.Transpose"(%arg0, %[[PERMUTATION]]) : (tensor<5x4xf32>, tensor<2xi32>) -> tensor<4x5xf32>
+  // CHECK: %[[RHS_1:.*]] = "tf.Transpose"(%arg1, %[[PERMUTATION]]) : (tensor<6x5xf32>, tensor<2xi32>) -> tensor<5x6xf32>
+
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+  // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
+}
+
+// -----
+
 func @batchMatMulV3MatrixInt8(%arg0: tensor<4x5xi8>, %arg1: tensor<5x6xi8>) -> tensor<4x6xi32> {
   %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xi8>, tensor<5x6xi8>) -> tensor<4x6xi32>
   return %0 : tensor<4x6xi32>