Attach internal states to fused ops

This patch added a pass to attach internal states to the fused ops as regions.
The connection of these internal states is by existing tflite ops (except the
terminator), and the quantization recipe is specified by qaunt.any type.

This spec of internal states is designed carefully, so

- all the logging point in the post-training quantization can find corresponding internal tensors;
- all the intermediate tensors for kernle execution can find corresponding internal tensors;
- the logging kernel should match this spec, ideally, should be generated from it.

Currently, we are manually keeping this spec and logging kernel in sync.

PiperOrigin-RevId: 268253490
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 356f56b..c4a3275 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -286,6 +286,7 @@
     name = "tensorflow_lite_quantize",
     srcs = [
         "transforms/generated_quantize.inc",
+        "transforms/load_quantization_recipe.cc",
         "transforms/post_quantize.cc",
         "transforms/prepare_quantize.cc",
         "transforms/quantize.cc",
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index cc2238f..36ff87e 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -355,7 +355,7 @@
 
 // TODO(haoliang): Implement legalization pass after pattern rewrite generator
 // supports variadic inputs.
-def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> {
+def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResultsScale]> {
   let summary = "add_n operator";
 
   let description = [{
@@ -363,11 +363,11 @@
   }];
 
   let arguments = (ins
-    Variadic<TensorOf<[F32, I32]>>:$inputs
+    Variadic<TensorOf<[F32, I32, QI16, QUI16]>>:$inputs
   );
 
   let results = (outs
-    TensorOf<[F32, I32]>:$sum
+    TensorOf<[F32, I32, QI16, QUI16]>:$sum
   );
 }
 
@@ -1138,11 +1138,11 @@
   }];
 
   let arguments = (ins
-    TensorOf<[F32, QUI8, QI8, I8]>:$input,
+    TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$input,
     TFL_AFAttr:$fused_activation_function
   );
 
-  let results = (outs TensorOf<[F32, QUI8, QI8, I8]>:$output);
+  let results = (outs TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$output);
 
   let hasOptions = 1;
 
@@ -1251,9 +1251,9 @@
     Computes element-wise Sigmoid of input
   }];
 
-  let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8]>:$x);
+  let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$x);
 
-  let results = (outs TensorOf<[AnyFloat, QI8, QUI8]>:$y);
+  let results = (outs TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$y);
 }
 
 def TFL_LogOp: TFL_Op<"log", [NoSideEffect, SameOperandsAndResultType]> {
@@ -2092,9 +2092,9 @@
     Computes element-wise Hyperbolic tangent of input
   }];
 
-  let arguments = (ins TensorOf<[F32, I16, I8, QI8, QUI8, TFL_Uint8]>:$x);
+  let arguments = (ins TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x);
 
-  let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, TFL_Uint8]>:$y);
+  let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
 }
 
 def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
@@ -2780,6 +2780,10 @@
 
   let results = (outs AnyTensor:$output);
 
+  // TODO(fengliuai): customize printer and parser to not display
+  // empty region.
+  let regions = (region AnyRegion:$internal);
+
   let hasOptions = 1;
 
   let verifier = [{ return Verify(*this); }];
diff --git a/tensorflow/compiler/mlir/lite/tests/load-quantization-recipe.mlir b/tensorflow/compiler/mlir/lite/tests/load-quantization-recipe.mlir
new file mode 100644
index 0000000..5c53d5e
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/load-quantization-recipe.mlir
@@ -0,0 +1,107 @@
+// RUN: tf-opt -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
+
+// CHECK-LABEL: testLstm
+func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>, %arg5: tensor<?xf32>, %arg6: tensor<?xf32>, %arg7: tensor<?xf32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
+  %0 = "tfl.lstm"(%arg0, // input
+    %arg1, %arg2, %arg3, %arg4, // weights
+    %arg5, %arg6, %arg7, %arg8, // recurrent weights
+    %arg9, %arg10, %arg11, // cell weights
+    %arg12, %arg13, %arg14, %arg15, // bias
+    %arg16, %arg17, // projection weight and bias
+    %arg18, %arg19, // stateful
+    %arg20, %arg21, %arg22, %arg23 // layer norm coefficients
+    ) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<? xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+
+// CHECK-NEXT:  "tfl.lstm"
+// CHECK-NEXT:  %[[cst:.*]] = constant unit
+
+// input gate
+// CHECK-NEXT:  %[[in1:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[in2:.*]] = "tfl.fully_connected"(%arg18, %arg5, %[[cst]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[in3:.*]] = "tfl.mul"(%arg19, %arg9)
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[in4:.*]] = "tfl.add_n"(%[[in1]], %[[in2]], %[[in3]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[in5:.*]] = "tfl.l2_normalization"(%[[in4]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[in6:.*]] = tfl.add %[[in4]], %[[in5]]
+// CHECK-SAME:    tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[in7:.*]] = "tfl.fully_connected"(%[[in6]], %arg20, %arg12)
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[in8:.*]] = "tfl.logistic"(%[[in7]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+
+// forget gate
+// CHECK-NEXT:  %[[fo1:.*]] = "tfl.fully_connected"(%arg0, %arg2, %[[cst]])
+// CHECK-SAME:    tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[fo2:.*]] = "tfl.fully_connected"(%arg18, %arg6, %[[cst]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[fo3:.*]] = "tfl.mul"(%arg19, %arg10)
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[fo4:.*]] = "tfl.add_n"(%[[fo1]], %[[fo2]], %[[fo3]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[fo5:.*]] = "tfl.l2_normalization"(%[[fo4]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[fo6:.*]] = tfl.add %[[fo4]], %[[fo5]]
+// CHECK-SAME:    tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[fo7:.*]] = "tfl.fully_connected"(%[[fo6]], %arg21, %arg13)
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[fo8:.*]] = "tfl.logistic"(%[[fo7]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+
+// cell gate
+// CHECK-NEXT:  %[[ce1:.*]] = "tfl.fully_connected"(%arg0, %arg3, %[[cst]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ce2:.*]] = "tfl.fully_connected"(%arg18, %arg7, %[[cst]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ce3:.*]] = "tfl.add_n"(%[[ce1]], %[[ce2]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ce4:.*]] = "tfl.l2_normalization"(%[[ce3]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ce5:.*]] = tfl.add %[[ce3]], %[[ce4]]
+// CHECK-SAME:    tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ce6:.*]] = "tfl.fully_connected"(%[[ce5]], %arg22, %arg14)
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ce7:.*]] = "tfl.tanh"(%[[ce6]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+
+// CHECK-NEXT:  %[[ac1:.*]] = "tfl.mul"(%[[fo8]], %arg19)
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ac2:.*]] = tfl.mul %[[in8]], %[[ce7]]
+// CHECK-SAME:    tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ac3:.*]] = tfl.add %[[ac1]], %[[ac2]]
+// CHECK-SAME:    tensor<?x!quant.any<i16:f32>>
+
+// output gate
+// CHECK-NEXT:  %[[ou1:.*]] = "tfl.fully_connected"(%arg0, %arg4, %[[cst]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ou2:.*]] = "tfl.fully_connected"(%arg18, %arg8, %[[cst]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ou3:.*]] = "tfl.mul"(%[[ac3]], %arg11)
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ou4:.*]] = "tfl.add_n"(%[[ou1]], %[[ou2]], %[[ou3]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ou5:.*]] = "tfl.l2_normalization"(%[[ou4]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ou6:.*]] = tfl.add %[[ou4]], %[[ou5]]
+// CHECK-SAME:    tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ou7:.*]] = "tfl.fully_connected"(%[[ou6]], %arg23, %arg15)
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ou8:.*]] = "tfl.logistic"(%[[ou7]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+
+// output activation
+// CHECK-NEXT:  %[[ac4:.*]] = "tfl.tanh"(%[[ac3]])
+// CHECK-SAME:    -> tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ac5:.*]] = tfl.mul %[[ac4]], %[[ou8]]
+// CHECK-SAME:    tensor<?x!quant.any<i16:f32>>
+// CHECK-NEXT:  %[[ac6:.*]] = "tfl.fully_connected"(%[[ac5]], %arg16, %arg17)
+// CHECK-SAME:    (tensor<?x!quant.any<i16:f32>>, tensor<?xf32>, tensor<?xf32>) -> tensor<?x!quant.any<i8:f32>>
+// CHECK-NEXT:  %[[ac7:.*]] = "tf_quant.pseudo_return"(%[[ac6]]) : (tensor<?x!quant.any<i8:f32>>) -> tensor<?x!quant.any<i8:f32>>
+// CHECK-NEXT:  })
+// CHECK-NEXT:  return
+
+  return %0 : tensor<?xf32>
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir
index ddb122f..23976db 100644
--- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir
@@ -278,6 +278,6 @@
   %21 = "tfl.pseudo_input" (%arg21) : (tensor<4 x f32>) -> tensor<4 x f32>
   %22 = "tfl.pseudo_input" (%arg22) : (tensor<4 x f32>) -> tensor<4 x f32>
   %23 = "tfl.pseudo_input" (%arg23) : (tensor<4 x f32>) -> tensor<4 x f32>
-  %24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   return %24 : tensor<4xf32>
 }
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index 9906dc7..16054a5 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -103,7 +103,7 @@
 // test invalid AddN
 func @testAddNWrongOperandResultType(tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16> {
 ^bb0(%arg0: tensor<? x f16>, %arg1: tensor<? x f16>, %arg2: tensor<? x f16>):
-  // expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer values}}
+  // expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer or QI16 type or QUI16 type values}}
   %0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16>
   return %0 : tensor<? x f16>
 }
@@ -537,7 +537,7 @@
 // test invalid Logistic input
 func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
 ^bb0(%arg0: tensor<?xi32>):
-  // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type values}}
+  // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
   %0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
   return %0#0 : tensor<?xi32>
 }
@@ -591,8 +591,9 @@
 
 // CHECK-LABEL: testLstm
 func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
-  // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
-  %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
+  // CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
@@ -600,8 +601,9 @@
 
 // CHECK-LABEL: testLstmWithNoneTypeAndOverrideAttr
 func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
-  // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
-  %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
+  // CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
@@ -610,7 +612,7 @@
 // test invalid none type applied to a tensor type arg
 func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: none, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
   // expected-error @+1 {{'tfl.lstm' op operand #2 must be tensor of 32-bit float or 8-bit integer values}}
-  %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
@@ -619,7 +621,7 @@
 // test violation of projection weight and projection bias pred op trait
 func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
   // expected-error @+1 {{'tfl.lstm' op failed to verify that either projection weight must be specified or both projection weight and projection bias must not be specified}}
-  %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
@@ -628,7 +630,7 @@
 // test invalid kernel type
 func @testLstmWithInvalidKernelType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
   // expected-error @+1 {{'tfl.lstm' op attribute 'kernel_type' failed to satisfy constraint: lstm kernel type enum case FULL}}
-  %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor<?xf32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor<?xf32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc
new file mode 100644
index 0000000..01e54da
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc
@@ -0,0 +1,228 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This transformation pass prepare the tflite fused ops for quantization.
+
+#include "absl/memory/memory.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/None.h"
+#include "llvm/ADT/Optional.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:local_config_mlir
+#include "mlir/Dialect/StandardOps/Ops.h"  // TF:local_config_mlir
+#include "mlir/IR/Builders.h"  // TF:local_config_mlir
+#include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
+#include "mlir/Pass/Pass.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
+
+//===----------------------------------------------------------------------===//
+// The LoadQuantizationRecipe Pass.
+//
+namespace mlir {
+namespace TFL {
+
+namespace {
+
+// This pass loads the quantization recipe for the TFLite ops to be quantized.
+// Specifically, it extends the fused ops with their internal implementation as
+// op regions. Each ops in the region produces results with element type
+// AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also
+// defines the op quantization traits, which are used to propgate the
+// quantization parameters by the following passes.
+struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
+  void runOnFunction() override;
+
+ private:
+  void Initialize(LSTMOp lstm, OpBuilder* builder);
+
+  // Create LSTM gates with different weights for input, recurrent and
+  // cell state, and also the layer normalization parameters.
+  Operation* CreateGate(Location loc, Value* in, Value* in_w, Value* rec,
+                        Value* rec_w,
+                        llvm::Optional<std::pair<Value*, Value*>> cell,
+                        Value* ln_w, Value* ln_bias, OpBuilder* builder);
+
+  Operation* CreateLayerNorm(Location loc, Value* in, Value* ln_w,
+                             Value* ln_bias, OpBuilder* builder);
+
+  // Add the internal implementation of the LSTM to its regions.
+  void LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder);
+
+  StringAttr none_af;
+  StringAttr fc_format;
+  BoolAttr keep_dims;
+  Type int8;
+  Type int16;
+  ConstantOp none_cst;
+};
+
+void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
+  Type expressed_type =
+      lstm.input()->getType().cast<ShapedType>().getElementType();
+  Type int8_storage_type = builder->getIntegerType(8);
+  Type int16_storage_type = builder->getIntegerType(16);
+  auto flag = quant::QuantizationFlags::FlagValue::Signed;
+  int64_t int8_min = quant::QuantizedType::getDefaultMininumForInteger(
+      flag, /*integralWidth=*/8);
+  int64_t int8_max = quant::QuantizedType::getDefaultMaxinumForInteger(
+      flag, /*integralWidth=*/8);
+  int64_t int16_min = quant::QuantizedType::getDefaultMininumForInteger(
+      flag, /*integralWidth=*/16);
+  int64_t int16_max = quant::QuantizedType::getDefaultMaxinumForInteger(
+      flag, /*integralWidth=*/16);
+  auto any_int8 = quant::AnyQuantizedType::get(
+      flag, int8_storage_type, expressed_type, int8_min, int8_max);
+  auto any_int16 = quant::AnyQuantizedType::get(
+      flag, int16_storage_type, expressed_type, int16_min, int16_max);
+
+  int8 = any_int8.castFromExpressedType(lstm.input()->getType());
+  int16 = any_int16.castFromExpressedType(lstm.input()->getType());
+}
+
+Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in,
+                                                   Value* ln_w, Value* ln_bias,
+                                                   OpBuilder* builder) {
+  // Note that l2_normalization and add ops here are not the execution kernle
+  // implementation for layer_normalization and we just want to use them to
+  // model the quantization requirement.
+  auto l2_norm = builder->create<L2NormalizationOp>(loc, int16, in, none_af);
+  auto add = builder->create<AddOp>(loc, int16, in, l2_norm, none_af);
+  return builder->create<FullyConnectedOp>(loc, int16, add, ln_w, ln_bias,
+                                           none_af, fc_format, keep_dims);
+}
+
+Operation* LoadQuantizationRecipe::CreateGate(
+    Location loc, Value* in, Value* in_w, Value* rec, Value* rec_w,
+    llvm::Optional<std::pair<Value*, Value*>> cell, Value* ln_w, Value* ln_bias,
+    OpBuilder* builder) {
+  auto s1 = builder->create<FullyConnectedOp>(loc, int16, in, in_w, none_cst,
+                                              none_af, fc_format, keep_dims);
+  auto s2 = builder->create<FullyConnectedOp>(loc, int16, rec, rec_w, none_cst,
+                                              none_af, fc_format, keep_dims);
+
+  AddNOp s4;
+  if (cell.hasValue()) {
+    auto s3 = builder->create<MulOp>(loc, int16, cell.getValue().first,
+                                     cell.getValue().second, none_af);
+    s4 = builder->create<AddNOp>(
+        loc, int16,
+        llvm::ArrayRef<Value*>(
+            {*s1.output().begin(), *s2.output().begin(), s3.output()}));
+
+  } else {
+    s4 = builder->create<AddNOp>(
+        loc, int16,
+        llvm::ArrayRef<Value*>({*s1.output().begin(), *s2.output().begin()}));
+  }
+
+  auto s5 = CreateLayerNorm(loc, s4.sum(), ln_w, ln_bias, builder);
+
+  if (cell.hasValue()) {
+    return builder->create<LogisticOp>(loc, int16, s5->getResult(0));
+  } else {
+    return builder->create<TanhOp>(loc, int16, s5->getResult(0));
+  }
+}
+
+void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) {
+  Initialize(lstm, builder);
+
+  Region region;
+  region.push_back(new Block);
+  builder->setInsertionPointToEnd(&region.front());
+  Location loc = lstm.getLoc();
+  Type int32_type = builder->getIntegerType(32);
+  Type int32_tensor = builder->getTensorType(int32_type);
+  none_cst = builder->create<ConstantOp>(loc, builder->getNoneType(),
+                                         builder->getUnitAttr());
+
+  auto input_gate = CreateGate(
+      loc, lstm.input(), lstm.input_to_input_weights(),
+      lstm.input_activation_state(), lstm.recurrent_to_input_weights(),
+      llvm::Optional<std::pair<Value*, Value*>>(
+          {lstm.input_cell_state(), lstm.cell_to_input_weights()}),
+      lstm.input_layer_norm_coefficients(), lstm.input_gate_bias(), builder);
+
+  auto forget_gate = CreateGate(
+      loc, lstm.input(), lstm.input_to_forget_weights(),
+      lstm.input_activation_state(), lstm.recurrent_to_forget_weights(),
+      llvm::Optional<std::pair<Value*, Value*>>(
+          {lstm.input_cell_state(), lstm.cell_to_forget_weights()}),
+      lstm.forget_layer_norm_coefficients(), lstm.forget_gate_bias(), builder);
+
+  auto cell_gate = CreateGate(loc, lstm.input(), lstm.input_to_cell_weights(),
+                              lstm.input_activation_state(),
+                              lstm.recurrent_to_cell_weights(), llvm::None,
+                              lstm.cell_layer_norm_coefficients(),
+                              lstm.cell_bias(), builder);
+
+  auto forget_cell_state = builder->create<MulOp>(
+      loc, int16, forget_gate->getResult(0), lstm.input_cell_state(), none_af);
+  auto input_cell_state = builder->create<MulOp>(
+      loc, int16, input_gate->getResult(0), cell_gate->getResult(0), none_af);
+  auto new_cell = builder->create<AddOp>(loc, int16, forget_cell_state.output(),
+                                         input_cell_state.output(), none_af);
+
+  auto output_gate = CreateGate(
+      loc, lstm.input(), lstm.input_to_output_weights(),
+      lstm.input_activation_state(), lstm.recurrent_to_output_weights(),
+      llvm::Optional<std::pair<Value*, Value*>>(
+          {new_cell, lstm.cell_to_output_weights()}),
+      lstm.output_layer_norm_coefficients(), lstm.output_gate_bias(), builder);
+
+  auto new_cell_tanh = builder->create<TanhOp>(loc, int16, new_cell);
+  auto hidden_state = builder->create<MulOp>(
+      loc, int16, new_cell_tanh.y(), output_gate->getResult(0), none_af);
+  auto act = builder->create<FullyConnectedOp>(
+      loc, int8, hidden_state.output(), lstm.projection_weights(),
+      lstm.projection_bias(), none_af, fc_format, keep_dims);
+
+  // TODO(fengliuai): define and register the op in the QuantOps Dialect.
+  OperationState return_state(loc, "tf_quant.pseudo_return", act.getResult(0),
+                              {int8}, {});
+  builder->createOperation(return_state);
+
+  lstm.internal().takeBody(region);
+}
+
+void LoadQuantizationRecipe::runOnFunction() {
+  FuncOp func = getFunction();
+  OpBuilder builder(func);
+  none_af = builder.getStringAttr("NONE");
+  fc_format = builder.getStringAttr("DEFAULT");
+  keep_dims = builder.getBoolAttr(false);
+
+  func.walk([&](Operation* op) {
+    if (auto lstm = llvm::dyn_cast<LSTMOp>(op)) {
+      LoadForLSTMOp(lstm, &builder);
+    }
+    // Handles other ops.
+  });
+}
+
+}  // namespace
+
+// Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe
+// pass.
+std::unique_ptr<FunctionPassBase> CreateLoadQuantizationRecipePass() {
+  return absl::make_unique<LoadQuantizationRecipe>();
+}
+
+static PassRegistration<LoadQuantizationRecipe> pass(
+    "tfl-load-recipe", "Load TFL op quantization recipe");
+
+}  // namespace TFL
+}  // namespace mlir