| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_ |
| #define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_ |
| |
| #include <utility> |
| |
| #include "absl/status/status.h" |
| #include "absl/strings/string_view.h" |
| #include "tensorflow/lite/c/common.h" |
| #include "tensorflow_lite_support/cc/common.h" |
| #include "tensorflow_lite_support/cc/port/status_macros.h" |
| #include "tensorflow_lite_support/cc/port/statusor.h" |
| #include "tensorflow_lite_support/cc/port/tflite_wrapper.h" |
| #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" |
| |
| namespace tflite { |
| namespace task { |
| namespace core { |
| |
| class BaseUntypedTaskApi { |
| public: |
| explicit BaseUntypedTaskApi(std::unique_ptr<TfLiteEngine> engine) |
| : engine_{std::move(engine)} {} |
| |
| virtual ~BaseUntypedTaskApi() = default; |
| |
| TfLiteEngine* GetTfLiteEngine() { return engine_.get(); } |
| const TfLiteEngine* GetTfLiteEngine() const { return engine_.get(); } |
| |
| const metadata::ModelMetadataExtractor* GetMetadataExtractor() const { |
| return engine_->metadata_extractor(); |
| } |
| |
| protected: |
| std::unique_ptr<TfLiteEngine> engine_; |
| }; |
| |
| template <class OutputType, class... InputTypes> |
| class BaseTaskApi : public BaseUntypedTaskApi { |
| public: |
| explicit BaseTaskApi(std::unique_ptr<TfLiteEngine> engine) |
| : BaseUntypedTaskApi(std::move(engine)) {} |
| // BaseTaskApi is neither copyable nor movable. |
| BaseTaskApi(const BaseTaskApi&) = delete; |
| BaseTaskApi& operator=(const BaseTaskApi&) = delete; |
| |
| // Cancels the current running TFLite invocation on CPU. |
| // |
| // Usually called on a different thread than the one inference is running on. |
| // Calling Cancel() will cause the underlying TFLite interpreter to return an |
| // error, which will turn into a `CANCELLED` status and empty results. Calling |
| // Cancel() at the other time will not take any effect on the current or |
| // following invocation. It is perfectly fine to run inference again on the |
| // same instance after a cancelled invocation. If the TFLite inference is |
| // partially delegated on CPU, logs a warning message and only cancels the |
| // invocation running on CPU. Other invocation which depends on the output of |
| // the CPU invocation will not be executed. |
| void Cancel() { engine_->Cancel(); } |
| |
| protected: |
| // Subclasses need to populate input_tensors from api_inputs. |
| virtual absl::Status Preprocess( |
| const std::vector<TfLiteTensor*>& input_tensors, |
| InputTypes... api_inputs) = 0; |
| |
| // Subclasses need to construct OutputType object from output_tensors. |
| // Original inputs are also provided as they may be needed. |
| virtual tflite::support::StatusOr<OutputType> Postprocess( |
| const std::vector<const TfLiteTensor*>& output_tensors, |
| InputTypes... api_inputs) = 0; |
| |
| // Returns (the addresses of) the model's inputs. |
| std::vector<TfLiteTensor*> GetInputTensors() { return engine_->GetInputs(); } |
| |
| // Returns (the addresses of) the model's outputs. |
| std::vector<const TfLiteTensor*> GetOutputTensors() { |
| return engine_->GetOutputs(); |
| } |
| |
| // Performs inference using tflite::support::TfLiteInterpreterWrapper |
| // InvokeWithoutFallback(). |
| tflite::support::StatusOr<OutputType> Infer(InputTypes... args) { |
| tflite::task::core::TfLiteEngine::InterpreterWrapper* interpreter_wrapper = |
| engine_->interpreter_wrapper(); |
| // Note: AllocateTensors() is already performed by the interpreter wrapper |
| // at InitInterpreter time (see TfLiteEngine). |
| RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...)); |
| absl::Status status = interpreter_wrapper->InvokeWithoutFallback(); |
| if (!status.ok()) { |
| return status.GetPayload(tflite::support::kTfLiteSupportPayload) |
| .has_value() |
| ? status |
| : tflite::support::CreateStatusWithPayload(status.code(), |
| status.message()); |
| } |
| return Postprocess(GetOutputTensors(), args...); |
| } |
| |
| // Performs inference using tflite::support::TfLiteInterpreterWrapper |
| // InvokeWithFallback() to benefit from automatic fallback from delegation to |
| // CPU where applicable. |
| tflite::support::StatusOr<OutputType> InferWithFallback(InputTypes... args) { |
| tflite::task::core::TfLiteEngine::InterpreterWrapper* interpreter_wrapper = |
| engine_->interpreter_wrapper(); |
| // Note: AllocateTensors() is already performed by the interpreter wrapper |
| // at InitInterpreter time (see TfLiteEngine). |
| RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...)); |
| auto set_inputs_nop = |
| [](tflite::task::core::TfLiteEngine::Interpreter* interpreter) |
| -> absl::Status { |
| // NOP since inputs are populated at Preprocess() time. |
| return absl::OkStatus(); |
| }; |
| absl::Status status = |
| interpreter_wrapper->InvokeWithFallback(set_inputs_nop); |
| if (!status.ok()) { |
| return status.GetPayload(tflite::support::kTfLiteSupportPayload) |
| .has_value() |
| ? status |
| : tflite::support::CreateStatusWithPayload(status.code(), |
| status.message()); |
| } |
| return Postprocess(GetOutputTensors(), args...); |
| } |
| }; |
| |
| } // namespace core |
| } // namespace task |
| } // namespace tflite |
| #endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_ |