blob: 79ab5e92be3211fe0f6e9fd5c4c82720163bf614 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
#include <stddef.h>
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <fp16.h>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/util.h"
namespace tflite {
namespace gpu {
namespace {
// Creates a node that consumes output from the given node. Because output need
// to stay the same, newly created node will inherit the output from the given
// node, which will in turn get newly created copy of output. This is necessary
// to preserve reference consistency if another node was pointing at that
// output:
// node(output)
// will turn into:
// node(copy(output)) <- passthrough_node(output)
Status NewPassthroughNode(GraphFloat32* graph, Node* node,
const Value<TensorRef<BHWC>>* output,
Node** passthru_node) {
*passthru_node = graph->NewNode();
// Make copies for every output in the original node.
RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id));
Value<TensorRef<BHWC>>* copy_output = graph->NewValue();
RETURN_IF_ERROR(graph->SetProducer(node->id, copy_output->id));
RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id));
copy_output->tensor = output->tensor;
copy_output->tensor.ref = -1;
return OkStatus();
}
template <typename T>
Status CreateVectorCopyData(const TfLiteTensor& tensor, T* tensor_data) {
if (tensor.bytes % sizeof(T) != 0) {
return InvalidArgumentError(
absl::StrCat("Input data size ", tensor.bytes,
" is not aligned to expected type: ", sizeof(T)));
}
std::memcpy(tensor_data, tensor.data.uint8, tensor.bytes);
return OkStatus();
}
void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src,
float* dst) {
for (size_t i = 0; i < num_elements; i++) {
*dst++ = fp16_ieee_to_fp32_value(*src++);
}
}
template <>
Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
float* tensor_data) {
switch (tensor.type) {
case kTfLiteFloat32:
std::memcpy(tensor_data, tensor.data.f, tensor.bytes);
break;
case kTfLiteFloat16:
ConvertFloat16ToFloat32(
NumElements(&tensor),
reinterpret_cast<uint16_t const*>(tensor.data.f16), tensor_data);
break;
default:
return InvalidArgumentError("Unsupported data type for float32 tensor");
}
return OkStatus();
}
template <typename ShapeT>
Status SetAllDimensions(const TfLiteIntArray* dimensions, ShapeT* shape);
template <>
Status SetAllDimensions<Scalar>(const TfLiteIntArray* dimensions,
Scalar* shape) {
if (dimensions->size < 0) {
return InvalidArgumentError("Invalid Scalar dimensions");
}
for (int i = 0; i < dimensions->size; ++i) {
if (dimensions->data[i] != 1) {
return InvalidArgumentError("Dimension can not be reduced to scalar.");
}
}
shape->v = 1;
return OkStatus();
}
template <>
Status SetAllDimensions<Linear>(const TfLiteIntArray* dimensions,
Linear* shape) {
if (dimensions->size <= 0) {
return InvalidArgumentError("Dimension is empty.");
}
for (int i = 0; i < dimensions->size - 1; ++i) {
if (dimensions->data[i] != 1) {
return InvalidArgumentError("Dimension can not be reduced to linear.");
}
}
shape->v = dimensions->data[dimensions->size - 1];
return OkStatus();
}
template <>
Status SetAllDimensions<HWC>(const TfLiteIntArray* dimensions, HWC* shape) {
if (dimensions->size != 4) {
return InvalidArgumentError("Dimensions are not HWC");
}
if (dimensions->data[0] != 1) {
return UnimplementedError("Batch size is not equal to 1.");
}
shape->h = dimensions->data[1];
shape->w = dimensions->data[2];
shape->c = dimensions->data[3];
return OkStatus();
}
template <>
Status SetAllDimensions<HW>(const TfLiteIntArray* dimensions, HW* shape) {
if (dimensions->size != 2) {
return InvalidArgumentError("Dimensions are not HW");
}
shape->h = dimensions->data[0];
shape->w = dimensions->data[1];
return OkStatus();
}
template <>
Status SetAllDimensions<OHWI>(const TfLiteIntArray* dimensions, OHWI* shape) {
if (dimensions->size != 4) {
return InvalidArgumentError(
absl::StrCat("Dimensions are not OHWI: ", dimensions->size));
}
shape->o = dimensions->data[0];
shape->h = dimensions->data[1];
shape->w = dimensions->data[2];
shape->i = dimensions->data[3];
return OkStatus();
}
template <>
Status SetAllDimensions<IHWO>(const TfLiteIntArray* dimensions, IHWO* shape) {
if (dimensions->size != 4) {
return InvalidArgumentError(
absl::StrCat("Dimensions are not IHWO: ", dimensions->size));
}
shape->i = dimensions->data[0];
shape->h = dimensions->data[1];
shape->w = dimensions->data[2];
shape->o = dimensions->data[3];
return OkStatus();
}
template <>
Status SetAllDimensions<BHWC>(const TfLiteIntArray* dimensions, BHWC* shape) {
if (dimensions->size != 4) {
return InvalidArgumentError("Dimensions are not BHWC");
}
shape->b = dimensions->data[0];
shape->h = dimensions->data[1];
shape->w = dimensions->data[2];
shape->c = dimensions->data[3];
return OkStatus();
}
DataType ToDataType(TfLiteType type) {
switch (type) {
case kTfLiteFloat32:
return DataType::FLOAT32;
case kTfLiteInt32:
return DataType::INT32;
case kTfLiteInt64:
return DataType::INT64;
case kTfLiteUInt8:
return DataType::UINT8;
default:
return DataType::UNKNOWN;
}
}
int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context,
const TfLiteNode* tflite_node) {
int number_of_runtime_inputs = 0;
for (int i = 0; i < tflite_node->inputs->size; i++) {
if (!IsConstantTensor(&context->tensors[tflite_node->inputs->data[i]])) {
number_of_runtime_inputs++;
}
}
return number_of_runtime_inputs;
}
int GetNumberOfRuntimeOutputsForNode(const TfLiteContext* context,
const TfLiteNode* tflite_node) {
int number_of_runtime_outputs = 0;
for (int i = 0; i < tflite_node->outputs->size; i++) {
if (!IsConstantTensor(&context->tensors[tflite_node->outputs->data[i]])) {
number_of_runtime_outputs++;
}
}
return number_of_runtime_outputs;
}
Status CheckTensorIsAvailable(const TfLiteContext* context,
const TfLiteNode* tflite_node, int idx) {
// If tensor id is in range, it's guaranteed that it'll be available.
if (idx >= tflite_node->inputs->size) {
return OutOfRangeError(
absl::StrFormat("Requested index goes beyond array size (%d vs %d).",
idx, tflite_node->inputs->data[idx]));
}
return OkStatus();
}
class ObjectReader {
public:
ObjectReader(GraphFloat32* graph, TfLiteContext* context,
const TfLiteNode* tflite_node,
std::vector<Value<TensorRef<BHWC>>*>* tensor_to_value)
: graph_(graph),
context_(context),
tflite_node_(tflite_node),
tensor_to_value_(tensor_to_value) {}
Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value) const {
if (idx >= tflite_node_->inputs->size) {
return OutOfRangeError(
absl::StrCat("ReadValue: input tensor index: ", idx));
}
return ReadValueByTensorIdx(tflite_node_->inputs->data[idx], value);
}
int GetNumberOfRuntimeInputs() const {
return GetNumberOfRuntimeInputsForNode(context_, tflite_node_);
}
Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) const {
if (idx >= tflite_node_->inputs->size) {
return OutOfRangeError(absl::StrCat("Input tensor index: ", idx));
}
const int tensor_idx = tflite_node_->inputs->data[idx];
if (tensor_idx < 0 || tensor_idx > context_->tensors_size) {
return OutOfRangeError(absl::StrCat("Tensor index: ", tensor_idx));
}
const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx];
*dimensions = *tflite_tensor.dims;
return OkStatus();
}
template <typename TensorT>
Status ReadTensor(uint32_t idx, TensorT* t) const {
RETURN_IF_ERROR(CheckTensorIsAvailable(context_, tflite_node_, idx));
const int32_t tensor_idx = tflite_node_->inputs->data[idx];
const TfLiteTensor* tflite_tensor = context_->tensors + tensor_idx;
t->data.resize(NumElements(tflite_tensor));
RETURN_IF_ERROR(CreateVectorCopyData(*tflite_tensor, &t->data[0]));
// Axis and data layout depend on operation this tensor is used in. So,
// postpone resolutions until operations are parsed.
t->id = tensor_idx;
return SetAllDimensions(tflite_tensor->dims, &t->shape);
}
Status AddOutput(const Node* node, int id) {
if (tflite_node_->outputs->size <= id) {
return InvalidArgumentError(absl::StrCat(
"Data id ", id, " must be less than tflite node outputs size ",
tflite_node_->outputs->size));
}
int output_tensor_idx = tflite_node_->outputs->data[id];
Value<TensorRef<BHWC>>* value;
RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value));
RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id));
return OkStatus();
}
Status AddOutputs(const Node* node) {
for (int i = 0; i < tflite_node_->outputs->size; ++i) {
RETURN_IF_ERROR(AddOutput(node, i));
}
return OkStatus();
}
Status AddInput(const Node* node, uint32_t idx) {
Value<TensorRef<BHWC>>* input;
RETURN_IF_ERROR(ReadValue(idx, &input));
return graph_->AddConsumer(node->id, input->id);
}
Status ReadValueByTensorIdx(uint32_t tensor_idx,
Value<TensorRef<BHWC>>** value) const {
if (tensor_idx >= tensor_to_value_->size()) {
return OutOfRangeError(
absl::StrCat("ReadValue: input tensor index: ", tensor_idx));
}
if ((*tensor_to_value_)[tensor_idx] == nullptr) {
const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx];
if (tflite::IsConstantTensor(&tflite_tensor)) {
return NotFoundError(absl::StrCat(
"ReadValue: value is a constant tensor: ", tensor_idx));
}
Value<TensorRef<BHWC>>* value = graph_->NewValue();
RETURN_IF_ERROR(
ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor));
value->tensor.ref = tensor_idx;
(*tensor_to_value_)[tensor_idx] = value;
}
*value = (*tensor_to_value_)[tensor_idx];
return OkStatus();
}
TfLiteTensor* GetInputTensor(int index) const {
return index >= 0 && index < tflite_node_->inputs->size
? context_->tensors + tflite_node_->inputs->data[index]
: nullptr;
}
TfLiteTensor* GetOutputTensor(int index) const {
return index >= 0 && index < tflite_node_->outputs->size
? context_->tensors + tflite_node_->outputs->data[index]
: nullptr;
}
private:
GraphFloat32* graph_ = nullptr;
const TfLiteContext* context_ = nullptr;
const TfLiteNode* tflite_node_ = nullptr;
std::vector<Value<TensorRef<BHWC>>*>* tensor_to_value_;
};
Status CheckInputsOutputs(const TfLiteContext* context,
const TfLiteNode* tflite_node, int inputs,
int outputs) {
int runtime_inputs = GetNumberOfRuntimeInputsForNode(context, tflite_node);
if (runtime_inputs != inputs) {
return InternalError(
absl::StrFormat("Expected %d input tensor(s), but node has %d runtime "
"input(s).",
inputs, runtime_inputs));
}
int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node);
if (runtime_outputs != outputs) {
return InternalError(
absl::StrFormat("Expected %d output tensor(s), but node has %d runtime "
"output(s).",
outputs, runtime_outputs));
}
return OkStatus();
}
// A parser responsible for parsing TFLite operation and adding it to a graph.
class TFLiteOperationParser {
public:
virtual ~TFLiteOperationParser() = default;
// Parses TFLite operation. This method allows expanding fused operations
// into more than one node.
virtual Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) = 0;
// Verifies whether passed tflite node may be built by GPU delegate or not.
virtual Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) = 0;
};
Status IsActivationSupported(TfLiteFusedActivation fused_activation) {
switch (fused_activation) {
case kTfLiteActNone:
case kTfLiteActRelu:
case kTfLiteActRelu1:
case kTfLiteActRelu6:
case kTfLiteActTanh:
return OkStatus();
case kTfLiteActSignBit:
return UnimplementedError("TfLiteFusedActivation.kTfLiteActSignBit");
case kTfLiteActSigmoid:
return UnimplementedError("TfLiteFusedActivation.kTfLiteActSigmoid");
// Do not add default; we want compilation error rather than run-time
// error.
}
}
// If there is fused activation present, then there will be another node created
// that will have identical output as the given node. New operation node will
// depend on the given node output.
Status MaybeFuseActivation(TfLiteFusedActivation fused_activation,
const std::vector<uint32_t>& output_indices,
GraphFloat32* graph, Node* node) {
if (fused_activation == kTfLiteActNone) {
return OkStatus();
}
const auto& outputs = graph->FindOutputs(node->id);
if (outputs.empty()) {
return InternalError("Empty outputs in fused node");
}
switch (fused_activation) {
case kTfLiteActRelu:
case kTfLiteActRelu1:
case kTfLiteActRelu6: {
ReLUAttributes attr;
attr.clip = fused_activation == kTfLiteActRelu
? 0.0f
: (fused_activation == kTfLiteActRelu1 ? 1.0f : 6.0f);
for (auto index : output_indices) {
Node* activation_node;
RETURN_IF_ERROR(
NewPassthroughNode(graph, node, outputs[index], &activation_node));
activation_node->operation.type = ToString(OperationType::RELU);
activation_node->operation.attributes = attr;
}
break;
}
case kTfLiteActTanh:
for (auto index : output_indices) {
Node* activation_node;
RETURN_IF_ERROR(
NewPassthroughNode(graph, node, outputs[index], &activation_node));
activation_node->operation.type = ToString(OperationType::TANH);
}
break;
default:
return NotFoundError(
absl::StrCat("Unsupported fused activation: ", fused_activation));
}
return OkStatus();
}
Status MaybeFuseActivationToTheSingleOutput(
TfLiteFusedActivation fused_activation, GraphFloat32* graph, Node* node) {
if (graph->FindOutputs(node->id).size() != 1) {
return InternalError("Number of outputs exceeds 1");
}
return MaybeFuseActivation(fused_activation, {0}, graph, node);
}
HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); }
template <typename AttrT>
void UpdatePadding(const TfLitePadding& padding, const BHWC& input_shape,
AttrT* attr) {
if (padding == kTfLitePaddingSame) {
attr->padding = CalculateSamePadding(input_shape, *attr);
} else {
attr->padding.prepended = HW(0, 0);
attr->padding.appended = HW(0, 0);
}
}
Status GetFullyConnectedAttributes(int weights_tensor_id, int bias_tensor_id,
ObjectReader* reader,
FullyConnectedAttributes* attr) {
Tensor<HW, DataType::FLOAT32> weights;
RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
attr->weights.data = std::move(weights.data);
attr->weights.id = weights.id;
attr->weights.shape.h = 1;
attr->weights.shape.w = 1;
attr->weights.shape.o = weights.shape.h;
attr->weights.shape.i = weights.shape.w;
reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError(); // optional
return OkStatus();
}
template <typename ParamsT>
Status RetrieveBuiltinData(const TfLiteNode* tflite_node,
ParamsT** tf_options) {
const auto* params =
reinterpret_cast<const ParamsT*>(tflite_node->builtin_data);
if (!params) {
return InternalError("Unable to retrieve builtin_data.");
}
*tf_options = const_cast<ParamsT*>(params);
return OkStatus();
}
template <typename ParamsType>
Status RetrieveCustomInitialData(const TfLiteNode* tflite_node,
ParamsType** tf_options) {
const auto* params =
reinterpret_cast<const ParamsType*>(tflite_node->custom_initial_data);
if (!params) {
return InternalError("Unable to retrieve custom_initial_data.");
}
*tf_options = const_cast<ParamsType*>(params);
return OkStatus();
}
Status CheckMaxSupportedOpVersion(const TfLiteRegistration* registration,
int max_version) {
const int op_version = registration->version;
if (op_version > max_version) {
return UnimplementedError(
absl::StrFormat("Max version supported: %d. Requested version %d.",
max_version, op_version));
}
return OkStatus();
}
Status CheckExactSupportedOpVersion(const TfLiteRegistration* registration,
int expected_version) {
int op_version = registration->version;
if (op_version != expected_version) {
return UnimplementedError(
absl::StrFormat("Only version %d is supported. Requested version %d.",
expected_version, op_version));
}
return OkStatus();
}
Status CheckKernels(int kernel_h, int kernel_w) {
if (kernel_h <= 0 || kernel_w <= 0) {
return InvalidArgumentError(absl::StrFormat(
"Incorrect kernel values: kernel_height = %d, kernel_width = %d.",
kernel_h, kernel_w));
}
return OkStatus();
}
Status CheckStrides(int strides_h, int strides_w) {
if (strides_h <= 0 || strides_w <= 0) {
return InvalidArgumentError(absl::StrFormat(
"Incorrect stride values: stride_height = %d, stride_width = %d.",
strides_h, strides_w));
}
return OkStatus();
}
Status CheckDilation(int dilation_h, int dilation_w) {
if (dilation_h <= 0 || dilation_w <= 0) {
return InvalidArgumentError(
absl::StrFormat("Incorrect dilation values: dilation_factor = %d, "
"dilation_factor = %d.",
dilation_h, dilation_w));
}
return OkStatus();
}
Status CheckStridesAndDilation(int strides_h, int strides_w, int dilation_h,
int dilation_w) {
RETURN_IF_ERROR(CheckStrides(strides_h, strides_w));
RETURN_IF_ERROR(CheckDilation(dilation_h, dilation_w));
return OkStatus();
}
Status CheckKernelsAndStrides(int kernel_h, int kernel_w, int strides_h,
int strides_w) {
RETURN_IF_ERROR(CheckKernels(kernel_h, kernel_w));
RETURN_IF_ERROR(CheckStrides(strides_h, strides_w));
return OkStatus();
}
// Creates a simple node that holds tensor value.
Status NewConstNode(TensorFloat32 t, GraphFloat32* graph,
Value<TensorRef<BHWC>>** value) {
ConstTensorAttributes attr;
attr.tensor = std::move(t);
Node* node = graph->NewNode();
node->operation.attributes = attr;
node->operation.type = ToString(OperationType::CONST);
*value = graph->NewValue();
RETURN_IF_ERROR(graph->SetProducer(node->id, (*value)->id));
// Keep data inside this tensor.
(*value)->tensor.ref = attr.tensor.id;
(*value)->tensor.type = attr.tensor.kType;
(*value)->tensor.shape = attr.tensor.shape;
return OkStatus();
}
Status ParsePoolingAttributes(const TfLitePoolParams* tf_options,
const BHWC& input_shape,
Pooling2DAttributes* attr) {
attr->kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
attr->strides = ToHW(tf_options->stride_height, tf_options->stride_width);
UpdatePadding(tf_options->padding, input_shape, attr);
return OkStatus();
}
Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) {
const TfLiteIntArray* dims = tflite_tensor.dims;
switch (dims->size) {
case 1:
*bhwc = BHWC(dims->data[0], 1, 1, 1);
return OkStatus();
case 2:
*bhwc = BHWC(dims->data[0], 1, 1, dims->data[1]);
return OkStatus();
case 3:
*bhwc = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]);
return OkStatus();
case 4:
*bhwc = BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]);
return OkStatus();
default:
return InvalidArgumentError(absl::StrCat(
"Tensor \"", tflite_tensor.name ? tflite_tensor.name : "nullptr",
"\" has bad input dims size: ", dims->size, "."));
}
}
class AddOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
if (tflite_node->inputs->size != 2) {
return UnimplementedError("ADD requires two input tensors.");
}
// TODO(eignasheva): Add shapes check.
TfLiteAddParams* tf_options = nullptr;
return RetrieveBuiltinData(tflite_node, &tf_options);
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
// TFLite currently only supports 2 input ADDs. Thus, the logic below only
// considers 2 input cases. The underlying GPU shader programs can accept
// more inputs, but the logic below would have to be expanded.
// Determine runtime/constant tensors.
const TfLiteTensor* input0 = reader->GetInputTensor(0);
if (!input0) {
return InvalidArgumentError("Couldn't get the 1st input tensor for ADD.");
}
const TfLiteTensor* input1 = reader->GetInputTensor(1);
if (!input1) {
return InvalidArgumentError("Couldn't get the 2nd input tensor for ADD.");
}
const bool constant_tensor0 = IsConstantTensor(input0);
const bool constant_tensor1 = IsConstantTensor(input1);
if (constant_tensor0 && constant_tensor1) {
return InvalidArgumentError("No runtime input tensors for ADD.");
}
const bool runtime_tensor0 = !constant_tensor0;
const bool runtime_tensor1 = !constant_tensor1;
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::ADD);
RETURN_IF_ERROR(reader->AddOutputs(node));
AddAttributes attr;
if (runtime_tensor0 && runtime_tensor1) {
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
} else {
int runtime_tensor = 0;
int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims;
if (constant_tensor0 && runtime_tensor1) {
runtime_tensor = 1;
constant_tensor = 0;
constant_dims = input0->dims;
}
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
if (constant_dims->size <= 0) {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = tensor.data[0];
} else {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = std::move(tensor);
}
}
node->operation.attributes = std::move(attr);
const auto* tf_options =
reinterpret_cast<const TfLiteAddParams*>(tflite_node->builtin_data);
if (!tf_options) {
return InternalError("Missing tflite params");
}
return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph,
node);
}
};
class ConcatenationOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
// TODO(eignasheva): add proper tensor availability checking
// for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
// RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, idx));
// }
// TODO(eignasheva): add axis checking.
TfLiteConcatenationParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
ConcatAttributes attr;
// Read inputs first to make sure const node is added to a graph before
// concat node to ensure topological order.
std::vector<const Value<TensorRef<BHWC>>*> inputs;
for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
Value<TensorRef<BHWC>>* value;
const auto status = reader->ReadValue(idx, &value);
if (status.ok()) {
inputs.push_back(value);
} else {
TensorFloat32 tensor;
RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
Value<TensorRef<BHWC>>* value;
RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
inputs.push_back(value);
}
}
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONCAT);
RETURN_IF_ERROR(reader->AddOutputs(node));
for (const Value<TensorRef<BHWC>>* input : inputs) {
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
}
std::vector<BHWC> input_shapes;
for (auto input : graph->FindInputs(node->id)) {
input_shapes.push_back(input->tensor.shape);
}
RETURN_IF_ERROR(SetAxis(input_shapes, &attr.axis));
// Guess axis.
BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
for (auto input : graph->FindInputs(node->id)) {
if (input->tensor.shape.h != output_shape.h) {
attr.axis = Axis::HEIGHT;
break;
}
if (input->tensor.shape.w != output_shape.w) {
attr.axis = Axis::WIDTH;
break;
}
if (input->tensor.shape.c != output_shape.c) {
attr.axis = Axis::CHANNELS;
break;
}
}
const auto* tf_options = reinterpret_cast<const TfLiteConcatenationParams*>(
tflite_node->builtin_data);
if (!tf_options) {
return InternalError("Missing tflite params");
}
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
graph, node));
node->operation.attributes = attr;
return OkStatus();
}
private:
Status SetAxis(const std::vector<BHWC>& input_shapes, Axis* axis) {
*axis = Axis::BATCH;
for (int i = 1; i < input_shapes.size(); i++) {
if (input_shapes[0].h != input_shapes[i].h &&
input_shapes[0].w != input_shapes[i].w &&
input_shapes[0].c != input_shapes[i].c) {
*axis = Axis::HEIGHT;
break;
}
}
if (*axis == Axis::BATCH) return OkStatus();
for (int i = 1; i < input_shapes.size(); i++) {
if (input_shapes[0].b != input_shapes[i].b &&
input_shapes[0].w != input_shapes[i].w &&
input_shapes[0].c != input_shapes[i].c) {
*axis = Axis::WIDTH;
break;
}
}
if (*axis == Axis::HEIGHT) return OkStatus();
for (int i = 1; i < input_shapes.size(); i++) {
if (input_shapes[0].b != input_shapes[i].b &&
input_shapes[0].h != input_shapes[i].h &&
input_shapes[0].c != input_shapes[i].c) {
*axis = Axis::CHANNELS;
break;
}
}
if (*axis == Axis::WIDTH) return OkStatus();
for (int i = 1; i < input_shapes.size(); i++) {
if (input_shapes[0].b != input_shapes[i].b &&
input_shapes[0].w != input_shapes[i].w &&
input_shapes[0].h != input_shapes[i].h) {
return UnimplementedError(
"Can concatenate tensors only by batch, height, width, or "
"channels.");
}
}
return OkStatus();
}
};
class Conv2DOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
TfLiteConvParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckStridesAndDilation(
tf_options->stride_height, tf_options->stride_width,
tf_options->dilation_height_factor, tf_options->dilation_width_factor));
return IsActivationSupported(tf_options->activation);
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_2D);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
Convolution2DAttributes attr;
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
const auto* tf_options =
reinterpret_cast<const TfLiteConvParams*>(tflite_node->builtin_data);
if (!tf_options) {
return InternalError("Missing tflite params");
}
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
attr.dilations = HW(tf_options->dilation_height_factor,
tf_options->dilation_width_factor);
UpdatePadding(tf_options->padding,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
graph, node));
node->operation.attributes = std::move(attr);
return OkStatus();
}
};
class Convolution2DTransposeBiasParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
TfLiteTransposeConvParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
RETURN_IF_ERROR(
CheckStrides(tf_options->stride_height, tf_options->stride_width));
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* params = reinterpret_cast<const TfLiteTransposeConvParams*>(
tflite_node->custom_initial_data);
ConvolutionTransposedAttributes attr;
attr.stride =
params ? HW(params->stride_height, params->stride_width) : HW(1, 1);
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
UpdatePadding(params->padding, graph->FindInputs(node->id)[0]->tensor.shape,
&attr);
node->operation.attributes = std::move(attr);
return OkStatus();
}
};
class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
TfLiteDepthwiseConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckStridesAndDilation(
tf_options->stride_height, tf_options->stride_width,
tf_options->dilation_height_factor, tf_options->dilation_width_factor));
RETURN_IF_ERROR(IsActivationSupported(tf_options->activation));
const int depth_multiplier = tf_options->depth_multiplier;
const auto* input = context->tensors + tflite_node->inputs->data[0];
const auto* filter = context->tensors + tflite_node->inputs->data[1];
const auto* bias = tflite_node->inputs->size > 2
? context->tensors + tflite_node->inputs->data[2]
: nullptr;
const auto* output = context->tensors + tflite_node->outputs->data[0];
if (!input->dims || input->dims->size != 4) {
return InvalidArgumentError("input.dims.size != 4");
}
if (!filter->dims || filter->dims->size != 4) {
return InvalidArgumentError("filter.dims.size != 4");
}
if (!output->dims || output->dims->size != 4) {
return InvalidArgumentError("output.dims.size != 4");
}
if (input->dims->data[0] != output->dims->data[0]) {
return InvalidArgumentError("input.b != output.b");
}
const int input_depth = input->dims->data[3];
const int output_depth = output->dims->data[3];
if (filter->dims->data[3] != output_depth) {
return InvalidArgumentError("filter.i != output.c");
}
if (output_depth != input_depth * depth_multiplier) {
return InvalidArgumentError("output.c != input.c * depth_multiplier");
}
if (bias && NumElements(bias) != output_depth) {
return InvalidArgumentError("bias.size != output.c");
}
if (depth_multiplier != 1 && input_depth != 1) {
return UnimplementedError("depth_multiplier != 1 && input.c != 1");
}
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
DepthwiseConvolution2DAttributes attr;
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
TfLiteDepthwiseConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
attr.dilations = HW(std::max(1, tf_options->dilation_height_factor),
std::max(1, tf_options->dilation_width_factor));
UpdatePadding(tf_options->padding,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
graph, node));
const int depth_multiplier = tf_options->depth_multiplier;
if (depth_multiplier != 1) {
const TfLiteTensor* input = reader->GetInputTensor(0);
const TfLiteTensor* filter = reader->GetInputTensor(1);
const TfLiteTensor* output = reader->GetOutputTensor(0);
TransposeWeights(input, filter, output, depth_multiplier, &attr);
}
node->operation.attributes = std::move(attr);
return OkStatus();
}
private:
// TFLite CPU stores weights as:
// [1, kernel_height, kernel_width, input_depth * depth_multiplier]
// TFLite GPU stores weights as:
// [depth_multiplier, kernel_height, kernel_width, input_depth]
static void TransposeWeights(const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* output, int depth_multiplier,
DepthwiseConvolution2DAttributes* attr) {
const int input_depth = input->dims->data[3];
const int filter_height = filter->dims->data[1];
const int filter_width = filter->dims->data[2];
const int output_depth = output->dims->data[3];
Tensor<OHWI, DataType::FLOAT32> weights;
weights.id = attr->weights.id;
weights.shape =
OHWI(output_depth, filter_height, filter_width, input_depth);
weights.data.resize(weights.shape.DimensionsProduct());
float* dst = &weights.data[0];
for (int j = 0; j < output_depth; ++j) {
const float* src = attr->weights.data.data() + j;
for (int i = 0; i < filter_height * filter_width; ++i) {
*dst = *src;
dst++;
src += output_depth;
}
}
attr->weights = std::move(weights);
}
};
class ElementwiseOperationParser : public TFLiteOperationParser {
public:
explicit ElementwiseOperationParser(OperationType operation_type)
: operation_type_(operation_type) {}
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
if (IsOneArgumentOperation()) {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
/*outputs=*/1));
} else if (IsTwoArgumentOperation()) {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/2,
/*outputs=*/1));
} else {
return InvalidArgumentError("Op can only handle 1 or 2 operand(s).");
}
TfLiteFusedActivation activation;
RETURN_IF_ERROR(GetActivation(tflite_node, &activation));
return IsActivationSupported(activation);
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(operation_type_);
if (IsOneArgumentOperation()) {
RETURN_IF_ERROR(reader->AddInput(node, 0));
} else if (IsTwoArgumentOperation()) {
if (tflite_node->inputs->size != 2) {
return InvalidArgumentError("Applies only two input tensors");
}
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
TfLiteFusedActivation activation = kTfLiteActNone;
switch (operation_type_) {
case OperationType::SUB: {
const auto* tf_options = reinterpret_cast<const TfLiteSubParams*>(
tflite_node->builtin_data);
if (tf_options != nullptr) {
activation = tf_options->activation;
}
break;
}
case OperationType::DIV: {
const auto* tf_options = reinterpret_cast<const TfLiteDivParams*>(
tflite_node->builtin_data);
if (tf_options != nullptr) {
activation = tf_options->activation;
}
break;
}
default:
// No activation expected.
activation = kTfLiteActNone;
}
if (activation) {
RETURN_IF_ERROR(
MaybeFuseActivationToTheSingleOutput(activation, graph, node));
}
} else {
return InvalidArgumentError("Incorrect operation type passed");
}
return reader->AddOutputs(node);
}
private:
Status GetActivation(const TfLiteNode* tflite_node,
TfLiteFusedActivation* activation) const {
if (operation_type_ == OperationType::DIV) {
TfLiteDivParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
*activation = tf_options ? tf_options->activation : kTfLiteActNone;
return OkStatus();
}
if (operation_type_ == OperationType::SUB) {
TfLiteSubParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
*activation = tf_options ? tf_options->activation : kTfLiteActNone;
return OkStatus();
}
// Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or
// TfLiteXxxParams.activation.
*activation = kTfLiteActNone;
return OkStatus();
}
bool IsOneArgumentOperation() const {
switch (operation_type_) {
case OperationType::ABS:
case OperationType::COS:
case OperationType::LOG:
case OperationType::RSQRT:
case OperationType::SIGMOID:
case OperationType::SIN:
case OperationType::SQRT:
case OperationType::SQUARE:
case OperationType::TANH:
return true;
default:
return false;
}
}
bool IsTwoArgumentOperation() const {
switch (operation_type_) {
case OperationType::DIV:
case OperationType::POW:
case OperationType::SQUARED_DIFF:
case OperationType::SUB:
return true;
default:
return false;
}
}
OperationType operation_type_;
};
class FullyConnectedOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
TfLiteFullyConnectedParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->weights_format !=
kTfLiteFullyConnectedWeightsFormatDefault) {
return UnimplementedError("Unsupported FullyConnected weights format.");
}
// TODO(eignasheva): check input shape
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0));
const auto* tf_options =
reinterpret_cast<const TfLiteFullyConnectedParams*>(
tflite_node->builtin_data);
if (tf_options->weights_format !=
kTfLiteFullyConnectedWeightsFormatDefault) {
return UnimplementedError("Unsupported FullyConnected weights format.");
}
FullyConnectedAttributes attr;
RETURN_IF_ERROR(GetFullyConnectedAttributes(1, 2, reader, &attr));
Tensor<HW, DataType::FLOAT32> weights;
RETURN_IF_ERROR(reader->ReadTensor(1, &weights));
auto input = graph->FindInputs(node->id)[0];
int batch_size = input->tensor.shape.b;
if (input->tensor.shape.DimensionsProduct() / batch_size !=
weights.shape.w) {
return UnimplementedError(
"Amount of input data should match weights width");
}
Node* conv = node;
if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
auto& reshape = node;
conv = graph->NewNode(); // reset conv pointer!
Value<TensorRef<BHWC>>* reshaped_value = graph->NewValue();
reshaped_value->tensor.shape = BHWC(1, 1, 1, weights.shape.w);
RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id));
reshape->operation.type = ToString(OperationType::RESHAPE);
ReshapeAttributes attr;
attr.new_shape = reshaped_value->tensor.shape;
reshape->operation.attributes = attr;
RETURN_IF_ERROR(graph->AddConsumer(conv->id, reshaped_value->id));
}
conv->operation.type = ToString(OperationType::FULLY_CONNECTED);
conv->operation.attributes = std::move(attr);
Status result = reader->AddOutputs(conv);
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
graph, conv));
return result;
}
};
class HardSwishOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration*) final {
return CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
/*outputs=*/1);
}
Status Parse(const TfLiteNode*, const TfLiteRegistration*,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::HARD_SWISH);
RETURN_IF_ERROR(reader->AddInput(node, 0));
return reader->AddOutputs(node);
}
};
// Basic LSTM Cell:
//
// 1name = name is at input index 1
// name1 = name is at output index 1
//
// 0input 1prev_activ
// \ /
// [[concat]]
// \
// concat_temp2 2weights 3biases
// \ / /
// [[fully-connected]]
// \
// activ_temp3 4prev_state
// \ /
// [[LSTM]]
// / \
// new_state1 activation0
//
class LSTMOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckExactSupportedOpVersion(registration, 2));
// TODO(eignasheva): Fix bad check.
// RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/5,
// /*outputs=*/4));
TfLiteLSTMParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckParameters(tf_options));
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
if (tflite_node->inputs->size != 5) {
return InvalidArgumentError("LSTM should have 5 input tensors");
}
if (tflite_node->outputs->size != 4) {
return InvalidArgumentError("LSTM should have 4 output tensors");
}
const auto* params =
reinterpret_cast<const TfLiteLSTMParams*>(tflite_node->builtin_data);
if (!params) {
return InternalError("Missing tflite params");
}
RETURN_IF_ERROR(CheckParameters(params));
Node* concat_node = graph->NewNode();
concat_node->operation.type = ToString(OperationType::CONCAT);
ConcatAttributes concat_attr;
concat_attr.axis = Axis::CHANNELS;
concat_node->operation.attributes = concat_attr;
Node* fc_node = graph->NewNode();
fc_node->operation.type = ToString(OperationType::FULLY_CONNECTED);
FullyConnectedAttributes fc_attr;
RETURN_IF_ERROR(GetFullyConnectedAttributes(2, 3, reader, &fc_attr));
fc_node->operation.attributes = std::move(fc_attr);
Node* lstm_node = graph->NewNode();
lstm_node->operation.type = ToString(OperationType::LSTM);
LstmAttributes lstm_attr;
lstm_attr.kernel_type = LstmKernelType::BASIC;
lstm_node->operation.attributes = lstm_attr;
Value<TensorRef<BHWC>>* concat_temp;
int concat_tensor_idx = tflite_node->outputs->data[2];
RETURN_IF_ERROR(
reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
Value<TensorRef<BHWC>>* activ_temp;
int activ_tensor_idx = tflite_node->outputs->data[3];
RETURN_IF_ERROR(
reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
RETURN_IF_ERROR(reader->AddInput(concat_node, 0)); // input
RETURN_IF_ERROR(reader->AddInput(concat_node, 1)); // prev_activ
RETURN_IF_ERROR(graph->SetProducer(concat_node->id, concat_temp->id));
RETURN_IF_ERROR(graph->AddConsumer(fc_node->id, concat_temp->id));
RETURN_IF_ERROR(graph->SetProducer(fc_node->id, activ_temp->id));
RETURN_IF_ERROR(graph->AddConsumer(lstm_node->id, activ_temp->id));
RETURN_IF_ERROR(reader->AddInput(lstm_node, 4)); // prev_state
RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1)); // new_state
RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0)); // activation
return OkStatus();
}
private:
Status CheckParameters(const TfLiteLSTMParams* tf_options) {
if (tf_options->kernel_type !=
TfLiteLSTMKernelType::kTfLiteLSTMBasicKernel) {
return UnimplementedError("Only kTfLiteLSTMBasicKernel is supported.");
}
if (tf_options->activation != kTfLiteActTanh) {
return UnimplementedError("Only TANH activation is supported.");
}
if (tf_options->cell_clip != 0.0f) {
return UnimplementedError("cell_clip is not supported.");
}
if (tf_options->proj_clip != 0.0f) {
return UnimplementedError("proj_clip is not supported.");
}
return OkStatus();
}
};
class MulOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
if (tflite_node->inputs->size != 2) {
return UnimplementedError("MUL requires two input tensors.");
}
// TODO(eignasheva): Add params check.
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
// Determine runtime/constant tensors.
const TfLiteTensor* input0 = reader->GetInputTensor(0);
if (!input0) {
return InvalidArgumentError("Couldn't get the 1st input tensor for MUL.");
}
const TfLiteTensor* input1 = reader->GetInputTensor(1);
if (!input1) {
return InvalidArgumentError("Couldn't get the 2nd input tensor for MUL.");
}
const bool constant_tensor0 = IsConstantTensor(input0);
const bool constant_tensor1 = IsConstantTensor(input1);
if (constant_tensor0 && constant_tensor1) {
return InvalidArgumentError("No runtime input tensors for MUL.");
}
const bool runtime_tensor0 = !constant_tensor0;
const bool runtime_tensor1 = !constant_tensor1;
// Parse for APPLY_MASK. The "larger" input tensor must be bound to 1st
// input and the "smaller" input tensor ("mask") must be bound to 2nd input.
if (runtime_tensor0 && runtime_tensor1) {
BHWC shape0;
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
BHWC shape1;
RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1));
int input_tensor0 = 0;
int input_tensor1 = 1;
if (shape0.h <= shape1.h && shape0.w <= shape1.w &&
shape0.c == shape1.c) {
input_tensor0 = 1;
input_tensor1 = 0;
}
return ParseApplyMask(input_tensor0, input_tensor1, graph, reader);
}
// Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st
// input and the constant input tensor must be bound to 2nd input.
int runtime_tensor = 0;
int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims;
if (constant_tensor0 && runtime_tensor1) {
runtime_tensor = 1;
constant_tensor = 0;
constant_dims = input0->dims;
}
return ParseMultiplyScalar(runtime_tensor, constant_tensor, constant_dims,
graph, reader);
}
private:
Status ParseApplyMask(int input_tensor0, int input_tensor1,
GraphFloat32* graph, ObjectReader* reader) {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::APPLY_MASK);
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
return reader->AddOutputs(node);
}
Status ParseMultiplyScalar(int runtime_tensor, int constant_tensor,
const TfLiteIntArray* constant_dims,
GraphFloat32* graph, ObjectReader* reader) {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MULTIPLY_SCALAR);
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
MultiplyScalarAttributes attr;
if (constant_dims->size <= 0) {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = tensor.data[0];
} else {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = std::move(tensor);
}
node->operation.attributes = std::move(attr);
return reader->AddOutputs(node);
}
};
class PReLUOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
// TODO(eignasheva): add params check
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::PRELU);
RETURN_IF_ERROR(reader->AddInput(node, 0));
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
PReLUAttributes attr;
Tensor<Linear, DataType::FLOAT32> linear_alpha;
Status status = reader->ReadTensor(1, &linear_alpha);
if (status.ok()) {
if (linear_alpha.shape.v != input_shape.c) {
return InvalidArgumentError(
"Linear alpha shape does not match the number of input channels.");
}
attr.alpha = std::move(linear_alpha);
} else {
Tensor<HWC, DataType::FLOAT32> hwc_alpha;
RETURN_IF_ERROR(reader->ReadTensor(1, &hwc_alpha));
if (hwc_alpha.shape.h != input_shape.h ||
hwc_alpha.shape.w != input_shape.w ||
hwc_alpha.shape.c != input_shape.c) {
return InvalidArgumentError("Alpha shape does not match input shape.");
}
attr.alpha = std::move(hwc_alpha);
}
node->operation.attributes = std::move(attr);
return reader->AddOutputs(node);
}
};
class PadOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::PAD);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
PadAttributes attr;
attr.type = PaddingContentType::ZEROS;
Tensor<HW, DataType::INT32> paddings;
RETURN_IF_ERROR(reader->ReadTensor(1, &paddings));
// 4x2 tensor with paddings.
if (paddings.shape.h != 4 || paddings.shape.w != 2) {
return InvalidArgumentError("Paddings tensor has unexpected shape.");
}
attr.prepended = BHWC(paddings.data[0], paddings.data[2], paddings.data[4],
paddings.data[6]);
attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5],
paddings.data[7]);
node->operation.attributes = attr;
return OkStatus();
}
};
class Pooling2DOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
TfLitePoolParams* tf_options = nullptr;
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
if (status.ok()) { // custom case with indices as a second output
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
/*outputs=*/2));
} else { // common pooling with 1 output
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
/*outputs=*/1));
}
RETURN_IF_ERROR(CheckKernelsAndStrides(
tf_options->filter_height, tf_options->filter_width,
tf_options->stride_height, tf_options->stride_width));
return IsActivationSupported(tf_options->activation);
}
public:
explicit Pooling2DOperationParser(PoolingType type) : type_(type) {}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::POOLING_2D);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutput(node, 0));
Pooling2DAttributes attr;
attr.type = type_;
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
// check whether there are custom options encoded. It happens if operation
// is MaxPoolingWithArgmax2D. There is no way to read
// tflite_node->builtin_code, so, simply check whether custom data is
// available.
auto* tf_options = reinterpret_cast<const TfLitePoolParams*>(
tflite_node->custom_initial_data);
if (!tf_options) {
tf_options =
reinterpret_cast<const TfLitePoolParams*>(tflite_node->builtin_data);
}
if (!tf_options) {
return InternalError("Missing tflite params");
}
std::vector<uint32_t> max_tensor_id{0};
RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, max_tensor_id,
graph, node));
// Second output is optional. It is not required, it but must be added after
// MaybeAddFusedActivation function is called
reader->AddOutput(node, 1).IgnoreError();
// First output is the result of pooling operation, while second output is
// indices used for pooling.
auto outputs = graph->FindOutputs(node->id);
attr.output_indices = outputs.size() == 2;
if (attr.output_indices) {
// Fix data type for output indices. In the model it is set as float32.
outputs[1]->tensor.type = DataType::INT32;
}
RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
node->operation.attributes = attr;
return OkStatus();
}
private:
const PoolingType type_;
};
class ReLUOperationParser : public TFLiteOperationParser {
public:
explicit ReLUOperationParser(int clip) : clip_(clip) {}
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::RELU);
RETURN_IF_ERROR(reader->AddInput(node, 0));
ReLUAttributes attr;
TfLiteLeakyReluParams* tf_options = nullptr;
RetrieveBuiltinData(tflite_node, &tf_options).IgnoreError();
attr.alpha = tf_options ? tf_options->alpha : 0;
attr.clip = clip_;
node->operation.attributes = attr;
return reader->AddOutputs(node);
}
private:
const int clip_;
};
class ReshapeOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
// TODO(eignasheva): add shape checking
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::RESHAPE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
// Here we may have extra inputs. Other tensors were supposed to
// define new shape, but in TFLite these are ignored.
// TODO(akulik): check that shapes match?
// New shape comes from output shape.
ReshapeAttributes attr;
attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
node->operation.attributes = attr;
return OkStatus();
}
};
class ResizeBilinearOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
// TODO(eignasheva): check shapes.
TfLiteResizeBilinearParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::UPSAMPLE_2D);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
// Here we may have extra inputs. Other tensors were supposed to
// define new shape, but in TFLite these are ignored.
const auto* tf_options =
reinterpret_cast<const TfLiteResizeBilinearParams*>(
tflite_node->builtin_data);
if (!tf_options) {
return InternalError("Missing tflite params");
}
Upsample2DAttributes attr;
attr.align_corners = tf_options->align_corners;
attr.type = UpsamplingType::BILINEAR;
attr.new_shape.CopyAllDefinedAxis(
graph->FindOutputs(node->id)[0]->tensor.shape);
node->operation.attributes = attr;
return OkStatus();
}
};
class SoftmaxOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
TfLiteSoftmaxParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->beta != 1) {
// TODO(eignasheva): figure out, what's wrong with softmax.
return UnimplementedError("Softmax.beta != 1 is not supported.");
}
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SOFTMAX);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options =
reinterpret_cast<const TfLiteSoftmaxParams*>(tflite_node->builtin_data);
if (!tf_options) {
return InternalError("Missing tflite params");
}
if (tf_options->beta != 1) {
// there is multiply by scalar operation fused in softmax. Make a layer
// out of it before softmax.
return UnimplementedError("Softmax.beta != 1 is not supported.");
// auto mul_node = reader->NewPassthroughNode(node);
// mul_node->operation.type = ToString(OperationType::MUL);
}
SoftmaxAttributes attr;
attr.axis = Axis::CHANNELS; // always by channels
node->operation.attributes = attr;
return OkStatus();
}
};
class SliceOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SLICE);
RETURN_IF_ERROR(reader->AddOutputs(node));
Value<TensorRef<BHWC>>* input;
RETURN_IF_ERROR(reader->ReadValue(0, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
SliceAttributes attr;
attr.strides = BHWC(1, 1, 1, 1);
Tensor<Linear, DataType::INT32> starts, sizes;
RETURN_IF_ERROR(reader->ReadTensor(1, &starts));
RETURN_IF_ERROR(reader->ReadTensor(2, &sizes));
if (starts.data.size() != sizes.data.size()) {
return InvalidArgumentError("Starts amount != sizes amount.");
}
if (starts.data.size() == 4) {
attr.starts =
BHWC(starts.data[0], starts.data[1], starts.data[2], starts.data[3]);
attr.ends =
BHWC(starts.data[0] + sizes.data[0], starts.data[1] + sizes.data[1],
starts.data[2] + sizes.data[2], starts.data[3] + sizes.data[3]);
} else if (starts.data.size() == 3) {
attr.starts = BHWC(0, starts.data[0], starts.data[1], starts.data[2]);
attr.ends =
BHWC(input->tensor.shape.b, starts.data[0] + sizes.data[0],
starts.data[1] + sizes.data[1], starts.data[2] + sizes.data[2]);
} else {
return UnimplementedError(
"Slicing is supported for 3 or 4 dimensional tensors only.");
}
RETURN_IF_ERROR(UpdateIfNegative(input->tensor.shape, &attr));
auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
if ((attr.ends.b - attr.starts.b) != out_shape.b) {
return UnimplementedError("Output batch don't match");
}
if ((attr.ends.h - attr.starts.h) != out_shape.h) {
return UnimplementedError("Output height doesn't match");
}
if ((attr.ends.w - attr.starts.w) != out_shape.w) {
return UnimplementedError("Output width doesn't match");
}
if ((attr.ends.c - attr.starts.c) != out_shape.c) {
return UnimplementedError("Output channels don't match");
}
node->operation.attributes = attr;
return OkStatus();
}
private:
Status UpdateIfNegative(const BHWC& input_shape, SliceAttributes* attr) {
if (attr->ends.h < 0) {
attr->ends.h = input_shape.h + attr->ends.h;
}
if (attr->ends.w < 0) {
attr->ends.w = input_shape.w + attr->ends.w;
}
if (attr->ends.c < 0) {
attr->ends.c = input_shape.c + attr->ends.c;
}
if (attr->ends.b < 0) {
attr->ends.b = input_shape.b + attr->ends.b;
}
return OkStatus();
}
};
class StridedSliceOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
TfLiteStridedSliceParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SLICE);
RETURN_IF_ERROR(reader->AddOutputs(node));
Value<TensorRef<BHWC>>* input;
RETURN_IF_ERROR(reader->ReadValue(0, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
Tensor<Linear, DataType::INT32> tmp;
RETURN_IF_ERROR(reader->ReadTensor(1, &tmp));
bool read_without_batch = tmp.data.size() == 3;
bool read_with_batch = tmp.data.size() == 4;
if (!read_without_batch && !read_with_batch) {
return UnimplementedError(
"Slicing is supported for 3 or 4 dimensional tensors only.");
}
const auto* tf_options = reinterpret_cast<const TfLiteStridedSliceParams*>(
tflite_node->builtin_data);
auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
if (!tf_options) {
return InternalError("Missing tflite params");
}
RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
SliceAttributes attr;
if (read_without_batch) {
RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options,
input->tensor.shape, &attr));
}
if (read_with_batch) {
RETURN_IF_ERROR(
ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr));
}
if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 ||
attr.strides.c == 0) {
return InvalidArgumentError("stride values must be non-zero");
}
if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 ||
attr.strides.c < 0) {
return UnimplementedError("Reverse slices are not supported.");
}
if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b !=
out_shape.b) {
return UnimplementedError("Output batch don't match");
}
if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h !=
out_shape.h) {
return UnimplementedError("Output height doesn't match");
}
if ((attr.ends.w - attr.starts.w + attr.strides.w - 1) / attr.strides.w !=
out_shape.w) {
return UnimplementedError("Output width doesn't match");
}
if ((attr.ends.c - attr.starts.c + attr.strides.c - 1) / attr.strides.c !=
out_shape.c) {
return UnimplementedError("Output channels don't match");
}
node->operation.attributes = attr;
return OkStatus();
}
private:
Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape, int ignore_b, int ignore_h,
int ignore_w, int ignore_c, SliceAttributes* attr) {
if (tf_options->begin_mask & ignore_h) {
attr->starts.h = 0;
}
if (tf_options->begin_mask & ignore_w) {
attr->starts.w = 0;
}
if (tf_options->begin_mask & ignore_c) {
attr->starts.c = 0;
}
if (tf_options->begin_mask & ignore_b) {
attr->starts.b = 0;
}
if (tf_options->end_mask & ignore_h) {
attr->ends.h = input_shape.h;
}
if (tf_options->end_mask & ignore_w) {
attr->ends.w = input_shape.w;
}
if (tf_options->end_mask & ignore_c) {
attr->ends.c = input_shape.c;
}
if (tf_options->end_mask & ignore_b) {
attr->ends.b = input_shape.b;
}
return OkStatus();
}
Status UpdateIfNegative(const BHWC& input_shape, SliceAttributes* attr) {
if (attr->ends.h < 0) {
attr->ends.h = input_shape.h + attr->ends.h;
}
if (attr->ends.w < 0) {
attr->ends.w = input_shape.w + attr->ends.w;
}
if (attr->ends.c < 0) {
attr->ends.c = input_shape.c + attr->ends.c;
}
if (attr->ends.b < 0) {
attr->ends.b = input_shape.b + attr->ends.b;
}
return OkStatus();
}
Status ReadAttribsWithBatch(const ObjectReader* reader,
const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape, SliceAttributes* attr) {
auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> Status {
Tensor<Linear, DataType::INT32> t;
RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
*bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]);
return OkStatus();
};
RETURN_IF_ERROR(read_bhwc(1, &attr->starts));
RETURN_IF_ERROR(read_bhwc(2, &attr->ends));
RETURN_IF_ERROR(read_bhwc(3, &attr->strides));
RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr));
return OkStatus();
}
Status ReadAttribsWithoutBatch(const ObjectReader* reader,
const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape,
SliceAttributes* attr) {
auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> Status {
Tensor<Linear, DataType::INT32> t;
RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
*bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]);
return OkStatus();
};
RETURN_IF_ERROR(read_hwc(1, &attr->starts));
RETURN_IF_ERROR(read_hwc(2, &attr->ends));
RETURN_IF_ERROR(read_hwc(3, &attr->strides));
RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr));
attr->starts.b = 0;
attr->ends.b = input_shape.b;
attr->strides.b = 1;
return OkStatus();
}
Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) {
if (tf_options->ellipsis_mask) {
return UnimplementedError("Slice does not support ellipsis_mask.");
}
if (tf_options->new_axis_mask) {
return UnimplementedError("Slice does not support new_axis_mask.");
}
if (tf_options->shrink_axis_mask) {
return UnimplementedError(
"Slice does not support shrink_axis_mask parameter. ");
}
return OkStatus();
}
};
class TransposeConvOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
TfLiteTransposeConvParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(
CheckStrides(tf_options->stride_height, tf_options->stride_width));
return OkStatus();
}
// TFLite's TRANSPOSE_CONV expects 3 input (output shape, weights, and input)
// and allows configurable padding & stride.
// TODO(impjdi): Translate output_shape to attr.adjacent.
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
Value<TensorRef<BHWC>>* input;
RETURN_IF_ERROR(reader->ReadValue(2, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options = reinterpret_cast<const TfLiteTransposeConvParams*>(
tflite_node->builtin_data);
if (!tf_options) {
return InternalError("Missing tflite options.");
}
ConvolutionTransposedAttributes attr;
attr.stride = tf_options
? HW(tf_options->stride_height, tf_options->stride_width)
: HW(1, 1);
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
// TFLite does not support bias.
UpdatePadding(tf_options->padding,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
node->operation.attributes = std::move(attr);
return OkStatus();
}
};
class TransposeOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::TRANSPOSE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
TransposeAttributes attr;
Tensor<Linear, DataType::INT32> perm;
RETURN_IF_ERROR(reader->ReadTensor(1, &perm));
if (perm.data.size() == 4) {
attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[0], perm.data[3]);
} else if (perm.data.size() == 3) {
attr.perm = BHWC(0, perm.data[0] + 1, perm.data[1] + 1, perm.data[2] + 1);
} else if (perm.data.size() == 2) {
attr.perm = BHWC(0, 1, perm.data[0] + 2, perm.data[1] + 2);
} else {
return InvalidArgumentError("Permutation for transpose is invalid.");
}
node->operation.attributes = attr;
return OkStatus();
}
};
class Unpooling2DOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
TfLitePoolParams* tf_options = nullptr;
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1));
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckKernelsAndStrides(
tf_options->filter_height, tf_options->filter_width,
tf_options->stride_height, tf_options->stride_width));
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
RETURN_IF_ERROR(reader->AddOutputs(node));
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
MaxUnpooling2DAttributes attr;
const auto* tf_options = reinterpret_cast<const TfLitePoolParams*>(
tflite_node->custom_initial_data);
if (!tf_options) {
return InternalError("Missing tflite params");
}
attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
UpdatePadding(tf_options->padding, input_shape, &attr);
node->operation.attributes = attr;
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = CalculateOutputShape(input_shape, attr);
return OkStatus();
}
};
// TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
class BatchToSpaceOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::BATCH_TO_SPACE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
BatchToSpaceAttributes bs_attr;
Tensor<Linear, DataType::INT32> block;
RETURN_IF_ERROR(reader->ReadTensor(1, &block));
if (block.shape.v != 2) {
return InternalError("Space has to be HxW.");
}
bs_attr.block.h = block.data[0];
bs_attr.block.w = block.data[1];
Tensor<HW, DataType::INT32> crop;
RETURN_IF_ERROR(reader->ReadTensor(2, &crop));
auto crop_shape = crop.shape;
if (crop_shape.h != 2 && crop_shape.w != 2) {
return InternalError("Space has to be HxW.");
}
bs_attr.crop.prepended.h = crop.data[0];
bs_attr.crop.prepended.w = crop.data[2];
bs_attr.crop.appended.h = crop.data[1];
bs_attr.crop.appended.w = crop.data[3];
node->operation.attributes = std::move(bs_attr);
return OkStatus();
}
};
class SpaceToBatchOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::SPACE_TO_BATCH);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
SpaceToBatchAttributes sb_attr;
Tensor<Linear, DataType::INT32> block;
RETURN_IF_ERROR(reader->ReadTensor(1, &block));
if (block.shape.v != 2) {
return InternalError("Space has to be HxW.");
}
sb_attr.block.h = block.data[0];
sb_attr.block.w = block.data[1];
Tensor<HW, DataType::INT32> padding;
RETURN_IF_ERROR(reader->ReadTensor(2, &padding));
auto padding_shape = padding.shape;
if (padding_shape.h != 2 && padding_shape.w != 2) {
return InternalError("Space has to be HxW.");
}
sb_attr.padding.prepended.h = padding.data[0];
sb_attr.padding.prepended.w = padding.data[2];
sb_attr.padding.appended.h = padding.data[1];
sb_attr.padding.appended.w = padding.data[3];
node->operation.attributes = std::move(sb_attr);
return OkStatus();
}
};
class UnsupportedOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return UnimplementedError("Operation is not supported.");
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
return UnimplementedError("Operation is not supported.");
}
};
std::unique_ptr<TFLiteOperationParser> NewOperationParser(
const TfLiteRegistration* registration) {
const auto builtin_code = registration->builtin_code;
const absl::string_view custom_name = registration->custom_name;
switch (builtin_code) {
case kTfLiteBuiltinAbs:
return absl::make_unique<ElementwiseOperationParser>(OperationType::ABS);
case kTfLiteBuiltinAdd:
return absl::make_unique<AddOperationParser>();
case kTfLiteBuiltinAveragePool2d:
return absl::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
case kTfLiteBuiltinConcatenation:
return absl::make_unique<ConcatenationOperationParser>();
case kTfLiteBuiltinConv2d:
return absl::make_unique<Conv2DOperationParser>();
case kTfLiteBuiltinCos:
return absl::make_unique<ElementwiseOperationParser>(OperationType::COS);
case kTfLiteBuiltinDepthwiseConv2d:
return absl::make_unique<DepthwiseConvolutionOperationParser>();
case kTfLiteBuiltinDiv:
return absl::make_unique<ElementwiseOperationParser>(OperationType::DIV);
case kTfLiteBuiltinFullyConnected:
return absl::make_unique<FullyConnectedOperationParser>();
case kTfLiteBuiltinHardSwish:
return absl::make_unique<HardSwishOperationParser>();
case kTfLiteBuiltinLogistic:
return absl::make_unique<ElementwiseOperationParser>(
OperationType::SIGMOID);
case kTfLiteBuiltinLog:
return absl::make_unique<ElementwiseOperationParser>(OperationType::LOG);
case kTfLiteBuiltinLstm:
return absl::make_unique<LSTMOperationParser>();
case kTfLiteBuiltinMaxPool2d:
return absl::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
case kTfLiteBuiltinMul:
return absl::make_unique<MulOperationParser>();
case kTfLiteBuiltinPad:
return absl::make_unique<PadOperationParser>();
case kTfLiteBuiltinPow:
return absl::make_unique<ElementwiseOperationParser>(OperationType::POW);
case kTfLiteBuiltinRelu:
return absl::make_unique<ReLUOperationParser>(0);
case kTfLiteBuiltinRelu6:
return absl::make_unique<ReLUOperationParser>(6);
case kTfLiteBuiltinLeakyRelu:
return absl::make_unique<ReLUOperationParser>(0);
case kTfLiteBuiltinPrelu:
return absl::make_unique<PReLUOperationParser>();
case kTfLiteBuiltinReshape:
return absl::make_unique<ReshapeOperationParser>();
case kTfLiteBuiltinResizeBilinear:
return absl::make_unique<ResizeBilinearOperationParser>();
case kTfLiteBuiltinRsqrt:
return absl::make_unique<ElementwiseOperationParser>(
OperationType::RSQRT);
case kTfLiteBuiltinSin:
return absl::make_unique<ElementwiseOperationParser>(OperationType::SIN);
case kTfLiteBuiltinSoftmax:
return absl::make_unique<SoftmaxOperationParser>();
case kTfLiteBuiltinSlice:
return absl::make_unique<SliceOperationParser>();
case kTfLiteBuiltinStridedSlice:
return absl::make_unique<StridedSliceOperationParser>();
case kTfLiteBuiltinSqrt:
return absl::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
case kTfLiteBuiltinSquare:
return absl::make_unique<ElementwiseOperationParser>(
OperationType::SQUARE);
case kTfLiteBuiltinSquaredDifference:
return absl::make_unique<ElementwiseOperationParser>(
OperationType::SQUARED_DIFF);
case kTfLiteBuiltinSub:
return absl::make_unique<ElementwiseOperationParser>(OperationType::SUB);
case kTfLiteBuiltinTanh:
return absl::make_unique<ElementwiseOperationParser>(OperationType::TANH);
case kTfLiteBuiltinTranspose:
return absl::make_unique<TransposeOperationParser>();
case kTfLiteBuiltinTransposeConv:
return absl::make_unique<TransposeConvOperationParser>();
case kTfLiteBuiltinCustom:
if (custom_name == "Convolution2DTransposeBias") {
return absl::make_unique<Convolution2DTransposeBiasParser>();
}
if (custom_name == "MaxPoolingWithArgmax2D") {
return absl::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
}
if (custom_name == "MaxUnpooling2D") {
return absl::make_unique<Unpooling2DOperationParser>();
}
break;
}
return absl::make_unique<UnsupportedOperationParser>();
}
} // namespace
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
TensorRef<BHWC>* tensor_ref) {
tensor_ref->type = ToDataType(tflite_tensor.type);
return ExtractTensorShape(tflite_tensor, &tensor_ref->shape);
}
Status IsSupported(const TfLiteContext* context, TfLiteNode* node,
const TfLiteRegistration* registration) {
return NewOperationParser(registration)
->IsSupported(context, node, registration);
}
bool IsAllFloatTensors(const TfLiteContext* context,
const TfLiteIntArray* array) {
for (int i = 0; i < array->size; ++i) {
const TfLiteTensor* t = context->tensors + array->data[i];
bool const type_supported =
(t->type == kTfLiteFloat32 || t->type == kTfLiteFloat16);
if (t->allocation_type == kTfLiteArenaRw && !type_supported) {
return false;
}
}
return true;
}
std::string GetOpNameByRegistration(const TfLiteRegistration* registration) {
auto op = registration->builtin_code;
std::string result =
EnumNameBuiltinOperator(static_cast<BuiltinOperator>(op));
if (op == kTfLiteBuiltinCustom) {
result += " " + std::string(registration->custom_name);
}
return result;
}
Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
TfLiteNode** tflite_node,
TfLiteRegistration** registration) {
if (context->GetNodeAndRegistration(context, node_id, tflite_node,
registration) != kTfLiteOk) {
return InvalidArgumentError(absl::StrCat(
"Couldn't get node and registration info for op: ", node_id));
}
return OkStatus();
}
TfLiteIntArray* GetOpsToReplaceFromGraphWithDequantize(TfLiteContext* context) {
TfLiteIntArray* execution_plan = nullptr;
if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
context->ReportError(context, "Unable to get graph execution plan.");
return nullptr;
}
std::set<std::string> errors;
std::unordered_map<int, int> dequant_nodes;
std::vector<int> ops_to_replace;
std::vector<int> dequant_nodes_to_save;
// Map the output tensor of a Dequantize nodes to its input tensor.
std::unordered_map<int, int> node_map;
for (int i = 0; i < execution_plan->size; ++i) {
bool replace_node = false;
// Keep track of any inputs from a Dequantize node.
std::vector<int> inputs_from_dequant;
std::vector<int> orig_inputs;
const int node_id = execution_plan->data[i];
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
auto status =
GetNodeAndRegistration(context, node_id, &node, &registration);
if (!status.ok()) {
context->ReportError(context, status.error_message().c_str());
return nullptr;
}
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
context->tensors[node->inputs->data[0]].type ==
TfLiteType::kTfLiteFloat16) {
// Record the output->input mapping for the op.
node_map[node->outputs->data[0]] = node->inputs->data[0];
// For now, add the node to the list of ops to replace.
ops_to_replace.push_back(node_id);
// Record the dequant node id, indexed by output id.
dequant_nodes[node->outputs->data[0]] = node_id;
continue;
}
TfLiteIntArray* inputs = node->inputs;
// Fix the node's inputs (i.e. prune out the preceding dequantize node)
// in order to test if it is supported on the GPU.
for (int j = 0; j < inputs->size; ++j) {
orig_inputs.push_back(inputs->data[j]);
if (node_map.find(inputs->data[j]) != node_map.end()) {
inputs_from_dequant.push_back(dequant_nodes[inputs->data[j]]);
// Remap inputs of this node to the inputs of the preceding dequant.
inputs->data[j] = node_map[inputs->data[j]];
}
}
status = IsSupported(context, node, registration);
if (status.ok() &&
// TODO(eignasheva): resolve sub operation support for metal delegate
// registration->builtin_code != kTfLiteBuiltinSub &&
IsAllFloatTensors(context, node->inputs) &&
IsAllFloatTensors(context, node->outputs)) {
if (errors.empty()) {
replace_node = true;
ops_to_replace.push_back(i);
}
} else {
// Unable to replace this node. Restore the inputs to the original
// if they were modified.
if (!inputs_from_dequant.empty()) {
TfLiteIntArray* inputs = node->inputs;
for (int j = 0; j < inputs->size; ++j) {
inputs->data[j] = orig_inputs[j];
}
}
errors.insert(GetOpNameByRegistration(registration) + ": " +
status.error_message());
}
// if any input is the output of a dequantize node AND we failed to
// replace this op, mark the corresponding dequantize node as a node to
// save.
if (!replace_node && !inputs_from_dequant.empty()) {
dequant_nodes_to_save.insert(dequant_nodes_to_save.end(),
inputs_from_dequant.begin(),
inputs_from_dequant.end());
}
}
if (!errors.empty()) {
std::string unsupported = absl::StrJoin(errors, "\n");
std::string error_message =
"Next operations are not supported by GPU delegate:\n" + unsupported +
"\nFirst " + std::to_string(ops_to_replace.size()) +
" operations will run on the GPU, and the remaining " +
std::to_string(execution_plan->size - ops_to_replace.size()) +
" on the CPU.";
context->ReportError(context, error_message.c_str());
}
// Pop all dequantize nodes that must be preserved.
for (int i = 0; i < dequant_nodes_to_save.size(); ++i) {
auto it = std::find(ops_to_replace.begin(), ops_to_replace.end(),
dequant_nodes_to_save[i]);
if (it != ops_to_replace.end()) {
ops_to_replace.erase(it);
}
}
return ConvertVectorToTfLiteIntArray(ops_to_replace);
}
// TODO(impjdi): Check number of input/output tensors and their dimensions.
// TODO(impjdi): Check ops' parameters.
TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
TfLiteIntArray* execution_plan = nullptr;
if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
context->ReportError(context, "Unable to get graph execution plan.");
return nullptr;
}
// Dispatch to another function if graph has Dequantize nodes.
for (int i = 0; i < execution_plan->size; ++i) {
const int node_id = execution_plan->data[i];
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
auto status =
GetNodeAndRegistration(context, node_id, &node, &registration);
if (!status.ok()) {
context->ReportError(context, status.error_message().c_str());
return nullptr;
}
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
context->tensors[node->inputs->data[0]].type ==
TfLiteType::kTfLiteFloat16) {
return GetOpsToReplaceFromGraphWithDequantize(context);
}
}
// No Dequantize nodes. Iterate through graph and find ops to replace.
TfLiteIntArray* subgraph = TfLiteIntArrayCreate(execution_plan->size);
subgraph->size = 0;
std::set<std::string> errors;
for (int i = 0; i < execution_plan->size; ++i) {
const int node_id = execution_plan->data[i];
TfLiteNode* node;
TfLiteRegistration* registration;
auto status =
GetNodeAndRegistration(context, node_id, &node, &registration);
if (!status.ok()) {
context->ReportError(context, status.error_message().c_str());
return nullptr;
}
status = IsSupported(context, node, registration);
if (status.ok() &&
// TODO(eignasheva): resolve sub operation support for metal delegate
// registration->builtin_code != kTfLiteBuiltinSub &&
IsAllFloatTensors(context, node->inputs) &&
IsAllFloatTensors(context, node->outputs)) {
if (errors.empty()) subgraph->data[subgraph->size++] = node_id;
} else {
errors.insert(absl::StrCat(GetOpNameByRegistration(registration), ": ",
status.error_message()));
}
}
if (!errors.empty()) {
std::string unsupported = absl::StrJoin(errors, "\n");
std::string error_message =
"Next operations are not supported by GPU delegate:\n" + unsupported +
"\nFirst " + std::to_string(subgraph->size) +
" operations will run on the GPU, and the remaining " +
std::to_string(execution_plan->size - subgraph->size) + " on the CPU.";
context->ReportError(context, error_message.c_str());
}
return subgraph;
}
Status BuildModel(TfLiteContext* context,
const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph) {
std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
std::vector<int> tflite_nodes;
for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
TfLiteNode* tflite_node = nullptr;
TfLiteRegistration* registration = nullptr;
RETURN_IF_ERROR(GetNodeAndRegistration(
context, delegate_params->nodes_to_replace->data[i], &tflite_node,
&registration));
if (registration->builtin_code == kTfLiteBuiltinDequantize) {
// Ignore Dequantize nodes.
continue;
}
auto op_parser = NewOperationParser(registration);
if (!op_parser) {
return UnimplementedError(
absl::StrCat("Operation ", registration->builtin_code, "(",
registration->custom_name,
") is not supported by TFLite GPU Delegate."));
}
operations.push_back(std::move(op_parser));
tflite_nodes.push_back(i);
}
std::vector<Value<TensorRef<BHWC>>*> tensor_to_value(context->tensors_size,
nullptr);
for (int i = 0; i < operations.size(); ++i) {
TfLiteNode* tflite_node;
TfLiteRegistration* registration;
RETURN_IF_ERROR(GetNodeAndRegistration(
context, delegate_params->nodes_to_replace->data[tflite_nodes[i]],
&tflite_node, &registration));
ObjectReader reader(graph, context, tflite_node, &tensor_to_value);
const auto status =
operations[i]->Parse(tflite_node, registration, graph, &reader);
if (!status.ok()) {
return InternalError(absl::StrCat(GetOpNameByRegistration(registration),
": ", status.error_message()));
}
}
return OkStatus();
}
} // namespace gpu
} // namespace tflite