In InterpreterBuilder, set number of threads after all subgraphs are added.

*Previously, in InterpreterBuilder::operator()(std::unique_ptr<Interpreter>* interpreter) method, we call interpreter->SetNumThreads(num_threads_) before calling (*interpreter)->AddSubgraphs. As a result, the thread number will only be set to the TfLiteContext of the main subgraph. After the fix, the thread number will be set to each subgraph's TfLiteContext.
*This CL also changes the benchmark tflite tool to use InterpreterBuilder(std::unique_ptr<Interpreter>* interpreter) API instead of the deprecated one.

PiperOrigin-RevId: 416414642
Change-Id: I96f1ec74684adc1711bafc8e268ec42b4abaebf6
diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc
index 97d9516..bf744b1 100644
--- a/tensorflow/lite/interpreter_builder.cc
+++ b/tensorflow/lite/interpreter_builder.cc
@@ -749,11 +749,13 @@
   }
 
   interpreter->reset(new Interpreter(error_reporter_));
-  (*interpreter)->SetNumThreads(num_threads_);
   if (subgraphs->size() > 1) {
     (*interpreter)->AddSubgraphs(subgraphs->size() - 1);
   }
 
+  // Set num threads after all the subgraphs are added.
+  (*interpreter)->SetNumThreads(num_threads_);
+
   if (preserve_all_tensors_) {
     (*interpreter)->PreserveAllTensorsExperimental();
   }
@@ -776,7 +778,6 @@
     if (modified_subgraph->AddTensors(tensors->size()) != kTfLiteOk) {
       return cleanup_and_error();
     }
-    // Set num threads
     // Parse inputs/outputs
     modified_subgraph->SetInputs(
         FlatBufferIntArrayToVector(subgraph->inputs()));
diff --git a/tensorflow/lite/model_test.cc b/tensorflow/lite/model_test.cc
index ab8fbb4..ecd9307 100644
--- a/tensorflow/lite/model_test.cc
+++ b/tensorflow/lite/model_test.cc
@@ -335,6 +335,26 @@
                       reporter.error_messages());
 }
 
+TEST(BasicFlatBufferModel, TestSetNumThreadsWithMultipleSubgraphs) {
+  TestErrorReporter reporter;
+  auto model = FlatBufferModel::BuildFromFile(
+      "tensorflow/lite/testdata/2_subgraphs.bin", &reporter);
+  ASSERT_TRUE(model);
+  std::unique_ptr<Interpreter> interpreter;
+  TrivialResolver resolver(&dummy_reg);
+  InterpreterBuilder builder(*model, resolver);
+
+  ASSERT_EQ(builder.SetNumThreads(4), kTfLiteOk);
+  interpreter.reset();
+  ASSERT_EQ(builder(&interpreter), kTfLiteOk);
+  ASSERT_NE(interpreter, nullptr);
+
+  // Check that each subgraph has the expected number of threads set.
+  for (int i = 0; i < interpreter->subgraphs_size(); ++i) {
+    EXPECT_EQ(interpreter->subgraph(i)->context()->recommended_num_threads, 4);
+  }
+}
+
 // Test that loading a model with TensorFlow ops fails when the flex delegate is
 // not linked into the target.
 TEST(FlexModel, FailureWithoutFlexDelegate) {
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
index c4726b8..cc1729e 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -648,7 +648,14 @@
   auto resolver = GetOpResolver();
   const int32_t num_threads = params_.Get<int32_t>("num_threads");
   const bool use_caching = params_.Get<bool>("use_caching");
-  tflite::InterpreterBuilder(*model_, *resolver)(&interpreter_, num_threads);
+
+  tflite::InterpreterBuilder builder(*model_, *resolver);
+  if (builder.SetNumThreads(num_threads) != kTfLiteOk) {
+    TFLITE_LOG(ERROR) << "Failed to set thread number";
+    return kTfLiteError;
+  }
+
+  builder(&interpreter_);
   if (!interpreter_) {
     TFLITE_LOG(ERROR) << "Failed to initialize the interpreter";
     return kTfLiteError;