[TFLite/MLIR] Avoids crash in ReduceWhileOperands.
PiperOrigin-RevId: 400071680
Change-Id: Ic8d50ed93f9b820dc63c9a40e5bb1e4906ab1fe4
diff --git a/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc b/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc
index 29038d5..e943885 100644
--- a/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc
@@ -67,7 +67,7 @@
void runOnFunction() override;
};
-void FindImplicityProducers(
+LogicalResult FindImplicityProducers(
const std::vector<uint64_t> &explicitly_consumed_ids,
std::vector<bool> &is_consumed_id,
const std::vector<std::vector<uint64_t>> &dependency_graph) {
@@ -80,12 +80,21 @@
while (!queue.empty()) {
auto i = queue.back();
queue.pop_back();
- for (auto j : dependency_graph[i]) {
+
+ // If there is a consumer which cannot be found in dependency graph, return
+ // false.
+ if (i >= dependency_graph.size()) {
+ return failure();
+ }
+
+ for (auto j : dependency_graph.at(i)) {
if (is_consumed_id[j]) continue;
queue.push_back(j);
is_consumed_id[j] = true;
}
}
+
+ return success();
}
void FindProducers(Value start_node, std::vector<uint64_t> &neighbors) {
@@ -218,8 +227,10 @@
}
std::vector<bool> is_consumed_id(n, false);
- FindImplicityProducers(explicitly_consumed_ids, is_consumed_id,
- dependency_graph);
+ if (failed(FindImplicityProducers(explicitly_consumed_ids, is_consumed_id,
+ dependency_graph))) {
+ return false;
+ }
// Find all consumed operations in while body.
llvm::DenseSet<Operation *> consumed_ops;