blob: feb6b4a124d4b00e2df828495cb5cac876451e62 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_
#include <array>
#include <memory>
#include <utility>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/time/clock.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/integral_types.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
#include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace task {
namespace vision {
// Base class providing common logic for vision models.
template <class OutputType>
class BaseVisionTaskApi
: public tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
const BoundingBox&> {
public:
explicit BaseVisionTaskApi(std::unique_ptr<core::TfLiteEngine> engine)
: tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
const BoundingBox&>(std::move(engine)) {
}
// BaseVisionTaskApi is neither copyable nor movable.
BaseVisionTaskApi(const BaseVisionTaskApi&) = delete;
BaseVisionTaskApi& operator=(const BaseVisionTaskApi&) = delete;
// Number of bytes required for 8-bit per pixel RGB color space.
static constexpr int kRgbPixelBytes = 3;
// Sets the ProcessEngine used for image pre-processing. Must be called before
// any inference is performed. Can be called between inferences to override
// the current process engine.
void SetProcessEngine(const FrameBufferUtils::ProcessEngine& process_engine) {
frame_buffer_utils_ = FrameBufferUtils::Create(process_engine);
}
protected:
using tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
const BoundingBox&>::engine_;
// Checks input tensor and metadata (if any) are valid, or return an error
// otherwise. This must be called once at initialization time, before running
// inference, as it is a prerequisite for `Preprocess`.
// Note: the underlying interpreter and metadata extractor are assumed to be
// already successfully initialized before calling this method.
virtual absl::Status CheckAndSetInputs() {
ASSIGN_OR_RETURN(
ImageTensorSpecs input_specs,
BuildInputImageTensorSpecs(*engine_->interpreter(),
*engine_->metadata_extractor()));
if (input_specs.color_space != tflite::ColorSpaceType_RGB) {
return tflite::support::CreateStatusWithPayload(
absl::StatusCode::kUnimplemented,
"BaseVisionTaskApi only supports RGB color space for now.");
}
input_specs_ = absl::make_unique<ImageTensorSpecs>(input_specs);
return absl::OkStatus();
}
// Performs image preprocessing on the input frame buffer over the region of
// interest so that it fits model requirements (e.g. upright 224x224 RGB) and
// populate the corresponding input tensor. This is performed by (in this
// order):
// - cropping the frame buffer to the region of interest (which, in most
// cases, just covers the entire input image),
// - resizing it (with bilinear interpolation, aspect-ratio *not* preserved)
// to the dimensions of the model input tensor,
// - converting it to the colorspace of the input tensor (i.e. RGB, which is
// the only supported colorspace for now),
// - rotating it according to its `Orientation` so that inference is performed
// on an "upright" image.
//
// IMPORTANT: as a consequence of cropping occurring first, the provided
// region of interest is expressed in the unrotated frame of reference
// coordinates system, i.e. in `[0, frame_buffer.width) x [0,
// frame_buffer.height)`, which are the dimensions of the underlying
// `frame_buffer` data before any `Orientation` flag gets applied. Also, the
// region of interest is not clamped, so this method will return a non-ok
// status if the region is out of these bounds.
absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
const FrameBuffer& frame_buffer,
const BoundingBox& roi) override {
if (input_specs_ == nullptr) {
return tflite::support::CreateStatusWithPayload(
absl::StatusCode::kInternal,
"Uninitialized input tensor specs: CheckAndSetInputs must be called "
"at initialization time.");
}
if (frame_buffer_utils_ == nullptr) {
return tflite::support::CreateStatusWithPayload(
absl::StatusCode::kInternal,
"Uninitialized frame buffer utils: SetProcessEngine must be called "
"at initialization time.");
}
if (input_tensors.size() != 1) {
return tflite::support::CreateStatusWithPayload(
absl::StatusCode::kInternal, "A single input tensor is expected.");
}
// Input data to be normalized (if needed) and used for inference. In most
// cases, this is the result of image preprocessing. In case no image
// preprocessing is needed (see below), this points to the input frame
// buffer raw data.
const uint8* input_data;
size_t input_data_byte_size;
// Optional buffers in case image preprocessing is needed.
std::unique_ptr<FrameBuffer> preprocessed_frame_buffer;
std::vector<uint8> preprocessed_data;
if (IsImagePreprocessingNeeded(frame_buffer, roi)) {
// Preprocess input image to fit model requirements.
// For now RGB is the only color space supported, which is ensured by
// `CheckAndSetInputs`.
FrameBuffer::Dimension to_buffer_dimension = {input_specs_->image_width,
input_specs_->image_height};
input_data_byte_size =
GetBufferByteSize(to_buffer_dimension, FrameBuffer::Format::kRGB);
preprocessed_data.resize(input_data_byte_size / sizeof(uint8), 0);
input_data = preprocessed_data.data();
FrameBuffer::Plane preprocessed_plane = {
/*buffer=*/preprocessed_data.data(),
/*stride=*/{input_specs_->image_width * kRgbPixelBytes,
kRgbPixelBytes}};
preprocessed_frame_buffer = FrameBuffer::Create(
{preprocessed_plane}, to_buffer_dimension, FrameBuffer::Format::kRGB,
FrameBuffer::Orientation::kTopLeft);
RETURN_IF_ERROR(frame_buffer_utils_->Preprocess(
frame_buffer, roi, preprocessed_frame_buffer.get()));
} else {
// Input frame buffer already targets model requirements: skip image
// preprocessing. For RGB, the data is always stored in a single plane.
input_data = frame_buffer.plane(0).buffer;
input_data_byte_size = frame_buffer.plane(0).stride.row_stride_bytes *
frame_buffer.dimension().height;
}
// Then normalize pixel data (if needed) and populate the input tensor.
switch (input_specs_->tensor_type) {
case kTfLiteUInt8:
if (input_tensors[0]->bytes != input_data_byte_size) {
return tflite::support::CreateStatusWithPayload(
absl::StatusCode::kInternal,
"Size mismatch or unsupported padding bytes between pixel data "
"and input tensor.");
}
// No normalization required: directly populate data.
tflite::task::core::PopulateTensor(
input_data, input_data_byte_size / sizeof(uint8), input_tensors[0]);
break;
case kTfLiteFloat32: {
if (input_tensors[0]->bytes / sizeof(float) !=
input_data_byte_size / sizeof(uint8)) {
return tflite::support::CreateStatusWithPayload(
absl::StatusCode::kInternal,
"Size mismatch or unsupported padding bytes between pixel data "
"and input tensor.");
}
// Normalize and populate.
float* normalized_input_data =
tflite::task::core::AssertAndReturnTypedTensor<float>(
input_tensors[0]);
const tflite::task::vision::NormalizationOptions&
normalization_options = input_specs_->normalization_options.value();
if (normalization_options.num_values == 1) {
float mean_value = normalization_options.mean_values[0];
float inv_std_value = (1.0f / normalization_options.std_values[0]);
for (int i = 0; i < input_data_byte_size / sizeof(uint8);
i++, input_data++, normalized_input_data++) {
*normalized_input_data =
inv_std_value * (static_cast<float>(*input_data) - mean_value);
}
} else {
std::array<float, 3> inv_std_values = {
1.0f / normalization_options.std_values[0],
1.0f / normalization_options.std_values[1],
1.0f / normalization_options.std_values[2]};
for (int i = 0; i < input_data_byte_size / sizeof(uint8);
i++, input_data++, normalized_input_data++) {
*normalized_input_data = inv_std_values[i % 3] *
(static_cast<float>(*input_data) -
normalization_options.mean_values[i % 3]);
}
}
break;
}
case kTfLiteInt8:
return tflite::support::CreateStatusWithPayload(
absl::StatusCode::kUnimplemented,
"kTfLiteInt8 input type is not implemented yet.");
default:
return tflite::support::CreateStatusWithPayload(
absl::StatusCode::kInternal, "Unexpected input tensor type.");
}
return absl::OkStatus();
}
// Utils for input image preprocessing (resizing, colorspace conversion, etc).
std::unique_ptr<FrameBufferUtils> frame_buffer_utils_;
// Parameters related to the input tensor which represents an image.
std::unique_ptr<ImageTensorSpecs> input_specs_;
private:
// Returns false if image preprocessing could be skipped, true otherwise.
bool IsImagePreprocessingNeeded(const FrameBuffer& frame_buffer,
const BoundingBox& roi) {
// Is crop required?
if (roi.origin_x() != 0 || roi.origin_y() != 0 ||
roi.width() != frame_buffer.dimension().width ||
roi.height() != frame_buffer.dimension().height) {
return true;
}
// Are image transformations required?
if (frame_buffer.orientation() != FrameBuffer::Orientation::kTopLeft ||
frame_buffer.format() != FrameBuffer::Format::kRGB ||
frame_buffer.dimension().width != input_specs_->image_width ||
frame_buffer.dimension().height != input_specs_->image_height) {
return true;
}
return false;
}
};
} // namespace vision
} // namespace task
} // namespace tflite
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_