blob: e21e28785017fb76f0cc3a169f9358858f31e051 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
#include <algorithm>
#include <any>
#include <cstdint>
#include <map>
#include <memory>
#include <optional>
#include <set>
#include <string>
#include <utility>
#include <variant>
#include <vector>
#include "absl/base/attributes.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
#include "tensorflow/lite/delegates/gpu/common/object_reader.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
#include "tensorflow/lite/delegates/utils.h"
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/util.h"
namespace tflite {
namespace gpu {
namespace {
absl::Status CheckTensorIsAvailable(const TfLiteContext* context,
const TfLiteNode* tflite_node, int idx) {
// If tensor id is in range, it's guaranteed that it'll be available.
if (idx >= tflite_node->inputs->size) {
return absl::OutOfRangeError(
absl::StrCat("Requested index goes beyond array size: ", idx, " vs ",
idx, tflite_node->inputs->size));
}
return absl::OkStatus();
}
// A parser responsible for parsing TFLite operation and adding it to a graph.
class TFLiteOperationParser {
public:
virtual ~TFLiteOperationParser() = default;
// Parses TFLite operation. This method allows expanding fused operations
// into more than one node.
virtual absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) = 0;
// Verifies whether passed tflite node may be built by GPU delegate or not.
virtual absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) = 0;
// Return the value ids in the graph that correspond to the updated values of
// the variable input tensor.
virtual absl::flat_hash_map<int, ValueId>
GetNewValueIdsForVariableInputNodes() {
return absl::flat_hash_map<int, ValueId>();
}
};
HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); }
template <typename AttrT>
void UpdatePadding(const TfLitePadding& padding, const BHWC& input_shape,
AttrT* attr) {
if (padding == kTfLitePaddingSame) {
attr->padding = CalculateSamePadding(input_shape, *attr);
} else {
attr->padding.prepended = HW(0, 0);
attr->padding.appended = HW(0, 0);
}
}
absl::Status GetFullyConnectedAttributes(int weights_tensor_id,
int bias_tensor_id,
ObjectReader* reader,
FullyConnectedAttributes* attr) {
Tensor<HW, DataType::FLOAT32> weights;
RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
attr->weights.data = std::move(weights.data);
attr->weights.id = weights.id;
attr->weights.shape.h = 1;
attr->weights.shape.w = 1;
attr->weights.shape.o = weights.shape.h;
attr->weights.shape.i = weights.shape.w;
reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError(); // optional
return absl::OkStatus();
}
template <typename ParamsT>
absl::Status RetrieveBuiltinData(const TfLiteNode* tflite_node,
const ParamsT** tf_options) {
*tf_options = static_cast<const ParamsT*>(tflite_node->builtin_data);
if (!*tf_options) {
return absl::InternalError("Unable to retrieve builtin_data.");
}
return absl::OkStatus();
}
template <typename ParamsT>
absl::Status RetrieveCustomInitialData(const TfLiteNode* tflite_node,
const ParamsT** tf_options) {
*tf_options = static_cast<const ParamsT*>(tflite_node->custom_initial_data);
if (!*tf_options) {
return absl::InternalError("Unable to retrieve custom_initial_data.");
}
return absl::OkStatus();
}
absl::Status CheckMaxSupportedOpVersion(const TfLiteRegistration* registration,
int max_version) {
const int op_version = registration->version;
if (op_version > max_version) {
return absl::UnimplementedError(
absl::StrCat("Max version supported: ", max_version,
". Requested version ", op_version, "."));
}
return absl::OkStatus();
}
absl::Status CheckKernels(int kernel_h, int kernel_w) {
if (kernel_h <= 0 || kernel_w <= 0) {
return absl::InvalidArgumentError(
absl::StrCat("Incorrect kernel values: kernel_height = ", kernel_h,
", kernel_width = ", kernel_w));
}
return absl::OkStatus();
}
absl::Status CheckStrides(int strides_h, int strides_w) {
if (strides_h <= 0 || strides_w <= 0) {
return absl::InvalidArgumentError(
absl::StrCat("Incorrect stride values: stride_height = ", strides_h,
", stride_width = ", strides_w));
}
return absl::OkStatus();
}
absl::Status CheckDilation(int dilation_h, int dilation_w) {
if (dilation_h <= 0 || dilation_w <= 0) {
return absl::InvalidArgumentError(absl::StrCat(
"Incorrect dilation values: dilation_factor = ", dilation_h,
", dilation_factor = ", dilation_w));
}
return absl::OkStatus();
}
absl::Status CheckStridesAndDilation(int strides_h, int strides_w,
int dilation_h, int dilation_w) {
RETURN_IF_ERROR(CheckStrides(strides_h, strides_w));
RETURN_IF_ERROR(CheckDilation(dilation_h, dilation_w));
return absl::OkStatus();
}
absl::Status CheckKernelsAndStrides(int kernel_h, int kernel_w, int strides_h,
int strides_w) {
RETURN_IF_ERROR(CheckKernels(kernel_h, kernel_w));
RETURN_IF_ERROR(CheckStrides(strides_h, strides_w));
return absl::OkStatus();
}
// Creates a simple node that holds tensor value.
absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, Value** value) {
ConstTensorAttributes attr;
attr.tensor = std::move(t);
Node* node = graph->NewNode();
node->operation.attributes = attr;
node->operation.type = ToString(OperationType::CONST);
*value = graph->NewValue();
RETURN_IF_ERROR(graph->SetProducer(node->id, (*value)->id));
// Keep data inside this tensor.
(*value)->tensor.ref = attr.tensor.id;
(*value)->tensor.type = attr.tensor.kType;
(*value)->tensor.shape = attr.tensor.shape;
return absl::OkStatus();
}
absl::Status ParsePoolingAttributes(const TfLitePoolParams* tf_options,
const BHWC& input_shape,
Pooling2DAttributes* attr) {
attr->kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
attr->strides = ToHW(tf_options->stride_height, tf_options->stride_width);
UpdatePadding(tf_options->padding, input_shape, attr);
return absl::OkStatus();
}
absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
TensorOrScalar* tensor_or_scalar) {
const std::string& opname = node->operation.type;
// Determine runtime/constant tensors.
const TfLiteTensor* input0 = reader->GetInputTensor(0);
if (!input0) {
return absl::InvalidArgumentError("Couldn't get the 1st input tensor for " +
opname);
}
const TfLiteTensor* input1 = reader->GetInputTensor(1);
if (!input1) {
return absl::InvalidArgumentError("Couldn't get the 2nd input tensor for " +
opname);
}
const bool constant_tensor0 = IsConstantTensor(input0);
const bool constant_tensor1 = IsConstantTensor(input1);
if (constant_tensor0 && constant_tensor1) {
return absl::InvalidArgumentError("No runtime input tensors for " + opname);
}
const bool runtime_tensor0 = !constant_tensor0;
const bool runtime_tensor1 = !constant_tensor1;
if (runtime_tensor0 && runtime_tensor1) {
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
} else {
int runtime_tensor = 0;
int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims;
if (constant_tensor0 && runtime_tensor1) {
runtime_tensor = 1;
constant_tensor = 0;
constant_dims = input0->dims;
}
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = tensor.data[0];
} else {
if (CheckIfLinearConvertible(constant_dims).ok()) {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = std::move(tensor);
} else {
Tensor<HWC, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = std::move(tensor);
}
}
}
return absl::OkStatus();
}
class AddOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
if (tflite_node->inputs->size != 2) {
return absl::UnimplementedError("ADD requires two input tensors.");
}
// TODO(eignasheva): Add shapes check.
const TfLiteAddParams* tf_options;
return RetrieveBuiltinData(tflite_node, &tf_options);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
// TFLite currently only supports 2 input ADDs. Thus, the logic below only
// considers 2 input cases. The underlying GPU shader programs can accept
// more inputs, but the logic below would have to be expanded.
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::ADD);
RETURN_IF_ERROR(reader->AddOutputs(node));
ElementwiseAttributes attr;
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
node->operation.attributes = std::move(attr);
const TfLiteAddParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return MaybeFuseActivation(tf_options->activation, graph, node);
}
};
class BatchedMatMulOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::BATCHED_MATMUL);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
RETURN_IF_ERROR(reader->AddOutputs(node));
return absl::OkStatus();
}
};
class ConcatenationOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
// TODO(eignasheva): add proper tensor availability checking
// for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
// RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, idx));
// }
// TODO(eignasheva): add axis checking.
const TfLiteConcatenationParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
ConcatAttributes attr;
// Read inputs first to make sure const node is added to a graph before
// concat node to ensure topological order.
std::vector<const Value*> inputs;
for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
Value* value;
const auto status = reader->ReadValue(idx, &value);
if (status.ok()) {
inputs.push_back(value);
} else {
TensorFloat32 tensor;
RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
Value* value;
RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
inputs.push_back(value);
}
}
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONCAT);
RETURN_IF_ERROR(reader->AddOutputs(node));
for (const Value* input : inputs) {
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
}
std::vector<BHWC> input_shapes;
for (auto input : graph->FindInputs(node->id)) {
input_shapes.push_back(input->tensor.shape);
}
RETURN_IF_ERROR(SetAxis(input_shapes, &attr.axis));
// Guess axis.
BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
for (auto input : graph->FindInputs(node->id)) {
if (input->tensor.shape.h != output_shape.h) {
attr.axis = Axis::HEIGHT;
break;
}
if (input->tensor.shape.w != output_shape.w) {
attr.axis = Axis::WIDTH;
break;
}
if (input->tensor.shape.c != output_shape.c) {
attr.axis = Axis::CHANNELS;
break;
}
}
const TfLiteConcatenationParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
absl::Status SetAxis(const std::vector<BHWC>& input_shapes, Axis* axis) {
*axis = Axis::BATCH;
for (int i = 1; i < input_shapes.size(); i++) {
if (input_shapes[0].h != input_shapes[i].h &&
input_shapes[0].w != input_shapes[i].w &&
input_shapes[0].c != input_shapes[i].c) {
*axis = Axis::HEIGHT;
break;
}
}
if (*axis == Axis::BATCH) return absl::OkStatus();
for (int i = 1; i < input_shapes.size(); i++) {
if (input_shapes[0].b != input_shapes[i].b &&
input_shapes[0].w != input_shapes[i].w &&
input_shapes[0].c != input_shapes[i].c) {
*axis = Axis::WIDTH;
break;
}
}
if (*axis == Axis::HEIGHT) return absl::OkStatus();
for (int i = 1; i < input_shapes.size(); i++) {
if (input_shapes[0].b != input_shapes[i].b &&
input_shapes[0].h != input_shapes[i].h &&
input_shapes[0].c != input_shapes[i].c) {
*axis = Axis::CHANNELS;
break;
}
}
if (*axis == Axis::WIDTH) return absl::OkStatus();
for (int i = 1; i < input_shapes.size(); i++) {
if (input_shapes[0].b != input_shapes[i].b &&
input_shapes[0].w != input_shapes[i].w &&
input_shapes[0].h != input_shapes[i].h) {
return absl::UnimplementedError(
"Can concatenate tensors only by batch, height, width, or "
"channels.");
}
}
return absl::OkStatus();
}
};
class Conv2DOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 5));
const int runtime_inputs =
GetNumberOfRuntimeInputsForNode(context, tflite_node);
if (runtime_inputs > 2) {
return absl::InternalError(
absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
runtime_inputs, " runtime inputs."));
}
const int runtime_outputs = NumOutputs(tflite_node);
if (runtime_outputs != 1) {
return absl::InternalError(
absl::StrCat("Expected 1 output tensor(s), but node has ",
runtime_outputs, " runtime outputs."));
}
if (runtime_inputs == 1) {
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
}
const TfLiteConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckStridesAndDilation(
tf_options->stride_height, tf_options->stride_width,
tf_options->dilation_height_factor, tf_options->dilation_width_factor));
return IsActivationSupported(tf_options->activation);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_2D);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
Convolution2DAttributes attr;
const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
if (runtime_inputs == 2) {
RETURN_IF_ERROR(reader->AddInput(node, 1));
} else { // runtime_inputs == 1;
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
}
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
const TfLiteConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
attr.dilations = HW(tf_options->dilation_height_factor,
tf_options->dilation_width_factor);
UpdatePadding(tf_options->padding,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
node->operation.attributes = std::move(attr);
return absl::OkStatus();
}
};
class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 6));
const int runtime_inputs =
GetNumberOfRuntimeInputsForNode(context, tflite_node);
if (runtime_inputs > 2) {
return absl::InternalError(
absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
runtime_inputs, " runtime inputs."));
}
const int runtime_outputs = NumOutputs(tflite_node);
if (runtime_outputs != 1) {
return absl::InternalError(
absl::StrCat("Expected 1 output tensor(s), but node has ",
runtime_outputs, " runtime outputs."));
}
if (runtime_inputs == 1) {
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
}
const TfLiteDepthwiseConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckStridesAndDilation(
tf_options->stride_height, tf_options->stride_width,
tf_options->dilation_height_factor, tf_options->dilation_width_factor));
RETURN_IF_ERROR(IsActivationSupported(tf_options->activation));
const int depth_multiplier = tf_options->depth_multiplier;
const auto* input = context->tensors + tflite_node->inputs->data[0];
const auto* filter = context->tensors + tflite_node->inputs->data[1];
const auto* bias = tflite_node->inputs->size > 2
? context->tensors + tflite_node->inputs->data[2]
: nullptr;
const auto* output = context->tensors + tflite_node->outputs->data[0];
if (!input->dims || input->dims->size != 4) {
return absl::InvalidArgumentError("input.dims.size != 4");
}
if (!filter->dims || filter->dims->size != 4) {
return absl::InvalidArgumentError("filter.dims.size != 4");
}
if (!output->dims || output->dims->size != 4) {
return absl::InvalidArgumentError("output.dims.size != 4");
}
if (input->dims->data[0] != output->dims->data[0]) {
return absl::InvalidArgumentError("input.b != output.b");
}
const int input_depth = input->dims->data[3];
const int output_depth = output->dims->data[3];
if (filter->dims->data[3] != output_depth) {
return absl::InvalidArgumentError("filter.i != output.c");
}
if (output_depth != input_depth * depth_multiplier) {
return absl::InvalidArgumentError(
"output.c != input.c * depth_multiplier");
}
if (bias && NumElements(bias) != output_depth) {
return absl::InvalidArgumentError("bias.size != output.c");
}
if (depth_multiplier != 1 && input_depth != 1) {
return absl::UnimplementedError("depth_multiplier != 1 && input.c != 1");
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
DepthwiseConvolution2DAttributes attr;
const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
if (runtime_inputs == 2) {
RETURN_IF_ERROR(reader->AddInput(node, 1));
} else { // runtime_inputs == 1;
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
}
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
const TfLiteDepthwiseConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
attr.dilations = HW(std::max(1, tf_options->dilation_height_factor),
std::max(1, tf_options->dilation_width_factor));
UpdatePadding(tf_options->padding,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
const int depth_multiplier = tf_options->depth_multiplier;
if (depth_multiplier != 1) {
const TfLiteTensor* input = reader->GetInputTensor(0);
const TfLiteTensor* filter = reader->GetInputTensor(1);
const TfLiteTensor* output = reader->GetOutputTensor(0);
TransposeWeights(input, filter, output, depth_multiplier, &attr);
}
node->operation.attributes = std::move(attr);
return absl::OkStatus();
}
private:
// TFLite CPU stores weights as:
// [1, kernel_height, kernel_width, input_depth * depth_multiplier]
// TFLite GPU stores weights as:
// [depth_multiplier, kernel_height, kernel_width, input_depth]
static void TransposeWeights(const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* output, int depth_multiplier,
DepthwiseConvolution2DAttributes* attr) {
const int input_depth = input->dims->data[3];
const int filter_height = filter->dims->data[1];
const int filter_width = filter->dims->data[2];
const int output_depth = output->dims->data[3];
Tensor<OHWI, DataType::FLOAT32> weights;
weights.id = attr->weights.id;
weights.shape =
OHWI(output_depth, filter_height, filter_width, input_depth);
weights.data.resize(weights.shape.DimensionsProduct());
float* dst = &weights.data[0];
for (int j = 0; j < output_depth; ++j) {
const float* src = attr->weights.data.data() + j;
for (int i = 0; i < filter_height * filter_width; ++i) {
*dst = *src;
dst++;
src += output_depth;
}
}
attr->weights = std::move(weights);
}
};
class DequantizeOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
// 'Dequantize' is rewritten as QuantizeAndDequantize since we are dealing
// with floating-point versions of the original tensors.
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
// Quantization attributes should already be present in the input tensor.
auto input_value = graph->FindInputs(node->id)[0];
if (!input_value->quant_params) {
return absl::InvalidArgumentError(
"Encountered Dequantize input with no quant params");
}
QuantizeAndDequantizeAttributes attr;
attr.min = input_value->quant_params.value().min;
attr.max = input_value->quant_params.value().max;
attr.scale = input_value->quant_params.value().scale;
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class ElementwiseOperationParser : public TFLiteOperationParser {
public:
explicit ElementwiseOperationParser(OperationType operation_type)
: operation_type_(operation_type) {}
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
if (IsOneArgumentOperation()) {
RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*const_inputs=*/0,
/*outputs=*/1));
// For some elementwise operations (currently only for SUB operation)
// second condition may be false. But it's worth checking the next case
// with const input, which may be supported.
} else if (IsTwoArgumentOperation() &&
CheckInputsConstsOutputs(context, tflite_node,
/*runtime_inputs=*/2,
/*const_inputs=*/0,
/*outputs=*/1)
.ok()) {
} else if (IsTwoArgumentOperationWithConst()) {
RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*const_inputs=*/1,
/*outputs=*/1));
} else {
return absl::InvalidArgumentError(
"Op can only handle 1 or 2 operand(s).");
}
TfLiteFusedActivation activation;
RETURN_IF_ERROR(GetActivation(tflite_node, &activation));
return IsActivationSupported(activation);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(operation_type_);
if (IsOneArgumentOperation()) {
RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
/*runtime_inputs=*/1,
/*const_inputs=*/0,
/*outputs=*/1));
RETURN_IF_ERROR(reader->AddInput(node, 0));
} else if (IsTwoArgumentOperation() &&
reader
->VerifyInputsConstsOutputs(tflite_node,
/*runtime_inputs=*/2,
/*const_inputs=*/0,
/*outputs=*/1)
.ok()) {
if (tflite_node->inputs->size != 2) {
return absl::InvalidArgumentError("Applies only two input tensors");
}
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
TfLiteFusedActivation activation = kTfLiteActNone;
switch (operation_type_) {
case OperationType::SUB: {
const TfLiteSubParams* tf_options;
if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
activation = tf_options->activation;
}
break;
}
case OperationType::DIV: {
const TfLiteDivParams* tf_options;
if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
activation = tf_options->activation;
}
break;
}
default:
// No activation expected.
activation = kTfLiteActNone;
}
if (activation) {
RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, node));
}
} else if (IsTwoArgumentOperationWithConst()) {
RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
/*runtime_inputs=*/1,
/*const_inputs=*/1,
/*outputs=*/1));
ElementwiseAttributes attr;
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
attr.runtime_tensor_is_second =
IsConstantTensor(reader->GetInputTensor(0));
node->operation.attributes = std::move(attr);
} else {
return absl::InvalidArgumentError("Incorrect operation type passed");
}
return reader->AddOutputs(node);
}
private:
absl::Status GetActivation(const TfLiteNode* tflite_node,
TfLiteFusedActivation* activation) const {
if (operation_type_ == OperationType::DIV) {
const TfLiteDivParams* tf_options;
auto status = RetrieveBuiltinData(tflite_node, &tf_options);
*activation = status.ok() ? tf_options->activation : kTfLiteActNone;
return absl::OkStatus();
}
if (operation_type_ == OperationType::SUB) {
const TfLiteSubParams* tf_options;
auto status = RetrieveBuiltinData(tflite_node, &tf_options);
*activation = status.ok() ? tf_options->activation : kTfLiteActNone;
return absl::OkStatus();
}
// Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or
// TfLiteXxxParams.activation.
*activation = kTfLiteActNone;
return absl::OkStatus();
}
bool IsOneArgumentOperation() const {
switch (operation_type_) {
case OperationType::ABS:
case OperationType::COPY:
case OperationType::COS:
case OperationType::ELU:
case OperationType::EXP:
case OperationType::LOG:
case OperationType::NEG:
case OperationType::RSQRT:
case OperationType::SIGMOID:
case OperationType::SIN:
case OperationType::SQRT:
case OperationType::SQUARE:
case OperationType::TANH:
return true;
default:
return false;
}
}
bool IsTwoArgumentOperation() const {
switch (operation_type_) {
case OperationType::DIV:
case OperationType::MAXIMUM:
case OperationType::MINIMUM:
case OperationType::POW:
case OperationType::SQUARED_DIFF:
case OperationType::SUB:
return true;
default:
return false;
}
}
bool IsTwoArgumentOperationWithConst() const {
switch (operation_type_) {
case OperationType::DIV:
case OperationType::MAXIMUM:
case OperationType::MINIMUM:
case OperationType::POW:
case OperationType::SQUARED_DIFF:
case OperationType::SUB:
return true;
default:
return false;
}
}
OperationType operation_type_;
};
class FullyConnectedOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 9));
const TfLiteFullyConnectedParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->weights_format !=
kTfLiteFullyConnectedWeightsFormatDefault) {
return absl::UnimplementedError(
"Unsupported FullyConnected weights format.");
}
if (GetNumberOfRuntimeInputsForNode(context, tflite_node) > 2) {
return absl::UnimplementedError(
"FullyConnected doesn't support more than 2 runtime inputs.");
}
if (tf_options->keep_num_dims == true) {
return absl::UnimplementedError(
"FullyConnected doesn't support keep_num_dims.");
}
// TODO(eignasheva): check input shape
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
const TfLiteFullyConnectedParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (reader->GetNumberOfRuntimeInputs() == 2) {
// Create Convolution2D, so as it supports runtime weights.
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_2D);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
RETURN_IF_ERROR(reader->AddOutputs(node));
Convolution2DAttributes attr;
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
attr.strides = HW(1, 1);
attr.dilations = HW(1, 1);
attr.padding.appended = HW(0, 0);
attr.padding.prepended = HW(0, 0);
RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
node->operation.attributes = std::move(attr);
return absl::OkStatus();
}
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0));
if (tf_options->weights_format !=
kTfLiteFullyConnectedWeightsFormatDefault) {
return absl::UnimplementedError(
"Unsupported FullyConnected weights format.");
}
FullyConnectedAttributes attr;
RETURN_IF_ERROR(GetFullyConnectedAttributes(1, 2, reader, &attr));
const int weights_width = attr.weights.shape.i;
auto input = graph->FindInputs(node->id)[0];
int batch_size = input->tensor.shape.b;
if (input->tensor.shape.DimensionsProduct() / batch_size != weights_width) {
return absl::UnimplementedError(
"Amount of input data should match weights width");
}
Node* conv = node;
if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
auto& reshape = node;
conv = graph->NewNode(); // reset conv pointer!
Value* reshaped_value = graph->NewValue();
reshaped_value->tensor.type = DataType::FLOAT32;
reshaped_value->tensor.shape =
BHWC(input->tensor.shape.b, 1, 1, weights_width);
RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id));
reshape->operation.type = ToString(OperationType::RESHAPE);
ReshapeAttributes attr;
attr.new_shape = reshaped_value->tensor.shape;
reshape->operation.attributes = attr;
RETURN_IF_ERROR(graph->AddConsumer(conv->id, reshaped_value->id));
}
conv->operation.type = ToString(OperationType::FULLY_CONNECTED);
conv->operation.attributes = std::move(attr);
absl::Status result = reader->AddOutputs(conv);
RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, conv));
return result;
}
};
class HardSwishOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration*) final {
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
absl::Status Parse(const TfLiteNode*, const TfLiteRegistration*,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::HARD_SWISH);
RETURN_IF_ERROR(reader->AddInput(node, 0));
return reader->AddOutputs(node);
}
};
// Basic LSTM Cell:
//
// 1name = name is at input index 1
// name1 = name is at output index 1
//
// 0input 1prev_activ
// \ /
// [[concat]]
// \
// concat_temp2 2weights 3biases
// \ / /
// [[fully-connected]]
// \
// activ_temp3 4prev_state
// \ /
// [[LSTM]]
// / \
// new_state1 activation0
//
// For full LSTM cells, see this blog post:
// https://colah.github.io/posts/2015-08-Understanding-LSTMs/
// In addition to Peephole connections and Combined Input Forget Gates (CIFG)
// described in that post, this code also adds the following optional features:
// - Configurable activations (sigmoid or TANH)
// - L2 Normalization of gates: https://arxiv.org/abs/1607.06450
// - Output projection:
// https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html
// - Configurable clipping of cell state and output state.
class LSTMOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 4));
const TfLiteLSTMParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
switch (tf_options->kernel_type) {
case kTfLiteLSTMFullKernel: {
const int inputs = NumInputs(tflite_node);
if (inputs != 20 && inputs != 24) {
return absl::InternalError(
absl::StrCat("Expected 20 or 24 input tensors, but node has ",
inputs, " input(s)."));
}
const int runtime_outputs = NumOutputs(tflite_node);
if (runtime_outputs != 1) {
return absl::InternalError(
absl::StrCat("Expected 1 output tensor, but node has ",
runtime_outputs, " output(s)."));
}
return CheckFullParameters(tf_options);
}
case kTfLiteLSTMBasicKernel:
RETURN_IF_ERROR(
CheckInputsConstsOutputs(context, tflite_node, /*runtime_inputs=*/3,
/*const_inputs=*/2, /*outputs=*/4));
return CheckBasicParameters(tf_options);
}
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
const TfLiteLSTMParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
switch (tf_options->kernel_type) {
case kTfLiteLSTMFullKernel:
return ParseFull(tflite_node, registration, graph, reader, tf_options);
case kTfLiteLSTMBasicKernel:
return ParseBasic(tflite_node, registration, graph, reader, tf_options);
}
}
absl::flat_hash_map<int, ValueId> GetNewValueIdsForVariableInputNodes()
final {
return new_variable_input_value_map_;
}
private:
absl::Status ParseBasic(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader,
const TfLiteLSTMParams* tf_options) {
if (tflite_node->inputs->size != 5) {
return absl::InvalidArgumentError("LSTM should have 5 input tensors");
}
if (tflite_node->outputs->size != 4) {
return absl::InvalidArgumentError("LSTM should have 4 output tensors");
}
RETURN_IF_ERROR(CheckBasicParameters(tf_options));
Node* concat_node = graph->NewNode();
concat_node->operation.type = ToString(OperationType::CONCAT);
ConcatAttributes concat_attr;
concat_attr.axis = Axis::CHANNELS;
concat_node->operation.attributes = concat_attr;
Node* fc_node = graph->NewNode();
fc_node->operation.type = ToString(OperationType::FULLY_CONNECTED);
FullyConnectedAttributes fc_attr;
RETURN_IF_ERROR(GetFullyConnectedAttributes(2, 3, reader, &fc_attr));
fc_node->operation.attributes = std::move(fc_attr);
Node* lstm_node = graph->NewNode();
lstm_node->operation.type = ToString(OperationType::LSTM);
LstmAttributes lstm_attr;
lstm_attr.kernel_type = LstmKernelType::BASIC;
lstm_node->operation.attributes = lstm_attr;
Value* concat_temp;
int concat_tensor_idx = tflite_node->outputs->data[2];
RETURN_IF_ERROR(
reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
Value* activ_temp;
int activ_tensor_idx = tflite_node->outputs->data[3];
RETURN_IF_ERROR(
reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
RETURN_IF_ERROR(reader->AddInput(concat_node, 0)); // input
RETURN_IF_ERROR(reader->AddInput(concat_node, 1)); // prev_activ
RETURN_IF_ERROR(graph->SetProducer(concat_node->id, concat_temp->id));
RETURN_IF_ERROR(graph->AddConsumer(fc_node->id, concat_temp->id));
RETURN_IF_ERROR(graph->SetProducer(fc_node->id, activ_temp->id));
RETURN_IF_ERROR(graph->AddConsumer(lstm_node->id, activ_temp->id));
RETURN_IF_ERROR(reader->AddInput(lstm_node, 4)); // prev_state
RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1)); // new_state
RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0)); // activation
return absl::OkStatus();
}
absl::Status CheckBasicParameters(const TfLiteLSTMParams* tf_options) {
if (tf_options->activation != kTfLiteActTanh) {
return absl::UnimplementedError("Only TANH activation is supported.");
}
if (tf_options->cell_clip != 0.0f) {
return absl::UnimplementedError("cell_clip is not supported.");
}
if (tf_options->proj_clip != 0.0f) {
return absl::UnimplementedError("proj_clip is not supported.");
}
return absl::OkStatus();
}
absl::Status ParseFull(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader,
const TfLiteLSTMParams* tf_options) {
// Invoke full LSTM parser
RETURN_IF_ERROR(ParseLSTMAttributes(tflite_node, registration, graph,
reader, tf_options,
&new_variable_input_value_map_));
return absl::OkStatus();
}
absl::Status CheckFullParameters(const TfLiteLSTMParams* tf_options) {
if (tf_options->activation != kTfLiteActSigmoid &&
tf_options->activation != kTfLiteActTanh) {
return absl::UnimplementedError(
"Only sigmoid or tanh activation is supported.");
}
return absl::OkStatus();
}
absl::flat_hash_map<int, ValueId> new_variable_input_value_map_;
};
class MulOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
if (tflite_node->inputs->size != 2) {
return absl::UnimplementedError("MUL requires two input tensors.");
}
const TfLiteTensor* input0 = GetInput(context, tflite_node, 0);
const TfLiteTensor* input1 = GetInput(context, tflite_node, 1);
if (input0 == nullptr || input1 == nullptr) {
return absl::InvalidArgumentError("At least one input tensor is null");
}
if (input0->dims->size == input1->dims->size) {
// this code checks that at least one input of Mul not smaller in all
// dimensions. Sometimes Mul used for matrix-vector multiplication that we
// currently don't support. For example input0 HWC(1, 256, 1), input1
// HWC(1, 1, 256) -> output HWC (1, 256, 256). In this case it can be
// replaced with Convolution operation.
bool first_has_smaller_dim = false;
bool second_has_smaller_dim = false;
for (int i = 0; i < input0->dims->size; ++i) {
if (input0->dims->data[i] < input1->dims->data[i]) {
first_has_smaller_dim = true;
}
if (input1->dims->data[i] < input0->dims->data[i]) {
second_has_smaller_dim = true;
}
}
if (first_has_smaller_dim && second_has_smaller_dim) {
return absl::UnimplementedError(
"MUL requires one tensor that not less than second in all "
"dimensions.");
}
}
const TfLiteMulParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return IsActivationSupported(tf_options->activation);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
const TfLiteTensor* input0 = reader->GetInputTensor(0);
if (!input0) {
return absl::InvalidArgumentError(
"Couldn't get the 1st input tensor for MUL.");
}
const TfLiteTensor* input1 = reader->GetInputTensor(1);
if (!input1) {
return absl::InvalidArgumentError(
"Couldn't get the 2nd input tensor for MUL.");
}
const bool constant_tensor0 = IsConstantTensor(input0);
const bool constant_tensor1 = IsConstantTensor(input1);
if (constant_tensor0 && constant_tensor1) {
return absl::InvalidArgumentError("No runtime input tensors for MUL.");
}
const bool runtime_tensor0 = !constant_tensor0;
const bool runtime_tensor1 = !constant_tensor1;
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MUL);
RETURN_IF_ERROR(reader->AddOutputs(node));
// Determine runtime/constant tensors.
if (runtime_tensor0 && runtime_tensor1) {
if (input0 == input1) {
// replace MUL(A, A) with POW(A, 2.0)
// TODO(b/166831113): Support the same inputs for operations.
node->operation.type = ToString(OperationType::POW);
ElementwiseAttributes attr;
attr.param = 2.0f;
node->operation.attributes = std::move(attr);
return reader->AddInput(node, 0);
}
// The "larger" input tensor must be bound to 1st input and the "smaller"
// input tensor must be bound to 2nd input.
BHWC shape0;
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
BHWC shape1;
RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1));
int input_tensor0 = 0;
int input_tensor1 = 1;
if (shape0.h <= shape1.h && shape0.w <= shape1.w &&
shape0.c == shape1.c) {
input_tensor0 = 1;
input_tensor1 = 0;
}
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
} else {
ElementwiseAttributes attr;
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
node->operation.attributes = std::move(attr);
}
const TfLiteMulParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return MaybeFuseActivation(tf_options->activation, graph, node);
}
};
class PackOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
const TfLitePackParams* tf_options;
return RetrieveBuiltinData(tflite_node, &tf_options);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
if (tflite_node->inputs->size == 1) {
// Pack with single input can be replaced with Reshape
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::RESHAPE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
// New shape comes from output shape.
ReshapeAttributes attr;
attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
node->operation.attributes = attr;
return absl::OkStatus();
} else {
// Pack with few inputs can be replaced with Concat
const TfLitePackParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
// Read inputs first to make sure const node is added to a graph before
// concat node to ensure topological order.
std::vector<const Value*> inputs;
for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
Value* value;
const auto status = reader->ReadValue(idx, &value);
if (status.ok()) {
inputs.push_back(value);
} else {
TensorFloat32 tensor;
RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
Value* value;
RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
inputs.push_back(value);
}
}
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONCAT);
RETURN_IF_ERROR(reader->AddOutputs(node));
for (const Value* input : inputs) {
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
}
const TfLiteTensor* output = reader->GetOutputTensor(0);
ConcatAttributes attr;
RETURN_IF_ERROR(
ExtractAxisFromIndex(*output, tf_options->axis, &attr.axis));
node->operation.attributes = attr;
return absl::OkStatus();
}
}
};
class PReLUOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
// TODO(eignasheva): add params check
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::PRELU);
RETURN_IF_ERROR(reader->AddInput(node, 0));
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
PReLUAttributes attr;
Tensor<Linear, DataType::FLOAT32> linear_alpha;
absl::Status status = reader->ReadTensor(1, &linear_alpha);
if (status.ok()) {
if (linear_alpha.shape.v != input_shape.c) {
return absl::InvalidArgumentError(
"Linear alpha shape does not match the number of input channels.");
}
attr.alpha = std::move(linear_alpha);
} else {
Tensor<HWC, DataType::FLOAT32> hwc_alpha;
RETURN_IF_ERROR(reader->ReadTensor(1, &hwc_alpha));
if (hwc_alpha.shape.h != input_shape.h ||
hwc_alpha.shape.w != input_shape.w ||
hwc_alpha.shape.c != input_shape.c) {
return absl::InvalidArgumentError(
"Alpha shape does not match input shape.");
}
attr.alpha = std::move(hwc_alpha);
}
node->operation.attributes = std::move(attr);
return reader->AddOutputs(node);
}
};
class PadOperationParser : public TFLiteOperationParser {
public:
explicit PadOperationParser(bool mirror_pad) : mirror_pad_(mirror_pad) {}
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
if (mirror_pad_) {
const TfLiteMirrorPaddingParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->mode !=
TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect) {
return absl::InvalidArgumentError(
"Only Reflective padding is supported for Mirror Pad operation.");
}
}
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
const TfLiteTensor* pad_tensor = GetInput(context, tflite_node, 1);
if (pad_tensor == nullptr) {
return absl::InvalidArgumentError("Padding tensor was null");
}
if (pad_tensor->dims->size != 2) {
return absl::InvalidArgumentError(absl::StrCat(
"Invalid paddings tensor dimension: expected 2 dim, got ",
pad_tensor->dims->size, " dim"));
}
bool supported =
pad_tensor->dims->data[0] == 3 || pad_tensor->dims->data[0] == 4;
if (!supported || pad_tensor->dims->data[1] != 2) {
return absl::InvalidArgumentError(absl::StrCat(
"Invalid paddings tensor shape: expected 4x2 or 3x2, got ",
pad_tensor->dims->data[0], "x", pad_tensor->dims->data[1]));
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::PAD);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
PadAttributes attr;
if (mirror_pad_) {
attr.type = PaddingContentType::REFLECT;
} else /*zero pad*/ {
attr.type = PaddingContentType::ZEROS;
}
Tensor<HW, DataType::INT32> paddings;
RETURN_IF_ERROR(reader->ReadTensor(1, &paddings));
if (paddings.shape.h == 4 && paddings.shape.w == 2) {
// 4x2 tensor with paddings.
attr.prepended = BHWC(paddings.data[0], paddings.data[2],
paddings.data[4], paddings.data[6]);
attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5],
paddings.data[7]);
} else if (paddings.shape.h == 3 && paddings.shape.w == 2) {
// 3x2 tensor with paddings.
attr.prepended =
BHWC(1, paddings.data[0], paddings.data[2], paddings.data[4]);
attr.appended =
BHWC(1, paddings.data[1], paddings.data[3], paddings.data[5]);
} else {
// It shouldn't fail here since it's checked at IsSupported().
return absl::InvalidArgumentError(
"Paddings tensor has unexpected shape.");
}
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
bool mirror_pad_ = false;
};
class Pooling2DOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
const TfLitePoolParams* tf_options;
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
if (status.ok()) { // custom case with indices as a second output
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*outputs=*/2));
} else { // common pooling with 1 output
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*outputs=*/1));
}
RETURN_IF_ERROR(CheckKernelsAndStrides(
tf_options->filter_height, tf_options->filter_width,
tf_options->stride_height, tf_options->stride_width));
return IsActivationSupported(tf_options->activation);
}
public:
explicit Pooling2DOperationParser(PoolingType type) : type_(type) {}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::POOLING_2D);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutput(node, 0));
Pooling2DAttributes attr;
attr.type = type_;
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
// check whether there are custom options encoded. It happens if operation
// is MaxPoolingWithArgmax2D. There is no way to read
// tflite_node->builtin_code, so, simply check whether custom data is
// available.
const TfLitePoolParams* tf_options;
if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) {
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
}
RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
// Second output is optional. It is not required, it but must be added after
// MaybeAddFusedActivation function is called
reader->AddOutput(node, 1).IgnoreError();
// First output is the result of pooling operation, while second output is
// indices used for pooling.
auto outputs = graph->FindOutputs(node->id);
attr.output_indices = outputs.size() == 2;
if (attr.output_indices) {
// Fix data type for output indices. In the model it is set as float32.
outputs[1]->tensor.type = DataType::INT32;
}
RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
const PoolingType type_;
};
class ReduceOperationParser : public TFLiteOperationParser {
public:
explicit ReduceOperationParser(OperationType operation_type)
: operation_type_(operation_type) {}
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
auto* axes = &context->tensors[tflite_node->inputs->data[1]];
if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) {
return absl::UnimplementedError(
"Reduce has unsupported tensor for axes.");
}
if (tflite::NumElements(axes) != 1) {
return absl::UnimplementedError(
"Supported reduce in single dimensions only.");
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(operation_type_);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const TfLiteReducerParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
Tensor<Scalar, DataType::INT32> axes;
RETURN_IF_ERROR(reader->ReadTensor(1, &axes));
const TfLiteTensor* input = reader->GetInputTensor(0);
ReduceAttributes attr;
Axis axis;
RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[0], &axis));
attr.dims = {axis};
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
const OperationType operation_type_;
};
class QuantizeOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
// 'Quantize' is rewritten as QuantizeAndDequantize since we are dealing
// with floating-point versions of the original tensors.
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
// Quantization attributes should already be present in the output tensor.
auto output_value = graph->FindOutputs(node->id)[0];
if (!output_value->quant_params) {
return absl::InvalidArgumentError(
"Encountered Quantize output with no quant params");
}
QuantizeAndDequantizeAttributes attr;
attr.min = output_value->quant_params.value().min;
attr.max = output_value->quant_params.value().max;
attr.scale = output_value->quant_params.value().scale;
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class ReLUOperationParser : public TFLiteOperationParser {
public:
explicit ReLUOperationParser(int clip) : clip_(clip) {}
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::RELU);
RETURN_IF_ERROR(reader->AddInput(node, 0));
ReLUAttributes attr;
const TfLiteLeakyReluParams* tf_options;
auto status = RetrieveBuiltinData(tflite_node, &tf_options);
attr.alpha = status.ok() ? tf_options->alpha : 0;
attr.clip = clip_;
node->operation.attributes = attr;
return reader->AddOutputs(node);
}
private:
const int clip_;
};
class ReshapeOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
// TODO(eignasheva): add shape checking
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::RESHAPE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
// Here we may have extra inputs. Other tensors were supposed to
// define new shape, but in TFLite these are ignored.
// TODO(akulik): check that shapes match?
// New shape comes from output shape.
ReshapeAttributes attr;
attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class Resize2DOperationParser : public TFLiteOperationParser {
public:
explicit Resize2DOperationParser(SamplingType sampling_type)
: sampling_type_(sampling_type) {}
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node));
bool align_corners;
RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &align_corners));
bool half_pixel_centers;
RETURN_IF_ERROR(GetHalfPixelCentersValue(tflite_node, &half_pixel_centers));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::RESIZE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
// Here we may have extra inputs. Other tensors were supposed to
// define new shape, but in TFLite these are ignored.
Resize2DAttributes attr;
RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &attr.align_corners));
RETURN_IF_ERROR(
GetHalfPixelCentersValue(tflite_node, &attr.half_pixel_centers));
attr.type = sampling_type_;
attr.new_shape.CopyAllDefinedAxis(
graph->FindOutputs(node->id)[0]->tensor.shape);
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
absl::Status GetAlignCornersValue(const TfLiteNode* tflite_node,
bool* align_corners) {
switch (sampling_type_) {
case SamplingType::BILINEAR:
return GetAlignCornersValueForType<TfLiteResizeBilinearParams>(
tflite_node, align_corners);
case SamplingType::NEAREST:
return GetAlignCornersValueForType<TfLiteResizeNearestNeighborParams>(
tflite_node, align_corners);
case SamplingType::UNKNOWN:
return absl::InternalError("Sampling type is not specified");
}
return absl::OkStatus();
}
template <class T>
absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node,
bool* align_corners) {
const T* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
*align_corners = tf_options->align_corners;
return absl::OkStatus();
}
absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node,
bool* half_pixel_centers) {
if (sampling_type_ == SamplingType::BILINEAR) {
const TfLiteResizeBilinearParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->align_corners && tf_options->half_pixel_centers) {
return absl::InternalError(
"If half_pixel_centers is True, align_corners must be False.");
}
*half_pixel_centers = tf_options->half_pixel_centers;
} else {
const TfLiteResizeNearestNeighborParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
*half_pixel_centers = tf_options->half_pixel_centers;
}
return absl::OkStatus();
}
absl::Status CheckOnlyUpsamplingIsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node) {
const auto* input = context->tensors + tflite_node->inputs->data[0];
const auto* output = context->tensors + tflite_node->outputs->data[0];
if (!input->dims || input->dims->size != 4) {
return absl::InvalidArgumentError("input.dims.size != 4");
}
if (!output->dims || output->dims->size != 4) {
return absl::InvalidArgumentError("output.dims.size != 4");
}
if (output->dims->data[1] < input->dims->data[1] ||
output->dims->data[2] < input->dims->data[2]) {
return absl::InvalidArgumentError(absl::StrCat(
"Only upsampling is supported, received output h,w = ",
output->dims->data[1], ",", output->dims->data[2],
" input h,w = ", input->dims->data[1], ",", input->dims->data[2]));
}
return absl::OkStatus();
}
SamplingType sampling_type_ = SamplingType::UNKNOWN;
};
class SliceOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
if (tflite_node->inputs->size < 3) {
return absl::UnimplementedError("SLICE requires 3 inputs.");
}
const TfLiteTensor* input = GetInput(context, tflite_node, 0);
if (input->dims->size != 3 && input->dims->size != 4) {
return absl::UnimplementedError(
"SLICE supports for 3 or 4 dimensional tensors only.");
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SLICE);
RETURN_IF_ERROR(reader->AddOutputs(node));
Value* input;
RETURN_IF_ERROR(reader->ReadValue(0, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
const TfLiteTensor* tfl_input = reader->GetInputTensor(0);
const int input_dims = tfl_input->dims->size;
SliceAttributes attr;
attr.strides = BHWC(1, 1, 1, 1);
Tensor<Linear, DataType::INT32> starts, sizes;
RETURN_IF_ERROR(reader->ReadTensor(1, &starts));
RETURN_IF_ERROR(reader->ReadTensor(2, &sizes));
if (starts.data.size() != sizes.data.size()) {
return absl::InvalidArgumentError("Starts amount != sizes amount.");
}
BHWC bhwc_starts(0, 0, 0, 0);
BHWC bhwc_sizes = input->tensor.shape;
if (input_dims == 4) {
// input in BHWC layout
if (starts.data.size() == 4) {
bhwc_starts.b = starts.data[0];
bhwc_starts.h = starts.data[1];
bhwc_starts.w = starts.data[2];
bhwc_starts.c = starts.data[3];
bhwc_sizes.b = sizes.data[0];
bhwc_sizes.h = sizes.data[1];
bhwc_sizes.w = sizes.data[2];
bhwc_sizes.c = sizes.data[3];
} else if (starts.data.size() == 3) {
// if input is 4D(BHWC) and args 3D, we assume that args in HWC layout
bhwc_starts.h = starts.data[0];
bhwc_starts.w = starts.data[1];
bhwc_starts.c = starts.data[2];
bhwc_sizes.h = sizes.data[0];
bhwc_sizes.w = sizes.data[1];
bhwc_sizes.c = sizes.data[2];
} else {
return absl::UnimplementedError(
"Slicing is supported for 3 or 4 dimensional tensors only.");
}
} else if (input_dims == 3) {
// input in BWC layout
if (starts.data.size() == 3) {
bhwc_starts.b = starts.data[0];
bhwc_starts.w = starts.data[1];
bhwc_starts.c = starts.data[2];
bhwc_sizes.b = sizes.data[0];
bhwc_sizes.w = sizes.data[1];
bhwc_sizes.c = sizes.data[2];
} else {
return absl::UnimplementedError(
"Slicing is supported for 3 or 4 dimensional tensors only.");
}
} else {
return absl::UnimplementedError(
"Slicing is supported for 3 or 4 dimensional tensors only.");
}
const auto& in_shape = input->tensor.shape;
if (bhwc_sizes.b == -1) {
bhwc_sizes.b = in_shape.b - bhwc_starts.b;
}
if (bhwc_sizes.h == -1) {
bhwc_sizes.h = in_shape.h - bhwc_starts.h;
}
if (bhwc_sizes.w == -1) {
bhwc_sizes.w = in_shape.w - bhwc_starts.w;
}
if (bhwc_sizes.c == -1) {
bhwc_sizes.c = in_shape.c - bhwc_starts.c;
}
attr.starts = bhwc_starts;
attr.ends =
BHWC(bhwc_starts.b + bhwc_sizes.b, bhwc_starts.h + bhwc_sizes.h,
bhwc_starts.w + bhwc_sizes.w, bhwc_starts.c + bhwc_sizes.c);
RETURN_IF_ERROR(UpdateIfNegative(in_shape, &attr));
auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
if ((attr.ends.b - attr.starts.b) != out_shape.b) {
return absl::UnimplementedError("Output batch don't match");
}
if ((attr.ends.h - attr.starts.h) != out_shape.h) {
return absl::UnimplementedError("Output height doesn't match");
}
if ((attr.ends.w - attr.starts.w) != out_shape.w) {
return absl::UnimplementedError("Output width doesn't match");
}
if ((attr.ends.c - attr.starts.c) != out_shape.c) {
return absl::UnimplementedError("Output channels don't match");
}
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
absl::Status UpdateIfNegative(const BHWC& input_shape,
SliceAttributes* attr) {
if (attr->ends.h < 0) {
attr->ends.h = input_shape.h + attr->ends.h;
}
if (attr->ends.w < 0) {
attr->ends.w = input_shape.w + attr->ends.w;
}
if (attr->ends.c < 0) {
attr->ends.c = input_shape.c + attr->ends.c;
}
if (attr->ends.b < 0) {
attr->ends.b = input_shape.b + attr->ends.b;
}
return absl::OkStatus();
}
};
class SoftmaxOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
const TfLiteSoftmaxParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->beta != 1) {
// TODO(eignasheva): figure out, what's wrong with softmax.
return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SOFTMAX);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const TfLiteSoftmaxParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->beta != 1) {
// there is multiply by scalar operation fused in softmax. Make a layer
// out of it before softmax.
return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
// auto mul_node = reader->NewPassthroughNode(node);
// mul_node->operation.type = ToString(OperationType::MUL);
}
SoftmaxAttributes attr;
attr.axis = Axis::CHANNELS; // always by channels
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class SpaceToDepthOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
// TODO(impjdi): Dims check.
const TfLiteSpaceToDepthParams* s2d_params;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params));
if (s2d_params->block_size == 1) {
return absl::InvalidArgumentError(
"SPACE_TO_DEPTH block_size = 1 is a no-op.");
}
if (s2d_params->block_size < 1) {
return absl::InvalidArgumentError(
"SPACE_TO_DEPTH block_size must be > 1.");
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SPACE_TO_DEPTH);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const TfLiteSpaceToDepthParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
SpaceToDepthAttributes attr;
attr.block_size = tf_options->block_size;
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class StridedSliceOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
const TfLiteStridedSliceParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
if (tflite_node->inputs->size < 4) {
return absl::UnimplementedError("STRIDED_SLICE requires 4 inputs.");
}
const TfLiteTensor* input = GetInput(context, tflite_node, 0);
if (input->dims->size != 3 && input->dims->size != 4) {
return absl::UnimplementedError(
"STRIDED_SLICE supports for 3 or 4 dimensional tensors only.");
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SLICE);
RETURN_IF_ERROR(reader->AddOutputs(node));
Value* input;
RETURN_IF_ERROR(reader->ReadValue(0, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
Tensor<Linear, DataType::INT32> tmp;
RETURN_IF_ERROR(reader->ReadTensor(1, &tmp));
bool read_without_batch = tmp.data.size() == 3;
bool read_with_batch = tmp.data.size() == 4;
if (!read_without_batch && !read_with_batch) {
// Error: Must be catched in IsSupported()
return absl::UnimplementedError(
"Slicing is supported for 3 or 4 dimensional tensors only.");
}
const TfLiteStridedSliceParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
SliceAttributes attr;
if (read_without_batch) {
RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options,
input->tensor.shape, &attr));
}
if (read_with_batch) {
RETURN_IF_ERROR(
ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr));
}
if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 ||
attr.strides.c == 0) {
return absl::InvalidArgumentError("stride values must be non-zero");
}
if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 ||
attr.strides.c < 0) {
return absl::UnimplementedError("Reverse slices are not supported.");
}
if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b !=
out_shape.b) {
return absl::UnimplementedError("Output batch don't match");
}
if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h !=
out_shape.h) {
return absl::UnimplementedError("Output height doesn't match");
}
if ((attr.ends.w - attr.starts.w + attr.strides.w - 1) / attr.strides.w !=
out_shape.w) {
return absl::UnimplementedError("Output width doesn't match");
}
if ((attr.ends.c - attr.starts.c + attr.strides.c - 1) / attr.strides.c !=
out_shape.c) {
return absl::UnimplementedError("Output channels don't match");
}
node->operation.attributes = attr;
return absl::OkStatus();
}
private:
absl::Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape, int ignore_b,
int ignore_h, int ignore_w, int ignore_c,
SliceAttributes* attr) {
if (tf_options->begin_mask & ignore_h) {
attr->starts.h = 0;
}
if (tf_options->begin_mask & ignore_w) {
attr->starts.w = 0;
}
if (tf_options->begin_mask & ignore_c) {
attr->starts.c = 0;
}
if (tf_options->begin_mask & ignore_b) {
attr->starts.b = 0;
}
if (tf_options->end_mask & ignore_h) {
attr->ends.h = input_shape.h;
}
if (tf_options->end_mask & ignore_w) {
attr->ends.w = input_shape.w;
}
if (tf_options->end_mask & ignore_c) {
attr->ends.c = input_shape.c;
}
if (tf_options->end_mask & ignore_b) {
attr->ends.b = input_shape.b;
}
return absl::OkStatus();
}
absl::Status UpdateIfNegative(const BHWC& input_shape,
SliceAttributes* attr) {
if (attr->ends.h < 0) {
attr->ends.h = input_shape.h + attr->ends.h;
}
if (attr->ends.w < 0) {
attr->ends.w = input_shape.w + attr->ends.w;
}
if (attr->ends.c < 0) {
attr->ends.c = input_shape.c + attr->ends.c;
}
if (attr->ends.b < 0) {
attr->ends.b = input_shape.b + attr->ends.b;
}
return absl::OkStatus();
}
absl::Status ReadAttribsWithBatch(const ObjectReader* reader,
const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape,
SliceAttributes* attr) {
auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
Tensor<Linear, DataType::INT32> t;
RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
*bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]);
return absl::OkStatus();
};
RETURN_IF_ERROR(read_bhwc(1, &attr->starts));
RETURN_IF_ERROR(read_bhwc(2, &attr->ends));
RETURN_IF_ERROR(read_bhwc(3, &attr->strides));
RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr));
return absl::OkStatus();
}
absl::Status ReadAttribsWithoutBatch(
const ObjectReader* reader, const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape, SliceAttributes* attr) {
auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
Tensor<Linear, DataType::INT32> t;
RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
*bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]);
return absl::OkStatus();
};
RETURN_IF_ERROR(read_hwc(1, &attr->starts));
RETURN_IF_ERROR(read_hwc(2, &attr->ends));
RETURN_IF_ERROR(read_hwc(3, &attr->strides));
RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr));
attr->starts.b = 0;
attr->ends.b = input_shape.b;
attr->strides.b = 1;
return absl::OkStatus();
}
absl::Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) {
if (tf_options->ellipsis_mask) {
return absl::UnimplementedError("Slice does not support ellipsis_mask.");
}
if (tf_options->new_axis_mask) {
return absl::UnimplementedError("Slice does not support new_axis_mask.");
}
if (tf_options->shrink_axis_mask) {
return absl::UnimplementedError(
"Slice does not support shrink_axis_mask parameter. ");
}
return absl::OkStatus();
}
};
// Builtin op version of TRANSPOSE_CONV.
class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
const int runtime_inputs =
GetNumberOfRuntimeInputsForNode(context, tflite_node);
if (runtime_inputs > 2) {
return absl::InternalError(
absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
runtime_inputs, " runtime inputs."));
}
const int runtime_outputs = NumOutputs(tflite_node);
if (runtime_outputs != 1) {
return absl::InternalError(
absl::StrCat("Expected 1 output tensor(s), but node has ",
runtime_outputs, " runtime outputs."));
}
if (runtime_inputs == 1) {
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
}
const TfLiteTransposeConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(
CheckStrides(tf_options->stride_height, tf_options->stride_width));
return absl::OkStatus();
}
// TFLite's TRANSPOSE_CONV expects 3-4 input tensors (output shape, weights,
// input, and an optional bias) and allows configurable padding & stride.
// TODO(impjdi): Translate output_shape to attr.adjacent.
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
Value* input;
RETURN_IF_ERROR(reader->ReadValue(2, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
RETURN_IF_ERROR(reader->AddOutputs(node));
const TfLiteTransposeConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
ConvolutionTransposedAttributes attr;
attr.stride = tf_options
? HW(tf_options->stride_height, tf_options->stride_width)
: HW(1, 1);
const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
if (runtime_inputs == 2) {
RETURN_IF_ERROR(reader->AddInput(node, 1));
auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape;
attr.weights.shape = OHWI(weights_shape.b, weights_shape.h,
weights_shape.w, weights_shape.c);
} else { // runtime_inputs == 1;
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
}
reader->ReadTensor(3, &attr.bias).IgnoreError(); // bias is optional
UpdatePadding(tf_options->padding,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
node->operation.attributes = std::move(attr);
return absl::OkStatus();
}
};
// Custom op version of TRANSPOSE_CONV.
class TransposeConvCustomOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
const TfLiteTransposeConvParams* tf_options;
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
RETURN_IF_ERROR(
CheckStrides(tf_options->stride_height, tf_options->stride_width));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const TfLiteTransposeConvParams* tf_options;
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
ConvolutionTransposedAttributes attr;
attr.stride = status.ok()
? HW(tf_options->stride_height, tf_options->stride_width)
: HW(1, 1);
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
node->operation.attributes = std::move(attr);
return absl::OkStatus();
}
};
class TransposeOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::TRANSPOSE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
TransposeAttributes attr;
Tensor<Linear, DataType::INT32> perm;
RETURN_IF_ERROR(reader->ReadTensor(1, &perm));
std::map<Axis, int> axis_to_index = {{Axis::BATCH, 0},
{Axis::HEIGHT, 1},
{Axis::WIDTH, 2},
{Axis::CHANNELS, 3}};
if (perm.data.size() == 4) {
attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[2], perm.data[3]);
} else if (perm.data.size() == 3) {
std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::WIDTH,
Axis::CHANNELS};
attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
attr.perm.h = 1;
attr.perm.w = axis_to_index[index_to_axis[perm.data[1]]];
attr.perm.c = axis_to_index[index_to_axis[perm.data[2]]];
} else if (perm.data.size() == 2) {
std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::CHANNELS};
attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
attr.perm.h = 1;
attr.perm.w = 2;
attr.perm.c = axis_to_index[index_to_axis[perm.data[1]]];
} else {
return absl::InvalidArgumentError(
"Permutation for transpose is invalid.");
}
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class Unpooling2DOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
const TfLitePoolParams* tf_options;
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckKernelsAndStrides(
tf_options->filter_height, tf_options->filter_width,
tf_options->stride_height, tf_options->stride_width));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
RETURN_IF_ERROR(reader->AddOutputs(node));
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
MaxUnpooling2DAttributes attr;
const TfLitePoolParams* tf_options;
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
UpdatePadding(tf_options->padding, input_shape, &attr);
node->operation.attributes = attr;
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = CalculateOutputShape(input_shape, attr);
return absl::OkStatus();
}
};
// TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
class BatchToSpaceOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::BATCH_TO_SPACE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
BatchToSpaceAttributes bs_attr;
Tensor<Linear, DataType::INT32> block;
RETURN_IF_ERROR(reader->ReadTensor(1, &block));
if (block.shape.v != 2) {
return absl::InternalError("Space has to be HxW.");
}
bs_attr.block.h = block.data[0];
bs_attr.block.w = block.data[1];
Tensor<HW, DataType::INT32> crop;
RETURN_IF_ERROR(reader->ReadTensor(2, &crop));
auto crop_shape = crop.shape;
if (crop_shape.h != 2 && crop_shape.w != 2) {
return absl::InternalError("Space has to be HxW.");
}
bs_attr.crop.prepended.h = crop.data[0];
bs_attr.crop.prepended.w = crop.data[2];
bs_attr.crop.appended.h = crop.data[1];
bs_attr.crop.appended.w = crop.data[3];
node->operation.attributes = std::move(bs_attr);
return absl::OkStatus();
}
};
class SpaceToBatchOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::SPACE_TO_BATCH);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
SpaceToBatchAttributes sb_attr;
Tensor<Linear, DataType::INT32> block;
RETURN_IF_ERROR(reader->ReadTensor(1, &block));
if (block.shape.v != 2) {
return absl::InternalError("Space has to be HxW.");
}
sb_attr.block.h = block.data[0];
sb_attr.block.w = block.data[1];
Tensor<HW, DataType::INT32> padding;
RETURN_IF_ERROR(reader->ReadTensor(2, &padding));
auto padding_shape = padding.shape;
if (padding_shape.h != 2 && padding_shape.w != 2) {
return absl::InternalError("Space has to be HxW.");
}
sb_attr.padding.prepended.h = padding.data[0];
sb_attr.padding.prepended.w = padding.data[2];
sb_attr.padding.appended.h = padding.data[1];
sb_attr.padding.appended.w = padding.data[3];
node->operation.attributes = std::move(sb_attr);
return absl::OkStatus();
}
};
class RoIToTransformMatrixOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "roi_to_transform_matrix";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(ParseCustomAttributes(
op_name, registration->version, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size, &(node->operation.attributes),
&output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = output_shape;
return absl::OkStatus();
}
};
class TransformTensorBilinearOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "transform_tensor_bilinear";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(ParseCustomAttributes(
op_name, registration->version, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size, &(node->operation.attributes),
&output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape =
BHWC(1, output_shape.h, output_shape.w,
graph->FindInputs(node->id)[0]->tensor.shape.c);
return absl::OkStatus();
}
};
class TransformLandmarksOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "transform_landmarks";
node->operation.type = op_name;
BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
RETURN_IF_ERROR(ParseCustomAttributes(
op_name, registration->version, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size, &(node->operation.attributes),
&output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
return absl::OkStatus();
}
};
class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks
RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix
const std::string op_name = "landmarks_to_transform_matrix";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(ParseCustomAttributes(
op_name, registration->version, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size, &(node->operation.attributes),
&output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = output_shape;
return absl::OkStatus();
}
};
class AlignmentPointsToTransformMatrixOperationParser
: public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // alignment points
RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix
const std::string op_name = "alignment_points_to_transform_matrix";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(ParseCustomAttributes(
op_name, registration->version, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size, &(node->operation.attributes),
&output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = output_shape;
return absl::OkStatus();
}
private:
};
class MeanOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*outputs=*/1));
// Simple mechanism to check if MEAN is to be performed only on HW plane.
auto* axes = &context->tensors[tflite_node->inputs->data[1]];
if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) {
return absl::UnimplementedError("Mean has unsupported tensor for axes");
}
auto* axes_data = axes->data.i32;
const bool is_hw_mean = tflite::NumElements(axes) == 2 &&
((axes_data[0] == 1 && axes_data[1] == 2) ||
(axes_data[0] == 2 && axes_data[1] == 1));
if (!is_hw_mean) {
return absl::UnimplementedError("Mean operation supports only HW plane");
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::MEAN);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
MeanAttributes attr;
Tensor<Linear, DataType::INT32> channel;
RETURN_IF_ERROR(reader->ReadTensor(1, &channel));
for (int i = 0; i < channel.data.size(); i++) {
std::string unsupported;
switch (channel.data[i]) {
case 1:
attr.dims.insert(Axis::HEIGHT);
break;
case 2:
attr.dims.insert(Axis::WIDTH);
break;
case 0:
unsupported = unsupported.empty() ? "batch" : unsupported;
ABSL_FALLTHROUGH_INTENDED;
case 3:
unsupported = unsupported.empty() ? "channels" : unsupported;
ABSL_FALLTHROUGH_INTENDED;
default:
return absl::UnimplementedError(
absl::StrCat("Unsupported mean dimension: ", unsupported));
}
}
node->operation.attributes = attr;
return absl::OkStatus();
}
};
class UnsupportedOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return absl::UnimplementedError("Operation is not supported.");
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
return absl::UnimplementedError("Operation is not supported.");
}
};
std::unique_ptr<TFLiteOperationParser> NewOperationParser(
const TfLiteRegistration* registration, bool allow_quant_ops = false) {
const auto builtin_code = registration->builtin_code;
switch (builtin_code) {
case kTfLiteBuiltinAbs:
return std::make_unique<ElementwiseOperationParser>(OperationType::ABS);
case kTfLiteBuiltinAdd:
return std::make_unique<AddOperationParser>();
case kTfLiteBuiltinAveragePool2d:
return std::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
case kTfLiteBuiltinBatchMatmul:
return std::make_unique<BatchedMatMulOperationParser>();
case kTfLiteBuiltinConcatenation:
return std::make_unique<ConcatenationOperationParser>();
case kTfLiteBuiltinConv2d:
return std::make_unique<Conv2DOperationParser>();
case kTfLiteBuiltinCos:
return std::make_unique<ElementwiseOperationParser>(OperationType::COS);
case kTfLiteBuiltinDepthwiseConv2d:
return std::make_unique<DepthwiseConvolutionOperationParser>();
case kTfLiteBuiltinDequantize:
if (allow_quant_ops) {
return std::make_unique<DequantizeOperationParser>();
}
break;
case kTfLiteBuiltinDiv:
return std::make_unique<ElementwiseOperationParser>(OperationType::DIV);
case kTfLiteBuiltinElu:
return std::make_unique<ElementwiseOperationParser>(OperationType::ELU);
case kTfLiteBuiltinExp:
return std::make_unique<ElementwiseOperationParser>(OperationType::EXP);
case kTfLiteBuiltinFullyConnected:
return std::make_unique<FullyConnectedOperationParser>();
case kTfLiteBuiltinHardSwish:
return std::make_unique<HardSwishOperationParser>();
case kTfLiteBuiltinLogistic:
return std::make_unique<ElementwiseOperationParser>(
OperationType::SIGMOID);
case kTfLiteBuiltinLog:
return std::make_unique<ElementwiseOperationParser>(OperationType::LOG);
case kTfLiteBuiltinLstm:
return std::make_unique<LSTMOperationParser>();
case kTfLiteBuiltinMaximum:
return std::make_unique<ElementwiseOperationParser>(
OperationType::MAXIMUM);
case kTfLiteBuiltinMaxPool2d:
return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
case kTfLiteBuiltinMean:
return std::make_unique<MeanOperationParser>();
case kTfLiteBuiltinMinimum:
return std::make_unique<ElementwiseOperationParser>(
OperationType::MINIMUM);
case kTfLiteBuiltinMirrorPad:
return std::make_unique<PadOperationParser>(/*mirror_pad=*/true);
case kTfLiteBuiltinMul:
return std::make_unique<MulOperationParser>();
case kTfLiteBuiltinNeg:
return std::make_unique<ElementwiseOperationParser>(OperationType::NEG);
case kTfLiteBuiltinPack:
return std::make_unique<PackOperationParser>();
case kTfLiteBuiltinPad:
return std::make_unique<PadOperationParser>(/*mirror_pad=*/false);
case kTfLiteBuiltinPow:
return std::make_unique<ElementwiseOperationParser>(OperationType::POW);
case kTfLiteBuiltinReduceMax:
return std::make_unique<ReduceOperationParser>(
OperationType::REDUCE_MAXIMUM);
case kTfLiteBuiltinReduceMin:
return std::make_unique<ReduceOperationParser>(
OperationType::REDUCE_MINIMUM);
case kTfLiteBuiltinReduceProd:
return std::make_unique<ReduceOperationParser>(
OperationType::REDUCE_PRODUCT);
case kTfLiteBuiltinQuantize:
if (allow_quant_ops) {
return std::make_unique<QuantizeOperationParser>();
}
break;
case kTfLiteBuiltinRelu:
return std::make_unique<ReLUOperationParser>(0);
case kTfLiteBuiltinRelu6:
return std::make_unique<ReLUOperationParser>(6);
case kTfLiteBuiltinLeakyRelu:
return std::make_unique<ReLUOperationParser>(0);
case kTfLiteBuiltinPrelu:
return std::make_unique<PReLUOperationParser>();
case kTfLiteBuiltinReshape:
return std::make_unique<ReshapeOperationParser>();
case kTfLiteBuiltinResizeBilinear:
return std::make_unique<Resize2DOperationParser>(SamplingType::BILINEAR);
case kTfLiteBuiltinResizeNearestNeighbor:
return std::make_unique<Resize2DOperationParser>(SamplingType::NEAREST);
case kTfLiteBuiltinRsqrt:
return std::make_unique<ElementwiseOperationParser>(OperationType::RSQRT);
case kTfLiteBuiltinSin:
return std::make_unique<ElementwiseOperationParser>(OperationType::SIN);
case kTfLiteBuiltinSlice:
return std::make_unique<SliceOperationParser>();
case kTfLiteBuiltinSoftmax:
return std::make_unique<SoftmaxOperationParser>();
case kTfLiteBuiltinSpaceToDepth:
return std::make_unique<SpaceToDepthOperationParser>();
case kTfLiteBuiltinSqrt:
return std::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
case kTfLiteBuiltinSquare:
return std::make_unique<ElementwiseOperationParser>(
OperationType::SQUARE);
case kTfLiteBuiltinSquaredDifference:
return std::make_unique<ElementwiseOperationParser>(
OperationType::SQUARED_DIFF);
case kTfLiteBuiltinStridedSlice:
return std::make_unique<StridedSliceOperationParser>();
case kTfLiteBuiltinSub:
return std::make_unique<ElementwiseOperationParser>(OperationType::SUB);
case kTfLiteBuiltinSum:
return std::make_unique<ReduceOperationParser>(OperationType::REDUCE_SUM);
case kTfLiteBuiltinTanh:
return std::make_unique<ElementwiseOperationParser>(OperationType::TANH);
case kTfLiteBuiltinTranspose:
return std::make_unique<TransposeOperationParser>();
case kTfLiteBuiltinTransposeConv:
return std::make_unique<TransposeConvBuiltinOperationParser>();
case kTfLiteBuiltinCustom:
const absl::string_view custom_name = registration->custom_name;
if (custom_name == "Convolution2DTransposeBias") {
return std::make_unique<TransposeConvCustomOperationParser>();
}
if (custom_name == "MaxPoolingWithArgmax2D") {
return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
}
if (custom_name == "MaxUnpooling2D") {
return std::make_unique<Unpooling2DOperationParser>();
}
if (custom_name == "RoIToTransformMatrix") {
return std::make_unique<RoIToTransformMatrixOperationParser>();
}
if (custom_name == "TransformTensor" /*for version 1*/ ||
custom_name == "TransformTensorBilinear" /*for version 2*/) {
return std::make_unique<TransformTensorBilinearOperationParser>();
}
if (custom_name == "TransformLandmarks") {
return std::make_unique<TransformLandmarksOperationParser>();
}
if (custom_name == "Landmarks2TransformMatrix" ||
custom_name == "Landmarks2TransformMatrixV2") {
return std::make_unique<Landmarks2TransformMatrixOperationParser>();
}
if (custom_name == "AlignmentPointsToTransformMatrix") {
return std::make_unique<
AlignmentPointsToTransformMatrixOperationParser>();
}
break;
}
return std::make_unique<UnsupportedOperationParser>();
}
absl::Status IsSupported(const TfLiteContext* context, TfLiteNode* node,
const TfLiteRegistration* registration,
bool allow_quant_ops = false) {
return NewOperationParser(registration, allow_quant_ops)
->IsSupported(context, node, registration);
}
bool IsAllAllowedTensors(TfLiteContext* context,
const TfLiteIntArray* tensor_indices,
bool allow_quant_ops = false) {
for (int i = 0; i < tensor_indices->size; ++i) {
int tensor_idx = tensor_indices->data[i];
if (tensor_idx == kTfLiteOptionalTensor) continue;
const TfLiteTensor* t = &context->tensors[tensor_idx];
bool type_supported =
(t->type == kTfLiteFloat32 || t->type == kTfLiteFloat16);
if (allow_quant_ops) {
// Since we only check non-constant tensors, type cannot be Int32.
type_supported =
type_supported || t->type == kTfLiteInt8 || t->type == kTfLiteUInt8;
}
if (t->allocation_type == kTfLiteArenaRw && !type_supported) {
return false;
}
}
return true;
}
} // namespace
// TODO(impjdi): Check number of input/output tensors and their dimensions.
// TODO(impjdi): Check ops' parameters.
TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops,
int max_delegated_partitions) {
delegates::IsNodeSupportedFn node_supported_fn =
[=](TfLiteContext* context, TfLiteNode* node,
TfLiteRegistration* registration,
std::string* unsupported_details) -> bool {
const auto status =
IsSupported(context, node, registration, allow_quant_ops);
if (!status.ok()) {
if (unsupported_details) {
*unsupported_details = std::string(status.message());
}
return false;
}
if (!IsAllAllowedTensors(context, node->inputs, allow_quant_ops) ||
!IsAllAllowedTensors(context, node->outputs, allow_quant_ops)) {
if (unsupported_details) {
*unsupported_details =
"OP is supported, but tensor type isn't matched!";
}
return false;
}
return true;
};
delegates::FP16GraphPartitionHelper partition_helper(context,
node_supported_fn);
std::set<std::string> unsupported_nodes_info;
if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) {
return TfLiteIntArrayCreate(0);
}
// By default, we simply get 1st largest partition as 'max_delegate_partions'
// is set to 1 by default.
std::vector<int> ops_to_replace =
partition_helper.GetNodesOfFirstNLargestPartitions(
max_delegated_partitions);
if (!unsupported_nodes_info.empty()) {
std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n");
std::string error_message = absl::StrCat(
"Following operations are not supported by GPU delegate:\n",
unsupported, "\n");
if (!ops_to_replace.empty()) {
absl::StrAppend(
&error_message, ops_to_replace.size(),
" operations will run on the GPU, and the remaining ",
partition_helper.num_total_nodes() - ops_to_replace.size());
} else {
absl::StrAppend(&error_message,
"No operations will run on the GPU, and all ",
partition_helper.num_total_nodes());
}
absl::StrAppend(&error_message, " operations will run on the CPU.");
TF_LITE_KERNEL_LOG(context, error_message.c_str());
}
return ConvertVectorToTfLiteIntArray(ops_to_replace);
}
// Creates inputs and outputs passed by io_tensors parameters in the resulting
// graph. We force it to make sure that delegated subgraph has same order of
// inputs and outputs with the original one. When delegated model is built from
// the tflite model representation tensors are created lazily, so there is no
// guarantee that the order will match the source model tensors order.
absl::Status PrecreateIOTensors(
TfLiteContext* context, GraphFloat32* graph, TfLiteIntArray* io_tensors,
absl::flat_hash_map<int, int>* quant_conversion_map,
absl::flat_hash_map<int, Value*>* tensor_to_value) {
for (int i = 0; i < io_tensors->size; ++i) {
const int tensor_index = io_tensors->data[i];
const TfLiteTensor& tflite_tensor = context->tensors[tensor_index];
if (tflite::IsConstantTensor(&tflite_tensor)) continue;
RETURN_IF_ERROR(ObjectReader::ReadNonConstantTensor(
context, tensor_to_value, quant_conversion_map, graph, tensor_index));
}
return absl::OkStatus();
}
absl::Status CopyVariableTensorOutputs(
TfLiteNode* tflite_node, TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader& reader,
const absl::flat_hash_map<int, ValueId>& new_variable_tensor_values) {
absl::flat_hash_map<int, ValueId> new_variable_tensor_values_copy(
new_variable_tensor_values);
// Retrieve the final value id for the variable input tensors.
for (int i = 0; i < tflite_node->inputs->size; i++) {
int tensor_idx = tflite_node->inputs->data[i];
Value* value;
if (!reader.ReadValueByTensorIdx(tensor_idx, &value).ok()) continue;
if (value->tensor.is_variable_input) {
if (new_variable_tensor_values_copy.find(i) ==
new_variable_tensor_values_copy.end()) {
return absl::InvalidArgumentError(
absl::StrCat(GetOpNameByRegistration(*registration),
" did not provide a new value for the variable input "
"tensor with index ",
tensor_idx));
} else {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::COPY);
RETURN_IF_ERROR(graph->AddConsumer(
node->id, new_variable_tensor_values_copy.at(i)));
RETURN_IF_ERROR(reader.AddUpdate(node, i));
new_variable_tensor_values_copy.erase(
new_variable_tensor_values_copy.find(i));
}
}
}
if (!new_variable_tensor_values_copy.empty()) {
return absl::InvalidArgumentError(
"More input variable tensors asked to be copied than present on the "
"node");
}
return absl::OkStatus();
}
absl::Status BuildModel(TfLiteContext* context,
const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph,
absl::flat_hash_map<int, int>* quant_conversion_map) {
std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
std::vector<int> tflite_nodes;
for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
TfLiteNode* tflite_node = nullptr;
TfLiteRegistration* registration = nullptr;
RETURN_IF_ERROR(GetNodeAndRegistration(
context, delegate_params->nodes_to_replace->data[i], &tflite_node,
&registration));
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
context->tensors[tflite_node->inputs->data[0]].type ==
TfLiteType::kTfLiteFloat16) {
// Ignore Fp16 Dequantize nodes.
continue;
}
auto op_parser = NewOperationParser(
registration, /*allow_quant_ops=*/quant_conversion_map != nullptr);
if (!op_parser) {
return absl::UnimplementedError(
absl::StrCat("Operation ", registration->builtin_code, "(",
registration->custom_name,
") is not supported by TFLite GPU Delegate."));
}
operations.push_back(std::move(op_parser));
tflite_nodes.push_back(i);
}
absl::flat_hash_map<int, Value*> tensor_to_value;
std::vector<ValueId> variable_inputs_to_value_id;
RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
delegate_params->input_tensors,
quant_conversion_map, &tensor_to_value));
RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
delegate_params->output_tensors,
quant_conversion_map, &tensor_to_value));
for (int i = 0; i < operations.size(); ++i) {
TfLiteNode* tflite_node;
TfLiteRegistration* registration;
RETURN_IF_ERROR(GetNodeAndRegistration(
context, delegate_params->nodes_to_replace->data[tflite_nodes[i]],
&tflite_node, &registration));
ObjectReader reader(graph, context, tflite_node, &tensor_to_value,
quant_conversion_map);
const auto status =
operations[i]->Parse(tflite_node, registration, graph, &reader);
if (!status.ok()) {
return absl::InternalError(absl::StrCat(
GetOpNameByRegistration(*registration), ": ", status.message()));
}
absl::flat_hash_map<int, ValueId> new_value_for_variable_input_tensors =
operations[i]->GetNewValueIdsForVariableInputNodes();
RETURN_IF_ERROR(
CopyVariableTensorOutputs(tflite_node, registration, graph, reader,
new_value_for_variable_input_tensors));
}
// Variable input tensors expect to be unchanged throughout model execution.
// They need to be an output of the graph in order to have them unchanged.
for (auto value_id : variable_inputs_to_value_id) {
if (!graph->IsGraphOutput(value_id)) {
return absl::InvalidArgumentError(
absl::StrCat("Variable input tensors must be a graph output. Value ",
value_id, " is not a graph output"));
}
}
return absl::OkStatus();
}
absl::Status BuildFinalModel(
TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) {
RETURN_IF_ERROR(
BuildModel(context, delegate_params, graph, quant_conversion_map));
// Apply general transformations on the graph.
NullTransformationReporter reporter;
ModelTransformer transformer(graph, &reporter);
if (!ApplyModelTransformations(&transformer)) {
return absl::InternalError("Graph transformations failed");
}
return absl::OkStatus();
}
} // namespace gpu
} // namespace tflite