| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow_lite_support/cc/task/vision/image_classifier.h" |
| |
| #include "absl/algorithm/container.h" |
| #include "absl/strings/str_format.h" |
| #include "absl/strings/string_view.h" |
| #include "flatbuffers/flatbuffers.h" // from @flatbuffers |
| #include "tensorflow/lite/interpreter.h" |
| #include "tensorflow_lite_support/cc/common.h" |
| #include "tensorflow_lite_support/cc/port/integral_types.h" |
| #include "tensorflow_lite_support/cc/port/status_macros.h" |
| #include "tensorflow_lite_support/cc/task/core/task_api_factory.h" |
| #include "tensorflow_lite_support/cc/task/core/task_utils.h" |
| #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" |
| #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" |
| #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" |
| #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" |
| #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" |
| #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" |
| |
| namespace tflite { |
| namespace task { |
| namespace vision { |
| |
| namespace { |
| |
| using ::absl::StatusCode; |
| using ::tflite::metadata::ModelMetadataExtractor; |
| using ::tflite::support::CreateStatusWithPayload; |
| using ::tflite::support::StatusOr; |
| using ::tflite::support::TfLiteSupportStatus; |
| using ::tflite::task::core::AssertAndReturnTypedTensor; |
| using ::tflite::task::core::TaskAPIFactory; |
| using ::tflite::task::core::TfLiteEngine; |
| |
| // Default score value used as a fallback for classes that (1) have no score |
| // calibration data or (2) have a very low confident uncalibrated score, i.e. |
| // lower than the `min_uncalibrated_score` threshold. |
| // |
| // (1) This happens when the ScoreCalibration does not cover all the classes |
| // listed in the label map. This can be used to enforce the blacklisting of |
| // given classes so that they are never returned. |
| // |
| // (2) This is an optional threshold provided part of the calibration data. It |
| // is used to mitigate false alarms on some classes. |
| // |
| // In both cases, a class that gets assigned a score of -1 is never returned as |
| // it gets discarded by the `score_threshold` check (see post-processing logic). |
| constexpr float kDefaultCalibratedScore = -1.0f; |
| |
| // Calibrated scores should be in the [0, 1] range, otherwise an error is |
| // returned at post-processing time. |
| constexpr float kMinCalibratedScore = 0.0f; |
| constexpr float kMaxCalibratedScore = 1.0f; |
| |
| } // namespace |
| |
| /* static */ |
| StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::CreateFromOptions( |
| const ImageClassifierOptions& options, |
| std::unique_ptr<tflite::OpResolver> resolver) { |
| RETURN_IF_ERROR(SanityCheckOptions(options)); |
| |
| // Copy options to ensure the ExternalFile outlives the constructed object. |
| auto options_copy = absl::make_unique<ImageClassifierOptions>(options); |
| |
| ASSIGN_OR_RETURN(auto image_classifier, |
| TaskAPIFactory::CreateFromExternalFileProto<ImageClassifier>( |
| &options_copy->model_file_with_metadata(), |
| std::move(resolver), options_copy->num_threads())); |
| |
| RETURN_IF_ERROR(image_classifier->Init(std::move(options_copy))); |
| |
| return image_classifier; |
| } |
| |
| /* static */ |
| absl::Status ImageClassifier::SanityCheckOptions( |
| const ImageClassifierOptions& options) { |
| if (!options.has_model_file_with_metadata()) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| "Missing mandatory `model_file_with_metadata` field", |
| TfLiteSupportStatus::kInvalidArgumentError); |
| } |
| if (options.max_results() == 0) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| "Invalid `max_results` option: value must be != 0", |
| TfLiteSupportStatus::kInvalidArgumentError); |
| } |
| if (options.score_threshold() < 0 || options.score_threshold() >= 1) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat( |
| "`score_threshold` out of range: %f. Valid range is [0,1[.", |
| options.score_threshold()), |
| TfLiteSupportStatus::kInvalidArgumentError); |
| } |
| if (options.class_name_whitelist_size() > 0 && |
| options.class_name_blacklist_size() > 0) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| "`class_name_whitelist` and `class_name_blacklist` are mutually " |
| "exclusive options.", |
| TfLiteSupportStatus::kInvalidArgumentError); |
| } |
| if (options.num_threads() == 0 || options.num_threads() < -1) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| "`num_threads` must be greater than 0 or equal to -1.", |
| TfLiteSupportStatus::kInvalidArgumentError); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status ImageClassifier::Init( |
| std::unique_ptr<ImageClassifierOptions> options) { |
| // Set options. |
| options_ = std::move(options); |
| |
| // Perform pre-initialization actions (by default, sets the process engine for |
| // image pre-processing to kLibyuv as a sane default). |
| RETURN_IF_ERROR(PreInit()); |
| |
| // Sanity check and set inputs and outputs. |
| RETURN_IF_ERROR(CheckAndSetInputs()); |
| RETURN_IF_ERROR(CheckAndSetOutputs()); |
| |
| // Initialize class whitelisting/blacklisting, if any. |
| RETURN_IF_ERROR(CheckAndSetClassNameSet()); |
| |
| // Perform final initialization (by default, initialize score calibration |
| // parameters, if any). |
| RETURN_IF_ERROR(PostInit()); |
| |
| return absl::OkStatus(); |
| } |
| |
| absl::Status ImageClassifier::PreInit() { |
| SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status ImageClassifier::PostInit() { return InitScoreCalibrations(); } |
| |
| absl::Status ImageClassifier::CheckAndSetOutputs() { |
| num_outputs_ = TfLiteEngine::OutputCount(engine_->interpreter()); |
| |
| // Perform sanity checks and extract metadata. |
| const ModelMetadataExtractor* metadata_extractor = |
| engine_->metadata_extractor(); |
| |
| const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* |
| output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata(); |
| |
| // Loop over output tensors metadata, if any. |
| // Note: models with no output tensor metadata at all are supported. |
| if (output_tensor_metadata != nullptr) { |
| int num_output_tensors = output_tensor_metadata->size(); |
| |
| if (num_outputs_ != num_output_tensors) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat("Mismatch between number of output tensors (%d) and " |
| "output tensors " |
| "metadata (%d).", |
| num_outputs_, num_output_tensors), |
| TfLiteSupportStatus::kMetadataInconsistencyError); |
| } |
| |
| for (int i = 0; i < num_output_tensors; ++i) { |
| const tflite::TensorMetadata* output_tensor = |
| output_tensor_metadata->Get(i); |
| |
| ASSIGN_OR_RETURN( |
| ClassificationHead head, |
| BuildClassificationHead(*metadata_extractor, *output_tensor, |
| options_->display_names_locale())); |
| |
| classification_heads_.emplace_back(std::move(head)); |
| } |
| } |
| |
| // If classifier heads are not set, build default ones based on model |
| // introspection. This happens if a model with partial or no metadata was |
| // provided through the `model_file_with_metadata` options field. |
| if (classification_heads_.empty()) { |
| classification_heads_.reserve(num_outputs_); |
| for (int output_index = 0; output_index < num_outputs_; ++output_index) { |
| classification_heads_.emplace_back(ClassificationHead{}); |
| } |
| } |
| |
| if (num_outputs_ != classification_heads_.size()) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat("Got %d classifier head(s), expected %d according to " |
| "the label map.", |
| num_outputs_, classification_heads_.size()), |
| TfLiteSupportStatus::kMetadataInconsistencyError); |
| } |
| |
| int num_quantized_outputs = 0; |
| for (int i = 0; i < num_outputs_; ++i) { |
| const TfLiteTensor* output_tensor = |
| TfLiteEngine::GetOutput(engine_->interpreter(), i); |
| const int num_dimensions = output_tensor->dims->size; |
| if (num_dimensions == 4) { |
| if (output_tensor->dims->data[1] != 1 || |
| output_tensor->dims->data[2] != 1) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat("Unexpected WxH sizes for output index %d: got " |
| "%dx%d, expected 1x1.", |
| i, output_tensor->dims->data[2], |
| output_tensor->dims->data[1]), |
| TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); |
| } |
| } else if (num_dimensions != 2) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat( |
| "Unexpected number of dimensions for output index %d: got %dD, " |
| "expected either 2D (BxN with B=1) or 4D (BxHxWxN with B=1, W=1, " |
| "H=1).", |
| i, num_dimensions), |
| TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); |
| } |
| if (output_tensor->dims->data[0] != 1) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat("The output array is expected to have a batch size " |
| "of 1. Got %d for output index %d.", |
| output_tensor->dims->data[0], i), |
| TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); |
| } |
| int num_classes = output_tensor->dims->data[num_dimensions - 1]; |
| // If label map is not set, build a default one based on model |
| // introspection. This happens if a model with partial or no metadata was |
| // provided through the `model_file_with_metadata` options field. |
| if (classification_heads_[i].label_map_items.empty()) { |
| classification_heads_[i].label_map_items.reserve(num_classes); |
| for (int class_index = 0; class_index < num_classes; ++class_index) { |
| classification_heads_[i].label_map_items.emplace_back(LabelMapItem{}); |
| } |
| } |
| int num_label_map_items = classification_heads_[i].label_map_items.size(); |
| if (num_classes != num_label_map_items) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat("Got %d class(es) for output index %d, expected %d " |
| "according to the label map.", |
| output_tensor->dims->data[num_dimensions - 1], i, |
| num_label_map_items), |
| TfLiteSupportStatus::kMetadataInconsistencyError); |
| } |
| if (output_tensor->type == kTfLiteUInt8) { |
| num_quantized_outputs++; |
| } else if (output_tensor->type != kTfLiteFloat32) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat("Type mismatch for output tensor %s. Requested one " |
| "of these types: " |
| "kTfLiteUint8/kTfLiteFloat32, got %s.", |
| output_tensor->name, |
| TfLiteTypeGetName(output_tensor->type)), |
| TfLiteSupportStatus::kInvalidOutputTensorTypeError); |
| } |
| } |
| |
| if (num_quantized_outputs > 0 && num_quantized_outputs != num_outputs_) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat("Got %d quantized output(s), expected %d (i.e. all " |
| "provided outputs must be quantized).", |
| num_quantized_outputs, num_outputs_), |
| TfLiteSupportStatus::kInvalidOutputTensorTypeError); |
| } |
| has_uint8_outputs_ = (num_quantized_outputs > 0); |
| |
| return absl::OkStatus(); |
| } |
| |
| absl::Status ImageClassifier::CheckAndSetClassNameSet() { |
| // Exit early if no blacklist/whitelist. |
| if (options_->class_name_blacklist_size() == 0 && |
| options_->class_name_whitelist_size() == 0) { |
| return absl::OkStatus(); |
| } |
| |
| // Before processing class names whitelist or blacklist from the input options |
| // create a set with _all_ known class names from the label map(s). |
| absl::flat_hash_set<std::string> all_class_names; |
| int head_index = 0; |
| for (const auto& head : classification_heads_) { |
| absl::flat_hash_set<std::string> head_class_names; |
| for (const auto& item : head.label_map_items) { |
| if (!item.name.empty()) { |
| head_class_names.insert(item.name); |
| } |
| } |
| if (head_class_names.empty()) { |
| std::string name = head.name; |
| if (name.empty()) { |
| name = absl::StrFormat("#%d", head_index); |
| } |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat( |
| "Using `class_name_whitelist` or `class_name_blacklist` " |
| "requires labels to be present but none was found for " |
| "classification head: %s", |
| name), |
| TfLiteSupportStatus::kMetadataMissingLabelsError); |
| } |
| all_class_names.insert(head_class_names.begin(), head_class_names.end()); |
| head_index++; |
| } |
| |
| class_name_set_.is_whitelist = options_->class_name_whitelist_size() > 0; |
| const auto& class_names = class_name_set_.is_whitelist |
| ? options_->class_name_whitelist() |
| : options_->class_name_blacklist(); |
| |
| // Note: duplicate or unknown classes are just ignored. |
| class_name_set_.values.clear(); |
| for (const auto& class_name : class_names) { |
| if (!all_class_names.contains(class_name)) { |
| continue; |
| } |
| class_name_set_.values.insert(class_name); |
| } |
| |
| if (class_name_set_.values.empty()) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat( |
| "Invalid class names specified via `class_name_%s`: none match " |
| "with model labels.", |
| class_name_set_.is_whitelist ? "whitelist" : "blacklist"), |
| TfLiteSupportStatus::kInvalidArgumentError); |
| } |
| |
| return absl::OkStatus(); |
| } |
| |
| absl::Status ImageClassifier::InitScoreCalibrations() { |
| score_calibrations_.clear(); |
| score_calibrations_.resize(classification_heads_.size()); |
| |
| for (int i = 0; i < classification_heads_.size(); ++i) { |
| if (!classification_heads_[i].calibration_params.has_value()) { |
| continue; |
| } |
| |
| // Use a specific default score instead of the one specified by default in |
| // cc/task/vision/utils/score_calibration.h. See `kDefaultCalibratedScore` |
| // documentation for more details. |
| classification_heads_[i].calibration_params->default_score = |
| kDefaultCalibratedScore; |
| |
| score_calibrations_[i] = absl::make_unique<ScoreCalibration>(); |
| if (score_calibrations_[i] == nullptr) { |
| return CreateStatusWithPayload( |
| StatusCode::kInternal, "Could not create score calibration object."); |
| } |
| |
| RETURN_IF_ERROR(score_calibrations_[i]->InitializeFromParameters( |
| classification_heads_[i].calibration_params.value())); |
| } |
| |
| return absl::OkStatus(); |
| } |
| |
| StatusOr<ClassificationResult> ImageClassifier::Classify( |
| const FrameBuffer& frame_buffer) { |
| BoundingBox roi; |
| roi.set_width(frame_buffer.dimension().width); |
| roi.set_height(frame_buffer.dimension().height); |
| return Classify(frame_buffer, roi); |
| } |
| |
| StatusOr<ClassificationResult> ImageClassifier::Classify( |
| const FrameBuffer& frame_buffer, const BoundingBox& roi) { |
| return InferWithFallback(frame_buffer, roi); |
| } |
| |
| StatusOr<ClassificationResult> ImageClassifier::Postprocess( |
| const std::vector<const TfLiteTensor*>& output_tensors, |
| const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) { |
| if (output_tensors.size() != num_outputs_) { |
| return CreateStatusWithPayload( |
| StatusCode::kInternal, |
| absl::StrFormat("Expected %d output tensors, found %d", num_outputs_, |
| output_tensors.size())); |
| } |
| |
| ClassificationResult result; |
| std::vector<std::pair<int, float>> score_pairs; |
| |
| for (int i = 0; i < num_outputs_; ++i) { |
| auto* classifications = result.add_classifications(); |
| classifications->set_head_index(i); |
| |
| const auto& head = classification_heads_[i]; |
| score_pairs.clear(); |
| score_pairs.reserve(head.label_map_items.size()); |
| |
| const TfLiteTensor* output_tensor = output_tensors[i]; |
| if (has_uint8_outputs_) { |
| const uint8* output_data = |
| AssertAndReturnTypedTensor<uint8>(output_tensor); |
| for (int j = 0; j < head.label_map_items.size(); ++j) { |
| score_pairs.emplace_back(j, output_tensor->params.scale * |
| (static_cast<int>(output_data[j]) - |
| output_tensor->params.zero_point)); |
| } |
| } else { |
| const float* output_data = |
| AssertAndReturnTypedTensor<float>(output_tensor); |
| for (int j = 0; j < head.label_map_items.size(); ++j) { |
| score_pairs.emplace_back(j, output_data[j]); |
| } |
| } |
| |
| // Optional score calibration. |
| if (score_calibrations_[i] != nullptr) { |
| for (auto& score_pair : score_pairs) { |
| const std::string& class_name = |
| head.label_map_items[score_pair.first].name; |
| score_pair.second = score_calibrations_[i]->ComputeCalibratedScore( |
| class_name, score_pair.second); |
| if (score_pair.second > kMaxCalibratedScore) { |
| return CreateStatusWithPayload( |
| StatusCode::kInternal, |
| absl::StrFormat("calibrated score is too high: got %f, expected " |
| "%f as maximum.", |
| score_pair.second, kMaxCalibratedScore)); |
| } |
| if (score_pair.second != kDefaultCalibratedScore && |
| score_pair.second < kMinCalibratedScore) { |
| return CreateStatusWithPayload( |
| StatusCode::kInternal, |
| absl::StrFormat("calibrated score is too low: got %f, expected " |
| "%f as minimum.", |
| score_pair.second, kMinCalibratedScore)); |
| } |
| } |
| } |
| |
| int num_results = |
| options_->max_results() >= 0 |
| ? std::min(static_cast<int>(head.label_map_items.size()), |
| options_->max_results()) |
| : head.label_map_items.size(); |
| float score_threshold = options_->has_score_threshold() |
| ? options_->score_threshold() |
| : head.score_threshold; |
| |
| if (class_name_set_.values.empty()) { |
| // Partially sort in descending order (higher score is better). |
| absl::c_partial_sort( |
| score_pairs, score_pairs.begin() + num_results, |
| [](const std::pair<int, float>& a, const std::pair<int, float>& b) { |
| return a.second > b.second; |
| }); |
| |
| for (int j = 0; j < num_results; ++j) { |
| float score = score_pairs[j].second; |
| if (score < score_threshold) { |
| break; |
| } |
| auto* cl = classifications->add_classes(); |
| cl->set_index(score_pairs[j].first); |
| cl->set_score(score); |
| } |
| } else { |
| // Sort in descending order (higher score is better). |
| absl::c_sort(score_pairs, [](const std::pair<int, float>& a, |
| const std::pair<int, float>& b) { |
| return a.second > b.second; |
| }); |
| |
| for (int j = 0; j < head.label_map_items.size(); ++j) { |
| float score = score_pairs[j].second; |
| if (score < score_threshold || |
| classifications->classes_size() >= num_results) { |
| break; |
| } |
| |
| const int class_index = score_pairs[j].first; |
| const std::string& class_name = head.label_map_items[class_index].name; |
| |
| bool class_name_found = class_name_set_.values.contains(class_name); |
| |
| if ((!class_name_found && class_name_set_.is_whitelist) || |
| (class_name_found && !class_name_set_.is_whitelist)) { |
| continue; |
| } |
| |
| auto* cl = classifications->add_classes(); |
| cl->set_index(class_index); |
| cl->set_score(score); |
| } |
| } |
| } |
| |
| RETURN_IF_ERROR(FillResultsFromLabelMaps(&result)); |
| |
| return result; |
| } |
| |
| absl::Status ImageClassifier::FillResultsFromLabelMaps( |
| ClassificationResult* result) { |
| for (int i = 0; i < result->classifications_size(); ++i) { |
| Classifications* classifications = result->mutable_classifications(i); |
| int head_index = classifications->head_index(); |
| if (head_index < 0 || head_index >= classification_heads_.size()) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat("Invalid head index (%d) with respect to total " |
| "number of classification heads (%d).", |
| head_index, classification_heads_.size()), |
| TfLiteSupportStatus::kMetadataInconsistencyError); |
| } |
| const std::vector<LabelMapItem>& label_map_items = |
| classification_heads_[head_index].label_map_items; |
| for (int j = 0; j < classifications->classes_size(); ++j) { |
| Class* current_class = classifications->mutable_classes(j); |
| int current_class_index = current_class->index(); |
| if (current_class_index < 0 || |
| current_class_index >= label_map_items.size()) { |
| return CreateStatusWithPayload( |
| StatusCode::kInvalidArgument, |
| absl::StrFormat("Invalid class index (%d) with respect to label " |
| "map size (%d) for head #%d.", |
| current_class_index, label_map_items.size(), |
| head_index), |
| TfLiteSupportStatus::kMetadataInconsistencyError); |
| } |
| const std::string& name = label_map_items[current_class_index].name; |
| if (!name.empty()) { |
| current_class->set_class_name(name); |
| } |
| const std::string& display_name = |
| label_map_items[current_class_index].display_name; |
| if (!display_name.empty()) { |
| current_class->set_display_name(display_name); |
| } |
| } |
| } |
| return absl::OkStatus(); |
| } |
| |
| } // namespace vision |
| } // namespace task |
| } // namespace tflite |