blob: 52324317fb16da3bd16467de1a3ec1ec33f89581 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/shim/tf_op_shim.h"
#include <cstdint>
#include <string>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/lite/kernels/shim/status_macros.h"
#include "tensorflow/lite/kernels/shim/tensor_view.h"
namespace tflite {
namespace shim {
namespace {
// Converts a TF AttrValue into a TF Shim AttrValue
absl::StatusOr<AttrValue> TfAttrValueToShimAttrValue(
const ::tensorflow::AttrValue& attr_value) {
AttrValue ret;
switch (attr_value.value_case()) {
case ::tensorflow::AttrValue::kB: {
ret = attr_value.b();
break;
}
case ::tensorflow::AttrValue::kI: {
ret = attr_value.i();
break;
}
case ::tensorflow::AttrValue::kF: {
ret = attr_value.f();
break;
}
case ::tensorflow::AttrValue::kS: {
ret = attr_value.s();
break;
}
default: {
return absl::FailedPreconditionError(absl::StrCat(
"Unsupported attribute type: ", attr_value.DebugString()));
}
}
return ret;
}
} // namespace
TfInitContext::TfInitContext(const ::tensorflow::OpKernelConstruction* context)
: context_(context) {}
absl::StatusOr<AttrValue> TfInitContext::GetAttr(
const std::string& attr_name) const {
if (!context_->HasAttr(attr_name))
return absl::InvalidArgumentError(
absl::StrCat("Non-existent attribute: ", attr_name, "\nop def:\n",
context_->def().DebugString()));
const auto attr_value = context_->def().attr().at(attr_name);
return TfAttrValueToShimAttrValue(attr_value);
}
TfInvokeContext::TfInvokeContext(::tensorflow::OpKernelContext* context)
: context_(context) {}
ConstTensorViewOr TfInvokeContext::GetInput(const int idx) const {
if (idx >= context_->num_inputs()) {
return absl::InvalidArgumentError(
absl::StrCat("Expected idx < num_inputs. idx: ", idx,
" num_inputs: ", context_->num_inputs()));
}
const auto tf_tensor = context_->input(idx);
SH_ASSIGN_OR_RETURN(const TfTensorView& tensor_view,
TensorView::New(&tf_tensor));
return absl::make_unique<const TfTensorView>(tensor_view);
}
TensorViewOr TfInvokeContext::GetOutput(const int idx,
const Shape& shape) const {
tensorflow::Tensor* output_t = nullptr;
if (!shape.has_value())
return absl::InvalidArgumentError("Output shape needs to be specified.");
std::vector<int64_t> shape_64(shape->size());
for (int i = 0; i < shape->size(); ++i) shape_64[i] = (*shape)[i];
auto status = context_->allocate_output(
idx, ::tensorflow::TensorShape(shape_64), &output_t);
if (!status.ok()) return ToAbslStatus(status);
SH_ASSIGN_OR_RETURN(const TfTensorView& tensor_view,
TensorView::New(output_t));
return absl::make_unique<TfTensorView>(std::move(tensor_view));
}
int TfInvokeContext::NumInputs() const { return context_->num_inputs(); }
int TfInvokeContext::NumOutputs() const { return context_->num_outputs(); }
TfShapeInferenceContext::TfShapeInferenceContext(
::tensorflow::shape_inference::InferenceContext* context)
: context_(context) {}
ShapeOr TfShapeInferenceContext::GetInputShape(const int idx) const {
std::vector<int> ret;
const auto& shape = context_->input(idx);
if (!context_->RankKnown(shape)) return Shape();
ret.resize(context_->Rank(shape));
for (int i = 0; i < ret.size(); ++i)
ret[i] = context_->Value(context_->Dim(shape, i));
return Shape(ret);
}
absl::Status TfShapeInferenceContext::SetOutputShape(const int idx,
const Shape& shape) {
tensorflow::shape_inference::ShapeHandle output_shape;
if (shape.has_value()) {
std::vector<::tensorflow::shape_inference::DimensionHandle> tf_shape;
tf_shape.reserve(shape.value().size());
for (const auto dim : shape.value())
tf_shape.emplace_back(context_->MakeDim(dim));
output_shape = context_->MakeShape(tf_shape);
} else {
output_shape = context_->UnknownShape();
}
context_->set_output(idx, output_shape);
return absl::OkStatus();
}
ConstTensorViewOr TfShapeInferenceContext::GetInputTensor(const int idx) const {
const auto* tf_tensor = context_->input_tensor(idx);
if (tf_tensor == nullptr) {
return absl::UnavailableError(
absl::StrCat("Tensor is not available. idx: ", idx));
}
SH_ASSIGN_OR_RETURN(const TfTensorView& tensor_view,
TensorView::New(tf_tensor));
return absl::make_unique<const TfTensorView>(tensor_view);
}
absl::StatusOr<AttrValue> TfShapeInferenceContext::GetAttr(
const std::string& attr_name) const {
const auto* tf_attr_value = context_->GetAttr(attr_name);
if (tf_attr_value == nullptr)
return absl::InvalidArgumentError(
absl::StrCat("Non-existent attribute: ", attr_name));
return TfAttrValueToShimAttrValue(*tf_attr_value);
}
int TfShapeInferenceContext::NumInputs() const {
return context_->num_inputs();
}
int TfShapeInferenceContext::NumOutputs() const {
return context_->num_outputs();
}
::tensorflow::Status FromAbslStatus(const absl::Status& s) {
if (s.ok()) return ::tensorflow::Status();
return ::tensorflow::Status(static_cast<::tensorflow::error::Code>(s.code()),
s.message());
}
absl::Status ToAbslStatus(const ::tensorflow::Status& s) {
return s.ok() ? absl::OkStatus()
: absl::Status(static_cast<absl::StatusCode>(s.code()),
s.error_message());
}
} // namespace shim
} // namespace tflite