Add the the kernel fusion pass to the pipeline and tests

The test is to make sure all the kernels, except softmax, are fused correctly.

PiperOrigin-RevId: 305521706
Change-Id: Id06148767779e943c04f48c5432a993e2a7517f6
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
index 2bc1568..74e81d7 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
@@ -77,6 +77,7 @@
         "quantize.h",
     ],
     deps = [
+        ":hlo_xla_quantization_passes",
         "//tensorflow/compiler/mlir/xla:hlo",
         "//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
         "//tensorflow/compiler/tf2xla",
@@ -87,8 +88,10 @@
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/compiler/xla/client:xla_computation",
         "//tensorflow/core/platform:status",
+        "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:QuantOps",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Transforms",
     ],
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td
index 6924001..b78338c 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td
@@ -39,6 +39,11 @@
           (Fused2Ops<"generic.mul_add"> $mul, $add),
           [(NeedsToBeFused $add)]>;
 
+// add
+def : Pat<(HLO_AddOp:$add $_, $_, $_),
+          (Fused1Ops<"generic.add"> $add),
+          [(NeedsToBeFused $add)]>;
+
 // reduce_window: maxpool, avgpool
 def : Pat<(HLO_ReduceWindowOp:$reduce $_, $_, $_, $_, $_, $_, $_),
           (Fused1Ops<"generic.reduce_window"> $reduce),
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/passes.h b/tensorflow/compiler/mlir/lite/quantization/xla/passes.h
index ba0c76a..857efdb 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/passes.h
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/passes.h
@@ -31,6 +31,9 @@
 // Rewrite the graph and quantize the constant.
 std::unique_ptr<OperationPass<FuncOp>> CreateMaterializeToXlaPass();
 
+// Fuse HLO ops into quantized regions.
+std::unique_ptr<OperationPass<FuncOp>> CreateCpuKernelFusionPass();
+
 }  // namespace xla_hlo
 }  // namespace mlir
 
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc
index 9df41bb..223a55d 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc
@@ -14,6 +14,7 @@
 ==============================================================================*/
 #include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
 
+#include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/Function.h"  // from @llvm-project
@@ -22,6 +23,7 @@
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/quantization/xla/passes.h"
 #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
 #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 #include "tensorflow/compiler/tf2xla/tf2xla.h"
@@ -34,6 +36,7 @@
   static bool init_once = []() {
     mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
     mlir::registerDialect<mlir::StandardOpsDialect>();
+    mlir::registerDialect<mlir::quant::QuantizationDialect>();
     return true;
   }();
   (void)init_once;
@@ -60,12 +63,13 @@
   pm.addPass(createInlinerPass());
   pm.addPass(createSymbolDCEPass());
   pm.addNestedPass<FuncOp>(createCSEPass());
+  pm.addNestedPass<FuncOp>(CreateCpuKernelFusionPass());
 
   mlir::StatusScopedDiagnosticHandler diag_handler(&context);
   LogicalResult result = pm.run(module.get());
   (void)result;
 
-  module->dump();
+  module->walk([&](quant::QuantizeRegionOp op) { op.dump(); });
 
   return tensorflow::Status::OK();
 }
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir
index d3e8b48..6184c2f 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir
@@ -1,15 +1,10 @@
 # RUN: not tfcompile --graph=%s.pbtxt --config=%s.config.pbtxt --experimental_quantize --cpp_class="::test::fadd_quant"  2>&1 | FileCheck %s -dump-input-on-failure
 
 # TODO(fengliuai): update this file with the progress of the implementation
-// CHECK: func @main
-// CHECK: %cst = constant dense<0.000000e+00> : tensor<f32>
-// CHECK: %cst_0 = constant dense<1.270000e+02> : tensor<f32>
-// CHECK: %cst_1 = constant dense<8> : tensor<i32>
-// CHECK: %cst_2 = constant dense<false> : tensor<i1>
-// CHECK: %0 = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.9"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
-// CHECK: %1 = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.14"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
-// CHECK: %2 = xla_hlo.add %0, %1 {name = "add.15"} : tensor<2x4xf32>
-// CHECK: %3 = "xla_hlo.custom_call"(%2, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.20"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
-// CHECK: %4 = "xla_hlo.tuple"(%3) {name = "tuple.22"} : (tensor<2x4xf32>) -> tuple<tensor<2x4xf32>>
-// CHECK: return %4 : tuple<tensor<2x4xf32>>
-// CHECK: }
+
+// CHECK: "quant.region"
+// CHECK: ^bb0(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>):	// no predecessors
+// CHECK:   xla_hlo.add %arg0, %arg1
+// CHECK:   "quant.return"
+// CHECK: }) {input_specs = [!quant.uniform<i8:f32, 0.49803921568627452:-128>, !quant.uniform<i8:f32, 0.49803921568627452:-128>],
+// CHECK-SAME: logical_kernel = "generic.add", output_specs = [!quant.uniform<i8:f32, 0.49803921568627452:-128>]}