mlir lite resolutions
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
index 2e69a17..89fae87 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
@@ -1418,7 +1418,7 @@
} else {
auto segments = dim_metadata.segments();
std::vector<int> vector_segments(segments.size(), 0);
- for (int j = 0; j < segments.size(); j++) {
+ for (int j = 0, end = segments.size(); j < end; j++) {
vector_segments[j] = segments[j].dyn_cast<mlir::IntegerAttr>().getInt();
}
tflite::SparseIndexVector segments_type;
@@ -1450,7 +1450,7 @@
auto indices = dim_metadata.indices();
std::vector<int> vector_indices(indices.size(), 0);
int max_of_indices = 0;
- for (int j = 0; j < indices.size(); j++) {
+ for (int j = 0, end = indices.size(); j < end; j++) {
vector_indices[j] = indices[j].dyn_cast<mlir::IntegerAttr>().getInt();
if (vector_indices[j] > max_of_indices) {
max_of_indices = vector_indices[j];
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index fa85b4e..29484fa 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -229,7 +229,7 @@
llvm::SmallVector<llvm::APFloat, 4> min_maxs;
min_maxs.reserve(mins.size() * 2);
- for (int i = 0; i < mins.size(); ++i) {
+ for (int i = 0, end = mins.size(); i < end; ++i) {
llvm::APFloat min(mins[i]);
llvm::APFloat max(maxs[i]);
min_maxs.push_back(min);
@@ -281,7 +281,7 @@
int bytes_len = bytes.size();
assert(bytes_len % read_size == 0);
- size_t elem_count = bytes_len / read_size;
+ int elem_count = bytes_len / read_size;
ret.reserve(elem_count);
const char* data_ptr = reinterpret_cast<const char*>(bytes.data());
@@ -318,7 +318,7 @@
switch (elem_type.getWidth()) {
case 16: {
assert(bytes_len % 2 == 0);
- size_t elem_count = bytes_len / 2;
+ int elem_count = bytes_len / 2;
std::vector<llvm::APFloat> values;
values.reserve(elem_count);
@@ -337,12 +337,11 @@
}
case 32: {
assert(bytes_len % 4 == 0);
- size_t elem_count = bytes_len / 4;
+ int elem_count = bytes_len / 4;
std::vector<float> values;
values.reserve(elem_count);
const char* data = reinterpret_cast<const char*>(buffer.data());
-
for (int i = 0; i < elem_count; i++) {
uint32_t bit_repr =
llvm::support::endian::readNext<uint32_t, llvm::support::little,
@@ -353,7 +352,7 @@
}
case 64: {
assert(bytes_len % 8 == 0);
- size_t elem_count = bytes_len / 8;
+ int elem_count = bytes_len / 8;
std::vector<double> values;
values.reserve(elem_count);
@@ -829,7 +828,7 @@
// Add state variables to inputs.
absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
func_inputs.end());
- for (int i = 0; i < subgraph.tensors.size(); i++) {
+ for (int i = 0, end = subgraph.tensors.size(); i < end; i++) {
auto& tensor = *subgraph.tensors.at(i);
if (tensor.is_variable && !input_index_set.contains(i)) {
func_inputs.emplace_back(i);
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 427b9c6..c7c3f57 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -147,18 +147,10 @@
bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
- // Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
+ // Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsQI8Type(element_type) ||
- IsQUI8Type(element_type)) {
- return VerifyOperandsHaveSameShapesOrBroadcastableShape(
- /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
- /*max_bcast_rank=*/5);
- }
-
- // Allows I32 output when the operands have valid shapes, which are
- // broadcastable shapes up to four dimension or have same shapes.
- if (IsI32Type(element_type)) {
+ IsQUI8Type(element_type) || IsI32Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
@@ -210,20 +202,13 @@
}
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
- /*max_bcast_rank=*/5);
+ /*max_bcast_rank=*/4);
}
- // Allows F32 output when the operands have valid shapes, which are
- // broadcastable shapes up to five dimension or have same shapes.
- if (element_type.isF32()) {
- return VerifyOperandsHaveSameShapesOrBroadcastableShape(
- /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
- /*max_bcast_rank=*/5);
- }
-
- // Allows I32 and QI16 outputs when the operands have valid shapes, which are
- // broadcastable shapes up to four dimension or have same shapes.
- if (IsI32Type(element_type) || IsQI16Type(element_type)) {
+ // Allows I32, QI16 and F32 outputs when the operands have valid shapes, which
+ // are broadcastable shapes up to four dimension or have same shapes.
+ if (IsI32Type(element_type) || IsQI16Type(element_type) ||
+ element_type.isF32()) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
@@ -773,7 +758,8 @@
op.custom_option().cast<OpaqueElementsAttr>();
if (!opaque_attr.getType().hasStaticShape())
return op.emitOpError("custom_option should have a static shape.");
- if (opaque_attr.getValue().size() !=
+ const int opaque_attr_getValue_size = opaque_attr.getValue().size();
+ if (opaque_attr_getValue_size !=
opaque_attr.getType().cast<ShapedType>().getDimSize(0))
return op.emitOpError(
"custom_option should have the same length of content with shape.");
@@ -955,7 +941,7 @@
// Checks whether the last `(shape_type.getDimSize(0) - outermost_dim)`
// dimensions of `updates` and `shape` are equal.
for (auto shape_it : llvm::enumerate(shape_value)) {
- auto i = shape_it.index();
+ long int i = shape_it.index();
auto value = shape_it.value().getSExtValue();
if (i >= outermost_dim) {
auto corresponding_dim = i - outermost_dim + outer_dims;
@@ -1192,7 +1178,8 @@
return failure();
const int total_pack_inputs = pack_op.getNumOperands();
- if (total_pack_inputs != input_unpack_op.getNumResults()) return failure();
+ const int input_unpack_op_getNumResults = input_unpack_op.getNumResults();
+ if (total_pack_inputs != input_unpack_op_getNumResults) return failure();
for (auto input_output :
llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
Value pack_input = std::get<0>(input_output);
@@ -1261,7 +1248,7 @@
}
if (begin && size && input_type.hasStaticShape()) {
- const int input_rank = begin.getNumElements();
+ const uint64_t input_rank = begin.getNumElements();
for (uint64_t i = 0; i < input_rank; i++) {
int begin_i =
begin.getValue({i}).cast<IntegerAttr>().getValue().getSExtValue();
diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
index dafcfd1..529c9ee 100644
--- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
@@ -75,7 +75,8 @@
}
auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
input_names.split(function_input_names, ",");
- if (function_input_names.size() != model_flags.input_arrays().size()) {
+ const int function_input_names_size = function_input_names.size();
+ if (function_input_names_size != model_flags.input_arrays().size()) {
return errors::InvalidArgument(
"input array size mismatch: got ", function_input_names.size(),
", expected: ", model_flags.input_arrays().size());
@@ -99,7 +100,8 @@
}
auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
output_names.split(function_output_names, ",");
- if (function_output_names.size() != model_flags.output_arrays().size()) {
+ const int function_output_names_size = function_output_names.size();
+ if (function_output_names_size != model_flags.output_arrays().size()) {
return errors::InvalidArgument(
"output array size mismatch: got ", function_output_names.size(),
", expected: ", model_flags.output_arrays().size());
diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
index b745be7..2054bab 100644
--- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
+++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
@@ -276,7 +276,7 @@
}
// Check that the block_shape of `stb_op` and `bts_op` are equal.
if (stb_bs_attr.getNumElements() != bts_bs_attr.getNumElements()) return {};
- for (uint64_t i = 0; i < stb_bs_attr.getNumElements(); ++i) {
+ for (uint64_t i = 0, end = stb_bs_attr.getNumElements(); i < end; ++i) {
if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
}
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
index 7d6866d..c85e825 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
@@ -170,7 +170,7 @@
size_t num_samples = Distribution::kResultElementCount;
llvm::SmallVector<float, 32> data;
data.resize(num_elements);
- while (offset < num_elements) {
+ while (static_cast<int>(offset) < num_elements) {
const typename Distribution::ResultType samples = dist(&generator);
std::copy(&samples[0],
&samples[0] + std::min(num_samples, data.size() - offset),
@@ -631,6 +631,156 @@
}
};
+// Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to
+// to make binary broadcast-able op conversion always successful and does not
+// require flex delegate.
+template <typename SourceOp>
+class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> {
+ public:
+ using OpRewritePattern<SourceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SourceOp src_op,
+ PatternRewriter& rewriter) const override {
+ Operation* op = static_cast<Operation*>(src_op);
+ auto lhs = op->getOperand(0);
+ auto rhs = op->getOperand(1);
+
+ // Should have static shapes to calculate the broadcasted shape.
+ if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
+ !rhs.getType().cast<ShapedType>().hasStaticShape()) {
+ return failure();
+ }
+
+ // Calculate the broadcasted shape.
+ SmallVector<int64_t, 4> result_shape;
+ if (!OpTrait::util::getBroadcastedShape(
+ lhs.getType().cast<ShapedType>().getShape(),
+ rhs.getType().cast<ShapedType>().getShape(), result_shape)) {
+ return failure();
+ }
+
+ RankedTensorType result_type = RankedTensorType::get(
+ result_shape, getElementTypeOrSelf(op->getResult(0).getType()));
+
+ // Create a const op, that stores the above broadcasted shape.
+ auto new_shape_attr = mlir::DenseIntElementsAttr::get(
+ RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)),
+ result_shape);
+ auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
+
+ // Apply BroadcastTo ops to each input.
+ auto broadcast_type = RankedTensorType::get(
+ result_shape, getElementTypeOrSelf(lhs.getType()));
+
+ if (result_type.getShape() != lhs.getType().cast<ShapedType>().getShape()) {
+ lhs = rewriter
+ .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
+ new_shape)
+ .output();
+ }
+ if (result_type.getShape() != rhs.getType().cast<ShapedType>().getShape()) {
+ rhs = rewriter
+ .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
+ new_shape)
+ .output();
+ }
+
+ // Recreate an op with the above Broadcast op results.
+ rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs);
+ return success();
+ }
+};
+
+// This specialization is for TF SelectV2 op. SelectV2 op have three inputs and
+// they should have broadcastable shapes.
+template <>
+class ApplyExplicitBroadcasting<TF::SelectV2Op>
+ : public OpRewritePattern<TF::SelectV2Op> {
+ public:
+ using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TF::SelectV2Op src_op,
+ PatternRewriter& rewriter) const override {
+ Operation* op = static_cast<Operation*>(src_op);
+ auto cond = op->getOperand(0);
+ auto lhs = op->getOperand(1);
+ auto rhs = op->getOperand(2);
+
+ // Should have static shapes to calculate the broadcasted shape.
+ if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
+ !rhs.getType().cast<ShapedType>().hasStaticShape() ||
+ !cond.getType().cast<ShapedType>().hasStaticShape()) {
+ return failure();
+ }
+
+ // Calculate the broadcasted shape.
+ SmallVector<int64_t, 4> broadcasted_shape;
+ if (!OpTrait::util::getBroadcastedShape(
+ lhs.getType().cast<ShapedType>().getShape(),
+ rhs.getType().cast<ShapedType>().getShape(), broadcasted_shape)) {
+ return failure();
+ }
+
+ SmallVector<int64_t, 4> result_shape;
+ if (!OpTrait::util::getBroadcastedShape(
+ broadcasted_shape, cond.getType().cast<ShapedType>().getShape(),
+ result_shape)) {
+ return failure();
+ }
+
+ // Create a const op, that stores the above broadcasted shape.
+ auto shape_type =
+ RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64));
+ auto new_shape_attr =
+ mlir::DenseIntElementsAttr::get(shape_type, result_shape);
+ auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
+
+ // Apply BroadcastTo ops to each input.
+ auto cond_result_type =
+ RankedTensorType::get(result_shape, rewriter.getIntegerType(1));
+ auto result_type = RankedTensorType::get(
+ result_shape, getElementTypeOrSelf(lhs.getType()));
+
+ if (result_shape != cond.getType().cast<ShapedType>().getShape()) {
+ cond = rewriter
+ .create<TF::BroadcastToOp>(op->getLoc(), cond_result_type,
+ cond, new_shape)
+ .output();
+ }
+ if (result_shape != lhs.getType().cast<ShapedType>().getShape()) {
+ lhs = rewriter
+ .create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs,
+ new_shape)
+ .output();
+ }
+ if (result_shape != rhs.getType().cast<ShapedType>().getShape()) {
+ rhs = rewriter
+ .create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs,
+ new_shape)
+ .output();
+ }
+
+ // Recreate an op with the above Broadcast op results.
+ rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs,
+ rhs);
+ return success();
+ }
+};
+
+void applyPatterns(FuncOp func, ConversionTarget& target,
+ const OwningRewritePatternList& patterns) {
+ // Keep trying to convert.
+ // TODO(karimnosseir): This is similar to what apply greedy patterns does.
+ // Look if there is a function that tries until it converge.
+ // Currently unit-test doesn't do multiple tries, so we need this.
+ const int max_iterations = 15;
+ for (int i = 0; i < max_iterations; ++i) {
+ if (failed(applyPartialConversion(func, target, patterns))) {
+ return;
+ }
+ }
+}
+
void LegalizeTF::runOnFunction() {
OwningRewritePatternList patterns;
auto* context = &getContext();
@@ -681,16 +831,32 @@
return success(current_thread_id == llvm::get_threadid());
});
- // Keep trying to convert.
- // TODO(karimnosseir): This is similar to what apply greedy patterns does.
- // Look if there is a function that tries until it converge.
- // Currently unit-test doesn't do multiple tries, so we need this.
- const int max_iterations = 15;
- for (int i = 0; i < max_iterations; ++i) {
- if (failed(applyPartialConversion(func, target, patterns))) {
- return;
- }
- }
+ applyPatterns(func, target, patterns);
+
+ // Explict BroadcastTo addition for left-over broadcast-able ops.
+ // The following pattern matchings should be done after the other legalization
+ // rules in order not to add unnecessary BroadcastTo ops.
+ patterns.insert<ApplyExplicitBroadcasting<TF::LessEqualOp>,
+ ApplyExplicitBroadcasting<TF::GreaterEqualOp>,
+ ApplyExplicitBroadcasting<TF::NotEqualOp>,
+ ApplyExplicitBroadcasting<TF::GreaterOp>,
+ ApplyExplicitBroadcasting<TF::LessOp>,
+ ApplyExplicitBroadcasting<TF::EqualOp>,
+ ApplyExplicitBroadcasting<TF::AddOp>,
+ ApplyExplicitBroadcasting<TF::AddV2Op>,
+ ApplyExplicitBroadcasting<TF::MulOp>,
+ ApplyExplicitBroadcasting<TF::DivOp>,
+ ApplyExplicitBroadcasting<TF::RealDivOp>,
+ ApplyExplicitBroadcasting<TF::SubOp>,
+ ApplyExplicitBroadcasting<TF::FloorDivOp>,
+ ApplyExplicitBroadcasting<TF::FloorModOp>,
+ ApplyExplicitBroadcasting<TF::PowOp>,
+ ApplyExplicitBroadcasting<TF::MaximumOp>,
+ ApplyExplicitBroadcasting<TF::MinimumOp>,
+ ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>,
+ ApplyExplicitBroadcasting<TF::SelectV2Op>>(context);
+
+ applyPatterns(func, target, patterns);
}
} // namespace
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index d26a490..751c526 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -198,7 +198,7 @@
auto output_type = output_val.getType().cast<RankedTensorType>();
auto shape_vector = output_type.getShape();
std::vector<int32_t> shape(shape_vector.size());
- for (int i = 0; i < shape_vector.size(); ++i) {
+ for (int i = 0, end = shape_vector.size(); i < end; ++i) {
shape[i] = shape_vector[i];
}
return mlir::DenseElementsAttr::get(
@@ -684,7 +684,7 @@
SmallVector<int, 8> old_major_index_ordering;
SmallVector<int, 8> new_major_index_ordering;
- for (int i = 0; i < input_shape.size(); i++) {
+ for (int i = 0, end = input_shape.size(); i < end; i++) {
if (input_shape[i] != 1) {
old_major_index_ordering.push_back(i);
}
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
index f792384..9261dea 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
@@ -225,7 +225,8 @@
LogicalResult CheckOutputConsumer(
Operation* call_op, int expected_num_outputs,
llvm::DenseSet<int> expected_consumer_indices) {
- if (call_op->getNumResults() != expected_num_outputs) return failure();
+ const int call_op_getNumResults = call_op->getNumResults();
+ if (call_op_getNumResults != expected_num_outputs) return failure();
for (int i = 0; i < expected_num_outputs; ++i) {
auto it = expected_consumer_indices.find(i);
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 9a883a3..0a7802c 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -41,7 +41,9 @@
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
+#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
@@ -49,6 +51,7 @@
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
@@ -58,7 +61,9 @@
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#define DEBUG_TYPE "tf-tfl-legalization"
@@ -495,7 +500,8 @@
original_input_type.getShape();
SmallVector<int64_t, 4> new_shape;
int index = 0;
- while (index < original_input_shape.size() || new_axis_mask) {
+ const int original_input_shape_size = original_input_shape.size();
+ while (index < original_input_shape_size || new_axis_mask) {
if (new_axis_mask & 1) {
new_shape.emplace_back(1);
} else {
@@ -737,6 +743,23 @@
return failure(has_illegal_ops);
}
+// Converts a set of TF2XLA ops into pure TF ops for future legalizations as
+// TF2XLA ops aren't supported by later stages.
+LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) {
+ ConversionTarget target(*context);
+ target.addLegalDialect<StandardOpsDialect>();
+ target.addLegalDialect<TF::TensorFlowDialect>();
+ target.addLegalOp<ModuleOp>();
+ target.addLegalOp<FuncOp>();
+ target.addIllegalOp<TF::XlaConvOp>();
+
+ OwningRewritePatternList patterns;
+ mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns);
+ TF::PopulateLegalizeHloToTfPatterns(&patterns, context);
+
+ return applyPartialConversion(func, target, patterns);
+}
+
void PrepareTFPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
@@ -752,6 +775,11 @@
return;
}
+ if (failed(ConvertTf2XlaOps(func, ctx))) {
+ signalPassFailure();
+ return;
+ }
+
// This pattern was intented to uses TFL QDQs to preserve the quantization
// parameters from the TF Quant ops, thus this pattern should run with the
// first `applyPatternsGreedily` method, which would otherwise removes the
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
index 2f876c6..3a469dd 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
@@ -134,7 +134,7 @@
// the input tensor's dimensions, return 0-valued tensor of the requested
// shape.
ArrayRef<int64_t> input_shape = GetRankedTensorShape(input);
- for (int i = 0; i < input_shape.size(); i++) {
+ for (int i = 0, end = input_shape.size(); i < end; i++) {
if (begin_values[i] < 0 ||
(begin_values[i] + size_values[i] > input_shape[i])) {
return CreateF32SplatConst(builder, size_shape, 0, location);