Utilize globally-registered delegate providers to initialize tflite delegates in the imagenet-based image classification accuracy evaluation tool.

PiperOrigin-RevId: 306357221
Change-Id: I34a9fbf22e3d928c1c2b376b93c6d792c74fa94c
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/lite/tools/accuracy/ilsvrc/BUILD
index f350914..7b82988 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/BUILD
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/BUILD
@@ -58,5 +58,6 @@
     deps = [
         ":imagenet_accuracy_eval_lib",
         "//tensorflow/lite/tools:command_line_flags",
+        "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
     ],
 )
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval_main.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval_main.cc
index af9bdf3..19ce803 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval_main.cc
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval_main.cc
@@ -15,19 +15,20 @@
 
 #include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h"
 #include "tensorflow/lite/tools/command_line_flags.h"
+#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
 
 namespace {
-constexpr char kNumThreadsFlag[] = "num_threads";
+constexpr char kNumEvalThreadsFlag[] = "num_eval_threads";
 constexpr char kOutputFilePathFlag[] = "output_file_path";
 constexpr char kProtoOutputFilePathFlag[] = "proto_output_file_path";
 }  // namespace
 
 int main(int argc, char* argv[]) {
   std::string output_file_path, proto_output_file_path;
-  int num_threads = 4;
+  int num_eval_threads = 4;
   std::vector<tflite::Flag> flag_list = {
-      tflite::Flag::CreateFlag(kNumThreadsFlag, &num_threads,
-                               "Number of threads."),
+      tflite::Flag::CreateFlag(kNumEvalThreadsFlag, &num_eval_threads,
+                               "Number of threads used for evaluation."),
       tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
                                "Path to output file."),
       tflite::Flag::CreateFlag(kProtoOutputFilePathFlag,
@@ -36,14 +37,17 @@
   };
   tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
 
-  if (num_threads <= 0) {
+  if (num_eval_threads <= 0) {
     LOG(ERROR) << "Invalid number of threads.";
     return EXIT_FAILURE;
   }
 
+  tflite::evaluation::DelegateProviders delegate_providers;
+  delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
+
   std::unique_ptr<tensorflow::metrics::ImagenetModelEvaluator> evaluator =
       tensorflow::metrics::CreateImagenetModelEvaluator(&argc, argv,
-                                                        num_threads);
+                                                        num_eval_threads);
 
   if (!evaluator) {
     LOG(ERROR) << "Fail to create the ImagenetModelEvaluator.";
@@ -59,8 +63,8 @@
   }
 
   evaluator->AddObserver(writer.get());
-  LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
-  if (evaluator->EvaluateModel() != kTfLiteOk) {
+  LOG(ERROR) << "Starting evaluation with: " << num_eval_threads << " threads.";
+  if (evaluator->EvaluateModel(&delegate_providers) != kTfLiteOk) {
     LOG(ERROR) << "Failed to evaluate the model!";
     return EXIT_FAILURE;
   }
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
index 558ee8b..f318dc6 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
@@ -26,7 +26,6 @@
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/tools/command_line_flags.h"
-#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
 #include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
@@ -155,12 +154,12 @@
   return kTfLiteOk;
 }
 
-TfLiteStatus EvaluateModelForShard(const uint64_t shard_id,
-                                   const std::vector<ImageLabel>& image_labels,
-                                   const std::vector<std::string>& model_labels,
-                                   const ImagenetModelEvaluator::Params& params,
-                                   ImagenetModelEvaluator::Observer* observer,
-                                   int num_ranks) {
+TfLiteStatus EvaluateModelForShard(
+    const uint64_t shard_id, const std::vector<ImageLabel>& image_labels,
+    const std::vector<std::string>& model_labels,
+    const ImagenetModelEvaluator::Params& params,
+    ImagenetModelEvaluator::Observer* observer, int num_ranks,
+    const tflite::evaluation::DelegateProviders* delegate_providers) {
   tflite::evaluation::EvaluationStageConfig eval_config;
   eval_config.set_name("image_classification");
   auto* classification_params = eval_config.mutable_specification()
@@ -174,7 +173,7 @@
 
   tflite::evaluation::ImageClassificationStage eval(eval_config);
   eval.SetAllLabels(model_labels);
-  TF_LITE_ENSURE_STATUS(eval.Init());
+  TF_LITE_ENSURE_STATUS(eval.Init(delegate_providers));
 
   for (const auto& image_label : image_labels) {
     eval.SetInputs(image_label.image, image_label.label);
@@ -191,7 +190,8 @@
   return kTfLiteOk;
 }
 
-TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
+TfLiteStatus ImagenetModelEvaluator::EvaluateModel(
+    const tflite::evaluation::DelegateProviders* delegate_providers) const {
   const std::string data_path = tflite::evaluation::StripTrailingSlashes(
                                     params_.ground_truth_images_path) +
                                 "/";
@@ -252,9 +252,10 @@
     const uint64_t shard_id = i + 1;
     shard_id_image_count_map[shard_id] = image_label.size();
     auto func = [shard_id, &image_label, &model_labels, this, &observer,
-                 &all_okay]() {
+                 &all_okay, delegate_providers]() {
       if (EvaluateModelForShard(shard_id, image_label, model_labels, params_,
-                                &observer, params_.num_ranks) != kTfLiteOk) {
+                                &observer, params_.num_ranks,
+                                delegate_providers) != kTfLiteOk) {
         all_okay = false;
       }
     };
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
index b10b0f8..65d4a2c 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
@@ -21,6 +21,7 @@
 #include <vector>
 
 #include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
 
 namespace tensorflow {
@@ -120,7 +121,8 @@
   const Params& params() const { return params_; }
 
   // Evaluates the provided model over the dataset.
-  TfLiteStatus EvaluateModel() const;
+  TfLiteStatus EvaluateModel(const tflite::evaluation::DelegateProviders*
+                                 delegate_providers = nullptr) const;
 
  private:
   const Params params_;