blob: 773ab76fe6ba0efaf3765e2aeb8d928e159bb9c2 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h"
#include <cmath>
#include <memory>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
namespace tflite {
namespace task {
namespace vision {
namespace {
using ::absl::StatusCode;
using ::tflite::support::CreateStatusWithPayload;
using ::tflite::support::StatusOr;
using ::tflite::support::TfLiteSupportStatus;
// Used to prevent log(<=0.0) in ClampedLog() calls.
constexpr float kLogScoreMinimum = 1e-16;
// Returns the following, depending on x:
// x => threshold: log(x)
// x < threshold: 2 * log(thresh) - log(2 * thresh - x)
// This form (a) is anti-symmetric about the threshold and (b) has continuous
// value and first derivative. This is done to prevent taking the log of values
// close to 0 which can lead to floating point errors and is better than simple
// clamping since it preserves order for scores less than the threshold.
float ClampedLog(float x, float threshold) {
if (x < threshold) {
return 2.0 * std::log(static_cast<double>(threshold)) -
log(2.0 * threshold - x);
}
return std::log(static_cast<double>(x));
}
// Applies the specified score transformation to the provided score.
// Currently supports the following,
// IDENTITY : f(x) = x
// LOG : f(x) = log(x)
// INVERSE_LOGISTIC : f(x) = log(x) - log(1-x)
float ApplyScoreTransformation(float score, const ScoreTransformation& type) {
switch (type) {
case ScoreTransformation::kIDENTITY:
return score;
case ScoreTransformation::kINVERSE_LOGISTIC:
return (ClampedLog(score, kLogScoreMinimum) -
ClampedLog(1.0 - score, kLogScoreMinimum));
case ScoreTransformation::kLOG:
return ClampedLog(score, kLogScoreMinimum);
}
}
// Builds a single Sigmoid from the label name and associated CSV file line.
StatusOr<Sigmoid> SigmoidFromLabelAndLine(absl::string_view label,
absl::string_view line) {
std::vector<absl::string_view> str_params = absl::StrSplit(line, ',');
if (str_params.size() != 3 && str_params.size() != 4) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Expected 3 or 4 parameters per line in score "
"calibration file, got %d.",
str_params.size()),
TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError);
}
std::vector<float> float_params(4);
for (int i = 0; i < str_params.size(); ++i) {
if (!absl::SimpleAtof(str_params[i], &float_params[i])) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Could not parse score calibration parameter as float: %s.",
str_params[i]),
TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError);
}
}
Sigmoid sigmoid;
sigmoid.label = std::string(label);
sigmoid.scale = float_params[0];
sigmoid.slope = float_params[1];
sigmoid.offset = float_params[2];
if (str_params.size() == 4) {
sigmoid.min_uncalibrated_score = float_params[3];
}
return sigmoid;
}
// Converts a tflite::ScoreTransformationType to its
// tflite::task::vision::ScoreTransformation equivalent.
ScoreTransformation ConvertScoreTransformationType(
tflite::ScoreTransformationType type) {
switch (type) {
case tflite::ScoreTransformationType_IDENTITY:
return ScoreTransformation::kIDENTITY;
case tflite::ScoreTransformationType_LOG:
return ScoreTransformation::kLOG;
case tflite::ScoreTransformationType_INVERSE_LOGISTIC:
return ScoreTransformation::kINVERSE_LOGISTIC;
}
}
} // namespace
std::ostream& operator<<(std::ostream& os, const Sigmoid& s) {
os << s.label << "," << s.slope << "," << s.offset << "," << s.scale;
if (s.min_uncalibrated_score.has_value()) {
os << "," << s.min_uncalibrated_score.value();
}
return os;
}
ScoreCalibration::ScoreCalibration() {}
ScoreCalibration::~ScoreCalibration() {}
absl::Status ScoreCalibration::InitializeFromParameters(
const SigmoidCalibrationParameters& params) {
sigmoid_parameters_ = std::move(params);
// Fill in the map from label -> sigmoid.
sigmoid_parameters_map_.clear();
for (const auto& sigmoid : sigmoid_parameters_.sigmoid) {
sigmoid_parameters_map_.insert_or_assign(sigmoid.label, sigmoid);
}
return absl::OkStatus();
}
float ScoreCalibration::ComputeCalibratedScore(const std::string& label,
float uncalibrated_score) const {
absl::optional<Sigmoid> sigmoid = FindSigmoidParameters(label);
if (!sigmoid.has_value() ||
(sigmoid.value().min_uncalibrated_score.has_value() &&
uncalibrated_score < sigmoid.value().min_uncalibrated_score.value())) {
return sigmoid_parameters_.default_score;
}
float transformed_score = ApplyScoreTransformation(
uncalibrated_score, sigmoid_parameters_.score_transformation);
float scale_shifted_score =
transformed_score * sigmoid.value().slope + sigmoid.value().offset;
// For numerical stability use 1 / (1+exp(-x)) when scale_shifted_score >= 0
// and exp(x) / (1+exp(x)) when scale_shifted_score < 0.
if (scale_shifted_score >= 0.0) {
return sigmoid.value().scale /
(1.0 + std::exp(static_cast<double>(-scale_shifted_score)));
} else {
float score_exp = std::exp(static_cast<double>(scale_shifted_score));
return sigmoid.value().scale * score_exp / (1.0 + score_exp);
}
}
absl::optional<Sigmoid> ScoreCalibration::FindSigmoidParameters(
const std::string& label) const {
auto it = sigmoid_parameters_map_.find(label);
if (it != sigmoid_parameters_map_.end()) {
return it->second;
} else if (sigmoid_parameters_.default_sigmoid.has_value()) {
return sigmoid_parameters_.default_sigmoid.value();
}
return absl::nullopt;
}
StatusOr<SigmoidCalibrationParameters> BuildSigmoidCalibrationParams(
const tflite::ScoreCalibrationOptions& score_calibration_options,
absl::string_view score_calibration_file,
const std::vector<LabelMapItem>& label_map_items) {
// Split file lines and perform sanity checks.
if (score_calibration_file.empty()) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"Expected non-empty score calibration file.");
}
std::vector<absl::string_view> lines =
absl::StrSplit(score_calibration_file, '\n');
if (label_map_items.size() != lines.size()) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Mismatch between number of labels (%d) and score "
"calibration parameters (%d).",
label_map_items.size(), lines.size()),
TfLiteSupportStatus::kMetadataNumLabelsMismatchError);
}
// Initialize SigmoidCalibrationParameters with its class-agnostic parameters.
SigmoidCalibrationParameters sigmoid_params = {};
sigmoid_params.score_transformation = ConvertScoreTransformationType(
score_calibration_options.score_transformation());
sigmoid_params.default_score = score_calibration_options.default_score();
std::vector<Sigmoid> sigmoid_vector;
// Fill sigmoids for each class with parameters in the file.
for (int i = 0; i < label_map_items.size(); ++i) {
if (lines[i].empty()) {
continue;
}
ASSIGN_OR_RETURN(Sigmoid sigmoid, SigmoidFromLabelAndLine(
label_map_items[i].name, lines[i]));
sigmoid_vector.emplace_back(std::move(sigmoid));
}
sigmoid_params.sigmoid = std::move(sigmoid_vector);
return sigmoid_params;
}
} // namespace vision
} // namespace task
} // namespace tflite