Handle LSTM with cifg in MLIR quantizer.

* Added CIFG constraints to MLIR op definition.
* modified ops.mlir.test to check cifg constraint.

Currently, this logic is not merged to operator_property.cc to keep it simple, and also calibration step relied on size of `OperatorProperty::intermediates` to allocate necessary intermediate tensors.

PiperOrigin-RevId: 350266407
Change-Id: I1f1ea9318743054bf2db9687915cc7a73c62cdb1
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 9fdff92..4bb5e05 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -543,6 +543,7 @@
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/lite/schema:schema_fbs",
         "//tensorflow/lite/tools/optimize:operator_property",
+        "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/memory",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index f283f97..e9b4500 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -3691,8 +3691,12 @@
 
 def LstmOptionalPeepholeWeightConstraint : PredOpTrait<
   "the optional peephole weights should all be specified or none",
-  And<[TFL_TCopVTEtAreSameAt<9, 10, 16>,
-       TFL_TCopVTEtAreSameAt<9, 11, 16>]>>;
+  // Ignore input 9 (cell_to_input_weights) for LSTM with CIFG.
+  And<[
+    TFL_TCopVTEtAreSameAt<10, 11, 16>,
+    Or<[TFL_TCopVTEtAreSameAt<9, 10, 16>,
+        And<[TypeIsPred<"input_to_input_weights", NoneType>,
+             TypeIsPred<"cell_to_input_weights", NoneType>]>]>]>>;
 
 def LstmProjectionWeightBiasConstraint : PredOpTrait<
   "either projection weight must be specified or both projection weight and "
@@ -3702,13 +3706,25 @@
            TypeIsPred<"projection_bias", NoneType>]>,
       Neg<TypeIsPred<"projection_weights", NoneType>>]>>;
 
-// TODO(b/137798843): Need to add two additional constraints for both LSTM and
+def LstmCifgInputConstraint : PredOpTrait<
+  "the cifg inputs should all be specified or none",
+   // If LSTM has combined input/forget gate, input 1, 5, 9, 12, 20 are all none
+   // or 1, 5, 12 should not be none. Inputs 9 and 20 depend on LSTM's variants.
+   Or<[
+     And<[TypeIsPred<"input_to_input_weights", NoneType>,
+          TypeIsPred<"recurrent_to_input_weights", NoneType>,
+          TypeIsPred<"cell_to_input_weights", NoneType>,
+          TypeIsPred<"input_gate_bias", NoneType>,
+          TypeIsPred<"input_layer_norm_coefficients", NoneType>]>,
+     Neg<Or<[
+       TypeIsPred<"input_to_input_weights", NoneType>,
+       TypeIsPred<"recurrent_to_input_weights", NoneType>,
+       TypeIsPred<"input_gate_bias", NoneType>]>>]>>;
+
+
+// TODO(b/137798843): Need to add an additional constraint for both LSTM and
 // UnidirectionalSequenceLstm
-// For coupling of input and forget gates (cifg): if cifg is true,
-// tensor {1, 5, 9, 12, 20} are null; if cifg is
-// false, tensors {1, 5, 12} are not null; tensor {9} is not null if
-// additionally peephole = true; tensor {20} is not null if additionally layer
-// norm = true. For layer norm: if layer norm is false, tensor {20, 21, 22, 23}
+// For layer norm: if layer norm is false, tensor {20, 21, 22, 23}
 // are null; if layer norm is true, tensors {21, 22, 23} are not null; tensor
 // {20} is not null if additionally cifg = false.
 
@@ -3759,6 +3775,7 @@
           [LstmMandatoryInputsConstraint,
            LstmOptionalPeepholeWeightConstraint,
            LstmProjectionWeightBiasConstraint,
+           LstmCifgInputConstraint,
            LstmResultConstraint,
            TFL_OperandHasRank<2, 2>,           // input_to_forget_weights
            TFL_OperandHasRank<3, 2>,           // input_to_cell_weights
@@ -3883,6 +3900,7 @@
           [LstmMandatoryInputsConstraint,
            LstmOptionalPeepholeWeightConstraint,
            LstmProjectionWeightBiasConstraint,
+           LstmCifgInputConstraint,
            LstmResultConstraint,
            TFL_OperandHasRankAtLeast<0, 2>,    // input
            TFL_OperandIsNoneOrHasRank<1, 2>,   // input_to_input_weights
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index b01edef..782d4ad 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -686,16 +686,37 @@
 // -----
 
 // CHECK-LABEL: testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr
-func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
-  // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
-  %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x ? x f32>,
+  %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>,
+  %arg5: none, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>,
+  %arg9: none, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>,
+  %arg12: none, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>,
+  %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>,
+  %arg18: tensor<? x f32>, %arg19: tensor<? x f32>,
+  %arg20: none, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
+  // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0,
+  // CHECK-SAME: %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15,
+  // CHECK-SAME: %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} :
+  // CHECK-SAME: (tensor<?x?xf32>,
+  // CHECK-SAME: none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+  // CHECK-SAME: none, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
+  // CHECK-SAME: none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %0 = "tfl.unidirectional_sequence_lstm"(%arg0,
+    %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8,
+    %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15,
+    %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>,
+      none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+      none, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
+      tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
+// -----
+
 // CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates
-func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
-  // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
-  %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
+  // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
@@ -709,9 +730,8 @@
 }
 
 // -----
+
 // CHECK-LABEL: testLstmIntermediates
-
-
 func @testLstmIntermediates(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
     %cst = constant unit
     %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
@@ -744,7 +764,6 @@
   // CHECK: return %[[RES1]]
 }
 
-
 // -----
 
 // CHECK-LABEL: testLstm
@@ -774,10 +793,27 @@
 // -----
 
 // CHECK-LABEL: testLstmWithNoneTypeAndOverrideAttr
-func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
+func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>,
+  %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>,
+  %arg5: none, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>,
+  %arg9: none, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>,
+  %arg12: none, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>,
+  %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>,
+  %arg18: tensor<? x f32>, %arg19: tensor<? x f32>,
+  %arg20: none, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
   // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
-  // CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
-  %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  // CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} :
+  // CHECK-SAME: (tensor<?xf32>,
+  // CHECK-SAME: none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+  // CHECK-SAME: none, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
+  // CHECK-SAME: none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %0 = "tfl.lstm"(%arg0,
+    %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8,
+    %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15,
+    %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>,
+      none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+      none, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
+      tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir
index 6b5a599..aac58c9 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir
@@ -115,6 +115,87 @@
 // CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>}
 }
 
+// CHECK-LABEL: QuantizeLstmCifg
+func @QuantizeLstmCifg(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} {
+  %none = constant unit
+  %input = "quant.stats"(%arg0) {layerStats = dense<[-1.2, 1.5]> : tensor<2xf32>} : (tensor<1x5xf32>) -> tensor<1x5xf32>
+  %1 = "tfl.pseudo_const"() {value = dense<[[2.32939887, -0.623641372, -0.0191893689, 0.326861918, 0.734137893], [0.499284297, 1.25277913, 0.60228157, -1.39478016, 0.115529917]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
+  %2 = "tfl.pseudo_const"() {value = dense<[[0.839470446, 0.564852297, -0.80136007, -0.0372898243, 0.57127893], [-5.516230e-01, -1.082380e+00, 1.41860521, -0.92541927, -1.13971734]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
+  %3 = "tfl.pseudo_const"() {value = dense<[[-0.440826088, -0.0863231644, -0.707756281, -0.695703208, -1.87899077], [0.16942361, 0.206325337, 1.09067786, -2.18648934, 0.273400396]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
+  %5 = "tfl.pseudo_const"() {value = dense<[[-0.435141891, -0.940576493, 1.30446923, -1.02953017], [0.684501767, 0.363370508, -2.29151702, 2.41928673]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
+  %6 = "tfl.pseudo_const"() {value = dense<[[0.270476967, 0.00706229592, 0.489950746, 1.05166924], [1.28193891, 0.273171216, 0.484176666, 1.11504579]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
+  %7 = "tfl.pseudo_const"() {value = dense<[[-2.36692929, -3.483900e-01, 0.322934568, -1.56939185], [-5.623850e-01, -0.083735466, 1.73820043, 0.218063414]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
+  %9 = "tfl.pseudo_const"() {value = dense<[-1.66391921, 1.14934266]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %10 = "tfl.pseudo_const"() {value = dense<[-1.59288621, 0.904723584]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %12 = "tfl.pseudo_const"() {value = dense<[-1.0347594, -1.09994471]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %13 = "tfl.pseudo_const"() {value = dense<[-2.03072214, -1.63648951]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %14 = "tfl.pseudo_const"() {value = dense<[-1.90073407, -0.286088765]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %15 = "tfl.pseudo_const"() {value = dense<[[0.580187321, -1.72028887], [1.48392391, 0.859561979], [0.316514879, 0.81852132], [0.0933789983, 0.58165586]]> : tensor<4x2xf32>} : () -> tensor<4x2xf32>
+  %16 = "tfl.pseudo_const"() {value = dense<[-0.0432887711, -0.431485623, -0.307492912, -0.882515907]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %recurrent_input = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
+  %recurrent_stats = "quant.stats"(%recurrent_input) {layerStats = dense<[-2.0, 1.0]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
+  %cell_input = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
+  %cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
+  %20 = "tfl.pseudo_const"() {value = dense<[-0.76004064, -0.892570137]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %21 = "tfl.pseudo_const"() {value = dense<[-0.330534697, -1.68513882]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %22 = "tfl.pseudo_const"() {value = dense<[-0.896740913, -0.382640809]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %23 = "tfl.lstm"(%input,
+    %none, %1, %2, %3,
+    %none, %5, %6, %7,
+    %none, %9, %10,
+    %none, %12, %13, %14,
+    %15, %16,
+    %recurrent_stats, %cell_stats,
+    %none, %20, %21, %22) ({}) {
+      cell_clip = 5.000000e+01 : f32,
+      effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>,
+      fused_activation_function = "TANH",
+      input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00:4.000000e+00>>>,
+      input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01:1.600000e+01>>>,
+      input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00:1.000000e+00>>>,
+      proj_clip = 0.000000e+00 : f32,time_major = false} : (
+        tensor<1x5xf32>,
+        none, tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>,
+        none, tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>,
+        none, tensor<2xf32>, tensor<2xf32>,
+        none, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
+        tensor<4x2xf32>, tensor<4xf32>,
+        tensor<1x4xf32>, tensor<1x2xf32>,
+        none, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<*xf32>
+  %24 = "quant.stats"(%23) {layerStats = dense<[-1.0, 2.0]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32>
+  return %24 : tensor<*xf32>
+
+// CHECK: %[[none:.*]] = constant unit
+// CHECK-DAG: %[[input_0:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x5x!quant.uniform<i8:f32, 0.010588235481112611:-15>>) -> tensor<1x5xf32>
+// CHECK-DAG: %[[input_2:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.018341723389512912>>) -> tensor<2x5xf32>
+// CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011170119751156785>>) -> tensor<2x5xf32>
+// CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.017216451524749515>>) -> tensor<2x5xf32>
+// CHECK-DAG: %[[input_6:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.019049501794529713>>) -> tensor<2x4xf32>
+// CHECK-DAG: %[[input_7:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.010094007169167826>>) -> tensor<2x4xf32>
+// CHECK-DAG: %[[input_8:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.018637238525030179>>) -> tensor<2x4xf32>
+// CHECK-DAG: %[[input_10:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.0780334190922573E-5>>) -> tensor<2xf32>
+// CHECK-DAG: %[[input_11:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.8612512878442185E-5>>) -> tensor<2xf32>
+// CHECK-DAG: %[[input_13:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 2.6601474818224132E-8>>) -> tensor<2xf32>
+// CHECK-DAG: %[[input_14:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 5.0222583101003261E-8>>) -> tensor<2xf32>
+// CHECK-DAG: %[[input_15:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 2.6725777405118232E-8>>) -> tensor<2xf32>
+// CHECK-DAG: %[[input_16:.*]] = "tfl.dequantize"({{.*}}) : (tensor<4x2x!quant.uniform<i8<-127:127>:f32, 0.013545581674951268>>) -> tensor<4x2xf32>
+// CHECK-DAG: %[[input_17:.*]] = "tfl.dequantize"({{.*}}) : (tensor<4x!quant.uniform<i32:f32, 5.3119928137063791E-5>>) -> tensor<4xf32>
+// CHECK-DAG: %[[input_18:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform<i8:f32, 0.015686274509803921:-1>>) -> tensor<1x4xf32>
+// CHECK-DAG: %[[input_19:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x2x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x2xf32>
+// CHECK-DAG: %[[input_21:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.7239910213861512E-5>>) -> tensor<2xf32>
+// CHECK-DAG: %[[input_22:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.1427925095427339E-5>>) -> tensor<2xf32>
+// CHECK-DAG: %[[input_23:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.736719606284107E-5>>) -> tensor<2xf32>
+
+// CHECK: %[[lstm:.*]] = "tfl.lstm"(%[[input_0]], %[[none]], %[[input_2]], %[[input_3]], %[[input_4]], %[[none]], %[[input_6]], %[[input_7]], %[[input_8]],
+// CHECK-SAME: %[[none]], %[[input_10]], %[[input_11]], %[[none]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]],
+// CHECK-SAME: %[[none]], %[[input_21]], %[[input_22]], %[[input_23]])
+// CHECK-NEXT: effective_hidden_scale_intermediate = tensor<!quant.uniform<i8:f32, 0.0039215686274509803:-1>>
+// CHECK-SAME: input_to_cell_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 1.2207403790398877E-4>>
+// CHECK-SAME: input_to_forget_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 4.8829615161595508E-4>>
+// CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>>
+
+// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>}
+}
 
 // CHECK-LABEL: QuantizeUnidirectionalLstmFull
 func @QuantizeUnidirectionalLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} {
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h
index ed92a3a..74c3e23 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h
@@ -23,6 +23,7 @@
 #include <string>
 #include <vector>
 
+#include "absl/container/flat_hash_set.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/MathExtras.h"
@@ -104,6 +105,30 @@
       !op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
 
   *op_property = operator_property::GetOperatorProperty(*lstm_variant);
+
+  // TODO(b/176258587) move this to operator_property.cc if this is needed in
+  // other components, too.
+  bool use_cifg =
+      op.input_to_input_weights().getType().template isa<NoneType>();
+  if (use_cifg) {
+    const absl::flat_hash_set<int> cifg_non_inputs = {1, 5, 9, 12, 20};
+    const int cifg_non_intermediate = 0;
+    op_property->inputs.erase(
+        std::remove_if(
+            op_property->inputs.begin(), op_property->inputs.end(),
+            [&](std::pair<int, operator_property::TensorProperty> input) {
+              return cifg_non_inputs.find(input.first) != cifg_non_inputs.end();
+            }),
+        op_property->inputs.end());
+    op_property->intermediates.erase(
+        std::remove_if(op_property->intermediates.begin(),
+                       op_property->intermediates.end(),
+                       [&](std::pair<int, operator_property::TensorProperty>
+                               intermediate) {
+                         return intermediate.first == cifg_non_intermediate;
+                       }),
+        op_property->intermediates.end());
+  }
   return success();
 }