lite: Update Subgraph::RemoveUnusedInputs()
The method has a bug which doesn't count graph output tensors.
This change fixes it.
PiperOrigin-RevId: 422469893
Change-Id: Ie0b1a4b56b0e29fe4c5d14f0713d1eeea72f3d55
diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc
index 1eda305..1cbfe2c 100644
--- a/tensorflow/lite/core/subgraph.cc
+++ b/tensorflow/lite/core/subgraph.cc
@@ -1120,6 +1120,11 @@
}
}
}
+ // Count references to SubGraph output tensors.
+ for (auto iter = outputs_.begin(); iter != outputs_.end(); iter++) {
+ if (*iter == kTfLiteOptionalTensor) continue;
+ refcounts[*iter]++;
+ }
// Mark unused inputs as kTfLiteOptionalTensor.
for (auto iter = inputs_.begin(); iter != inputs_.end(); iter++) {
diff --git a/tensorflow/lite/core/subgraph_test.cc b/tensorflow/lite/core/subgraph_test.cc
index 4ee5105..bbf9b27 100644
--- a/tensorflow/lite/core/subgraph_test.cc
+++ b/tensorflow/lite/core/subgraph_test.cc
@@ -60,5 +60,16 @@
ASSERT_EQ(subgraph.inputs(), std::vector<int>({-1, -1, 2}));
}
+TEST(RemoveUnusedInputs, BypassInputsWithoutOp) {
+ Interpreter interpreter;
+ auto& subgraph = interpreter.primary_subgraph();
+ subgraph.AddTensors(3);
+ subgraph.SetInputs({0, 1, 2});
+ subgraph.SetOutputs({0, 2});
+
+ ASSERT_EQ(subgraph.RemoveUnusedInputs(), kTfLiteOk);
+ ASSERT_EQ(subgraph.inputs(), std::vector<int>({0, -1, 2}));
+}
+
} // namespace
} // namespace tflite