Utilize globally-registered delegate providers to initialize tflite delegates in the imagenet-based image classification accuracy evaluation tool.
PiperOrigin-RevId: 306357221
Change-Id: I34a9fbf22e3d928c1c2b376b93c6d792c74fa94c
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/lite/tools/accuracy/ilsvrc/BUILD
index f350914..7b82988 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/BUILD
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/BUILD
@@ -58,5 +58,6 @@
deps = [
":imagenet_accuracy_eval_lib",
"//tensorflow/lite/tools:command_line_flags",
+ "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
],
)
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval_main.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval_main.cc
index af9bdf3..19ce803 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval_main.cc
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval_main.cc
@@ -15,19 +15,20 @@
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h"
#include "tensorflow/lite/tools/command_line_flags.h"
+#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
namespace {
-constexpr char kNumThreadsFlag[] = "num_threads";
+constexpr char kNumEvalThreadsFlag[] = "num_eval_threads";
constexpr char kOutputFilePathFlag[] = "output_file_path";
constexpr char kProtoOutputFilePathFlag[] = "proto_output_file_path";
} // namespace
int main(int argc, char* argv[]) {
std::string output_file_path, proto_output_file_path;
- int num_threads = 4;
+ int num_eval_threads = 4;
std::vector<tflite::Flag> flag_list = {
- tflite::Flag::CreateFlag(kNumThreadsFlag, &num_threads,
- "Number of threads."),
+ tflite::Flag::CreateFlag(kNumEvalThreadsFlag, &num_eval_threads,
+ "Number of threads used for evaluation."),
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
"Path to output file."),
tflite::Flag::CreateFlag(kProtoOutputFilePathFlag,
@@ -36,14 +37,17 @@
};
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
- if (num_threads <= 0) {
+ if (num_eval_threads <= 0) {
LOG(ERROR) << "Invalid number of threads.";
return EXIT_FAILURE;
}
+ tflite::evaluation::DelegateProviders delegate_providers;
+ delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
+
std::unique_ptr<tensorflow::metrics::ImagenetModelEvaluator> evaluator =
tensorflow::metrics::CreateImagenetModelEvaluator(&argc, argv,
- num_threads);
+ num_eval_threads);
if (!evaluator) {
LOG(ERROR) << "Fail to create the ImagenetModelEvaluator.";
@@ -59,8 +63,8 @@
}
evaluator->AddObserver(writer.get());
- LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
- if (evaluator->EvaluateModel() != kTfLiteOk) {
+ LOG(ERROR) << "Starting evaluation with: " << num_eval_threads << " threads.";
+ if (evaluator->EvaluateModel(&delegate_providers) != kTfLiteOk) {
LOG(ERROR) << "Failed to evaluate the model!";
return EXIT_FAILURE;
}
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
index 558ee8b..f318dc6 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
@@ -26,7 +26,6 @@
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/tools/command_line_flags.h"
-#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
#include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
@@ -155,12 +154,12 @@
return kTfLiteOk;
}
-TfLiteStatus EvaluateModelForShard(const uint64_t shard_id,
- const std::vector<ImageLabel>& image_labels,
- const std::vector<std::string>& model_labels,
- const ImagenetModelEvaluator::Params& params,
- ImagenetModelEvaluator::Observer* observer,
- int num_ranks) {
+TfLiteStatus EvaluateModelForShard(
+ const uint64_t shard_id, const std::vector<ImageLabel>& image_labels,
+ const std::vector<std::string>& model_labels,
+ const ImagenetModelEvaluator::Params& params,
+ ImagenetModelEvaluator::Observer* observer, int num_ranks,
+ const tflite::evaluation::DelegateProviders* delegate_providers) {
tflite::evaluation::EvaluationStageConfig eval_config;
eval_config.set_name("image_classification");
auto* classification_params = eval_config.mutable_specification()
@@ -174,7 +173,7 @@
tflite::evaluation::ImageClassificationStage eval(eval_config);
eval.SetAllLabels(model_labels);
- TF_LITE_ENSURE_STATUS(eval.Init());
+ TF_LITE_ENSURE_STATUS(eval.Init(delegate_providers));
for (const auto& image_label : image_labels) {
eval.SetInputs(image_label.image, image_label.label);
@@ -191,7 +190,8 @@
return kTfLiteOk;
}
-TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
+TfLiteStatus ImagenetModelEvaluator::EvaluateModel(
+ const tflite::evaluation::DelegateProviders* delegate_providers) const {
const std::string data_path = tflite::evaluation::StripTrailingSlashes(
params_.ground_truth_images_path) +
"/";
@@ -252,9 +252,10 @@
const uint64_t shard_id = i + 1;
shard_id_image_count_map[shard_id] = image_label.size();
auto func = [shard_id, &image_label, &model_labels, this, &observer,
- &all_okay]() {
+ &all_okay, delegate_providers]() {
if (EvaluateModelForShard(shard_id, image_label, model_labels, params_,
- &observer, params_.num_ranks) != kTfLiteOk) {
+ &observer, params_.num_ranks,
+ delegate_providers) != kTfLiteOk) {
all_okay = false;
}
};
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
index b10b0f8..65d4a2c 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
@@ -21,6 +21,7 @@
#include <vector>
#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
namespace tensorflow {
@@ -120,7 +121,8 @@
const Params& params() const { return params_; }
// Evaluates the provided model over the dataset.
- TfLiteStatus EvaluateModel() const;
+ TfLiteStatus EvaluateModel(const tflite::evaluation::DelegateProviders*
+ delegate_providers = nullptr) const;
private:
const Params params_;