blob: 683b69732ee2ca54482bdbc3788d1726eab89c09 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/shim/tflite_op_shim.h"
#include <cstdint>
#include <string>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/shim/status_macros.h"
namespace tflite {
namespace shim {
namespace internal {
absl::StatusOr<AttrValue> GetAttr(const flexbuffers::Map* attr_map,
const std::string& attr_name) {
const auto value = (*attr_map)[attr_name.data()];
if (value.IsNull())
return absl::InvalidArgumentError(
absl::StrCat("Non-existent attribute: ", attr_name));
AttrValue ret;
switch (value.GetType()) {
case ::flexbuffers::FBT_BOOL: {
ret = value.AsBool();
break;
}
case ::flexbuffers::FBT_INT: {
ret = static_cast<int64_t>(value.AsInt64());
break;
}
case ::flexbuffers::FBT_FLOAT: {
ret = value.AsFloat();
break;
}
case ::flexbuffers::FBT_STRING: {
const auto str_val = value.AsString();
ret = absl::string_view(str_val.c_str(), str_val.length());
break;
}
default: {
return absl::FailedPreconditionError(
absl::StrCat("Unsupported type for attribute: ", attr_name,
" with value: ", value.ToString()));
}
}
return ret;
}
} // namespace internal
TfLiteInitContext::TfLiteInitContext(const TfLiteContext* context,
const flexbuffers::Map* attr_map)
: attr_map_(attr_map) {}
absl::StatusOr<AttrValue> TfLiteInitContext::GetAttr(
const std::string& attr_name) const {
return internal::GetAttr(attr_map_, attr_name);
}
TfLiteInvokeContext::TfLiteInvokeContext(TfLiteContext* context,
TfLiteNode* node)
: context_(context), node_(node) {}
ConstTensorViewOr TfLiteInvokeContext::GetInput(const int idx) const {
// Scope is used to ensure tensor_view string contents are flushed
const auto* tflite_tensor = ::tflite::GetInput(context_, node_, idx);
if (tflite_tensor == nullptr)
return absl::InternalError(
absl::StrCat("input tensor is null during invocation. idx: ", idx));
SH_ASSIGN_OR_RETURN(const TfLiteTensorView& tensor_view,
TensorView::New(tflite_tensor));
return absl::make_unique<const TfLiteTensorView>(tensor_view);
}
TensorViewOr TfLiteInvokeContext::GetOutput(const int idx,
const Shape& output_shape) const {
if (!output_shape.has_value()) {
return absl::InvalidArgumentError(
absl::StrCat("output_shape value should be populated. idx: ", idx));
}
auto* tflite_tensor = ::tflite::GetOutput(context_, node_, idx);
if (tflite_tensor == nullptr)
return absl::InternalError(
absl::StrCat("output tensor is null during invocation. idx: ", idx));
if (tflite_tensor->data.raw == nullptr) {
TfLiteIntArray* output_shape_array =
ShapeToTfLiteShape(output_shape.value());
context_->ResizeTensor(context_, tflite_tensor, output_shape_array);
} else {
DCHECK(TfLiteShapeToShape(tflite_tensor->dims) == output_shape);
}
SH_ASSIGN_OR_RETURN(TfLiteTensorView tensor_view,
TensorView::New(tflite_tensor));
return absl::make_unique<TfLiteTensorView>(std::move(tensor_view));
}
int TfLiteInvokeContext::NumInputs() const {
return ::tflite::NumInputs(node_);
}
int TfLiteInvokeContext::NumOutputs() const {
return ::tflite::NumOutputs(node_);
}
TfLiteShapeInferenceContext::TfLiteShapeInferenceContext(
TfLiteContext* context, TfLiteNode* node, const flexbuffers::Map* attr_map,
std::vector<Shape>* inferred_shapes)
: context_(context),
node_(node),
attr_map_(attr_map),
inferred_shapes_(inferred_shapes) {}
ShapeOr TfLiteShapeInferenceContext::GetInputShape(const int idx) const {
auto* tflite_tensor = ::tflite::GetInput(context_, node_, idx);
if (tflite_tensor == nullptr)
return absl::InternalError(absl::StrCat(
"input tensor is null during shape inference. idx: ", idx));
return TfLiteShapeToShape(tflite_tensor->dims);
}
// A function object to set output shape information from a Shape
// object
absl::Status TfLiteShapeInferenceContext::SetOutputShape(const int idx,
const Shape& shape) {
if (idx >= inferred_shapes_->size()) {
return absl::InternalError(absl::StrCat("output idx out of bounds: ", idx,
" >= ", inferred_shapes_->size()));
}
(*inferred_shapes_)[idx] = shape;
return absl::OkStatus();
}
// A function object to read input tensor during shape inference as a TensorView
ConstTensorViewOr TfLiteShapeInferenceContext::GetInputTensor(
const int idx) const {
const auto* tflite_tensor = ::tflite::GetInput(context_, node_, idx);
if (tflite_tensor == nullptr)
return absl::InternalError(absl::StrCat(
"input tensor is null during shape inference. idx: ", idx));
if (::tflite::IsConstantTensor(tflite_tensor)) {
SH_ASSIGN_OR_RETURN(const TfLiteTensorView& tensor_view,
TensorView::New(tflite_tensor));
return absl::make_unique<const TfLiteTensorView>(tensor_view);
} else {
return absl::FailedPreconditionError(absl::StrCat(
"input tensor is unavailable during shape inference. idx: ", idx));
}
}
absl::StatusOr<AttrValue> TfLiteShapeInferenceContext::GetAttr(
const std::string& attr_name) const {
return internal::GetAttr(attr_map_, attr_name);
}
int TfLiteShapeInferenceContext::NumInputs() const {
return ::tflite::NumInputs(node_);
}
int TfLiteShapeInferenceContext::NumOutputs() const {
return ::tflite::NumOutputs(node_);
}
TfLiteStatus StatusToTfLiteStatus(TfLiteContext* context,
const absl::Status& status) {
if (status.ok()) return kTfLiteOk;
const auto error_string = std::string(status.message());
TF_LITE_KERNEL_LOG(context, "error: %s", error_string.c_str());
return kTfLiteError;
}
TfLiteIntArray* ShapeToTfLiteShape(const std::vector<int>& shape) {
TfLiteIntArray* tflite_shape = TfLiteIntArrayCreate(shape.size());
// TfLiteIntArray has a data array inside which is not a pointer so there's a
// need for copy
std::memcpy(tflite_shape->data, shape.data(), sizeof(int) * shape.size());
return tflite_shape;
}
// Converts an int array representing shape in TFLite to Shape.
Shape TfLiteShapeToShape(const TfLiteIntArray* tflite_shape) {
Shape ret(tflite_shape->size);
std::memcpy(ret->data(), tflite_shape->data,
sizeof(int) * tflite_shape->size);
return ret;
}
} // namespace shim
} // namespace tflite