Add the the kernel fusion pass to the pipeline and tests
The test is to make sure all the kernels, except softmax, are fused correctly.
PiperOrigin-RevId: 305521706
Change-Id: Id06148767779e943c04f48c5432a993e2a7517f6
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
index 2bc1568..74e81d7 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
@@ -77,6 +77,7 @@
"quantize.h",
],
deps = [
+ ":hlo_xla_quantization_passes",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
"//tensorflow/compiler/tf2xla",
@@ -87,8 +88,10 @@
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core/platform:status",
+ "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td
index 6924001..b78338c 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td
@@ -39,6 +39,11 @@
(Fused2Ops<"generic.mul_add"> $mul, $add),
[(NeedsToBeFused $add)]>;
+// add
+def : Pat<(HLO_AddOp:$add $_, $_, $_),
+ (Fused1Ops<"generic.add"> $add),
+ [(NeedsToBeFused $add)]>;
+
// reduce_window: maxpool, avgpool
def : Pat<(HLO_ReduceWindowOp:$reduce $_, $_, $_, $_, $_, $_, $_),
(Fused1Ops<"generic.reduce_window"> $reduce),
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/passes.h b/tensorflow/compiler/mlir/lite/quantization/xla/passes.h
index ba0c76a..857efdb 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/passes.h
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/passes.h
@@ -31,6 +31,9 @@
// Rewrite the graph and quantize the constant.
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializeToXlaPass();
+// Fuse HLO ops into quantized regions.
+std::unique_ptr<OperationPass<FuncOp>> CreateCpuKernelFusionPass();
+
} // namespace xla_hlo
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc
index 9df41bb..223a55d 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
+#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
@@ -22,6 +23,7 @@
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/quantization/xla/passes.h"
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
@@ -34,6 +36,7 @@
static bool init_once = []() {
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
+ mlir::registerDialect<mlir::quant::QuantizationDialect>();
return true;
}();
(void)init_once;
@@ -60,12 +63,13 @@
pm.addPass(createInlinerPass());
pm.addPass(createSymbolDCEPass());
pm.addNestedPass<FuncOp>(createCSEPass());
+ pm.addNestedPass<FuncOp>(CreateCpuKernelFusionPass());
mlir::StatusScopedDiagnosticHandler diag_handler(&context);
LogicalResult result = pm.run(module.get());
(void)result;
- module->dump();
+ module->walk([&](quant::QuantizeRegionOp op) { op.dump(); });
return tensorflow::Status::OK();
}
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir
index d3e8b48..6184c2f 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir
@@ -1,15 +1,10 @@
# RUN: not tfcompile --graph=%s.pbtxt --config=%s.config.pbtxt --experimental_quantize --cpp_class="::test::fadd_quant" 2>&1 | FileCheck %s -dump-input-on-failure
# TODO(fengliuai): update this file with the progress of the implementation
-// CHECK: func @main
-// CHECK: %cst = constant dense<0.000000e+00> : tensor<f32>
-// CHECK: %cst_0 = constant dense<1.270000e+02> : tensor<f32>
-// CHECK: %cst_1 = constant dense<8> : tensor<i32>
-// CHECK: %cst_2 = constant dense<false> : tensor<i1>
-// CHECK: %0 = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.9"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
-// CHECK: %1 = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.14"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
-// CHECK: %2 = xla_hlo.add %0, %1 {name = "add.15"} : tensor<2x4xf32>
-// CHECK: %3 = "xla_hlo.custom_call"(%2, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.20"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
-// CHECK: %4 = "xla_hlo.tuple"(%3) {name = "tuple.22"} : (tensor<2x4xf32>) -> tuple<tensor<2x4xf32>>
-// CHECK: return %4 : tuple<tensor<2x4xf32>>
-// CHECK: }
+
+// CHECK: "quant.region"
+// CHECK: ^bb0(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>): // no predecessors
+// CHECK: xla_hlo.add %arg0, %arg1
+// CHECK: "quant.return"
+// CHECK: }) {input_specs = [!quant.uniform<i8:f32, 0.49803921568627452:-128>, !quant.uniform<i8:f32, 0.49803921568627452:-128>],
+// CHECK-SAME: logical_kernel = "generic.add", output_specs = [!quant.uniform<i8:f32, 0.49803921568627452:-128>]}