mlir lite resolutions
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
index 2e69a17..89fae87 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
@@ -1418,7 +1418,7 @@
     } else {
       auto segments = dim_metadata.segments();
       std::vector<int> vector_segments(segments.size(), 0);
-      for (int j = 0; j < segments.size(); j++) {
+      for (int j = 0, end = segments.size(); j < end; j++) {
         vector_segments[j] = segments[j].dyn_cast<mlir::IntegerAttr>().getInt();
       }
       tflite::SparseIndexVector segments_type;
@@ -1450,7 +1450,7 @@
       auto indices = dim_metadata.indices();
       std::vector<int> vector_indices(indices.size(), 0);
       int max_of_indices = 0;
-      for (int j = 0; j < indices.size(); j++) {
+      for (int j = 0, end = indices.size(); j < end; j++) {
         vector_indices[j] = indices[j].dyn_cast<mlir::IntegerAttr>().getInt();
         if (vector_indices[j] > max_of_indices) {
           max_of_indices = vector_indices[j];
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index fa85b4e..29484fa 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -229,7 +229,7 @@
 
   llvm::SmallVector<llvm::APFloat, 4> min_maxs;
   min_maxs.reserve(mins.size() * 2);
-  for (int i = 0; i < mins.size(); ++i) {
+  for (int i = 0, end = mins.size(); i < end; ++i) {
     llvm::APFloat min(mins[i]);
     llvm::APFloat max(maxs[i]);
     min_maxs.push_back(min);
@@ -281,7 +281,7 @@
   int bytes_len = bytes.size();
   assert(bytes_len % read_size == 0);
 
-  size_t elem_count = bytes_len / read_size;
+  int elem_count = bytes_len / read_size;
   ret.reserve(elem_count);
 
   const char* data_ptr = reinterpret_cast<const char*>(bytes.data());
@@ -318,7 +318,7 @@
   switch (elem_type.getWidth()) {
     case 16: {
       assert(bytes_len % 2 == 0);
-      size_t elem_count = bytes_len / 2;
+      int elem_count = bytes_len / 2;
       std::vector<llvm::APFloat> values;
       values.reserve(elem_count);
 
@@ -337,12 +337,11 @@
     }
     case 32: {
       assert(bytes_len % 4 == 0);
-      size_t elem_count = bytes_len / 4;
+      int elem_count = bytes_len / 4;
       std::vector<float> values;
       values.reserve(elem_count);
 
       const char* data = reinterpret_cast<const char*>(buffer.data());
-
       for (int i = 0; i < elem_count; i++) {
         uint32_t bit_repr =
             llvm::support::endian::readNext<uint32_t, llvm::support::little,
@@ -353,7 +352,7 @@
     }
     case 64: {
       assert(bytes_len % 8 == 0);
-      size_t elem_count = bytes_len / 8;
+      int elem_count = bytes_len / 8;
       std::vector<double> values;
       values.reserve(elem_count);
 
@@ -829,7 +828,7 @@
   // Add state variables to inputs.
   absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
                                                func_inputs.end());
-  for (int i = 0; i < subgraph.tensors.size(); i++) {
+  for (int i = 0, end = subgraph.tensors.size(); i < end; i++) {
     auto& tensor = *subgraph.tensors.at(i);
     if (tensor.is_variable && !input_index_set.contains(i)) {
       func_inputs.emplace_back(i);
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 427b9c6..c7c3f57 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -147,18 +147,10 @@
 bool VerifyAddOpShapeConstraints(AddOp op) {
   auto element_type = getElementTypeOrSelf(op.output().getType());
 
-  // Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
+  // Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
   // which are broadcastable shapes up to five dimension or have same shapes.
   if (element_type.isF32() || IsQI8Type(element_type) ||
-      IsQUI8Type(element_type)) {
-    return VerifyOperandsHaveSameShapesOrBroadcastableShape(
-        /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
-        /*max_bcast_rank=*/5);
-  }
-
-  // Allows I32 output when the operands have valid shapes, which are
-  // broadcastable shapes up to four dimension or have same shapes.
-  if (IsI32Type(element_type)) {
+      IsQUI8Type(element_type) || IsI32Type(element_type)) {
     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
         /*max_bcast_rank=*/4);
@@ -210,20 +202,13 @@
     }
     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
-        /*max_bcast_rank=*/5);
+        /*max_bcast_rank=*/4);
   }
 
-  // Allows F32 output when the operands have valid shapes, which are
-  // broadcastable shapes up to five dimension or have same shapes.
-  if (element_type.isF32()) {
-    return VerifyOperandsHaveSameShapesOrBroadcastableShape(
-        /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
-        /*max_bcast_rank=*/5);
-  }
-
-  // Allows I32 and QI16 outputs when the operands have valid shapes, which are
-  // broadcastable shapes up to four dimension or have same shapes.
-  if (IsI32Type(element_type) || IsQI16Type(element_type)) {
+  // Allows I32, QI16 and F32 outputs when the operands have valid shapes, which
+  // are broadcastable shapes up to four dimension or have same shapes.
+  if (IsI32Type(element_type) || IsQI16Type(element_type) ||
+      element_type.isF32()) {
     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
         /*max_bcast_rank=*/4);
@@ -773,7 +758,8 @@
       op.custom_option().cast<OpaqueElementsAttr>();
   if (!opaque_attr.getType().hasStaticShape())
     return op.emitOpError("custom_option should have a static shape.");
-  if (opaque_attr.getValue().size() !=
+  const int opaque_attr_getValue_size = opaque_attr.getValue().size();
+  if (opaque_attr_getValue_size !=
       opaque_attr.getType().cast<ShapedType>().getDimSize(0))
     return op.emitOpError(
         "custom_option should have the same length of content with shape.");
@@ -955,7 +941,7 @@
     // Checks whether the last `(shape_type.getDimSize(0) - outermost_dim)`
     // dimensions of `updates` and `shape` are equal.
     for (auto shape_it : llvm::enumerate(shape_value)) {
-      auto i = shape_it.index();
+      long int i = shape_it.index();
       auto value = shape_it.value().getSExtValue();
       if (i >= outermost_dim) {
         auto corresponding_dim = i - outermost_dim + outer_dims;
@@ -1192,7 +1178,8 @@
       return failure();
 
     const int total_pack_inputs = pack_op.getNumOperands();
-    if (total_pack_inputs != input_unpack_op.getNumResults()) return failure();
+    const int input_unpack_op_getNumResults = input_unpack_op.getNumResults();
+    if (total_pack_inputs != input_unpack_op_getNumResults) return failure();
     for (auto input_output :
          llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
       Value pack_input = std::get<0>(input_output);
@@ -1261,7 +1248,7 @@
   }
 
   if (begin && size && input_type.hasStaticShape()) {
-    const int input_rank = begin.getNumElements();
+    const uint64_t input_rank = begin.getNumElements();
     for (uint64_t i = 0; i < input_rank; i++) {
       int begin_i =
           begin.getValue({i}).cast<IntegerAttr>().getValue().getSExtValue();
diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
index dafcfd1..529c9ee 100644
--- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
@@ -75,7 +75,8 @@
   }
   auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
   input_names.split(function_input_names, ",");
-  if (function_input_names.size() != model_flags.input_arrays().size()) {
+  const int function_input_names_size = function_input_names.size();
+  if (function_input_names_size != model_flags.input_arrays().size()) {
     return errors::InvalidArgument(
         "input array size mismatch: got ", function_input_names.size(),
         ", expected: ", model_flags.input_arrays().size());
@@ -99,7 +100,8 @@
   }
   auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
   output_names.split(function_output_names, ",");
-  if (function_output_names.size() != model_flags.output_arrays().size()) {
+  const int function_output_names_size = function_output_names.size();
+  if (function_output_names_size != model_flags.output_arrays().size()) {
     return errors::InvalidArgument(
         "output array size mismatch: got ", function_output_names.size(),
         ", expected: ", model_flags.output_arrays().size());
diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
index b745be7..2054bab 100644
--- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
+++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
@@ -276,7 +276,7 @@
   }
   // Check that the block_shape of `stb_op` and `bts_op` are equal.
   if (stb_bs_attr.getNumElements() != bts_bs_attr.getNumElements()) return {};
-  for (uint64_t i = 0; i < stb_bs_attr.getNumElements(); ++i) {
+  for (uint64_t i = 0, end = stb_bs_attr.getNumElements(); i < end; ++i) {
     if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
   }
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
index 7d6866d..c85e825 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
@@ -170,7 +170,7 @@
       size_t num_samples = Distribution::kResultElementCount;
       llvm::SmallVector<float, 32> data;
       data.resize(num_elements);
-      while (offset < num_elements) {
+      while (static_cast<int>(offset) < num_elements) {
         const typename Distribution::ResultType samples = dist(&generator);
         std::copy(&samples[0],
                   &samples[0] + std::min(num_samples, data.size() - offset),
@@ -631,6 +631,156 @@
   }
 };
 
+// Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to
+// to make binary broadcast-able op conversion always successful and does not
+// require flex delegate.
+template <typename SourceOp>
+class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> {
+ public:
+  using OpRewritePattern<SourceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SourceOp src_op,
+                                PatternRewriter& rewriter) const override {
+    Operation* op = static_cast<Operation*>(src_op);
+    auto lhs = op->getOperand(0);
+    auto rhs = op->getOperand(1);
+
+    // Should have static shapes to calculate the broadcasted shape.
+    if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
+        !rhs.getType().cast<ShapedType>().hasStaticShape()) {
+      return failure();
+    }
+
+    // Calculate the broadcasted shape.
+    SmallVector<int64_t, 4> result_shape;
+    if (!OpTrait::util::getBroadcastedShape(
+            lhs.getType().cast<ShapedType>().getShape(),
+            rhs.getType().cast<ShapedType>().getShape(), result_shape)) {
+      return failure();
+    }
+
+    RankedTensorType result_type = RankedTensorType::get(
+        result_shape, getElementTypeOrSelf(op->getResult(0).getType()));
+
+    // Create a const op, that stores the above broadcasted shape.
+    auto new_shape_attr = mlir::DenseIntElementsAttr::get(
+        RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)),
+        result_shape);
+    auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
+
+    // Apply BroadcastTo ops to each input.
+    auto broadcast_type = RankedTensorType::get(
+        result_shape, getElementTypeOrSelf(lhs.getType()));
+
+    if (result_type.getShape() != lhs.getType().cast<ShapedType>().getShape()) {
+      lhs = rewriter
+                .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
+                                           new_shape)
+                .output();
+    }
+    if (result_type.getShape() != rhs.getType().cast<ShapedType>().getShape()) {
+      rhs = rewriter
+                .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
+                                           new_shape)
+                .output();
+    }
+
+    // Recreate an op with the above Broadcast op results.
+    rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs);
+    return success();
+  }
+};
+
+// This specialization is for TF SelectV2 op. SelectV2 op have three inputs and
+// they should have broadcastable shapes.
+template <>
+class ApplyExplicitBroadcasting<TF::SelectV2Op>
+    : public OpRewritePattern<TF::SelectV2Op> {
+ public:
+  using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TF::SelectV2Op src_op,
+                                PatternRewriter& rewriter) const override {
+    Operation* op = static_cast<Operation*>(src_op);
+    auto cond = op->getOperand(0);
+    auto lhs = op->getOperand(1);
+    auto rhs = op->getOperand(2);
+
+    // Should have static shapes to calculate the broadcasted shape.
+    if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
+        !rhs.getType().cast<ShapedType>().hasStaticShape() ||
+        !cond.getType().cast<ShapedType>().hasStaticShape()) {
+      return failure();
+    }
+
+    // Calculate the broadcasted shape.
+    SmallVector<int64_t, 4> broadcasted_shape;
+    if (!OpTrait::util::getBroadcastedShape(
+            lhs.getType().cast<ShapedType>().getShape(),
+            rhs.getType().cast<ShapedType>().getShape(), broadcasted_shape)) {
+      return failure();
+    }
+
+    SmallVector<int64_t, 4> result_shape;
+    if (!OpTrait::util::getBroadcastedShape(
+            broadcasted_shape, cond.getType().cast<ShapedType>().getShape(),
+            result_shape)) {
+      return failure();
+    }
+
+    // Create a const op, that stores the above broadcasted shape.
+    auto shape_type =
+        RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64));
+    auto new_shape_attr =
+        mlir::DenseIntElementsAttr::get(shape_type, result_shape);
+    auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
+
+    // Apply BroadcastTo ops to each input.
+    auto cond_result_type =
+        RankedTensorType::get(result_shape, rewriter.getIntegerType(1));
+    auto result_type = RankedTensorType::get(
+        result_shape, getElementTypeOrSelf(lhs.getType()));
+
+    if (result_shape != cond.getType().cast<ShapedType>().getShape()) {
+      cond = rewriter
+                 .create<TF::BroadcastToOp>(op->getLoc(), cond_result_type,
+                                            cond, new_shape)
+                 .output();
+    }
+    if (result_shape != lhs.getType().cast<ShapedType>().getShape()) {
+      lhs = rewriter
+                .create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs,
+                                           new_shape)
+                .output();
+    }
+    if (result_shape != rhs.getType().cast<ShapedType>().getShape()) {
+      rhs = rewriter
+                .create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs,
+                                           new_shape)
+                .output();
+    }
+
+    // Recreate an op with the above Broadcast op results.
+    rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs,
+                                                rhs);
+    return success();
+  }
+};
+
+void applyPatterns(FuncOp func, ConversionTarget& target,
+                   const OwningRewritePatternList& patterns) {
+  // Keep trying to convert.
+  // TODO(karimnosseir): This is similar to what apply greedy patterns does.
+  // Look if there is a function that tries until it converge.
+  // Currently unit-test doesn't do multiple tries, so we need this.
+  const int max_iterations = 15;
+  for (int i = 0; i < max_iterations; ++i) {
+    if (failed(applyPartialConversion(func, target, patterns))) {
+      return;
+    }
+  }
+}
+
 void LegalizeTF::runOnFunction() {
   OwningRewritePatternList patterns;
   auto* context = &getContext();
@@ -681,16 +831,32 @@
         return success(current_thread_id == llvm::get_threadid());
       });
 
-  // Keep trying to convert.
-  // TODO(karimnosseir): This is similar to what apply greedy patterns does.
-  // Look if there is a function that tries until it converge.
-  // Currently unit-test doesn't do multiple tries, so we need this.
-  const int max_iterations = 15;
-  for (int i = 0; i < max_iterations; ++i) {
-    if (failed(applyPartialConversion(func, target, patterns))) {
-      return;
-    }
-  }
+  applyPatterns(func, target, patterns);
+
+  // Explict BroadcastTo addition for left-over broadcast-able ops.
+  // The following pattern matchings should be done after the other legalization
+  // rules in order not to add unnecessary BroadcastTo ops.
+  patterns.insert<ApplyExplicitBroadcasting<TF::LessEqualOp>,
+                  ApplyExplicitBroadcasting<TF::GreaterEqualOp>,
+                  ApplyExplicitBroadcasting<TF::NotEqualOp>,
+                  ApplyExplicitBroadcasting<TF::GreaterOp>,
+                  ApplyExplicitBroadcasting<TF::LessOp>,
+                  ApplyExplicitBroadcasting<TF::EqualOp>,
+                  ApplyExplicitBroadcasting<TF::AddOp>,
+                  ApplyExplicitBroadcasting<TF::AddV2Op>,
+                  ApplyExplicitBroadcasting<TF::MulOp>,
+                  ApplyExplicitBroadcasting<TF::DivOp>,
+                  ApplyExplicitBroadcasting<TF::RealDivOp>,
+                  ApplyExplicitBroadcasting<TF::SubOp>,
+                  ApplyExplicitBroadcasting<TF::FloorDivOp>,
+                  ApplyExplicitBroadcasting<TF::FloorModOp>,
+                  ApplyExplicitBroadcasting<TF::PowOp>,
+                  ApplyExplicitBroadcasting<TF::MaximumOp>,
+                  ApplyExplicitBroadcasting<TF::MinimumOp>,
+                  ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>,
+                  ApplyExplicitBroadcasting<TF::SelectV2Op>>(context);
+
+  applyPatterns(func, target, patterns);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index d26a490..751c526 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -198,7 +198,7 @@
   auto output_type = output_val.getType().cast<RankedTensorType>();
   auto shape_vector = output_type.getShape();
   std::vector<int32_t> shape(shape_vector.size());
-  for (int i = 0; i < shape_vector.size(); ++i) {
+  for (int i = 0, end = shape_vector.size(); i < end; ++i) {
     shape[i] = shape_vector[i];
   }
   return mlir::DenseElementsAttr::get(
@@ -684,7 +684,7 @@
 
     SmallVector<int, 8> old_major_index_ordering;
     SmallVector<int, 8> new_major_index_ordering;
-    for (int i = 0; i < input_shape.size(); i++) {
+    for (int i = 0, end = input_shape.size(); i < end; i++) {
       if (input_shape[i] != 1) {
         old_major_index_ordering.push_back(i);
       }
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
index f792384..9261dea 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
@@ -225,7 +225,8 @@
 LogicalResult CheckOutputConsumer(
     Operation* call_op, int expected_num_outputs,
     llvm::DenseSet<int> expected_consumer_indices) {
-  if (call_op->getNumResults() != expected_num_outputs) return failure();
+  const int call_op_getNumResults = call_op->getNumResults();
+  if (call_op_getNumResults != expected_num_outputs) return failure();
 
   for (int i = 0; i < expected_num_outputs; ++i) {
     auto it = expected_consumer_indices.find(i);
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 9a883a3..0a7802c 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -41,7 +41,9 @@
 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
@@ -49,6 +51,7 @@
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
 #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
@@ -58,7 +61,9 @@
 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
 
 #define DEBUG_TYPE "tf-tfl-legalization"
 
@@ -495,7 +500,8 @@
         original_input_type.getShape();
     SmallVector<int64_t, 4> new_shape;
     int index = 0;
-    while (index < original_input_shape.size() || new_axis_mask) {
+    const int original_input_shape_size = original_input_shape.size();
+    while (index < original_input_shape_size || new_axis_mask) {
       if (new_axis_mask & 1) {
         new_shape.emplace_back(1);
       } else {
@@ -737,6 +743,23 @@
   return failure(has_illegal_ops);
 }
 
+// Converts a set of TF2XLA ops into pure TF ops for future legalizations as
+// TF2XLA ops aren't supported by later stages.
+LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) {
+  ConversionTarget target(*context);
+  target.addLegalDialect<StandardOpsDialect>();
+  target.addLegalDialect<TF::TensorFlowDialect>();
+  target.addLegalOp<ModuleOp>();
+  target.addLegalOp<FuncOp>();
+  target.addIllegalOp<TF::XlaConvOp>();
+
+  OwningRewritePatternList patterns;
+  mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns);
+  TF::PopulateLegalizeHloToTfPatterns(&patterns, context);
+
+  return applyPartialConversion(func, target, patterns);
+}
+
 void PrepareTFPass::runOnFunction() {
   OwningRewritePatternList patterns;
   auto func = getFunction();
@@ -752,6 +775,11 @@
     return;
   }
 
+  if (failed(ConvertTf2XlaOps(func, ctx))) {
+    signalPassFailure();
+    return;
+  }
+
   // This pattern was intented to uses TFL QDQs to preserve the quantization
   // parameters from the TF Quant ops, thus this pattern should run with the
   // first `applyPatternsGreedily` method, which would otherwise removes the
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
index 2f876c6..3a469dd 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
@@ -134,7 +134,7 @@
   // the input tensor's dimensions, return 0-valued tensor of the requested
   // shape.
   ArrayRef<int64_t> input_shape = GetRankedTensorShape(input);
-  for (int i = 0; i < input_shape.size(); i++) {
+  for (int i = 0, end = input_shape.size(); i < end; i++) {
     if (begin_values[i] < 0 ||
         (begin_values[i] + size_values[i] > input_shape[i])) {
       return CreateF32SplatConst(builder, size_shape, 0, location);