Relax TFL_BatchMatMulOp input constraints for hybrid quantization.
PiperOrigin-RevId: 360471859
Change-Id: I23c5b452869f772ad61a166253d1f6d4b4a508cc
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 507752c..1fd70a2 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -1003,9 +1003,7 @@
TFL_OperandHasAtleastRank<0, 2>,
TFL_OperandHasAtleastRank<1, 2>,
PredOpTrait<"x and output must have same element type",
- TFL_TCresVTEtIsSameAsOp<0, 0>>,
- PredOpTrait<"y and output must have same element type",
- TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
+ TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
let summary = "Batch Matrix Multiply Operator";
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index fd33655..2b1c318 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -1387,6 +1387,15 @@
%0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x384x32x!quant.uniform<i8:f32, 0.06:-2>>, tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384x!quant.uniform<i8:f32, 1.02:-73>>
return %0 : tensor<1x4x384x384x!quant.uniform<i8:f32, 1.02:-73>>
}
+
+// -----
+
+func @testBatchMatmulHybridQuant(%arg0 : tensor<1x4x384x32xf32>, %arg1 : tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384xf32> {
+ // CHECK: "tfl.batch_matmul"(%arg0, %arg1)
+ %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384xf32>
+ return %0 : tensor<1x4x384x384xf32>
+}
+
// -----
func @testConcat(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<2x2xi32> {