In InterpreterBuilder, set number of threads after all subgraphs are added.
*Previously, in InterpreterBuilder::operator()(std::unique_ptr<Interpreter>* interpreter) method, we call interpreter->SetNumThreads(num_threads_) before calling (*interpreter)->AddSubgraphs. As a result, the thread number will only be set to the TfLiteContext of the main subgraph. After the fix, the thread number will be set to each subgraph's TfLiteContext.
*This CL also changes the benchmark tflite tool to use InterpreterBuilder(std::unique_ptr<Interpreter>* interpreter) API instead of the deprecated one.
PiperOrigin-RevId: 416414642
Change-Id: I96f1ec74684adc1711bafc8e268ec42b4abaebf6
diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc
index 97d9516..bf744b1 100644
--- a/tensorflow/lite/interpreter_builder.cc
+++ b/tensorflow/lite/interpreter_builder.cc
@@ -749,11 +749,13 @@
}
interpreter->reset(new Interpreter(error_reporter_));
- (*interpreter)->SetNumThreads(num_threads_);
if (subgraphs->size() > 1) {
(*interpreter)->AddSubgraphs(subgraphs->size() - 1);
}
+ // Set num threads after all the subgraphs are added.
+ (*interpreter)->SetNumThreads(num_threads_);
+
if (preserve_all_tensors_) {
(*interpreter)->PreserveAllTensorsExperimental();
}
@@ -776,7 +778,6 @@
if (modified_subgraph->AddTensors(tensors->size()) != kTfLiteOk) {
return cleanup_and_error();
}
- // Set num threads
// Parse inputs/outputs
modified_subgraph->SetInputs(
FlatBufferIntArrayToVector(subgraph->inputs()));
diff --git a/tensorflow/lite/model_test.cc b/tensorflow/lite/model_test.cc
index ab8fbb4..ecd9307 100644
--- a/tensorflow/lite/model_test.cc
+++ b/tensorflow/lite/model_test.cc
@@ -335,6 +335,26 @@
reporter.error_messages());
}
+TEST(BasicFlatBufferModel, TestSetNumThreadsWithMultipleSubgraphs) {
+ TestErrorReporter reporter;
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/lite/testdata/2_subgraphs.bin", &reporter);
+ ASSERT_TRUE(model);
+ std::unique_ptr<Interpreter> interpreter;
+ TrivialResolver resolver(&dummy_reg);
+ InterpreterBuilder builder(*model, resolver);
+
+ ASSERT_EQ(builder.SetNumThreads(4), kTfLiteOk);
+ interpreter.reset();
+ ASSERT_EQ(builder(&interpreter), kTfLiteOk);
+ ASSERT_NE(interpreter, nullptr);
+
+ // Check that each subgraph has the expected number of threads set.
+ for (int i = 0; i < interpreter->subgraphs_size(); ++i) {
+ EXPECT_EQ(interpreter->subgraph(i)->context()->recommended_num_threads, 4);
+ }
+}
+
// Test that loading a model with TensorFlow ops fails when the flex delegate is
// not linked into the target.
TEST(FlexModel, FailureWithoutFlexDelegate) {
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
index c4726b8..cc1729e 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -648,7 +648,14 @@
auto resolver = GetOpResolver();
const int32_t num_threads = params_.Get<int32_t>("num_threads");
const bool use_caching = params_.Get<bool>("use_caching");
- tflite::InterpreterBuilder(*model_, *resolver)(&interpreter_, num_threads);
+
+ tflite::InterpreterBuilder builder(*model_, *resolver);
+ if (builder.SetNumThreads(num_threads) != kTfLiteOk) {
+ TFLITE_LOG(ERROR) << "Failed to set thread number";
+ return kTfLiteError;
+ }
+
+ builder(&interpreter_);
if (!interpreter_) {
TFLITE_LOG(ERROR) << "Failed to initialize the interpreter";
return kTfLiteError;