blob: 061c65095eb16952eef6c7ba05b06d129bdab401 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <limits>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
#include "tensorflow/lite/delegates/gpu/common/object_reader.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
#include "tensorflow/lite/delegates/utils.h"
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/util.h"
namespace tflite {
namespace gpu {
namespace {
// Creates a node that consumes output from the given node. Because output need
// to stay the same, newly created node will inherit the output from the given
// node, which will in turn get newly created copy of output. This is necessary
// to preserve reference consistency if another node was pointing at that
// output:
// node(output)
// will turn into:
// node(copy(output)) <- passthrough_node(output)
absl::Status NewPassthroughNode(GraphFloat32* graph, Node* node,
const Value* output, Node** passthru_node) {
*passthru_node = graph->NewNode();
// Make copies for every output in the original node.
RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id));
Value* copy_output = graph->NewValue();
RETURN_IF_ERROR(graph->SetProducer(node->id, copy_output->id));
RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id));
copy_output->tensor = output->tensor;
copy_output->tensor.ref = -1;
return absl::OkStatus();
}
template <typename T>
inline void DequantizeConstantTensor(const TfLiteTensor& tensor,
const T* source_data,
float* dequantized_data) {
TfLiteAffineQuantization* quant_params =
reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
if (quant_params->scale->size > 1) {
// Tensor is per-channel quantized.
PerChannelDequantizationParams op_params;
op_params.zero_point = quant_params->zero_point->data;
op_params.scale = quant_params->scale->data;
op_params.quantized_dimension = quant_params->quantized_dimension;
reference_ops::PerChannelDequantize(op_params, GetTensorShape(&tensor),
source_data, GetTensorShape(&tensor),
dequantized_data);
} else {
DequantizationParams op_params;
op_params.zero_point = tensor.params.zero_point;
op_params.scale = tensor.params.scale;
reference_ops::Dequantize(op_params, GetTensorShape(&tensor), source_data,
GetTensorShape(&tensor), dequantized_data);
}
}
absl::Status CheckTensorIsAvailable(const TfLiteContext* context,
const TfLiteNode* tflite_node, int idx) {
// If tensor id is in range, it's guaranteed that it'll be available.
if (idx >= tflite_node->inputs->size) {
return absl::OutOfRangeError(
absl::StrCat("Requested index goes beyond array size: ", idx, " vs ",
idx, tflite_node->inputs->size));
}
return absl::OkStatus();
}
// A parser responsible for parsing TFLite operation and adding it to a graph.
class TFLiteOperationParser {
public:
virtual ~TFLiteOperationParser() = default;
// Parses TFLite operation. This method allows expanding fused operations
// into more than one node.
virtual absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) = 0;
// Verifies whether passed tflite node may be built by GPU delegate or not.
virtual absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) = 0;
};
absl::Status IsActivationSupported(TfLiteFusedActivation fused_activation) {
switch (fused_activation) {
case kTfLiteActNone:
case kTfLiteActRelu:
case kTfLiteActRelu1:
case kTfLiteActRelu6:
case kTfLiteActTanh:
return absl::OkStatus();
case kTfLiteActSignBit:
return absl::UnimplementedError(
"TfLiteFusedActivation.kTfLiteActSignBit");
case kTfLiteActSigmoid:
return absl::UnimplementedError(
"TfLiteFusedActivation.kTfLiteActSigmoid");
// Do not add default; we want compilation error rather than run-time
// error.
}
}
// If there is fused activation present, then there will be another node created
// that will have identical output as the given node. New operation node will
// depend on the given node output.
absl::Status MaybeFuseActivation(TfLiteFusedActivation fused_activation,
const std::vector<uint32_t>& output_indices,
GraphFloat32* graph, Node* node) {
if (fused_activation == kTfLiteActNone) {
return absl::OkStatus();
}
const auto outputs = graph->FindOutputs(node->id);
if (outputs.empty()) {
return absl::InternalError("Empty outputs in fused node");
}
switch (fused_activation) {
case kTfLiteActRelu:
case kTfLiteActRelu1:
case kTfLiteActRelu6: {
ReLUAttributes attr;
attr.clip = fused_activation == kTfLiteActRelu
? 0.0f
: (fused_activation == kTfLiteActRelu1 ? 1.0f : 6.0f);
for (auto index : output_indices) {
Node* activation_node;
RETURN_IF_ERROR(
NewPassthroughNode(graph, node, outputs[index], &activation_node));
activation_node->operation.type = ToString(OperationType::RELU);
activation_node->operation.attributes = attr;
}
break;
}
case kTfLiteActTanh:
for (auto index : output_indices) {
Node* activation_node;
RETURN_IF_ERROR(
NewPassthroughNode(graph, node, outputs[index], &activation_node));
activation_node->operation.type = ToString(OperationType::TANH);
}
break;
default:
return absl::NotFoundError(
absl::StrCat("Unsupported fused activation: ", fused_activation));
}
return absl::OkStatus();
}
absl::Status MaybeFuseActivationToTheSingleOutput(
TfLiteFusedActivation fused_activation, GraphFloat32* graph, Node* node) {
if (graph->FindOutputs(node->id).size() != 1) {
return absl::InternalError("Number of outputs exceeds 1");
}
return MaybeFuseActivation(fused_activation, {0}, graph, node);
}
HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); }
template <typename AttrT>
void UpdatePadding(const TfLitePadding& padding, const BHWC& input_shape,
AttrT* attr) {
if (padding == kTfLitePaddingSame) {
attr->padding = CalculateSamePadding(input_shape, *attr);
} else {
attr->padding.prepended = HW(0, 0);
attr->padding.appended = HW(0, 0);
}
}
absl::Status GetFullyConnectedAttributes(int weights_tensor_id,
int bias_tensor_id,
ObjectReader* reader,
FullyConnectedAttributes* attr) {
Tensor<HW, DataType::FLOAT32> weights;
RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
attr->weights.data = std::move(weights.data);
attr->weights.id = weights.id;
attr->weights.shape.h = 1;
attr->weights.shape.w = 1;
attr->weights.shape.o = weights.shape.h;
attr->weights.shape.i = weights.shape.w;
reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError(); // optional
return absl::OkStatus();
}
template <typename ParamsT>
absl::Status RetrieveBuiltinData(const TfLiteNode* tflite_node,
ParamsT** tf_options) {
const auto* params =
reinterpret_cast<const ParamsT*>(tflite_node->builtin_data);
if (!params) {
return absl::InternalError("Unable to retrieve builtin_data.");
}
*tf_options = const_cast<ParamsT*>(params);
return absl::OkStatus();
}
template <typename ParamsType>
absl::Status RetrieveCustomInitialData(const TfLiteNode* tflite_node,
ParamsType** tf_options) {
const auto* params =
reinterpret_cast<const ParamsType*>(tflite_node->custom_initial_data);
if (!params) {
return absl::InternalError("Unable to retrieve custom_initial_data.");
}
*tf_options = const_cast<ParamsType*>(params);
return absl::OkStatus();
}
absl::Status CheckMaxSupportedOpVersion(const TfLiteRegistration* registration,
int max_version) {
const int op_version = registration->version;
if (op_version > max_version) {
return absl::UnimplementedError(
absl::StrCat("Max version supported: ", max_version,
". Requested version ", op_version));
}
return absl::OkStatus();
}
absl::Status CheckExactSupportedOpVersion(
const TfLiteRegistration* registration, int expected_version) {
int op_version = registration->version;
if (op_version != expected_version) {
return absl::UnimplementedError(
absl::StrCat("Only version ", expected_version,
" is supported. Requested version ", op_version));
}
return absl::OkStatus();
}
absl::Status CheckKernels(int kernel_h, int kernel_w) {
if (kernel_h <= 0 || kernel_w <= 0) {
return absl::InvalidArgumentError(
absl::StrCat("Incorrect kernel values: kernel_height = ", kernel_h,
", kernel_width = ", kernel_w));
}
return absl::OkStatus();
}
absl::Status CheckStrides(int strides_h, int strides_w) {
if (strides_h <= 0 || strides_w <= 0) {
return absl::InvalidArgumentError(
absl::StrCat("Incorrect stride values: stride_height = ", strides_h,
", stride_width = ", strides_w));
}
return absl::OkStatus();
}
absl::Status CheckDilation(int dilation_h, int dilation_w) {
if (dilation_h <= 0 || dilation_w <= 0) {
return absl::InvalidArgumentError(absl::StrCat(
"Incorrect dilation values: dilation_factor = ", dilation_h,
", dilation_factor = ", dilation_w));
}
return absl::OkStatus();
}
absl::Status CheckStridesAndDilation(int strides_h, int strides_w,
int dilation_h, int dilation_w) {
RETURN_IF_ERROR(CheckStrides(strides_h, strides_w));
RETURN_IF_ERROR(CheckDilation(dilation_h, dilation_w));
return absl::OkStatus();
}
absl::Status CheckKernelsAndStrides(int kernel_h, int kernel_w, int strides_h,
int strides_w) {
RETURN_IF_ERROR(CheckKernels(kernel_h, kernel_w));
RETURN_IF_ERROR(CheckStrides(strides_h, strides_w));
return absl::OkStatus();
}
// Creates a simple node that holds tensor value.
absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, Value** value) {
ConstTensorAttributes attr;
attr.tensor = std::move(t);
Node* node = graph->NewNode();
node->operation.attributes = attr;
node->operation.type = ToString(OperationType::CONST);
*value = graph->NewValue();
RETURN_IF_ERROR(graph->SetProducer(node->id, (*value)->id));
// Keep data inside this tensor.
(*value)->tensor.ref = attr.tensor.id;
(*value)->tensor.type = attr.tensor.kType;
(*value)->tensor.shape = attr.tensor.shape;
return absl::OkStatus();
}
absl::Status ParsePoolingAttributes(const TfLitePoolParams* tf_options,
const BHWC& input_shape,
Pooling2DAttributes* attr) {
attr->kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
attr->strides = ToHW(tf_options->stride_height, tf_options->stride_width);
UpdatePadding(tf_options->padding, input_shape, attr);
return absl::OkStatus();
}
absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
TensorOrScalar* tensor_or_scalar) {
const std::string& opname = node->operation.type;
// Determine runtime/constant tensors.
const TfLiteTensor* input0 = reader->GetInputTensor(0);
if (!input0) {
return absl::InvalidArgumentError("Couldn't get the 1st input tensor for " +
opname);
}
const TfLiteTensor* input1 = reader->GetInputTensor(1);
if (!input1) {
return absl::InvalidArgumentError("Couldn't get the 2nd input tensor for " +
opname);
}
const bool constant_tensor0 = IsConstantTensor(input0);
const bool constant_tensor1 = IsConstantTensor(input1);
if (constant_tensor0 && constant_tensor1) {
return absl::InvalidArgumentError("No runtime input tensors for " + opname);
}
const bool runtime_tensor0 = !constant_tensor0;
const bool runtime_tensor1 = !constant_tensor1;
if (runtime_tensor0 && runtime_tensor1) {
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
} else {
int runtime_tensor = 0;
int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims;
if (constant_tensor0 && runtime_tensor1) {
runtime_tensor = 1;
constant_tensor = 0;
constant_dims = input0->dims;
}
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = tensor.data[0];
} else {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = std::move(tensor);
}
}
return absl::OkStatus();
}
class AddOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
if (tflite_node->inputs->size != 2) {
return absl::UnimplementedError("ADD requires two input tensors.");
}
// TODO(eignasheva): Add shapes check.
TfLiteAddParams* tf_options = nullptr;
return RetrieveBuiltinData(tflite_node, &tf_options);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
// TFLite currently only supports 2 input ADDs. Thus, the logic below only
// considers 2 input cases. The underlying GPU shader programs can accept
// more inputs, but the logic below would have to be expanded.
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::ADD);
RETURN_IF_ERROR(reader->AddOutputs(node));
AddAttributes attr;
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
node->operation.attributes = std::move(attr);
const auto* tf_options =
reinterpret_cast<const TfLiteAddParams*>(tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph,
node);
}
};
class ConcatenationOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
// TODO(eignasheva): add proper tensor availability checking
// for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
// RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, idx));
// }
// TODO(eignasheva): add axis checking.
TfLiteConcatenationParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
ConcatAttributes attr;
// Read inputs first to make sure const node is added to a graph before
// concat node to ensure topological order.
std::vector<const Value*> inputs;
for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
Value* value;
const auto status = reader->ReadValue(idx, &value);
if (status.ok()) {
inputs.push_back(value);
} else {
TensorFloat32 tensor;
RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
Value* value;
RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
inputs.push_back(value);
}
}
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONCAT);
RETURN_IF_ERROR(reader->AddOutputs(node));
for (const Value* input : inputs) {
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
}
std::vector<BHWC> input_shapes;
for (auto input : graph->FindInputs(node->id)) {
input_shapes.push_back(input->tensor.shape);
}
RETURN_IF_ERROR(SetAxis(input_shapes, &attr.axis));
// Guess axis.
BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
for (auto input : graph->FindInputs(node->id)) {
if (input->tensor.shape.h != output_shape.h) {
attr.axis = Axis::HEIGHT;
break;
}
if (input->tensor.shape.w != output_shape.w) {
attr.axis = Axis::WIDTH;
break;
}
if (input->tensor.shape.c != output_shape.c) {
attr.axis = Axis::CHANNELS;
break;
}
}
const auto* tf_options = reinterpret_cast<const TfLiteConcatenationParams*>(
tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
graph, node));
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
absl::Status SetAxis(const std::vector<BHWC>& input_shapes, Axis* axis) {
*axis = Axis::BATCH;
for (int i = 1; i < input_shapes.size(); i++) {
if (input_shapes[0].h != input_shapes[i].h &&
input_shapes[0].w != input_shapes[i].w &&
input_shapes[0].c != input_shapes[i].c) {
*axis = Axis::HEIGHT;
break;
}
}
if (*axis == Axis::BATCH) return absl::OkStatus();
for (int i = 1; i < input_shapes.size(); i++) {
if (input_shapes[0].b != input_shapes[i].b &&
input_shapes[0].w != input_shapes[i].w &&
input_shapes[0].c != input_shapes[i].c) {
*axis = Axis::WIDTH;
break;
}
}
if (*axis == Axis::HEIGHT) return absl::OkStatus();
for (int i = 1; i < input_shapes.size(); i++) {
if (input_shapes[0].b != input_shapes[i].b &&
input_shapes[0].h != input_shapes[i].h &&
input_shapes[0].c != input_shapes[i].c) {
*axis = Axis::CHANNELS;
break;
}
}
if (*axis == Axis::WIDTH) return absl::OkStatus();
for (int i = 1; i < input_shapes.size(); i++) {
if (input_shapes[0].b != input_shapes[i].b &&
input_shapes[0].w != input_shapes[i].w &&
input_shapes[0].h != input_shapes[i].h) {
return absl::UnimplementedError(
"Can concatenate tensors only by batch, height, width, or "
"channels.");
}
}
return absl::OkStatus();
}
};
class Conv2DOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
const int runtime_inputs =
GetNumberOfRuntimeInputsForNode(context, tflite_node);
if (runtime_inputs > 2) {
return absl::InternalError(
absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
runtime_inputs, " runtime inputs."));
}
const int runtime_outputs =
GetNumberOfRuntimeOutputsForNode(context, tflite_node);
if (runtime_outputs != 1) {
return absl::InternalError(
absl::StrCat("Expected 1 output tensor(s), but node has ",
runtime_outputs, " runtime outputs."));
}
if (runtime_inputs == 1) {
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
}
TfLiteConvParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckStridesAndDilation(
tf_options->stride_height, tf_options->stride_width,
tf_options->dilation_height_factor, tf_options->dilation_width_factor));
return IsActivationSupported(tf_options->activation);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_2D);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
Convolution2DAttributes attr;
const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
if (runtime_inputs == 2) {
RETURN_IF_ERROR(reader->AddInput(node, 1));
} else { // runtime_inputs == 1;
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
}
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
const auto* tf_options =
reinterpret_cast<const TfLiteConvParams*>(tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
attr.dilations = HW(tf_options->dilation_height_factor,
tf_options->dilation_width_factor);
UpdatePadding(tf_options->padding,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
graph, node));
node->operation.attributes = std::move(attr);
return absl::OkStatus();
}
};
class Convolution2DTransposeBiasParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
TfLiteTransposeConvParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
RETURN_IF_ERROR(
CheckStrides(tf_options->stride_height, tf_options->stride_width));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* params = reinterpret_cast<const TfLiteTransposeConvParams*>(
tflite_node->custom_initial_data);
ConvolutionTransposedAttributes attr;
attr.stride =
params ? HW(params->stride_height, params->stride_width) : HW(1, 1);
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
UpdatePadding(params->padding, graph->FindInputs(node->id)[0]->tensor.shape,
&attr);
node->operation.attributes = std::move(attr);
return absl::OkStatus();
}
};
class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
TfLiteDepthwiseConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckStridesAndDilation(
tf_options->stride_height, tf_options->stride_width,
tf_options->dilation_height_factor, tf_options->dilation_width_factor));
RETURN_IF_ERROR(IsActivationSupported(tf_options->activation));
const int depth_multiplier = tf_options->depth_multiplier;
const auto* input = context->tensors + tflite_node->inputs->data[0];
const auto* filter = context->tensors + tflite_node->inputs->data[1];
const auto* bias = tflite_node->inputs->size > 2
? context->tensors + tflite_node->inputs->data[2]
: nullptr;
const auto* output = context->tensors + tflite_node->outputs->data[0];
if (!input->dims || input->dims->size != 4) {
return absl::InvalidArgumentError("input.dims.size != 4");
}
if (!filter->dims || filter->dims->size != 4) {
return absl::InvalidArgumentError("filter.dims.size != 4");
}
if (!output->dims || output->dims->size != 4) {
return absl::InvalidArgumentError("output.dims.size != 4");
}
if (input->dims->data[0] != output->dims->data[0]) {
return absl::InvalidArgumentError("input.b != output.b");
}
const int input_depth = input->dims->data[3];
const int output_depth = output->dims->data[3];
if (filter->dims->data[3] != output_depth) {
return absl::InvalidArgumentError("filter.i != output.c");
}
if (output_depth != input_depth * depth_multiplier) {
return absl::InvalidArgumentError(
"output.c != input.c * depth_multiplier");
}
if (bias && NumElements(bias) != output_depth) {
return absl::InvalidArgumentError("bias.size != output.c");
}
if (depth_multiplier != 1 && input_depth != 1) {
return absl::UnimplementedError("depth_multiplier != 1 && input.c != 1");
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
DepthwiseConvolution2DAttributes attr;
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
TfLiteDepthwiseConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
attr.dilations = HW(std::max(1, tf_options->dilation_height_factor),
std::max(1, tf_options->dilation_width_factor));
UpdatePadding(tf_options->padding,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
graph, node));
const int depth_multiplier = tf_options->depth_multiplier;
if (depth_multiplier != 1) {
const TfLiteTensor* input = reader->GetInputTensor(0);
const TfLiteTensor* filter = reader->GetInputTensor(1);
const TfLiteTensor* output = reader->GetOutputTensor(0);
TransposeWeights(input, filter, output, depth_multiplier, &attr);
}
node->operation.attributes = std::move(attr);
return absl::OkStatus();
}
private:
// TFLite CPU stores weights as:
// [1, kernel_height, kernel_width, input_depth * depth_multiplier]
// TFLite GPU stores weights as:
// [depth_multiplier, kernel_height, kernel_width, input_depth]
static void TransposeWeights(const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* output, int depth_multiplier,
DepthwiseConvolution2DAttributes* attr) {
const int input_depth = input->dims->data[3];
const int filter_height = filter->dims->data[1];
const int filter_width = filter->dims->data[2];
const int output_depth = output->dims->data[3];
Tensor<OHWI, DataType::FLOAT32> weights;
weights.id = attr->weights.id;
weights.shape =
OHWI(output_depth, filter_height, filter_width, input_depth);
weights.data.resize(weights.shape.DimensionsProduct());
float* dst = &weights.data[0];
for (int j = 0; j < output_depth; ++j) {
const float* src = attr->weights.data.data() + j;
for (int i = 0; i < filter_height * filter_width; ++i) {
*dst = *src;
dst++;
src += output_depth;
}
}
attr->weights = std::move(weights);
}
};
class DequantizeOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
// 'Dequantize' is rewritten as QuantizeAndDequantize since we are dealing
// with floating-point versions of the original tensors.
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
// Quantization attributes should already be present in the input tensor.
auto input_value = graph->FindInputs(node->id)[0];
if (!input_value->quant_params) {
return absl::InvalidArgumentError(
"Encountered Dequantize input with no quant params");
}
QuantizeAndDequantizeAttributes attr;
attr.min = input_value->quant_params.value().min;
attr.max = input_value->quant_params.value().max;
attr.scale = input_value->quant_params.value().scale;
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class ElementwiseOperationParser : public TFLiteOperationParser {
public:
explicit ElementwiseOperationParser(OperationType operation_type)
: operation_type_(operation_type) {}
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
if (IsOneArgumentOperation()) {
RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*const_inputs=*/0,
/*outputs=*/1));
// For some elementwise operations (currently only for SUB operation)
// second condition may be false. But it's worth checking the next case
// with const input, which may be supported.
} else if (IsTwoArgumentOperation() &&
CheckInputsConstsOutputs(context, tflite_node,
/*runtime_inputs=*/2,
/*const_inputs=*/0,
/*outputs=*/1)
.ok()) {
} else if (IsTwoArgumentOperationWithConst()) {
RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*const_inputs=*/1,
/*outputs=*/1));
} else {
return absl::InvalidArgumentError(
"Op can only handle 1 or 2 operand(s).");
}
TfLiteFusedActivation activation;
RETURN_IF_ERROR(GetActivation(tflite_node, &activation));
return IsActivationSupported(activation);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(operation_type_);
if (IsOneArgumentOperation()) {
RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
/*runtime_inputs=*/1,
/*const_inputs=*/0,
/*outputs=*/1));
RETURN_IF_ERROR(reader->AddInput(node, 0));
} else if (IsTwoArgumentOperation() &&
reader
->VerifyInputsConstsOutputs(tflite_node,
/*runtime_inputs=*/2,
/*const_inputs=*/0,
/*outputs=*/1)
.ok()) {
if (tflite_node->inputs->size != 2) {
return absl::InvalidArgumentError("Applies only two input tensors");
}
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
TfLiteFusedActivation activation = kTfLiteActNone;
switch (operation_type_) {
case OperationType::SUB: {
const auto* tf_options = reinterpret_cast<const TfLiteSubParams*>(
tflite_node->builtin_data);
if (tf_options != nullptr) {
activation = tf_options->activation;
}
break;
}
case OperationType::DIV: {
const auto* tf_options = reinterpret_cast<const TfLiteDivParams*>(
tflite_node->builtin_data);
if (tf_options != nullptr) {
activation = tf_options->activation;
}
break;
}
default:
// No activation expected.
activation = kTfLiteActNone;
}
if (activation) {
RETURN_IF_ERROR(
MaybeFuseActivationToTheSingleOutput(activation, graph, node));
}
} else if (IsTwoArgumentOperationWithConst()) {
RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
/*runtime_inputs=*/1,
/*const_inputs=*/1,
/*outputs=*/1));
ElementwiseAttributes attr;
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
node->operation.attributes = std::move(attr);
} else {
return absl::InvalidArgumentError("Incorrect operation type passed");
}
return reader->AddOutputs(node);
}
private:
absl::Status GetActivation(const TfLiteNode* tflite_node,
TfLiteFusedActivation* activation) const {
if (operation_type_ == OperationType::DIV) {
TfLiteDivParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
*activation = tf_options ? tf_options->activation : kTfLiteActNone;
return absl::OkStatus();
}
if (operation_type_ == OperationType::SUB) {
TfLiteSubParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
*activation = tf_options ? tf_options->activation : kTfLiteActNone;
return absl::OkStatus();
}
// Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or
// TfLiteXxxParams.activation.
*activation = kTfLiteActNone;
return absl::OkStatus();
}
bool IsOneArgumentOperation() const {
switch (operation_type_) {
case OperationType::ABS:
case OperationType::COS:
case OperationType::EXP:
case OperationType::LOG:
case OperationType::RSQRT:
case OperationType::SIGMOID:
case OperationType::SIN:
case OperationType::SQRT:
case OperationType::SQUARE:
case OperationType::TANH:
return true;
default:
return false;
}
}
bool IsTwoArgumentOperation() const {
switch (operation_type_) {
case OperationType::DIV:
case OperationType::POW:
case OperationType::SQUARED_DIFF:
case OperationType::SUB:
return true;
default:
return false;
}
}
bool IsTwoArgumentOperationWithConst() const {
switch (operation_type_) {
case OperationType::MINIMUM:
case OperationType::MAXIMUM:
case OperationType::SUB:
return true;
default:
return false;
}
}
OperationType operation_type_;
};
class FullyConnectedOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 4));
TfLiteFullyConnectedParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->weights_format !=
kTfLiteFullyConnectedWeightsFormatDefault) {
return absl::UnimplementedError(
"Unsupported FullyConnected weights format.");
}
// TODO(eignasheva): check input shape
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0));
const auto* tf_options =
reinterpret_cast<const TfLiteFullyConnectedParams*>(
tflite_node->builtin_data);
if (tf_options->weights_format !=
kTfLiteFullyConnectedWeightsFormatDefault) {
return absl::UnimplementedError(
"Unsupported FullyConnected weights format.");
}
FullyConnectedAttributes attr;
RETURN_IF_ERROR(GetFullyConnectedAttributes(1, 2, reader, &attr));
Tensor<HW, DataType::FLOAT32> weights;
RETURN_IF_ERROR(reader->ReadTensor(1, &weights));
auto input = graph->FindInputs(node->id)[0];
int batch_size = input->tensor.shape.b;
if (input->tensor.shape.DimensionsProduct() / batch_size !=
weights.shape.w) {
return absl::UnimplementedError(
"Amount of input data should match weights width");
}
Node* conv = node;
if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
auto& reshape = node;
conv = graph->NewNode(); // reset conv pointer!
Value* reshaped_value = graph->NewValue();
reshaped_value->tensor.type = DataType::FLOAT32;
reshaped_value->tensor.shape =
BHWC(input->tensor.shape.b, 1, 1, weights.shape.w);
RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id));
reshape->operation.type = ToString(OperationType::RESHAPE);
ReshapeAttributes attr;
attr.new_shape = reshaped_value->tensor.shape;
reshape->operation.attributes = attr;
RETURN_IF_ERROR(graph->AddConsumer(conv->id, reshaped_value->id));
}
conv->operation.type = ToString(OperationType::FULLY_CONNECTED);
conv->operation.attributes = std::move(attr);
absl::Status result = reader->AddOutputs(conv);
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
graph, conv));
return result;
}
};
class HardSwishOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration*) final {
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
absl::Status Parse(const TfLiteNode*, const TfLiteRegistration*,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::HARD_SWISH);
RETURN_IF_ERROR(reader->AddInput(node, 0));
return reader->AddOutputs(node);
}
};
// Basic LSTM Cell:
//
// 1name = name is at input index 1
// name1 = name is at output index 1
//
// 0input 1prev_activ
// \ /
// [[concat]]
// \
// concat_temp2 2weights 3biases
// \ / /
// [[fully-connected]]
// \
// activ_temp3 4prev_state
// \ /
// [[LSTM]]
// / \
// new_state1 activation0
//
class LSTMOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckExactSupportedOpVersion(registration, 2));
// TODO(eignasheva): Fix bad check.
// RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
// /*runtime_inputs=*/5,
// /*outputs=*/4));
TfLiteLSTMParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckParameters(tf_options));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
if (tflite_node->inputs->size != 5) {
return absl::InvalidArgumentError("LSTM should have 5 input tensors");
}
if (tflite_node->outputs->size != 4) {
return absl::InvalidArgumentError("LSTM should have 4 output tensors");
}
const auto* params =
reinterpret_cast<const TfLiteLSTMParams*>(tflite_node->builtin_data);
if (!params) {
return absl::InternalError("Missing tflite params");
}
RETURN_IF_ERROR(CheckParameters(params));
Node* concat_node = graph->NewNode();
concat_node->operation.type = ToString(OperationType::CONCAT);
ConcatAttributes concat_attr;
concat_attr.axis = Axis::CHANNELS;
concat_node->operation.attributes = concat_attr;
Node* fc_node = graph->NewNode();
fc_node->operation.type = ToString(OperationType::FULLY_CONNECTED);
FullyConnectedAttributes fc_attr;
RETURN_IF_ERROR(GetFullyConnectedAttributes(2, 3, reader, &fc_attr));
fc_node->operation.attributes = std::move(fc_attr);
Node* lstm_node = graph->NewNode();
lstm_node->operation.type = ToString(OperationType::LSTM);
LstmAttributes lstm_attr;
lstm_attr.kernel_type = LstmKernelType::BASIC;
lstm_node->operation.attributes = lstm_attr;
Value* concat_temp;
int concat_tensor_idx = tflite_node->outputs->data[2];
RETURN_IF_ERROR(
reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
Value* activ_temp;
int activ_tensor_idx = tflite_node->outputs->data[3];
RETURN_IF_ERROR(
reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
RETURN_IF_ERROR(reader->AddInput(concat_node, 0)); // input
RETURN_IF_ERROR(reader->AddInput(concat_node, 1)); // prev_activ
RETURN_IF_ERROR(graph->SetProducer(concat_node->id, concat_temp->id));
RETURN_IF_ERROR(graph->AddConsumer(fc_node->id, concat_temp->id));
RETURN_IF_ERROR(graph->SetProducer(fc_node->id, activ_temp->id));
RETURN_IF_ERROR(graph->AddConsumer(lstm_node->id, activ_temp->id));
RETURN_IF_ERROR(reader->AddInput(lstm_node, 4)); // prev_state
RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1)); // new_state
RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0)); // activation
return absl::OkStatus();
}
private:
absl::Status CheckParameters(const TfLiteLSTMParams* tf_options) {
if (tf_options->kernel_type !=
TfLiteLSTMKernelType::kTfLiteLSTMBasicKernel) {
return absl::UnimplementedError(
"Only kTfLiteLSTMBasicKernel is supported.");
}
if (tf_options->activation != kTfLiteActTanh) {
return absl::UnimplementedError("Only TANH activation is supported.");
}
if (tf_options->cell_clip != 0.0f) {
return absl::UnimplementedError("cell_clip is not supported.");
}
if (tf_options->proj_clip != 0.0f) {
return absl::UnimplementedError("proj_clip is not supported.");
}
return absl::OkStatus();
}
};
class MulOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
if (tflite_node->inputs->size != 2) {
return absl::UnimplementedError("MUL requires two input tensors.");
}
TfLiteMulParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return IsActivationSupported(tf_options->activation);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
// Determine runtime/constant tensors.
const TfLiteTensor* input0 = reader->GetInputTensor(0);
if (!input0) {
return absl::InvalidArgumentError(
"Couldn't get the 1st input tensor for MUL.");
}
const TfLiteTensor* input1 = reader->GetInputTensor(1);
if (!input1) {
return absl::InvalidArgumentError(
"Couldn't get the 2nd input tensor for MUL.");
}
const bool constant_tensor0 = IsConstantTensor(input0);
const bool constant_tensor1 = IsConstantTensor(input1);
if (constant_tensor0 && constant_tensor1) {
return absl::InvalidArgumentError("No runtime input tensors for MUL.");
}
const bool runtime_tensor0 = !constant_tensor0;
const bool runtime_tensor1 = !constant_tensor1;
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MUL);
// The "larger" input tensor must be bound to 1st input and the "smaller"
// input tensor ("mask") must be bound to 2nd input.
if (runtime_tensor0 && runtime_tensor1) {
BHWC shape0;
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
BHWC shape1;
RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1));
int input_tensor0 = 0;
int input_tensor1 = 1;
if (shape0.h <= shape1.h && shape0.w <= shape1.w &&
shape0.c == shape1.c) {
input_tensor0 = 1;
input_tensor1 = 0;
}
RETURN_IF_ERROR(
ParseApplyMask(node, input_tensor0, input_tensor1, graph, reader));
} else {
// The runtime input tensor must be bound to 1st input and the constant
// input tensor must be bound to 2nd input.
int runtime_tensor = 0;
int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims;
if (constant_tensor0 && runtime_tensor1) {
runtime_tensor = 1;
constant_tensor = 0;
constant_dims = input0->dims;
}
RETURN_IF_ERROR(ParseMultiplyScalar(node, runtime_tensor, constant_tensor,
constant_dims, graph, reader));
}
const auto* tf_options =
reinterpret_cast<const TfLiteMulParams*>(tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing TfLiteMulParams");
}
return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph,
node);
}
private:
absl::Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1,
GraphFloat32* graph, ObjectReader* reader) {
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
return reader->AddOutputs(node);
}
absl::Status ParseMultiplyScalar(Node* node, int runtime_tensor,
int constant_tensor,
const TfLiteIntArray* constant_dims,
GraphFloat32* graph, ObjectReader* reader) {
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
MultiplyAttributes attr;
if (constant_dims->size <= 0) {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = tensor.data[0];
} else {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = std::move(tensor);
}
node->operation.attributes = std::move(attr);
return reader->AddOutputs(node);
}
};
class PReLUOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
// TODO(eignasheva): add params check
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::PRELU);
RETURN_IF_ERROR(reader->AddInput(node, 0));
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
PReLUAttributes attr;
Tensor<Linear, DataType::FLOAT32> linear_alpha;
absl::Status status = reader->ReadTensor(1, &linear_alpha);
if (status.ok()) {
if (linear_alpha.shape.v != input_shape.c) {
return absl::InvalidArgumentError(
"Linear alpha shape does not match the number of input channels.");
}
attr.alpha = std::move(linear_alpha);
} else {
Tensor<HWC, DataType::FLOAT32> hwc_alpha;
RETURN_IF_ERROR(reader->ReadTensor(1, &hwc_alpha));
if (hwc_alpha.shape.h != input_shape.h ||
hwc_alpha.shape.w != input_shape.w ||
hwc_alpha.shape.c != input_shape.c) {
return absl::InvalidArgumentError(
"Alpha shape does not match input shape.");
}
attr.alpha = std::move(hwc_alpha);
}
node->operation.attributes = std::move(attr);
return reader->AddOutputs(node);
}
};
class PadOperationParser : public TFLiteOperationParser {
public:
explicit PadOperationParser(bool mirror_pad) : mirror_pad_(mirror_pad) {}
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
if (mirror_pad_) {
auto* tf_options = reinterpret_cast<const TfLiteMirrorPaddingParams*>(
tflite_node->builtin_data);
if (tf_options->mode !=
TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect) {
return absl::InvalidArgumentError(
"Only Reflective padding is supported for Mirror Pad operation.");
}
}
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
auto pad_tensor = tflite::GetInput(context, tflite_node, 1);
if (pad_tensor->dims->size != 2) {
return absl::InvalidArgumentError(absl::StrCat(
"Invalid paddings tensor dimension: expected 2 dim, got ",
pad_tensor->dims->size, " dim"));
}
if (pad_tensor->dims->data[0] != 4 || pad_tensor->dims->data[1] != 2) {
return absl::InvalidArgumentError(absl::StrCat(
"Invalid paddings tensor shape: expected 4x2, got ",
pad_tensor->dims->data[0], "x", pad_tensor->dims->data[1]));
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::PAD);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
PadAttributes attr;
if (mirror_pad_) {
attr.type = PaddingContentType::REFLECT;
} else /*zero pad*/ {
attr.type = PaddingContentType::ZEROS;
}
Tensor<HW, DataType::INT32> paddings;
RETURN_IF_ERROR(reader->ReadTensor(1, &paddings));
// 4x2 tensor with paddings.
if (paddings.shape.h != 4 || paddings.shape.w != 2) {
// It shouldn't fail here since it's checked at IsSupported().
return absl::InvalidArgumentError(
"Paddings tensor has unexpected shape.");
}
attr.prepended = BHWC(paddings.data[0], paddings.data[2], paddings.data[4],
paddings.data[6]);
attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5],
paddings.data[7]);
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
bool mirror_pad_ = false;
};
class Pooling2DOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
TfLitePoolParams* tf_options = nullptr;
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
if (status.ok()) { // custom case with indices as a second output
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*outputs=*/2));
} else { // common pooling with 1 output
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*outputs=*/1));
}
RETURN_IF_ERROR(CheckKernelsAndStrides(
tf_options->filter_height, tf_options->filter_width,
tf_options->stride_height, tf_options->stride_width));
return IsActivationSupported(tf_options->activation);
}
public:
explicit Pooling2DOperationParser(PoolingType type) : type_(type) {}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::POOLING_2D);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutput(node, 0));
Pooling2DAttributes attr;
attr.type = type_;
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
// check whether there are custom options encoded. It happens if operation
// is MaxPoolingWithArgmax2D. There is no way to read
// tflite_node->builtin_code, so, simply check whether custom data is
// available.
auto* tf_options = reinterpret_cast<const TfLitePoolParams*>(
tflite_node->custom_initial_data);
if (!tf_options) {
tf_options =
reinterpret_cast<const TfLitePoolParams*>(tflite_node->builtin_data);
}
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
std::vector<uint32_t> max_tensor_id{0};
RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, max_tensor_id,
graph, node));
// Second output is optional. It is not required, it but must be added after
// MaybeAddFusedActivation function is called
reader->AddOutput(node, 1).IgnoreError();
// First output is the result of pooling operation, while second output is
// indices used for pooling.
auto outputs = graph->FindOutputs(node->id);
attr.output_indices = outputs.size() == 2;
if (attr.output_indices) {
// Fix data type for output indices. In the model it is set as float32.
outputs[1]->tensor.type = DataType::INT32;
}
RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
const PoolingType type_;
};
class QuantizeOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
// 'Quantize' is rewritten as QuantizeAndDequantize since we are dealing
// with floating-point versions of the original tensors.
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
// Quantization attributes should already be present in the output tensor.
auto output_value = graph->FindOutputs(node->id)[0];
if (!output_value->quant_params) {
return absl::InvalidArgumentError(
"Encountered Quantize output with no quant params");
}
QuantizeAndDequantizeAttributes attr;
attr.min = output_value->quant_params.value().min;
attr.max = output_value->quant_params.value().max;
attr.scale = output_value->quant_params.value().scale;
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class ReLUOperationParser : public TFLiteOperationParser {
public:
explicit ReLUOperationParser(int clip) : clip_(clip) {}
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::RELU);
RETURN_IF_ERROR(reader->AddInput(node, 0));
ReLUAttributes attr;
TfLiteLeakyReluParams* tf_options = nullptr;
RetrieveBuiltinData(tflite_node, &tf_options).IgnoreError();
attr.alpha = tf_options ? tf_options->alpha : 0;
attr.clip = clip_;
node->operation.attributes = attr;
return reader->AddOutputs(node);
}
private:
const int clip_;
};
class ReshapeOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
// TODO(eignasheva): add shape checking
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::RESHAPE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
// Here we may have extra inputs. Other tensors were supposed to
// define new shape, but in TFLite these are ignored.
// TODO(akulik): check that shapes match?
// New shape comes from output shape.
ReshapeAttributes attr;
attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class Resize2DOperationParser : public TFLiteOperationParser {
public:
explicit Resize2DOperationParser(SamplingType sampling_type)
: sampling_type_(sampling_type) {}
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node));
bool align_corners;
RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &align_corners));
bool half_pixel_centers;
RETURN_IF_ERROR(GetHalfPixelCentersValue(tflite_node, &half_pixel_centers));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::RESIZE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
// Here we may have extra inputs. Other tensors were supposed to
// define new shape, but in TFLite these are ignored.
Resize2DAttributes attr;
RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &attr.align_corners));
RETURN_IF_ERROR(
GetHalfPixelCentersValue(tflite_node, &attr.half_pixel_centers));
attr.type = sampling_type_;
attr.new_shape.CopyAllDefinedAxis(
graph->FindOutputs(node->id)[0]->tensor.shape);
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
absl::Status GetAlignCornersValue(const TfLiteNode* tflite_node,
bool* align_corners) {
switch (sampling_type_) {
case SamplingType::BILINEAR:
return GetAlignCornersValueForType<TfLiteResizeBilinearParams>(
tflite_node, align_corners);
case SamplingType::NEAREST:
return GetAlignCornersValueForType<TfLiteResizeNearestNeighborParams>(
tflite_node, align_corners);
case SamplingType::UNKNOWN:
return absl::InternalError("Sampling type is not specified");
}
return absl::OkStatus();
}
template <class T>
absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node,
bool* align_corners) {
const auto* tf_options =
reinterpret_cast<const T*>(tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
*align_corners = tf_options->align_corners;
return absl::OkStatus();
}
absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node,
bool* half_pixel_centers) {
if (sampling_type_ == SamplingType::BILINEAR) {
const auto* tf_options = reinterpret_cast<TfLiteResizeBilinearParams*>(
tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError(
"Missing tflite params for ResizeBilinear op");
}
if (tf_options->align_corners && tf_options->half_pixel_centers) {
return absl::InternalError(
"If half_pixel_centers is True, align_corners must be False.");
}
*half_pixel_centers = tf_options->half_pixel_centers;
} else {
*half_pixel_centers = false;
}
return absl::OkStatus();
}
absl::Status CheckOnlyUpsamplingIsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node) {
const auto* input = context->tensors + tflite_node->inputs->data[0];
const auto* output = context->tensors + tflite_node->outputs->data[0];
if (!input->dims || input->dims->size != 4) {
return absl::InvalidArgumentError("input.dims.size != 4");
}
if (!output->dims || output->dims->size != 4) {
return absl::InvalidArgumentError("output.dims.size != 4");
}
if (output->dims->data[1] < input->dims->data[1] ||
output->dims->data[2] < input->dims->data[2]) {
return absl::InvalidArgumentError(absl::StrCat(
"Only upsampling is supported, received output h,w = ",
output->dims->data[1], ",", output->dims->data[2],
" input h,w = ", input->dims->data[1], ",", input->dims->data[2]));
}
return absl::OkStatus();
}
SamplingType sampling_type_ = SamplingType::UNKNOWN;
};
class SliceOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SLICE);
RETURN_IF_ERROR(reader->AddOutputs(node));
Value* input;
RETURN_IF_ERROR(reader->ReadValue(0, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
SliceAttributes attr;
attr.strides = BHWC(1, 1, 1, 1);
Tensor<Linear, DataType::INT32> starts, sizes;
RETURN_IF_ERROR(reader->ReadTensor(1, &starts));
RETURN_IF_ERROR(reader->ReadTensor(2, &sizes));
if (starts.data.size() != sizes.data.size()) {
return absl::InvalidArgumentError("Starts amount != sizes amount.");
}
const auto& in_shape = input->tensor.shape;
if (starts.data.size() == 4) {
sizes.data[0] =
sizes.data[0] != -1 ? sizes.data[0] : in_shape.b - starts.data[0];
sizes.data[1] =
sizes.data[1] != -1 ? sizes.data[1] : in_shape.h - starts.data[1];
sizes.data[2] =
sizes.data[2] != -1 ? sizes.data[2] : in_shape.w - starts.data[2];
sizes.data[3] =
sizes.data[3] != -1 ? sizes.data[3] : in_shape.c - starts.data[3];
attr.starts =
BHWC(starts.data[0], starts.data[1], starts.data[2], starts.data[3]);
attr.ends =
BHWC(starts.data[0] + sizes.data[0], starts.data[1] + sizes.data[1],
starts.data[2] + sizes.data[2], starts.data[3] + sizes.data[3]);
} else if (starts.data.size() == 3) {
sizes.data[0] =
sizes.data[0] != -1 ? sizes.data[0] : in_shape.h - starts.data[0];
sizes.data[1] =
sizes.data[1] != -1 ? sizes.data[1] : in_shape.w - starts.data[1];
sizes.data[2] =
sizes.data[2] != -1 ? sizes.data[2] : in_shape.c - starts.data[2];
attr.starts = BHWC(0, starts.data[0], starts.data[1], starts.data[2]);
attr.ends =
BHWC(in_shape.b, starts.data[0] + sizes.data[0],
starts.data[1] + sizes.data[1], starts.data[2] + sizes.data[2]);
} else {
return absl::UnimplementedError(
"Slicing is supported for 3 or 4 dimensional tensors only.");
}
RETURN_IF_ERROR(UpdateIfNegative(in_shape, &attr));
auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
if ((attr.ends.b - attr.starts.b) != out_shape.b) {
return absl::UnimplementedError("Output batch don't match");
}
if ((attr.ends.h - attr.starts.h) != out_shape.h) {
return absl::UnimplementedError("Output height doesn't match");
}
if ((attr.ends.w - attr.starts.w) != out_shape.w) {
return absl::UnimplementedError("Output width doesn't match");
}
if ((attr.ends.c - attr.starts.c) != out_shape.c) {
return absl::UnimplementedError("Output channels don't match");
}
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
absl::Status UpdateIfNegative(const BHWC& input_shape,
SliceAttributes* attr) {
if (attr->ends.h < 0) {
attr->ends.h = input_shape.h + attr->ends.h;
}
if (attr->ends.w < 0) {
attr->ends.w = input_shape.w + attr->ends.w;
}
if (attr->ends.c < 0) {
attr->ends.c = input_shape.c + attr->ends.c;
}
if (attr->ends.b < 0) {
attr->ends.b = input_shape.b + attr->ends.b;
}
return absl::OkStatus();
}
};
class SoftmaxOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
TfLiteSoftmaxParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->beta != 1) {
// TODO(eignasheva): figure out, what's wrong with softmax.
return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SOFTMAX);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options =
reinterpret_cast<const TfLiteSoftmaxParams*>(tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
if (tf_options->beta != 1) {
// there is multiply by scalar operation fused in softmax. Make a layer
// out of it before softmax.
return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
// auto mul_node = reader->NewPassthroughNode(node);
// mul_node->operation.type = ToString(OperationType::MUL);
}
SoftmaxAttributes attr;
attr.axis = Axis::CHANNELS; // always by channels
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class SpaceToDepthOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
// TODO(impjdi): Dims check.
TfLiteSpaceToDepthParams* s2d_params = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params));
if (s2d_params->block_size == 1) {
return absl::InvalidArgumentError(
"SPACE_TO_DEPTH block_size = 1 is a no-op.");
}
if (s2d_params->block_size < 1) {
return absl::InvalidArgumentError(
"SPACE_TO_DEPTH block_size must be > 1.");
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SPACE_TO_DEPTH);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options = reinterpret_cast<const TfLiteSpaceToDepthParams*>(
tflite_node->builtin_data);
SpaceToDepthAttributes attr;
attr.block_size = tf_options->block_size;
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class StridedSliceOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
TfLiteStridedSliceParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SLICE);
RETURN_IF_ERROR(reader->AddOutputs(node));
Value* input;
RETURN_IF_ERROR(reader->ReadValue(0, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
Tensor<Linear, DataType::INT32> tmp;
RETURN_IF_ERROR(reader->ReadTensor(1, &tmp));
bool read_without_batch = tmp.data.size() == 3;
bool read_with_batch = tmp.data.size() == 4;
if (!read_without_batch && !read_with_batch) {
return absl::UnimplementedError(
"Slicing is supported for 3 or 4 dimensional tensors only.");
}
const auto* tf_options = reinterpret_cast<const TfLiteStridedSliceParams*>(
tflite_node->builtin_data);
auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
SliceAttributes attr;
if (read_without_batch) {
RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options,
input->tensor.shape, &attr));
}
if (read_with_batch) {
RETURN_IF_ERROR(
ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr));
}
if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 ||
attr.strides.c == 0) {
return absl::InvalidArgumentError("stride values must be non-zero");
}
if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 ||
attr.strides.c < 0) {
return absl::UnimplementedError("Reverse slices are not supported.");
}
if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b !=
out_shape.b) {
return absl::UnimplementedError("Output batch don't match");
}
if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h !=
out_shape.h) {
return absl::UnimplementedError("Output height doesn't match");
}
if ((attr.ends.w - attr.starts.w + attr.strides.w - 1) / attr.strides.w !=
out_shape.w) {
return absl::UnimplementedError("Output width doesn't match");
}
if ((attr.ends.c - attr.starts.c + attr.strides.c - 1) / attr.strides.c !=
out_shape.c) {
return absl::UnimplementedError("Output channels don't match");
}
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
absl::Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape, int ignore_b,
int ignore_h, int ignore_w, int ignore_c,
SliceAttributes* attr) {
if (tf_options->begin_mask & ignore_h) {
attr->starts.h = 0;
}
if (tf_options->begin_mask & ignore_w) {
attr->starts.w = 0;
}
if (tf_options->begin_mask & ignore_c) {
attr->starts.c = 0;
}
if (tf_options->begin_mask & ignore_b) {
attr->starts.b = 0;
}
if (tf_options->end_mask & ignore_h) {
attr->ends.h = input_shape.h;
}
if (tf_options->end_mask & ignore_w) {
attr->ends.w = input_shape.w;
}
if (tf_options->end_mask & ignore_c) {
attr->ends.c = input_shape.c;
}
if (tf_options->end_mask & ignore_b) {
attr->ends.b = input_shape.b;
}
return absl::OkStatus();
}
absl::Status UpdateIfNegative(const BHWC& input_shape,
SliceAttributes* attr) {
if (attr->ends.h < 0) {
attr->ends.h = input_shape.h + attr->ends.h;
}
if (attr->ends.w < 0) {
attr->ends.w = input_shape.w + attr->ends.w;
}
if (attr->ends.c < 0) {
attr->ends.c = input_shape.c + attr->ends.c;
}
if (attr->ends.b < 0) {
attr->ends.b = input_shape.b + attr->ends.b;
}
return absl::OkStatus();
}
absl::Status ReadAttribsWithBatch(const ObjectReader* reader,
const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape,
SliceAttributes* attr) {
auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
Tensor<Linear, DataType::INT32> t;
RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
*bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]);
return absl::OkStatus();
};
RETURN_IF_ERROR(read_bhwc(1, &attr->starts));
RETURN_IF_ERROR(read_bhwc(2, &attr->ends));
RETURN_IF_ERROR(read_bhwc(3, &attr->strides));
RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr));
return absl::OkStatus();
}
absl::Status ReadAttribsWithoutBatch(
const ObjectReader* reader, const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape, SliceAttributes* attr) {
auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
Tensor<Linear, DataType::INT32> t;
RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
*bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]);
return absl::OkStatus();
};
RETURN_IF_ERROR(read_hwc(1, &attr->starts));
RETURN_IF_ERROR(read_hwc(2, &attr->ends));
RETURN_IF_ERROR(read_hwc(3, &attr->strides));
RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr));
attr->starts.b = 0;
attr->ends.b = input_shape.b;
attr->strides.b = 1;
return absl::OkStatus();
}
absl::Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) {
if (tf_options->ellipsis_mask) {
return absl::UnimplementedError("Slice does not support ellipsis_mask.");
}
if (tf_options->new_axis_mask) {
return absl::UnimplementedError("Slice does not support new_axis_mask.");
}
if (tf_options->shrink_axis_mask) {
return absl::UnimplementedError(
"Slice does not support shrink_axis_mask parameter. ");
}
return absl::OkStatus();
}
};
class TransposeConvOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
TfLiteTransposeConvParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(
CheckStrides(tf_options->stride_height, tf_options->stride_width));
return absl::OkStatus();
}
// TFLite's TRANSPOSE_CONV expects 3 input (output shape, weights, and input)
// and allows configurable padding & stride.
// TODO(impjdi): Translate output_shape to attr.adjacent.
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
Value* input;
RETURN_IF_ERROR(reader->ReadValue(2, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options = reinterpret_cast<const TfLiteTransposeConvParams*>(
tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing tflite options.");
}
ConvolutionTransposedAttributes attr;
attr.stride = tf_options
? HW(tf_options->stride_height, tf_options->stride_width)
: HW(1, 1);
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
// TFLite does not support bias.
UpdatePadding(tf_options->padding,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
node->operation.attributes = std::move(attr);
return absl::OkStatus();
}
};
class TransposeOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::TRANSPOSE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
TransposeAttributes attr;
Tensor<Linear, DataType::INT32> perm;
RETURN_IF_ERROR(reader->ReadTensor(1, &perm));
if (perm.data.size() == 4) {
attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[2], perm.data[3]);
} else if (perm.data.size() == 3) {
attr.perm = BHWC(0, perm.data[0] + 1, perm.data[1] + 1, perm.data[2] + 1);
} else if (perm.data.size() == 2) {
attr.perm = BHWC(0, 1, perm.data[0] + 2, perm.data[1] + 2);
} else {
return absl::InvalidArgumentError(
"Permutation for transpose is invalid.");
}
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class Unpooling2DOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
TfLitePoolParams* tf_options = nullptr;
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckKernelsAndStrides(
tf_options->filter_height, tf_options->filter_width,
tf_options->stride_height, tf_options->stride_width));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
RETURN_IF_ERROR(reader->AddOutputs(node));
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
MaxUnpooling2DAttributes attr;
const auto* tf_options = reinterpret_cast<const TfLitePoolParams*>(
tflite_node->custom_initial_data);
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
UpdatePadding(tf_options->padding, input_shape, &attr);
node->operation.attributes = attr;
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = CalculateOutputShape(input_shape, attr);
return absl::OkStatus();
}
};
// TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
class BatchToSpaceOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::BATCH_TO_SPACE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
BatchToSpaceAttributes bs_attr;
Tensor<Linear, DataType::INT32> block;
RETURN_IF_ERROR(reader->ReadTensor(1, &block));
if (block.shape.v != 2) {
return absl::InternalError("Space has to be HxW.");
}
bs_attr.block.h = block.data[0];
bs_attr.block.w = block.data[1];
Tensor<HW, DataType::INT32> crop;
RETURN_IF_ERROR(reader->ReadTensor(2, &crop));
auto crop_shape = crop.shape;
if (crop_shape.h != 2 && crop_shape.w != 2) {
return absl::InternalError("Space has to be HxW.");
}
bs_attr.crop.prepended.h = crop.data[0];
bs_attr.crop.prepended.w = crop.data[2];
bs_attr.crop.appended.h = crop.data[1];
bs_attr.crop.appended.w = crop.data[3];
node->operation.attributes = std::move(bs_attr);
return absl::OkStatus();
}
};
class SpaceToBatchOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::SPACE_TO_BATCH);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
SpaceToBatchAttributes sb_attr;
Tensor<Linear, DataType::INT32> block;
RETURN_IF_ERROR(reader->ReadTensor(1, &block));
if (block.shape.v != 2) {
return absl::InternalError("Space has to be HxW.");
}
sb_attr.block.h = block.data[0];
sb_attr.block.w = block.data[1];
Tensor<HW, DataType::INT32> padding;
RETURN_IF_ERROR(reader->ReadTensor(2, &padding));
auto padding_shape = padding.shape;
if (padding_shape.h != 2 && padding_shape.w != 2) {
return absl::InternalError("Space has to be HxW.");
}
sb_attr.padding.prepended.h = padding.data[0];
sb_attr.padding.prepended.w = padding.data[2];
sb_attr.padding.appended.h = padding.data[1];
sb_attr.padding.appended.w = padding.data[3];
node->operation.attributes = std::move(sb_attr);
return absl::OkStatus();
}
};
class RoIToTransformMatrixOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "roi_to_transform_matrix";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size,
&(node->operation.attributes), &output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = output_shape;
return absl::OkStatus();
}
private:
};
class RoIToTransformMatrixV2OperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "roi_to_transform_matrix_v2";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size,
&(node->operation.attributes), &output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = output_shape;
return absl::OkStatus();
}
private:
};
class TransformTensorOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "transform_tensor";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size,
&(node->operation.attributes), &output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape =
BHWC(1, output_shape.h, output_shape.w,
graph->FindInputs(node->id)[0]->tensor.shape.c);
return absl::OkStatus();
}
private:
};
class TransformTensorBilinearV2OperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "transform_tensor_bilinear_v2";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size,
&(node->operation.attributes), &output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape =
BHWC(1, output_shape.h, output_shape.w,
graph->FindInputs(node->id)[0]->tensor.shape.c);
return absl::OkStatus();
}
private:
};
class TransformLandmarksOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "transform_landmarks";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size,
&(node->operation.attributes), &output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
return absl::OkStatus();
}
private:
};
class TransformLandmarksV2OperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "transform_landmarks_v2";
node->operation.type = op_name;
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
BHWC output_shape = output_value->tensor.shape;
RETURN_IF_ERROR(
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size,
&(node->operation.attributes), &output_shape));
return absl::OkStatus();
}
private:
};
class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks
RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix
const std::string op_name = "landmarks_to_transform_matrix";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size,
&(node->operation.attributes), &output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = output_shape;
return absl::OkStatus();
}
};
class Landmarks2TransformMatrixV2OperationParser
: public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks
RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix
const std::string op_name = "landmarks_to_transform_matrix_v2";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size,
&(node->operation.attributes), &output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = output_shape;
return absl::OkStatus();
}
};
class AlignmentPointsToTransformMatrixOperationParser
: public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // alignment points
RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix
const std::string op_name = "alignment_points_to_transform_matrix";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size,
&(node->operation.attributes), &output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = output_shape;
return absl::OkStatus();
}
private:
};
class MeanOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::MEAN);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
MeanAttributes attr;
Tensor<Linear, DataType::INT32> channel;
RETURN_IF_ERROR(reader->ReadTensor(1, &channel));
for (int i = 0; i < channel.data.size(); i++) {
std::string unsupported;
switch (channel.data[i]) {
case 1:
attr.dims.insert(Axis::HEIGHT);
break;
case 2:
attr.dims.insert(Axis::WIDTH);
break;
case 0:
unsupported = unsupported.empty() ? "batch" : unsupported;
ABSL_FALLTHROUGH_INTENDED;
case 3:
unsupported = unsupported.empty() ? "channels" : unsupported;
ABSL_FALLTHROUGH_INTENDED;
default:
return absl::UnimplementedError(
absl::StrCat("Unsupported mean dimension: ", unsupported));
}
}
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class UnsupportedOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return absl::UnimplementedError("Operation is not supported.");
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
return absl::UnimplementedError("Operation is not supported.");
}
};
std::unique_ptr<TFLiteOperationParser> NewOperationParser(
const TfLiteRegistration* registration, bool allow_quant_ops = false) {
const auto builtin_code = registration->builtin_code;
switch (builtin_code) {
case kTfLiteBuiltinAbs:
return std::make_unique<ElementwiseOperationParser>(OperationType::ABS);
case kTfLiteBuiltinAdd:
return std::make_unique<AddOperationParser>();
case kTfLiteBuiltinAveragePool2d:
return std::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
case kTfLiteBuiltinConcatenation:
return std::make_unique<ConcatenationOperationParser>();
case kTfLiteBuiltinConv2d:
return std::make_unique<Conv2DOperationParser>();
case kTfLiteBuiltinCos:
return std::make_unique<ElementwiseOperationParser>(OperationType::COS);
case kTfLiteBuiltinDepthwiseConv2d:
return std::make_unique<DepthwiseConvolutionOperationParser>();
case kTfLiteBuiltinDequantize:
if (allow_quant_ops) {
return std::make_unique<DequantizeOperationParser>();
}
break;
case kTfLiteBuiltinDiv:
return std::make_unique<ElementwiseOperationParser>(OperationType::DIV);
case kTfLiteBuiltinFullyConnected:
return std::make_unique<FullyConnectedOperationParser>();
case kTfLiteBuiltinHardSwish:
return std::make_unique<HardSwishOperationParser>();
case kTfLiteBuiltinLogistic:
return std::make_unique<ElementwiseOperationParser>(
OperationType::SIGMOID);
case kTfLiteBuiltinLog:
return std::make_unique<ElementwiseOperationParser>(OperationType::LOG);
case kTfLiteBuiltinLstm:
return std::make_unique<LSTMOperationParser>();
case kTfLiteBuiltinMaximum:
return std::make_unique<ElementwiseOperationParser>(
OperationType::MAXIMUM);
case kTfLiteBuiltinMaxPool2d:
return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
case kTfLiteBuiltinMean:
return std::make_unique<MeanOperationParser>();
case kTfLiteBuiltinMinimum:
return std::make_unique<ElementwiseOperationParser>(
OperationType::MINIMUM);
case kTfLiteBuiltinMirrorPad:
return std::make_unique<PadOperationParser>(/*mirror_pad=*/true);
case kTfLiteBuiltinMul:
return std::make_unique<MulOperationParser>();
case kTfLiteBuiltinPad:
return std::make_unique<PadOperationParser>(/*mirror_pad=*/false);
case kTfLiteBuiltinPow:
return std::make_unique<ElementwiseOperationParser>(OperationType::POW);
case kTfLiteBuiltinQuantize:
if (allow_quant_ops) {
return std::make_unique<QuantizeOperationParser>();
}
break;
case kTfLiteBuiltinRelu:
return std::make_unique<ReLUOperationParser>(0);
case kTfLiteBuiltinRelu6:
return std::make_unique<ReLUOperationParser>(6);
case kTfLiteBuiltinLeakyRelu:
return std::make_unique<ReLUOperationParser>(0);
case kTfLiteBuiltinPrelu:
return std::make_unique<PReLUOperationParser>();
case kTfLiteBuiltinReshape:
return std::make_unique<ReshapeOperationParser>();
case kTfLiteBuiltinResizeBilinear:
return std::make_unique<Resize2DOperationParser>(SamplingType::BILINEAR);
case kTfLiteBuiltinResizeNearestNeighbor:
return std::make_unique<Resize2DOperationParser>(SamplingType::NEAREST);
case kTfLiteBuiltinRsqrt:
return std::make_unique<ElementwiseOperationParser>(OperationType::RSQRT);
case kTfLiteBuiltinSin:
return std::make_unique<ElementwiseOperationParser>(OperationType::SIN);
case kTfLiteBuiltinSlice:
return std::make_unique<SliceOperationParser>();
case kTfLiteBuiltinSoftmax:
return std::make_unique<SoftmaxOperationParser>();
case kTfLiteBuiltinSpaceToDepth:
return std::make_unique<SpaceToDepthOperationParser>();
case kTfLiteBuiltinSqrt:
return std::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
case kTfLiteBuiltinSquare:
return std::make_unique<ElementwiseOperationParser>(
OperationType::SQUARE);
case kTfLiteBuiltinSquaredDifference:
return std::make_unique<ElementwiseOperationParser>(
OperationType::SQUARED_DIFF);
case kTfLiteBuiltinStridedSlice:
return std::make_unique<StridedSliceOperationParser>();
case kTfLiteBuiltinSub:
return std::make_unique<ElementwiseOperationParser>(OperationType::SUB);
case kTfLiteBuiltinTanh:
return std::make_unique<ElementwiseOperationParser>(OperationType::TANH);
case kTfLiteBuiltinTranspose:
return std::make_unique<TransposeOperationParser>();
case kTfLiteBuiltinTransposeConv:
return std::make_unique<TransposeConvOperationParser>();
case kTfLiteBuiltinCustom:
const absl::string_view custom_name = registration->custom_name;
if (custom_name == "Convolution2DTransposeBias") {
return std::make_unique<Convolution2DTransposeBiasParser>();
}
if (custom_name == "MaxPoolingWithArgmax2D") {
return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
}
if (custom_name == "MaxUnpooling2D") {
return std::make_unique<Unpooling2DOperationParser>();
}
if (custom_name == "RoIToTransformMatrix") {
return std::make_unique<RoIToTransformMatrixOperationParser>();
}
if (custom_name == "RoIToTransformMatrixV2") {
return std::make_unique<RoIToTransformMatrixV2OperationParser>();
}
if (custom_name == "TransformTensor") {
return std::make_unique<TransformTensorOperationParser>();
}
if (custom_name == "TransformTensorBilinearV2") {
return std::make_unique<TransformTensorBilinearV2OperationParser>();
}
if (custom_name == "TransformLandmarks") {
return std::make_unique<TransformLandmarksOperationParser>();
}
if (custom_name == "TransformLandmarksV2") {
return std::make_unique<TransformLandmarksV2OperationParser>();
}
if (custom_name == "Landmarks2TransformMatrix") {
return std::make_unique<Landmarks2TransformMatrixOperationParser>();
}
if (custom_name == "Landmarks2TransformMatrixV2") {
return std::make_unique<Landmarks2TransformMatrixV2OperationParser>();
}
if (custom_name == "AlignmentPointsToTransformMatrix") {
return std::make_unique<
AlignmentPointsToTransformMatrixOperationParser>();
}
break;
}
return std::make_unique<UnsupportedOperationParser>();
}
absl::Status IsSupported(const TfLiteContext* context, TfLiteNode* node,
const TfLiteRegistration* registration,
bool allow_quant_ops = false) {
return NewOperationParser(registration, allow_quant_ops)
->IsSupported(context, node, registration);
}
bool IsAllAllowedTensors(TfLiteContext* context,
const TfLiteIntArray* tensor_indices,
bool allow_quant_ops = false) {
for (int i = 0; i < tensor_indices->size; ++i) {
int tensor_idx = tensor_indices->data[i];
if (tensor_idx == kTfLiteOptionalTensor) continue;
const TfLiteTensor* t = &context->tensors[tensor_idx];
bool type_supported =
(t->type == kTfLiteFloat32 || t->type == kTfLiteFloat16);
if (allow_quant_ops) {
// Since we only check non-constant tensors, type cannot be Int32.
type_supported =
type_supported || t->type == kTfLiteInt8 || t->type == kTfLiteUInt8;
}
if (t->allocation_type == kTfLiteArenaRw && !type_supported) {
return false;
}
}
return true;
}
} // namespace
// TODO(impjdi): Check number of input/output tensors and their dimensions.
// TODO(impjdi): Check ops' parameters.
TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops,
int max_delegated_partitions) {
delegates::IsNodeSupportedFn node_supported_fn =
[=](TfLiteContext* context, TfLiteNode* node,
TfLiteRegistration* registration,
std::string* unsupported_details) -> bool {
const auto status =
IsSupported(context, node, registration, allow_quant_ops);
if (!status.ok()) {
if (unsupported_details) {
*unsupported_details = std::string(status.message());
}
return false;
}
if (!IsAllAllowedTensors(context, node->inputs, allow_quant_ops) ||
!IsAllAllowedTensors(context, node->outputs, allow_quant_ops)) {
if (unsupported_details) {
*unsupported_details =
"OP is supported, but tensor type isn't matched!";
}
return false;
}
return true;
};
delegates::FP16GraphPartitionHelper partition_helper(context,
node_supported_fn);
std::set<std::string> unsupported_nodes_info;
if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) {
return TfLiteIntArrayCreate(0);
}
// By default, we simply get 1st largest partition as 'max_delegate_partions'
// is set to 1 by default.
std::vector<int> ops_to_replace =
partition_helper.GetNodesOfFirstNLargestPartitions(
max_delegated_partitions);
if (!unsupported_nodes_info.empty()) {
std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n");
std::string error_message = absl::StrCat(
"Following operations are not supported by GPU delegate:\n",
unsupported, "\n");
if (!ops_to_replace.empty()) {
absl::StrAppend(
&error_message, ops_to_replace.size(),
" operations will run on the GPU, and the remaining ",
partition_helper.num_total_nodes() - ops_to_replace.size());
} else {
absl::StrAppend(&error_message,
"No operations will run on the GPU, and all ",
partition_helper.num_total_nodes());
}
absl::StrAppend(&error_message, " operations will run on the CPU.");
TF_LITE_KERNEL_LOG(context, error_message.c_str());
}
return ConvertVectorToTfLiteIntArray(ops_to_replace);
}
// Creates inputs and outputs passed by io_tensors parameters in the resulting
// graph. We force it to make sure that delegated subgraph has same order of
// inputs and outputs with the original one. When delegated model is built from
// the tflite model representation tensors are created lazily, so there is no
// guarantee that the order will match the source model tensors order.
absl::Status PrecreateIOTensors(
TfLiteContext* context, GraphFloat32* graph, TfLiteIntArray* io_tensors,
std::unordered_map<int, int>* quant_conversion_map,
std::unordered_map<int, Value*>* tensor_to_value) {
for (int i = 0; i < io_tensors->size; ++i) {
const int tensor_index = io_tensors->data[i];
const TfLiteTensor& tflite_tensor = context->tensors[tensor_index];
if (tflite::IsConstantTensor(&tflite_tensor)) continue;
RETURN_IF_ERROR(ObjectReader::ReadNonConstantTensor(
context, tensor_to_value, quant_conversion_map, graph, tensor_index));
}
return absl::OkStatus();
}
absl::Status BuildModel(TfLiteContext* context,
const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph,
std::unordered_map<int, int>* quant_conversion_map) {
std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
std::vector<int> tflite_nodes;
for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
TfLiteNode* tflite_node = nullptr;
TfLiteRegistration* registration = nullptr;
RETURN_IF_ERROR(GetNodeAndRegistration(
context, delegate_params->nodes_to_replace->data[i], &tflite_node,
&registration));
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
context->tensors[tflite_node->inputs->data[0]].type ==
TfLiteType::kTfLiteFloat16) {
// Ignore Fp16 Dequantize nodes.
continue;
}
auto op_parser = NewOperationParser(
registration, /*allow_quant_ops=*/quant_conversion_map != nullptr);
if (!op_parser) {
return absl::UnimplementedError(
absl::StrCat("Operation ", registration->builtin_code, "(",
registration->custom_name,
") is not supported by TFLite GPU Delegate."));
}
operations.push_back(std::move(op_parser));
tflite_nodes.push_back(i);
}
std::unordered_map<int, Value*> tensor_to_value;
RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
delegate_params->input_tensors,
quant_conversion_map, &tensor_to_value));
RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
delegate_params->output_tensors,
quant_conversion_map, &tensor_to_value));
for (int i = 0; i < operations.size(); ++i) {
TfLiteNode* tflite_node;
TfLiteRegistration* registration;
RETURN_IF_ERROR(GetNodeAndRegistration(
context, delegate_params->nodes_to_replace->data[tflite_nodes[i]],
&tflite_node, &registration));
ObjectReader reader(graph, context, tflite_node, &tensor_to_value,
quant_conversion_map);
const auto status =
operations[i]->Parse(tflite_node, registration, graph, &reader);
if (!status.ok()) {
return absl::InternalError(absl::StrCat(
GetOpNameByRegistration(*registration), ": ", status.message()));
}
}
return absl::OkStatus();
}
absl::Status BuildFinalModel(
TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph, std::unordered_map<int, int>* quant_conversion_map) {
RETURN_IF_ERROR(
BuildModel(context, delegate_params, graph, quant_conversion_map));
// Apply general transformations on the graph.
NullTransformationReporter reporter;
ModelTransformer transformer(graph, &reporter);
if (!ApplyGeneralTransformations(&transformer)) {
return absl::InternalError("Graph general transformations failed");
}
return absl::OkStatus();
}
} // namespace gpu
} // namespace tflite