BatchMatMul conversion implemented
PiperOrigin-RevId: 267505822
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 9974051..225fd39 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -198,9 +198,11 @@
"transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc",
"transforms/trim_functions_tf.cc",
+ "transforms/unroll_batch_matmul.cc",
],
hdrs = [
"transforms/passes.h",
+ "transforms/unroll_batch_matmul.h",
],
deps = [
":common",
diff --git a/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir
new file mode 100644
index 0000000..09f1dfc
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir
@@ -0,0 +1,223 @@
+// RUN: tf-opt -tfl-unroll-batch-matmul %s | FileCheck %s
+
+func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
+ %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
+ return %0 : tensor<2x3x4x6xf32>
+
+ // CHECK-LABEL: batchMatMulV2TwoDim
+ // CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
+ // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
+ // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
+ // CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
+ // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
+ // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
+ // CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
+ // CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
+
+ // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
+ // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+ // CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
+ // CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+ // CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+ // CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+ // CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
+
+ // CHECK: return %[[v33]] : tensor<2x3x4x6xf32>
+}
+
+func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
+ %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
+ return %0 : tensor<3x4x6xf32>
+
+ // CHECK-LABEL: batchMatMulV2FlatInput
+ // CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
+ // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
+ // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
+ // CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
+ // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
+ // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
+ // CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
+ // CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
+
+ // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
+ // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+ // CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
+ // CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+ // CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+ // CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+ // CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32>
+
+ // CHECK: return %[[v18]] : tensor<3x4x6xf32>
+}
+
+func @batchMatMulV2Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
+ %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ return %0 : tensor<4x6xf32>
+
+ // CHECK-LABEL: batchMatMulV2Matrix
+ // CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: return %[[v0]] : tensor<4x6xf32>
+}
+
+func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
+ %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
+ return %0 : tensor<2x3x4x6xf32>
+
+ // CHECK-LABEL: batchMatMulTwoDim
+ // CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
+ // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
+ // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
+ // CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
+ // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
+ // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
+ // CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
+ // CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
+
+ // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
+ // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+ // CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
+ // CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+ // CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+ // CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+ // CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
+
+ // CHECK: return %[[v33]] : tensor<2x3x4x6xf32>
+}
+
+func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
+ %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
+ return %0 : tensor<3x4x6xf32>
+
+ // CHECK-LABEL: batchMatMulFlatInput
+ // CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
+ // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
+ // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
+ // CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
+ // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
+ // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
+ // CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
+ // CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
+ // CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
+
+ // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
+ // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+ // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
+ // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
+
+ // CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
+ // CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+ // CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
+ // CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
+
+ // CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+
+ // CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+ // CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32>
+
+ // CHECK: return %[[v18]] : tensor<3x4x6xf32>
+}
+
+func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
+ %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ return %0 : tensor<4x6xf32>
+
+ // CHECK-LABEL: batchMatMulMatrix
+ // CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+ // CHECK: return %[[v0]] : tensor<4x6xf32>
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 7c7983a..102887d 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -50,6 +50,7 @@
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
+#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@@ -377,6 +378,11 @@
void PrepareTFPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
+
+ patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
+ ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
+ applyPatternsGreedily(func, patterns);
+
// This pattern was intented to uses TFL QDQs to preserve the quantization
// parameters from the TF Quant ops, thus this pattern should run with the
// first `applyPatternsGreedily` method, which would otherwise removes the
diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
new file mode 100644
index 0000000..1fde6ac
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
@@ -0,0 +1,328 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This transformation pass prepares for legalization to the TFLite dialect by
+// converting operations in TensorFlow dialect into operations that can be
+// legalized to TensorFlow Lite dialect with simple replacements. The newly
+// created operations are in the TensorFlow dialect if the operation can be
+// represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op is
+// used. For example, Conv2D in TFLite which uses OHWI data format for filters
+// is not supported in TensorFlow because TensorFlow requires filters in the
+// HWIO data format.
+//
+// Motivation to prepare for the TFLite legalization before the actual
+// legalization is to exploit constant folding opportunities in any newly
+// created ops by leveraging constant folding support for the TensorFlow ops.
+// This way TFLite can be used as a serialization format only and does not
+// require access to the TFLite runtime for optimizations as required by the
+// TFLite team.
+
+#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
+
+#include <climits>
+#include <cstdint>
+
+#include "absl/memory/memory.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
+#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
+#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
+#include "mlir/IR/Attributes.h" // TF:local_config_mlir
+#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
+#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
+#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
+#include "mlir/Pass/Pass.h" // TF:local_config_mlir
+#include "mlir/Support/Functional.h" // TF:local_config_mlir
+#include "mlir/Support/LLVM.h" // TF:local_config_mlir
+#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
+#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
+#include "tensorflow/compiler/mlir/lite/utils/validators.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/core/util/matmul_bcast.h"
+
+namespace mlir {
+namespace TFL {
+
+namespace {
+// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
+// of the inputs, matmul them individually, then stack them all back together at
+// the end.
+struct UnrollBatchMatMulPass : public FunctionPass<UnrollBatchMatMulPass> {
+ void runOnFunction() override;
+};
+
+void UnrollBatchMatMulPass::runOnFunction() {
+ OwningRewritePatternList patterns;
+ auto func = getFunction();
+
+ patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
+ ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
+ applyPatternsGreedily(func, patterns);
+}
+
+} // namespace
+
+template <typename BatchMatMulOpType>
+TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
+ Value* value, ArrayRef<int64_t> shape, Type elementType, Location loc,
+ PatternRewriter& rewriter) {
+ int64_t shape_rank = shape.size();
+ auto shapeSpecType =
+ rewriter.getTensorType({shape_rank}, rewriter.getIntegerType(64));
+ Type resultType = rewriter.getTensorType(shape, elementType);
+ auto constant_attr = DenseElementsAttr::get(shapeSpecType, shape);
+ auto shapeTensor =
+ rewriter.create<ConstantOp>(loc, shapeSpecType, constant_attr);
+ return rewriter.create<TF::ReshapeOp>(loc, resultType, /* tensor = */ value,
+ /* shape = */ shapeTensor);
+}
+
+template <typename BatchMatMulOpType>
+std::vector<Value*> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
+ Value* value, int batch_size, Location loc, PatternRewriter& rewriter) {
+ RankedTensorType tensorType = value->getType().cast<RankedTensorType>();
+ Type elementType = tensorType.getElementType();
+
+ int rank = tensorType.getShape().size();
+ int num_rows = tensorType.getShape()[rank - 2];
+ int num_cols = tensorType.getShape()[rank - 1];
+
+ // Reshape to rank-3 Tensor with first dimension as the batch size.
+ auto reshapeOp = createReshapeOp(value, {batch_size, num_rows, num_cols},
+ elementType, loc, rewriter);
+
+ SmallVector<int64_t, 3> sliceSize = {1, num_rows, num_cols};
+
+ std::vector<Value*> sliced;
+ Type int64Type = rewriter.getIntegerType(64);
+ Type sliceResultType = rewriter.getTensorType(sliceSize, elementType);
+
+ // Slice along each batch index and remember the slice output for future
+ // use.
+ for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
+ auto vector3Type = rewriter.getTensorType({3}, int64Type);
+
+ auto begin_attr =
+ DenseElementsAttr::get<int64_t>(vector3Type, {batch_idx, 0, 0});
+ auto size_attr = DenseElementsAttr::get<int64_t>(vector3Type, sliceSize);
+ auto sliceOp = rewriter.create<TF::SliceOp>(
+ loc, sliceResultType,
+ /* input = */ reshapeOp.output(),
+ /* begin = */
+ rewriter.create<ConstantOp>(loc, vector3Type, begin_attr),
+ /* size = */
+ rewriter.create<ConstantOp>(loc, vector3Type, size_attr));
+
+ // Squeeze matrix, i.e. reshape [1, num_rows, num_cols] -> [num_rows,
+ // num_cols]
+ auto squeezeOp = createReshapeOp(sliceOp.output(), {num_rows, num_cols},
+ elementType, loc, rewriter);
+
+ sliced.emplace_back(squeezeOp.output());
+ }
+ return sliced;
+}
+
+template <typename BatchMatMulOpType>
+TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
+ Value* value, Location loc, PatternRewriter& rewriter) {
+ auto valueType = value->getType().cast<RankedTensorType>();
+ auto shape = valueType.getShape();
+ int dims = shape.size();
+
+ std::vector<int32_t> perm(dims);
+ for (int i = 0; i < dims - 2; i++) {
+ perm[i] = i;
+ }
+ perm[dims - 2] = dims - 1;
+ perm[dims - 1] = dims - 2;
+
+ auto perm_type = rewriter.getTensorType({static_cast<int32_t>(perm.size())},
+ rewriter.getIntegerType(32));
+
+ auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm));
+ auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
+
+ std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
+ int64_t r = transposed_shape[dims - 1];
+ int64_t c = transposed_shape[dims - 2];
+
+ transposed_shape[dims - 1] = c;
+ transposed_shape[dims - 2] = r;
+
+ auto transposed_type =
+ rewriter.getTensorType(transposed_shape, valueType.getElementType());
+ return rewriter.create<TF::TransposeOp>(loc, transposed_type, value, perm_op);
+}
+
+template <typename BatchMatMulOpType>
+TF::PackOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createMatMulOps(
+ const std::vector<Value*>& sliced_lhs,
+ const std::vector<Value*>& sliced_rhs, const tensorflow::MatMulBCast& bcast,
+ int rows, int cols, Type elementType, Location loc,
+ PatternRewriter& rewriter) {
+ auto matmulType = rewriter.getTensorType({rows, cols}, elementType);
+
+ std::vector<Value*> matmuls;
+ for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
+ int lhs_batch_idx, rhs_batch_idx;
+ if (bcast.IsBroadcastingRequired()) {
+ lhs_batch_idx = bcast.x_batch_indices()[batch_idx];
+ rhs_batch_idx = bcast.y_batch_indices()[batch_idx];
+ } else {
+ lhs_batch_idx = batch_idx;
+ rhs_batch_idx = batch_idx;
+ }
+ auto matmul = rewriter.create<TF::MatMulOp>(
+ loc, matmulType,
+ /* a = */ sliced_lhs[lhs_batch_idx],
+ /* b = */ sliced_rhs[rhs_batch_idx],
+ /* transpose_a = */ rewriter.getBoolAttr(false),
+ /* transpose_b = */ rewriter.getBoolAttr(false));
+ matmuls.emplace_back(matmul.product());
+ }
+
+ // Combine the result of each individual MatMul into a rank-3 Tensor.
+ Type packedType = rewriter.getTensorType(
+ {bcast.output_batch_size(), rows, cols}, elementType);
+
+ return rewriter.create<TF::PackOp>(
+ loc, packedType,
+ /* values = */ matmuls,
+ /* N = */ rewriter.getI64IntegerAttr(matmuls.size()),
+ /* axis = */ rewriter.getI64IntegerAttr(0));
+}
+
+template <typename BatchMatMulOpType>
+PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
+ BatchMatMulOpType op, PatternRewriter& rewriter) const {
+ Value* input_lhs = op.x();
+ Value* input_rhs = op.y();
+
+ if (!input_lhs->getType().isa<RankedTensorType>()) {
+ // LHS must be a ranked tensor type
+ return this->matchFailure();
+ }
+ if (!input_rhs->getType().isa<RankedTensorType>()) {
+ // RHS must be a ranked tensor type
+ return this->matchFailure();
+ }
+
+ auto lhs_type = input_lhs->getType().cast<RankedTensorType>();
+ auto rhs_type = input_rhs->getType().cast<RankedTensorType>();
+
+ auto elementType = lhs_type.getElementType();
+
+ if (elementType != rhs_type.getElementType()) {
+ // The element type of LHS must be the same with element type of RHS
+ return this->matchFailure();
+ }
+
+ auto lhs_shape = lhs_type.getShape();
+ auto rhs_shape = rhs_type.getShape();
+
+ Location loc = op.getLoc();
+
+ // Transpose LHS input if necessary.
+ if (op.adj_x()) {
+ input_lhs = createTransposeOp(input_lhs, loc, rewriter);
+
+ lhs_type = input_lhs->getType().cast<RankedTensorType>();
+ lhs_shape = lhs_type.getShape();
+ }
+
+ // Transpose RHS input if necessary.
+ if (op.adj_y()) {
+ input_rhs = createTransposeOp(input_rhs, loc, rewriter);
+
+ rhs_type = input_rhs->getType().cast<RankedTensorType>();
+ rhs_shape = rhs_type.getShape();
+ }
+
+ // Ensure that input ranks are at least 2 and batch shapes are
+ // broadcastable.
+ const int dims_a = lhs_shape.size();
+ const int dims_b = rhs_shape.size();
+ if (dims_a < 2 || dims_b < 2) {
+ // Both inputs must have rank >= 2
+ return this->matchFailure();
+ }
+
+ if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
+ // Input dimensions must be compatible for multipication.
+ return this->matchFailure();
+ }
+
+ if (dims_a == 2 && dims_b == 2) {
+ // When both inputs are matrices, just replace the op to a matmul op.
+ Type resultType =
+ rewriter.getTensorType({lhs_shape[0], rhs_shape[1]}, elementType);
+ rewriter.replaceOpWithNewOp<TF::MatMulOp>(
+ op, resultType,
+ /* a = */ input_lhs,
+ /* b = */ input_rhs,
+ /* transpose_a = */ rewriter.getBoolAttr(false),
+ /* transpose_b = */ rewriter.getBoolAttr(false));
+ return this->matchSuccess();
+ }
+
+ tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>(
+ lhs_shape.begin(), lhs_shape.end()),
+ absl::InlinedVector<tensorflow::int64, 4>(
+ rhs_shape.begin(), rhs_shape.end()));
+
+ if (!bcast.IsValid()) {
+ // Input batch dimensions must be broadcastable
+ return this->matchFailure();
+ }
+
+ // Compute slices for each batch in the LHS and RHS.
+ std::vector<Value*> sliced_lhs =
+ sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
+ std::vector<Value*> sliced_rhs =
+ sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
+
+ // Compute (single batch) MatMul for each output batch. The MatMul outputs
+ // are then packed together into one output Tensor.
+ auto packOp =
+ createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2],
+ rhs_shape[dims_b - 1], elementType, loc, rewriter);
+
+ // Reshape the rank-3 Tensor into the correct output shape.
+ const auto& resultBatchShape = bcast.output_batch_shape().dim_sizes();
+ std::vector<int64_t> resultShape(resultBatchShape.begin(),
+ resultBatchShape.end());
+ resultShape.push_back(lhs_shape[dims_a - 2]);
+ resultShape.push_back(rhs_shape[dims_b - 1]);
+
+ auto reshapeOp =
+ createReshapeOp(packOp.output(), resultShape, elementType, loc, rewriter);
+ rewriter.replaceOp(op, reshapeOp.output());
+ return this->matchSuccess();
+}
+
+static PassRegistration<UnrollBatchMatMulPass> pass(
+ "tfl-unroll-batch-matmul",
+ "Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
+
+} // namespace TFL
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h
new file mode 100644
index 0000000..d4b46ea
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h
@@ -0,0 +1,60 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
+
+#include "llvm/ADT/ArrayRef.h"
+#include "mlir/IR/Location.h" // TF:local_config_mlir
+#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
+#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/core/util/matmul_bcast.h"
+
+namespace mlir {
+namespace TFL {
+
+// Unroll tf.BatchMatMulV2 op into a sequence of TF ops. Since TFLite does not
+// support BatchMatMul operation, it unrolls a BatchMatMul op into tf.Reshape,
+// tf.Slice, tf.MatMul, tf.Pack, and tf.Reshape ops.
+template <typename BatchMatMulOpType>
+class ConvertTFBatchMatMulOp : public OpRewritePattern<BatchMatMulOpType> {
+ using OpRewritePattern<BatchMatMulOpType>::OpRewritePattern;
+
+ static TF::ReshapeOp createReshapeOp(Value* value, ArrayRef<int64_t> shape,
+ Type elementType, Location loc,
+ PatternRewriter& rewriter);
+
+ static std::vector<Value*> sliceInput(Value* value, int batch_size,
+ Location loc,
+ PatternRewriter& rewriter);
+
+ static TF::TransposeOp createTransposeOp(Value* value, Location loc,
+ PatternRewriter& rewriter);
+
+ static TF::PackOp createMatMulOps(const std::vector<Value*>& sliced_lhs,
+ const std::vector<Value*>& sliced_rhs,
+ const tensorflow::MatMulBCast& bcast,
+ int rows, int cols, Type elementType,
+ Location loc, PatternRewriter& rewriter);
+
+ PatternMatchResult matchAndRewrite(BatchMatMulOpType op,
+ PatternRewriter& rewriter) const override;
+};
+
+} // namespace TFL
+} // namespace mlir
+
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 153ac53..8facd95 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -261,6 +261,88 @@
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
+def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect]> {
+ let summary = "Multiplies slices of two tensors in batches.";
+
+ let description = [{
+Multiplies all slices of `Tensor` `x` and `y` (each slice can be
+viewed as an element of a batch), and arranges the individual results
+in a single output tensor of the same batch size. Each of the
+individual slices can optionally be adjointed (to adjoint a matrix
+means to transpose and conjugate it) before multiplication by setting
+the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
+
+The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
+and `[..., r_y, c_y]`.
+
+The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
+
+ r_o = c_x if adj_x else r_x
+ c_o = r_y if adj_y else c_y
+
+It is computed as:
+
+ output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
+ }];
+
+ let arguments = (ins
+ TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x,
+ TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y,
+
+ DefaultValuedAttr<BoolAttr, "false">:$adj_x,
+ DefaultValuedAttr<BoolAttr, "false">:$adj_y
+ );
+
+ let results = (outs
+ TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output
+ );
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
+def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect]> {
+ let summary = "Multiplies slices of two tensors in batches.";
+
+ let description = [{
+Multiplies all slices of `Tensor` `x` and `y` (each slice can be
+viewed as an element of a batch), and arranges the individual results
+in a single output tensor of the same batch size. Each of the
+individual slices can optionally be adjointed (to adjoint a matrix
+means to transpose and conjugate it) before multiplication by setting
+the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
+
+The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
+and `[..., r_y, c_y]`.
+
+The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
+
+ r_o = c_x if adj_x else r_x
+ c_o = r_y if adj_y else c_y
+
+It is computed as:
+
+ output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
+
+*NOTE*: `BatchMatMulV2` supports broadcasting in the batch dimensions. More
+about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
+ }];
+
+ let arguments = (ins
+ TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x,
+ TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y,
+
+ DefaultValuedAttr<BoolAttr, "false">:$adj_x,
+ DefaultValuedAttr<BoolAttr, "false">:$adj_y
+ );
+
+ let results = (outs
+ TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output
+ );
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> {
let summary = "BatchToSpace for N-D tensors of type T.";