blob: 536eed4d7f6efc662ec62e79c8a17ea1b3540abe [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_
#include <array>
#include "absl/types/optional.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace task {
namespace vision {
// Parameters used for input image normalization when input tensor has
// kTfLiteFloat32 type.
//
// Exactly 1 or 3 values are expected for `mean_values` and `std_values`. In
// case 1 value only is specified, it is used for all channels. E.g. for a RGB
// image, the normalization is done as follow:
//
// (R - mean_values[0]) / std_values[0]
// (G - mean_values[1]) / std_values[1]
// (B - mean_values[2]) / std_values[2]
//
// `num_values` keeps track of how many values have been provided, which should
// be 1 or 3 (see above). In particular, single-channel grayscale images expect
// only 1 value.
struct NormalizationOptions {
std::array<float, 3> mean_values;
std::array<float, 3> std_values;
int num_values;
};
// Parameters related to the expected tensor specifications when the tensor
// represents an image.
//
// E.g. input tensor specifications expected by the model at Invoke() time. In
// such a case, and before running inference with the TF Lite interpreter, the
// caller must use these values and perform image preprocessing and/or
// normalization so as to fill the actual input tensor appropriately.
struct ImageTensorSpecs {
// Expected image dimensions, e.g. image_width=224, image_height=224.
int image_width;
int image_height;
// Expected color space, e.g. color_space=RGB.
tflite::ColorSpaceType color_space;
// Expected input tensor type, e.g. if tensor_type=kTfLiteFloat32 the caller
// should usually perform some normalization to convert the uint8 pixels into
// floats (see NormalizationOptions in TF Lite Metadata for more details).
TfLiteType tensor_type;
// Optional normalization parameters read from TF Lite Metadata. Those are
// mandatory when tensor_type=kTfLiteFloat32 in order to convert the input
// image data into the expected range of floating point values, an error is
// returned otherwise (see sanity checks below). They should be ignored for
// other tensor input types, e.g. kTfLiteUInt8.
absl::optional<NormalizationOptions> normalization_options;
};
// Performs sanity checks on the expected input tensor including consistency
// checks against model metadata, if any. For now, a single RGB input with BHWD
// layout, where B = 1 and D = 3, is expected. Returns the corresponding input
// specifications if they pass, or an error otherwise (too many input tensors,
// etc).
// Note: both interpreter and metadata extractor *must* be successfully
// initialized before calling this function by means of (respectively):
// - `tflite::InterpreterBuilder`,
// - `tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer`.
tflite::support::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
const tflite::task::core::TfLiteEngine::Interpreter& interpreter,
const tflite::metadata::ModelMetadataExtractor& metadata_extractor);
} // namespace vision
} // namespace task
} // namespace tflite
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_