| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/lite/delegates/gpu/common/model_builder.h" |
| |
| #include <algorithm> |
| #include <any> |
| #include <cstdint> |
| #include <map> |
| #include <memory> |
| #include <optional> |
| #include <set> |
| #include <string> |
| #include <utility> |
| #include <variant> |
| #include <vector> |
| |
| #include "absl/base/attributes.h" |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/status/status.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/str_join.h" |
| #include "absl/strings/string_view.h" |
| #include "tensorflow/lite/builtin_ops.h" |
| #include "tensorflow/lite/c/builtin_op_data.h" |
| #include "tensorflow/lite/c/common.h" |
| #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h" |
| #include "tensorflow/lite/delegates/gpu/common/data_type.h" |
| #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h" |
| #include "tensorflow/lite/delegates/gpu/common/model.h" |
| #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" |
| #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" |
| #include "tensorflow/lite/delegates/gpu/common/object_reader.h" |
| #include "tensorflow/lite/delegates/gpu/common/operations.h" |
| #include "tensorflow/lite/delegates/gpu/common/shape.h" |
| #include "tensorflow/lite/delegates/gpu/common/status.h" |
| #include "tensorflow/lite/delegates/gpu/common/tensor.h" |
| #include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h" |
| #include "tensorflow/lite/delegates/utils.h" |
| #include "tensorflow/lite/kernels/internal/reference/dequantize.h" |
| #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
| #include "tensorflow/lite/kernels/kernel_util.h" |
| #include "tensorflow/lite/util.h" |
| |
| namespace tflite { |
| namespace gpu { |
| namespace { |
| |
| absl::Status CheckTensorIsAvailable(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, int idx) { |
| // If tensor id is in range, it's guaranteed that it'll be available. |
| if (idx >= tflite_node->inputs->size) { |
| return absl::OutOfRangeError( |
| absl::StrCat("Requested index goes beyond array size: ", idx, " vs ", |
| idx, tflite_node->inputs->size)); |
| } |
| return absl::OkStatus(); |
| } |
| |
| // A parser responsible for parsing TFLite operation and adding it to a graph. |
| class TFLiteOperationParser { |
| public: |
| virtual ~TFLiteOperationParser() = default; |
| |
| // Parses TFLite operation. This method allows expanding fused operations |
| // into more than one node. |
| virtual absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) = 0; |
| |
| // Verifies whether passed tflite node may be built by GPU delegate or not. |
| virtual absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) = 0; |
| |
| // Return the value ids in the graph that correspond to the updated values of |
| // the variable input tensor. |
| virtual absl::flat_hash_map<int, ValueId> |
| GetNewValueIdsForVariableInputNodes() { |
| return absl::flat_hash_map<int, ValueId>(); |
| } |
| }; |
| |
| HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); } |
| |
| template <typename AttrT> |
| void UpdatePadding(const TfLitePadding& padding, const BHWC& input_shape, |
| AttrT* attr) { |
| if (padding == kTfLitePaddingSame) { |
| attr->padding = CalculateSamePadding(input_shape, *attr); |
| } else { |
| attr->padding.prepended = HW(0, 0); |
| attr->padding.appended = HW(0, 0); |
| } |
| } |
| |
| absl::Status GetFullyConnectedAttributes(int weights_tensor_id, |
| int bias_tensor_id, |
| ObjectReader* reader, |
| FullyConnectedAttributes* attr) { |
| Tensor<HW, DataType::FLOAT32> weights; |
| RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights)); |
| attr->weights.data = std::move(weights.data); |
| attr->weights.id = weights.id; |
| attr->weights.shape.h = 1; |
| attr->weights.shape.w = 1; |
| attr->weights.shape.o = weights.shape.h; |
| attr->weights.shape.i = weights.shape.w; |
| reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError(); // optional |
| return absl::OkStatus(); |
| } |
| |
| template <typename ParamsT> |
| absl::Status RetrieveBuiltinData(const TfLiteNode* tflite_node, |
| const ParamsT** tf_options) { |
| *tf_options = static_cast<const ParamsT*>(tflite_node->builtin_data); |
| if (!*tf_options) { |
| return absl::InternalError("Unable to retrieve builtin_data."); |
| } |
| return absl::OkStatus(); |
| } |
| |
| template <typename ParamsT> |
| absl::Status RetrieveCustomInitialData(const TfLiteNode* tflite_node, |
| const ParamsT** tf_options) { |
| *tf_options = static_cast<const ParamsT*>(tflite_node->custom_initial_data); |
| if (!*tf_options) { |
| return absl::InternalError("Unable to retrieve custom_initial_data."); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status CheckMaxSupportedOpVersion(const TfLiteRegistration* registration, |
| int max_version) { |
| const int op_version = registration->version; |
| if (op_version > max_version) { |
| return absl::UnimplementedError( |
| absl::StrCat("Max version supported: ", max_version, |
| ". Requested version ", op_version, ".")); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status CheckKernels(int kernel_h, int kernel_w) { |
| if (kernel_h <= 0 || kernel_w <= 0) { |
| return absl::InvalidArgumentError( |
| absl::StrCat("Incorrect kernel values: kernel_height = ", kernel_h, |
| ", kernel_width = ", kernel_w)); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status CheckStrides(int strides_h, int strides_w) { |
| if (strides_h <= 0 || strides_w <= 0) { |
| return absl::InvalidArgumentError( |
| absl::StrCat("Incorrect stride values: stride_height = ", strides_h, |
| ", stride_width = ", strides_w)); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status CheckDilation(int dilation_h, int dilation_w) { |
| if (dilation_h <= 0 || dilation_w <= 0) { |
| return absl::InvalidArgumentError(absl::StrCat( |
| "Incorrect dilation values: dilation_factor = ", dilation_h, |
| ", dilation_factor = ", dilation_w)); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status CheckStridesAndDilation(int strides_h, int strides_w, |
| int dilation_h, int dilation_w) { |
| RETURN_IF_ERROR(CheckStrides(strides_h, strides_w)); |
| RETURN_IF_ERROR(CheckDilation(dilation_h, dilation_w)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status CheckKernelsAndStrides(int kernel_h, int kernel_w, int strides_h, |
| int strides_w) { |
| RETURN_IF_ERROR(CheckKernels(kernel_h, kernel_w)); |
| RETURN_IF_ERROR(CheckStrides(strides_h, strides_w)); |
| return absl::OkStatus(); |
| } |
| |
| // Creates a simple node that holds tensor value. |
| absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, Value** value) { |
| ConstTensorAttributes attr; |
| attr.tensor = std::move(t); |
| Node* node = graph->NewNode(); |
| node->operation.attributes = attr; |
| node->operation.type = ToString(OperationType::CONST); |
| *value = graph->NewValue(); |
| RETURN_IF_ERROR(graph->SetProducer(node->id, (*value)->id)); |
| // Keep data inside this tensor. |
| (*value)->tensor.ref = attr.tensor.id; |
| (*value)->tensor.type = attr.tensor.kType; |
| (*value)->tensor.shape = attr.tensor.shape; |
| return absl::OkStatus(); |
| } |
| |
| absl::Status ParsePoolingAttributes(const TfLitePoolParams* tf_options, |
| const BHWC& input_shape, |
| Pooling2DAttributes* attr) { |
| attr->kernel = ToHW(tf_options->filter_height, tf_options->filter_width); |
| attr->strides = ToHW(tf_options->stride_height, tf_options->stride_width); |
| UpdatePadding(tf_options->padding, input_shape, attr); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader, |
| TensorOrScalar* tensor_or_scalar) { |
| const std::string& opname = node->operation.type; |
| |
| // Determine runtime/constant tensors. |
| const TfLiteTensor* input0 = reader->GetInputTensor(0); |
| if (!input0) { |
| return absl::InvalidArgumentError("Couldn't get the 1st input tensor for " + |
| opname); |
| } |
| const TfLiteTensor* input1 = reader->GetInputTensor(1); |
| if (!input1) { |
| return absl::InvalidArgumentError("Couldn't get the 2nd input tensor for " + |
| opname); |
| } |
| const bool constant_tensor0 = IsConstantTensor(input0); |
| const bool constant_tensor1 = IsConstantTensor(input1); |
| if (constant_tensor0 && constant_tensor1) { |
| return absl::InvalidArgumentError("No runtime input tensors for " + opname); |
| } |
| const bool runtime_tensor0 = !constant_tensor0; |
| const bool runtime_tensor1 = !constant_tensor1; |
| |
| if (runtime_tensor0 && runtime_tensor1) { |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddInput(node, 1)); |
| } else { |
| int runtime_tensor = 0; |
| int constant_tensor = 1; |
| TfLiteIntArray* constant_dims = input1->dims; |
| if (constant_tensor0 && runtime_tensor1) { |
| runtime_tensor = 1; |
| constant_tensor = 0; |
| constant_dims = input0->dims; |
| } |
| RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor)); |
| if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) { |
| Tensor<Scalar, DataType::FLOAT32> tensor; |
| RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); |
| *tensor_or_scalar = tensor.data[0]; |
| } else { |
| if (CheckIfLinearConvertible(constant_dims).ok()) { |
| Tensor<Linear, DataType::FLOAT32> tensor; |
| RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); |
| *tensor_or_scalar = std::move(tensor); |
| } else { |
| Tensor<HWC, DataType::FLOAT32> tensor; |
| RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); |
| *tensor_or_scalar = std::move(tensor); |
| } |
| } |
| } |
| return absl::OkStatus(); |
| } |
| |
| class AddOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| if (tflite_node->inputs->size != 2) { |
| return absl::UnimplementedError("ADD requires two input tensors."); |
| } |
| // TODO(eignasheva): Add shapes check. |
| |
| const TfLiteAddParams* tf_options; |
| return RetrieveBuiltinData(tflite_node, &tf_options); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| // TFLite currently only supports 2 input ADDs. Thus, the logic below only |
| // considers 2 input cases. The underlying GPU shader programs can accept |
| // more inputs, but the logic below would have to be expanded. |
| |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::ADD); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| ElementwiseAttributes attr; |
| RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param)); |
| node->operation.attributes = std::move(attr); |
| const TfLiteAddParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| return MaybeFuseActivation(tf_options->activation, graph, node); |
| } |
| }; |
| |
| class BatchedMatMulOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| return CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/2, /*outputs=*/1); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::BATCHED_MATMUL); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddInput(node, 1)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class ConcatenationOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| |
| // TODO(eignasheva): add proper tensor availability checking |
| // for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) { |
| // RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, idx)); |
| // } |
| // TODO(eignasheva): add axis checking. |
| const TfLiteConcatenationParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| ConcatAttributes attr; |
| // Read inputs first to make sure const node is added to a graph before |
| // concat node to ensure topological order. |
| std::vector<const Value*> inputs; |
| for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) { |
| Value* value; |
| const auto status = reader->ReadValue(idx, &value); |
| if (status.ok()) { |
| inputs.push_back(value); |
| } else { |
| TensorFloat32 tensor; |
| RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor)); |
| Value* value; |
| RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value)); |
| inputs.push_back(value); |
| } |
| } |
| |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::CONCAT); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| for (const Value* input : inputs) { |
| RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); |
| } |
| |
| std::vector<BHWC> input_shapes; |
| for (auto input : graph->FindInputs(node->id)) { |
| input_shapes.push_back(input->tensor.shape); |
| } |
| RETURN_IF_ERROR(SetAxis(input_shapes, &attr.axis)); |
| |
| // Guess axis. |
| BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape; |
| for (auto input : graph->FindInputs(node->id)) { |
| if (input->tensor.shape.h != output_shape.h) { |
| attr.axis = Axis::HEIGHT; |
| break; |
| } |
| if (input->tensor.shape.w != output_shape.w) { |
| attr.axis = Axis::WIDTH; |
| break; |
| } |
| if (input->tensor.shape.c != output_shape.c) { |
| attr.axis = Axis::CHANNELS; |
| break; |
| } |
| } |
| const TfLiteConcatenationParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node)); |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| |
| private: |
| absl::Status SetAxis(const std::vector<BHWC>& input_shapes, Axis* axis) { |
| *axis = Axis::BATCH; |
| for (int i = 1; i < input_shapes.size(); i++) { |
| if (input_shapes[0].h != input_shapes[i].h && |
| input_shapes[0].w != input_shapes[i].w && |
| input_shapes[0].c != input_shapes[i].c) { |
| *axis = Axis::HEIGHT; |
| break; |
| } |
| } |
| if (*axis == Axis::BATCH) return absl::OkStatus(); |
| for (int i = 1; i < input_shapes.size(); i++) { |
| if (input_shapes[0].b != input_shapes[i].b && |
| input_shapes[0].w != input_shapes[i].w && |
| input_shapes[0].c != input_shapes[i].c) { |
| *axis = Axis::WIDTH; |
| break; |
| } |
| } |
| if (*axis == Axis::HEIGHT) return absl::OkStatus(); |
| for (int i = 1; i < input_shapes.size(); i++) { |
| if (input_shapes[0].b != input_shapes[i].b && |
| input_shapes[0].h != input_shapes[i].h && |
| input_shapes[0].c != input_shapes[i].c) { |
| *axis = Axis::CHANNELS; |
| break; |
| } |
| } |
| if (*axis == Axis::WIDTH) return absl::OkStatus(); |
| for (int i = 1; i < input_shapes.size(); i++) { |
| if (input_shapes[0].b != input_shapes[i].b && |
| input_shapes[0].w != input_shapes[i].w && |
| input_shapes[0].h != input_shapes[i].h) { |
| return absl::UnimplementedError( |
| "Can concatenate tensors only by batch, height, width, or " |
| "channels."); |
| } |
| } |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class Conv2DOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 5)); |
| const int runtime_inputs = |
| GetNumberOfRuntimeInputsForNode(context, tflite_node); |
| if (runtime_inputs > 2) { |
| return absl::InternalError( |
| absl::StrCat("Expected 1 or 2 input tensor(s), but node has ", |
| runtime_inputs, " runtime inputs.")); |
| } |
| const int runtime_outputs = NumOutputs(tflite_node); |
| if (runtime_outputs != 1) { |
| return absl::InternalError( |
| absl::StrCat("Expected 1 output tensor(s), but node has ", |
| runtime_outputs, " runtime outputs.")); |
| } |
| if (runtime_inputs == 1) { |
| RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); |
| } |
| const TfLiteConvParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| RETURN_IF_ERROR(CheckStridesAndDilation( |
| tf_options->stride_height, tf_options->stride_width, |
| tf_options->dilation_height_factor, tf_options->dilation_width_factor)); |
| return IsActivationSupported(tf_options->activation); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::CONVOLUTION_2D); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| Convolution2DAttributes attr; |
| const int runtime_inputs = reader->GetNumberOfRuntimeInputs(); |
| if (runtime_inputs == 2) { |
| RETURN_IF_ERROR(reader->AddInput(node, 1)); |
| } else { // runtime_inputs == 1; |
| RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); |
| } |
| reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional |
| |
| const TfLiteConvParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); |
| attr.dilations = HW(tf_options->dilation_height_factor, |
| tf_options->dilation_width_factor); |
| UpdatePadding(tf_options->padding, |
| graph->FindInputs(node->id)[0]->tensor.shape, &attr); |
| RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node)); |
| node->operation.attributes = std::move(attr); |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 6)); |
| const int runtime_inputs = |
| GetNumberOfRuntimeInputsForNode(context, tflite_node); |
| if (runtime_inputs > 2) { |
| return absl::InternalError( |
| absl::StrCat("Expected 1 or 2 input tensor(s), but node has ", |
| runtime_inputs, " runtime inputs.")); |
| } |
| const int runtime_outputs = NumOutputs(tflite_node); |
| if (runtime_outputs != 1) { |
| return absl::InternalError( |
| absl::StrCat("Expected 1 output tensor(s), but node has ", |
| runtime_outputs, " runtime outputs.")); |
| } |
| if (runtime_inputs == 1) { |
| RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); |
| } |
| const TfLiteDepthwiseConvParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| RETURN_IF_ERROR(CheckStridesAndDilation( |
| tf_options->stride_height, tf_options->stride_width, |
| tf_options->dilation_height_factor, tf_options->dilation_width_factor)); |
| RETURN_IF_ERROR(IsActivationSupported(tf_options->activation)); |
| |
| const int depth_multiplier = tf_options->depth_multiplier; |
| const auto* input = context->tensors + tflite_node->inputs->data[0]; |
| const auto* filter = context->tensors + tflite_node->inputs->data[1]; |
| const auto* bias = tflite_node->inputs->size > 2 |
| ? context->tensors + tflite_node->inputs->data[2] |
| : nullptr; |
| const auto* output = context->tensors + tflite_node->outputs->data[0]; |
| if (!input->dims || input->dims->size != 4) { |
| return absl::InvalidArgumentError("input.dims.size != 4"); |
| } |
| if (!filter->dims || filter->dims->size != 4) { |
| return absl::InvalidArgumentError("filter.dims.size != 4"); |
| } |
| if (!output->dims || output->dims->size != 4) { |
| return absl::InvalidArgumentError("output.dims.size != 4"); |
| } |
| if (input->dims->data[0] != output->dims->data[0]) { |
| return absl::InvalidArgumentError("input.b != output.b"); |
| } |
| const int input_depth = input->dims->data[3]; |
| const int output_depth = output->dims->data[3]; |
| if (filter->dims->data[3] != output_depth) { |
| return absl::InvalidArgumentError("filter.i != output.c"); |
| } |
| if (output_depth != input_depth * depth_multiplier) { |
| return absl::InvalidArgumentError( |
| "output.c != input.c * depth_multiplier"); |
| } |
| if (bias && NumElements(bias) != output_depth) { |
| return absl::InvalidArgumentError("bias.size != output.c"); |
| } |
| if (depth_multiplier != 1 && input_depth != 1) { |
| return absl::UnimplementedError("depth_multiplier != 1 && input.c != 1"); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| DepthwiseConvolution2DAttributes attr; |
| const int runtime_inputs = reader->GetNumberOfRuntimeInputs(); |
| if (runtime_inputs == 2) { |
| RETURN_IF_ERROR(reader->AddInput(node, 1)); |
| } else { // runtime_inputs == 1; |
| RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); |
| } |
| reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional |
| const TfLiteDepthwiseConvParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); |
| attr.dilations = HW(std::max(1, tf_options->dilation_height_factor), |
| std::max(1, tf_options->dilation_width_factor)); |
| UpdatePadding(tf_options->padding, |
| graph->FindInputs(node->id)[0]->tensor.shape, &attr); |
| RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node)); |
| const int depth_multiplier = tf_options->depth_multiplier; |
| if (depth_multiplier != 1) { |
| const TfLiteTensor* input = reader->GetInputTensor(0); |
| const TfLiteTensor* filter = reader->GetInputTensor(1); |
| const TfLiteTensor* output = reader->GetOutputTensor(0); |
| TransposeWeights(input, filter, output, depth_multiplier, &attr); |
| } |
| node->operation.attributes = std::move(attr); |
| return absl::OkStatus(); |
| } |
| |
| private: |
| // TFLite CPU stores weights as: |
| // [1, kernel_height, kernel_width, input_depth * depth_multiplier] |
| // TFLite GPU stores weights as: |
| // [depth_multiplier, kernel_height, kernel_width, input_depth] |
| static void TransposeWeights(const TfLiteTensor* input, |
| const TfLiteTensor* filter, |
| const TfLiteTensor* output, int depth_multiplier, |
| DepthwiseConvolution2DAttributes* attr) { |
| const int input_depth = input->dims->data[3]; |
| const int filter_height = filter->dims->data[1]; |
| const int filter_width = filter->dims->data[2]; |
| const int output_depth = output->dims->data[3]; |
| Tensor<OHWI, DataType::FLOAT32> weights; |
| weights.id = attr->weights.id; |
| weights.shape = |
| OHWI(output_depth, filter_height, filter_width, input_depth); |
| weights.data.resize(weights.shape.DimensionsProduct()); |
| float* dst = &weights.data[0]; |
| for (int j = 0; j < output_depth; ++j) { |
| const float* src = attr->weights.data.data() + j; |
| for (int i = 0; i < filter_height * filter_width; ++i) { |
| *dst = *src; |
| dst++; |
| src += output_depth; |
| } |
| } |
| attr->weights = std::move(weights); |
| } |
| }; |
| |
| class DequantizeOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, /*outputs=*/1)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| // 'Dequantize' is rewritten as QuantizeAndDequantize since we are dealing |
| // with floating-point versions of the original tensors. |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| // Quantization attributes should already be present in the input tensor. |
| auto input_value = graph->FindInputs(node->id)[0]; |
| if (!input_value->quant_params) { |
| return absl::InvalidArgumentError( |
| "Encountered Dequantize input with no quant params"); |
| } |
| QuantizeAndDequantizeAttributes attr; |
| attr.min = input_value->quant_params.value().min; |
| attr.max = input_value->quant_params.value().max; |
| attr.scale = input_value->quant_params.value().scale; |
| |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class ElementwiseOperationParser : public TFLiteOperationParser { |
| public: |
| explicit ElementwiseOperationParser(OperationType operation_type) |
| : operation_type_(operation_type) {} |
| |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| if (IsOneArgumentOperation()) { |
| RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, |
| /*const_inputs=*/0, |
| /*outputs=*/1)); |
| // For some elementwise operations (currently only for SUB operation) |
| // second condition may be false. But it's worth checking the next case |
| // with const input, which may be supported. |
| } else if (IsTwoArgumentOperation() && |
| CheckInputsConstsOutputs(context, tflite_node, |
| /*runtime_inputs=*/2, |
| /*const_inputs=*/0, |
| /*outputs=*/1) |
| .ok()) { |
| } else if (IsTwoArgumentOperationWithConst()) { |
| RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, |
| /*const_inputs=*/1, |
| /*outputs=*/1)); |
| } else { |
| return absl::InvalidArgumentError( |
| "Op can only handle 1 or 2 operand(s)."); |
| } |
| TfLiteFusedActivation activation; |
| RETURN_IF_ERROR(GetActivation(tflite_node, &activation)); |
| return IsActivationSupported(activation); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(operation_type_); |
| |
| if (IsOneArgumentOperation()) { |
| RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node, |
| /*runtime_inputs=*/1, |
| /*const_inputs=*/0, |
| /*outputs=*/1)); |
| |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| } else if (IsTwoArgumentOperation() && |
| reader |
| ->VerifyInputsConstsOutputs(tflite_node, |
| /*runtime_inputs=*/2, |
| /*const_inputs=*/0, |
| /*outputs=*/1) |
| .ok()) { |
| if (tflite_node->inputs->size != 2) { |
| return absl::InvalidArgumentError("Applies only two input tensors"); |
| } |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddInput(node, 1)); |
| |
| TfLiteFusedActivation activation = kTfLiteActNone; |
| switch (operation_type_) { |
| case OperationType::SUB: { |
| const TfLiteSubParams* tf_options; |
| if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) { |
| activation = tf_options->activation; |
| } |
| break; |
| } |
| case OperationType::DIV: { |
| const TfLiteDivParams* tf_options; |
| if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) { |
| activation = tf_options->activation; |
| } |
| break; |
| } |
| default: |
| // No activation expected. |
| activation = kTfLiteActNone; |
| } |
| |
| if (activation) { |
| RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, node)); |
| } |
| } else if (IsTwoArgumentOperationWithConst()) { |
| RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node, |
| /*runtime_inputs=*/1, |
| /*const_inputs=*/1, |
| /*outputs=*/1)); |
| ElementwiseAttributes attr; |
| RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param)); |
| attr.runtime_tensor_is_second = |
| IsConstantTensor(reader->GetInputTensor(0)); |
| node->operation.attributes = std::move(attr); |
| } else { |
| return absl::InvalidArgumentError("Incorrect operation type passed"); |
| } |
| |
| return reader->AddOutputs(node); |
| } |
| |
| private: |
| absl::Status GetActivation(const TfLiteNode* tflite_node, |
| TfLiteFusedActivation* activation) const { |
| if (operation_type_ == OperationType::DIV) { |
| const TfLiteDivParams* tf_options; |
| auto status = RetrieveBuiltinData(tflite_node, &tf_options); |
| *activation = status.ok() ? tf_options->activation : kTfLiteActNone; |
| return absl::OkStatus(); |
| } |
| if (operation_type_ == OperationType::SUB) { |
| const TfLiteSubParams* tf_options; |
| auto status = RetrieveBuiltinData(tflite_node, &tf_options); |
| *activation = status.ok() ? tf_options->activation : kTfLiteActNone; |
| return absl::OkStatus(); |
| } |
| |
| // Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or |
| // TfLiteXxxParams.activation. |
| *activation = kTfLiteActNone; |
| return absl::OkStatus(); |
| } |
| |
| bool IsOneArgumentOperation() const { |
| switch (operation_type_) { |
| case OperationType::ABS: |
| case OperationType::COPY: |
| case OperationType::COS: |
| case OperationType::ELU: |
| case OperationType::EXP: |
| case OperationType::LOG: |
| case OperationType::NEG: |
| case OperationType::RSQRT: |
| case OperationType::SIGMOID: |
| case OperationType::SIN: |
| case OperationType::SQRT: |
| case OperationType::SQUARE: |
| case OperationType::TANH: |
| return true; |
| default: |
| return false; |
| } |
| } |
| |
| bool IsTwoArgumentOperation() const { |
| switch (operation_type_) { |
| case OperationType::DIV: |
| case OperationType::MAXIMUM: |
| case OperationType::MINIMUM: |
| case OperationType::POW: |
| case OperationType::SQUARED_DIFF: |
| case OperationType::SUB: |
| return true; |
| default: |
| return false; |
| } |
| } |
| |
| bool IsTwoArgumentOperationWithConst() const { |
| switch (operation_type_) { |
| case OperationType::DIV: |
| case OperationType::MAXIMUM: |
| case OperationType::MINIMUM: |
| case OperationType::POW: |
| case OperationType::SQUARED_DIFF: |
| case OperationType::SUB: |
| return true; |
| default: |
| return false; |
| } |
| } |
| |
| OperationType operation_type_; |
| }; |
| |
| class FullyConnectedOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 9)); |
| const TfLiteFullyConnectedParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| if (tf_options->weights_format != |
| kTfLiteFullyConnectedWeightsFormatDefault) { |
| return absl::UnimplementedError( |
| "Unsupported FullyConnected weights format."); |
| } |
| if (GetNumberOfRuntimeInputsForNode(context, tflite_node) > 2) { |
| return absl::UnimplementedError( |
| "FullyConnected doesn't support more than 2 runtime inputs."); |
| } |
| if (tf_options->keep_num_dims == true) { |
| return absl::UnimplementedError( |
| "FullyConnected doesn't support keep_num_dims."); |
| } |
| // TODO(eignasheva): check input shape |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| const TfLiteFullyConnectedParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| |
| if (reader->GetNumberOfRuntimeInputs() == 2) { |
| // Create Convolution2D, so as it supports runtime weights. |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::CONVOLUTION_2D); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddInput(node, 1)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| Convolution2DAttributes attr; |
| reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional |
| |
| attr.strides = HW(1, 1); |
| attr.dilations = HW(1, 1); |
| attr.padding.appended = HW(0, 0); |
| attr.padding.prepended = HW(0, 0); |
| RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node)); |
| node->operation.attributes = std::move(attr); |
| return absl::OkStatus(); |
| } |
| Node* node = graph->NewNode(); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| |
| if (tf_options->weights_format != |
| kTfLiteFullyConnectedWeightsFormatDefault) { |
| return absl::UnimplementedError( |
| "Unsupported FullyConnected weights format."); |
| } |
| |
| FullyConnectedAttributes attr; |
| RETURN_IF_ERROR(GetFullyConnectedAttributes(1, 2, reader, &attr)); |
| const int weights_width = attr.weights.shape.i; |
| |
| auto input = graph->FindInputs(node->id)[0]; |
| int batch_size = input->tensor.shape.b; |
| if (input->tensor.shape.DimensionsProduct() / batch_size != weights_width) { |
| return absl::UnimplementedError( |
| "Amount of input data should match weights width"); |
| } |
| |
| Node* conv = node; |
| if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) { |
| auto& reshape = node; |
| conv = graph->NewNode(); // reset conv pointer! |
| Value* reshaped_value = graph->NewValue(); |
| reshaped_value->tensor.type = DataType::FLOAT32; |
| reshaped_value->tensor.shape = |
| BHWC(input->tensor.shape.b, 1, 1, weights_width); |
| RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id)); |
| reshape->operation.type = ToString(OperationType::RESHAPE); |
| ReshapeAttributes attr; |
| attr.new_shape = reshaped_value->tensor.shape; |
| reshape->operation.attributes = attr; |
| RETURN_IF_ERROR(graph->AddConsumer(conv->id, reshaped_value->id)); |
| } |
| |
| conv->operation.type = ToString(OperationType::FULLY_CONNECTED); |
| conv->operation.attributes = std::move(attr); |
| absl::Status result = reader->AddOutputs(conv); |
| RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, conv)); |
| |
| return result; |
| } |
| }; |
| |
| class HardSwishOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration*) final { |
| return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, |
| /*outputs=*/1); |
| } |
| |
| absl::Status Parse(const TfLiteNode*, const TfLiteRegistration*, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::HARD_SWISH); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| return reader->AddOutputs(node); |
| } |
| }; |
| |
| // Basic LSTM Cell: |
| // |
| // 1name = name is at input index 1 |
| // name1 = name is at output index 1 |
| // |
| // 0input 1prev_activ |
| // \ / |
| // [[concat]] |
| // \ |
| // concat_temp2 2weights 3biases |
| // \ / / |
| // [[fully-connected]] |
| // \ |
| // activ_temp3 4prev_state |
| // \ / |
| // [[LSTM]] |
| // / \ |
| // new_state1 activation0 |
| // |
| // For full LSTM cells, see this blog post: |
| // https://colah.github.io/posts/2015-08-Understanding-LSTMs/ |
| // In addition to Peephole connections and Combined Input Forget Gates (CIFG) |
| // described in that post, this code also adds the following optional features: |
| // - Configurable activations (sigmoid or TANH) |
| // - L2 Normalization of gates: https://arxiv.org/abs/1607.06450 |
| // - Output projection: |
| // https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html |
| // - Configurable clipping of cell state and output state. |
| class LSTMOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 4)); |
| const TfLiteLSTMParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| switch (tf_options->kernel_type) { |
| case kTfLiteLSTMFullKernel: { |
| const int inputs = NumInputs(tflite_node); |
| if (inputs != 20 && inputs != 24) { |
| return absl::InternalError( |
| absl::StrCat("Expected 20 or 24 input tensors, but node has ", |
| inputs, " input(s).")); |
| } |
| const int runtime_outputs = NumOutputs(tflite_node); |
| if (runtime_outputs != 1) { |
| return absl::InternalError( |
| absl::StrCat("Expected 1 output tensor, but node has ", |
| runtime_outputs, " output(s).")); |
| } |
| return CheckFullParameters(tf_options); |
| } |
| case kTfLiteLSTMBasicKernel: |
| RETURN_IF_ERROR( |
| CheckInputsConstsOutputs(context, tflite_node, /*runtime_inputs=*/3, |
| /*const_inputs=*/2, /*outputs=*/4)); |
| return CheckBasicParameters(tf_options); |
| } |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| const TfLiteLSTMParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| switch (tf_options->kernel_type) { |
| case kTfLiteLSTMFullKernel: |
| return ParseFull(tflite_node, registration, graph, reader, tf_options); |
| case kTfLiteLSTMBasicKernel: |
| return ParseBasic(tflite_node, registration, graph, reader, tf_options); |
| } |
| } |
| |
| absl::flat_hash_map<int, ValueId> GetNewValueIdsForVariableInputNodes() |
| final { |
| return new_variable_input_value_map_; |
| } |
| |
| private: |
| absl::Status ParseBasic(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader, |
| const TfLiteLSTMParams* tf_options) { |
| if (tflite_node->inputs->size != 5) { |
| return absl::InvalidArgumentError("LSTM should have 5 input tensors"); |
| } |
| if (tflite_node->outputs->size != 4) { |
| return absl::InvalidArgumentError("LSTM should have 4 output tensors"); |
| } |
| RETURN_IF_ERROR(CheckBasicParameters(tf_options)); |
| |
| Node* concat_node = graph->NewNode(); |
| concat_node->operation.type = ToString(OperationType::CONCAT); |
| ConcatAttributes concat_attr; |
| concat_attr.axis = Axis::CHANNELS; |
| concat_node->operation.attributes = concat_attr; |
| |
| Node* fc_node = graph->NewNode(); |
| fc_node->operation.type = ToString(OperationType::FULLY_CONNECTED); |
| FullyConnectedAttributes fc_attr; |
| RETURN_IF_ERROR(GetFullyConnectedAttributes(2, 3, reader, &fc_attr)); |
| fc_node->operation.attributes = std::move(fc_attr); |
| |
| Node* lstm_node = graph->NewNode(); |
| lstm_node->operation.type = ToString(OperationType::LSTM); |
| LstmAttributes lstm_attr; |
| lstm_attr.kernel_type = LstmKernelType::BASIC; |
| lstm_node->operation.attributes = lstm_attr; |
| |
| Value* concat_temp; |
| int concat_tensor_idx = tflite_node->outputs->data[2]; |
| RETURN_IF_ERROR( |
| reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp)); |
| Value* activ_temp; |
| int activ_tensor_idx = tflite_node->outputs->data[3]; |
| RETURN_IF_ERROR( |
| reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp)); |
| |
| RETURN_IF_ERROR(reader->AddInput(concat_node, 0)); // input |
| RETURN_IF_ERROR(reader->AddInput(concat_node, 1)); // prev_activ |
| RETURN_IF_ERROR(graph->SetProducer(concat_node->id, concat_temp->id)); |
| |
| RETURN_IF_ERROR(graph->AddConsumer(fc_node->id, concat_temp->id)); |
| RETURN_IF_ERROR(graph->SetProducer(fc_node->id, activ_temp->id)); |
| |
| RETURN_IF_ERROR(graph->AddConsumer(lstm_node->id, activ_temp->id)); |
| RETURN_IF_ERROR(reader->AddInput(lstm_node, 4)); // prev_state |
| RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1)); // new_state |
| RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0)); // activation |
| |
| return absl::OkStatus(); |
| } |
| |
| absl::Status CheckBasicParameters(const TfLiteLSTMParams* tf_options) { |
| if (tf_options->activation != kTfLiteActTanh) { |
| return absl::UnimplementedError("Only TANH activation is supported."); |
| } |
| if (tf_options->cell_clip != 0.0f) { |
| return absl::UnimplementedError("cell_clip is not supported."); |
| } |
| if (tf_options->proj_clip != 0.0f) { |
| return absl::UnimplementedError("proj_clip is not supported."); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status ParseFull(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader, |
| const TfLiteLSTMParams* tf_options) { |
| // Invoke full LSTM parser |
| RETURN_IF_ERROR(ParseLSTMAttributes(tflite_node, registration, graph, |
| reader, tf_options, |
| &new_variable_input_value_map_)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status CheckFullParameters(const TfLiteLSTMParams* tf_options) { |
| if (tf_options->activation != kTfLiteActSigmoid && |
| tf_options->activation != kTfLiteActTanh) { |
| return absl::UnimplementedError( |
| "Only sigmoid or tanh activation is supported."); |
| } |
| |
| return absl::OkStatus(); |
| } |
| |
| absl::flat_hash_map<int, ValueId> new_variable_input_value_map_; |
| }; |
| |
| class MulOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3)); |
| if (tflite_node->inputs->size != 2) { |
| return absl::UnimplementedError("MUL requires two input tensors."); |
| } |
| const TfLiteTensor* input0 = GetInput(context, tflite_node, 0); |
| const TfLiteTensor* input1 = GetInput(context, tflite_node, 1); |
| if (input0 == nullptr || input1 == nullptr) { |
| return absl::InvalidArgumentError("At least one input tensor is null"); |
| } |
| if (input0->dims->size == input1->dims->size) { |
| // this code checks that at least one input of Mul not smaller in all |
| // dimensions. Sometimes Mul used for matrix-vector multiplication that we |
| // currently don't support. For example input0 HWC(1, 256, 1), input1 |
| // HWC(1, 1, 256) -> output HWC (1, 256, 256). In this case it can be |
| // replaced with Convolution operation. |
| bool first_has_smaller_dim = false; |
| bool second_has_smaller_dim = false; |
| for (int i = 0; i < input0->dims->size; ++i) { |
| if (input0->dims->data[i] < input1->dims->data[i]) { |
| first_has_smaller_dim = true; |
| } |
| if (input1->dims->data[i] < input0->dims->data[i]) { |
| second_has_smaller_dim = true; |
| } |
| } |
| if (first_has_smaller_dim && second_has_smaller_dim) { |
| return absl::UnimplementedError( |
| "MUL requires one tensor that not less than second in all " |
| "dimensions."); |
| } |
| } |
| const TfLiteMulParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| return IsActivationSupported(tf_options->activation); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| const TfLiteTensor* input0 = reader->GetInputTensor(0); |
| if (!input0) { |
| return absl::InvalidArgumentError( |
| "Couldn't get the 1st input tensor for MUL."); |
| } |
| const TfLiteTensor* input1 = reader->GetInputTensor(1); |
| if (!input1) { |
| return absl::InvalidArgumentError( |
| "Couldn't get the 2nd input tensor for MUL."); |
| } |
| const bool constant_tensor0 = IsConstantTensor(input0); |
| const bool constant_tensor1 = IsConstantTensor(input1); |
| if (constant_tensor0 && constant_tensor1) { |
| return absl::InvalidArgumentError("No runtime input tensors for MUL."); |
| } |
| const bool runtime_tensor0 = !constant_tensor0; |
| const bool runtime_tensor1 = !constant_tensor1; |
| |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::MUL); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| // Determine runtime/constant tensors. |
| if (runtime_tensor0 && runtime_tensor1) { |
| if (input0 == input1) { |
| // replace MUL(A, A) with POW(A, 2.0) |
| // TODO(b/166831113): Support the same inputs for operations. |
| node->operation.type = ToString(OperationType::POW); |
| ElementwiseAttributes attr; |
| attr.param = 2.0f; |
| node->operation.attributes = std::move(attr); |
| return reader->AddInput(node, 0); |
| } |
| |
| // The "larger" input tensor must be bound to 1st input and the "smaller" |
| // input tensor must be bound to 2nd input. |
| BHWC shape0; |
| RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0)); |
| BHWC shape1; |
| RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1)); |
| int input_tensor0 = 0; |
| int input_tensor1 = 1; |
| if (shape0.h <= shape1.h && shape0.w <= shape1.w && |
| shape0.c == shape1.c) { |
| input_tensor0 = 1; |
| input_tensor1 = 0; |
| } |
| RETURN_IF_ERROR(reader->AddInput(node, input_tensor0)); |
| RETURN_IF_ERROR(reader->AddInput(node, input_tensor1)); |
| } else { |
| ElementwiseAttributes attr; |
| RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param)); |
| node->operation.attributes = std::move(attr); |
| } |
| |
| const TfLiteMulParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| return MaybeFuseActivation(tf_options->activation, graph, node); |
| } |
| }; |
| |
| class PackOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| const TfLitePackParams* tf_options; |
| return RetrieveBuiltinData(tflite_node, &tf_options); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| if (tflite_node->inputs->size == 1) { |
| // Pack with single input can be replaced with Reshape |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::RESHAPE); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| // New shape comes from output shape. |
| ReshapeAttributes attr; |
| attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape; |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } else { |
| // Pack with few inputs can be replaced with Concat |
| const TfLitePackParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| |
| // Read inputs first to make sure const node is added to a graph before |
| // concat node to ensure topological order. |
| std::vector<const Value*> inputs; |
| for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) { |
| Value* value; |
| const auto status = reader->ReadValue(idx, &value); |
| if (status.ok()) { |
| inputs.push_back(value); |
| } else { |
| TensorFloat32 tensor; |
| RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor)); |
| Value* value; |
| RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value)); |
| inputs.push_back(value); |
| } |
| } |
| |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::CONCAT); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| for (const Value* input : inputs) { |
| RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); |
| } |
| const TfLiteTensor* output = reader->GetOutputTensor(0); |
| ConcatAttributes attr; |
| RETURN_IF_ERROR( |
| ExtractAxisFromIndex(*output, tf_options->axis, &attr.axis)); |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| } |
| }; |
| |
| class PReLUOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); |
| // TODO(eignasheva): add params check |
| return absl::OkStatus(); |
| } |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::PRELU); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape; |
| |
| PReLUAttributes attr; |
| Tensor<Linear, DataType::FLOAT32> linear_alpha; |
| absl::Status status = reader->ReadTensor(1, &linear_alpha); |
| if (status.ok()) { |
| if (linear_alpha.shape.v != input_shape.c) { |
| return absl::InvalidArgumentError( |
| "Linear alpha shape does not match the number of input channels."); |
| } |
| attr.alpha = std::move(linear_alpha); |
| } else { |
| Tensor<HWC, DataType::FLOAT32> hwc_alpha; |
| RETURN_IF_ERROR(reader->ReadTensor(1, &hwc_alpha)); |
| if (hwc_alpha.shape.h != input_shape.h || |
| hwc_alpha.shape.w != input_shape.w || |
| hwc_alpha.shape.c != input_shape.c) { |
| return absl::InvalidArgumentError( |
| "Alpha shape does not match input shape."); |
| } |
| attr.alpha = std::move(hwc_alpha); |
| } |
| node->operation.attributes = std::move(attr); |
| return reader->AddOutputs(node); |
| } |
| }; |
| |
| class PadOperationParser : public TFLiteOperationParser { |
| public: |
| explicit PadOperationParser(bool mirror_pad) : mirror_pad_(mirror_pad) {} |
| |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| if (mirror_pad_) { |
| const TfLiteMirrorPaddingParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| if (tf_options->mode != |
| TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect) { |
| return absl::InvalidArgumentError( |
| "Only Reflective padding is supported for Mirror Pad operation."); |
| } |
| } |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, /*outputs=*/1)); |
| RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); |
| const TfLiteTensor* pad_tensor = GetInput(context, tflite_node, 1); |
| if (pad_tensor == nullptr) { |
| return absl::InvalidArgumentError("Padding tensor was null"); |
| } |
| if (pad_tensor->dims->size != 2) { |
| return absl::InvalidArgumentError(absl::StrCat( |
| "Invalid paddings tensor dimension: expected 2 dim, got ", |
| pad_tensor->dims->size, " dim")); |
| } |
| bool supported = |
| pad_tensor->dims->data[0] == 3 || pad_tensor->dims->data[0] == 4; |
| if (!supported || pad_tensor->dims->data[1] != 2) { |
| return absl::InvalidArgumentError(absl::StrCat( |
| "Invalid paddings tensor shape: expected 4x2 or 3x2, got ", |
| pad_tensor->dims->data[0], "x", pad_tensor->dims->data[1])); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::PAD); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| PadAttributes attr; |
| if (mirror_pad_) { |
| attr.type = PaddingContentType::REFLECT; |
| } else /*zero pad*/ { |
| attr.type = PaddingContentType::ZEROS; |
| } |
| |
| Tensor<HW, DataType::INT32> paddings; |
| RETURN_IF_ERROR(reader->ReadTensor(1, &paddings)); |
| |
| if (paddings.shape.h == 4 && paddings.shape.w == 2) { |
| // 4x2 tensor with paddings. |
| attr.prepended = BHWC(paddings.data[0], paddings.data[2], |
| paddings.data[4], paddings.data[6]); |
| attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5], |
| paddings.data[7]); |
| } else if (paddings.shape.h == 3 && paddings.shape.w == 2) { |
| // 3x2 tensor with paddings. |
| attr.prepended = |
| BHWC(1, paddings.data[0], paddings.data[2], paddings.data[4]); |
| attr.appended = |
| BHWC(1, paddings.data[1], paddings.data[3], paddings.data[5]); |
| } else { |
| // It shouldn't fail here since it's checked at IsSupported(). |
| return absl::InvalidArgumentError( |
| "Paddings tensor has unexpected shape."); |
| } |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| |
| private: |
| bool mirror_pad_ = false; |
| }; |
| |
| class Pooling2DOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| const TfLitePoolParams* tf_options; |
| auto status = RetrieveCustomInitialData(tflite_node, &tf_options); |
| if (status.ok()) { // custom case with indices as a second output |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, |
| /*outputs=*/2)); |
| } else { // common pooling with 1 output |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, |
| /*outputs=*/1)); |
| } |
| RETURN_IF_ERROR(CheckKernelsAndStrides( |
| tf_options->filter_height, tf_options->filter_width, |
| tf_options->stride_height, tf_options->stride_width)); |
| return IsActivationSupported(tf_options->activation); |
| } |
| |
| public: |
| explicit Pooling2DOperationParser(PoolingType type) : type_(type) {} |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::POOLING_2D); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutput(node, 0)); |
| |
| Pooling2DAttributes attr; |
| attr.type = type_; |
| |
| auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape; |
| |
| // check whether there are custom options encoded. It happens if operation |
| // is MaxPoolingWithArgmax2D. There is no way to read |
| // tflite_node->builtin_code, so, simply check whether custom data is |
| // available. |
| const TfLitePoolParams* tf_options; |
| if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) { |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| } |
| |
| RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node)); |
| // Second output is optional. It is not required, it but must be added after |
| // MaybeAddFusedActivation function is called |
| reader->AddOutput(node, 1).IgnoreError(); |
| |
| // First output is the result of pooling operation, while second output is |
| // indices used for pooling. |
| auto outputs = graph->FindOutputs(node->id); |
| attr.output_indices = outputs.size() == 2; |
| if (attr.output_indices) { |
| // Fix data type for output indices. In the model it is set as float32. |
| outputs[1]->tensor.type = DataType::INT32; |
| } |
| RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr)); |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| |
| private: |
| const PoolingType type_; |
| }; |
| |
| class ReduceOperationParser : public TFLiteOperationParser { |
| public: |
| explicit ReduceOperationParser(OperationType operation_type) |
| : operation_type_(operation_type) {} |
| |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, /*outputs=*/1)); |
| auto* axes = &context->tensors[tflite_node->inputs->data[1]]; |
| if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) { |
| return absl::UnimplementedError( |
| "Reduce has unsupported tensor for axes."); |
| } |
| if (tflite::NumElements(axes) != 1) { |
| return absl::UnimplementedError( |
| "Supported reduce in single dimensions only."); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(operation_type_); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| const TfLiteReducerParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| |
| Tensor<Scalar, DataType::INT32> axes; |
| RETURN_IF_ERROR(reader->ReadTensor(1, &axes)); |
| const TfLiteTensor* input = reader->GetInputTensor(0); |
| ReduceAttributes attr; |
| Axis axis; |
| RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[0], &axis)); |
| attr.dims = {axis}; |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| |
| private: |
| const OperationType operation_type_; |
| }; |
| |
| class QuantizeOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, /*outputs=*/1)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| // 'Quantize' is rewritten as QuantizeAndDequantize since we are dealing |
| // with floating-point versions of the original tensors. |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| // Quantization attributes should already be present in the output tensor. |
| auto output_value = graph->FindOutputs(node->id)[0]; |
| if (!output_value->quant_params) { |
| return absl::InvalidArgumentError( |
| "Encountered Quantize output with no quant params"); |
| } |
| QuantizeAndDequantizeAttributes attr; |
| attr.min = output_value->quant_params.value().min; |
| attr.max = output_value->quant_params.value().max; |
| attr.scale = output_value->quant_params.value().scale; |
| |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class ReLUOperationParser : public TFLiteOperationParser { |
| public: |
| explicit ReLUOperationParser(int clip) : clip_(clip) {} |
| |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::RELU); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| |
| ReLUAttributes attr; |
| const TfLiteLeakyReluParams* tf_options; |
| auto status = RetrieveBuiltinData(tflite_node, &tf_options); |
| attr.alpha = status.ok() ? tf_options->alpha : 0; |
| attr.clip = clip_; |
| node->operation.attributes = attr; |
| return reader->AddOutputs(node); |
| } |
| |
| private: |
| const int clip_; |
| }; |
| |
| class ReshapeOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, /*outputs=*/1)); |
| // TODO(eignasheva): add shape checking |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::RESHAPE); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| // Here we may have extra inputs. Other tensors were supposed to |
| // define new shape, but in TFLite these are ignored. |
| // TODO(akulik): check that shapes match? |
| |
| // New shape comes from output shape. |
| ReshapeAttributes attr; |
| attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape; |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class Resize2DOperationParser : public TFLiteOperationParser { |
| public: |
| explicit Resize2DOperationParser(SamplingType sampling_type) |
| : sampling_type_(sampling_type) {} |
| |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3)); |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, /*outputs=*/1)); |
| |
| RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node)); |
| bool align_corners; |
| RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &align_corners)); |
| bool half_pixel_centers; |
| RETURN_IF_ERROR(GetHalfPixelCentersValue(tflite_node, &half_pixel_centers)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::RESIZE); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| // Here we may have extra inputs. Other tensors were supposed to |
| // define new shape, but in TFLite these are ignored. |
| |
| Resize2DAttributes attr; |
| RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &attr.align_corners)); |
| RETURN_IF_ERROR( |
| GetHalfPixelCentersValue(tflite_node, &attr.half_pixel_centers)); |
| attr.type = sampling_type_; |
| attr.new_shape.CopyAllDefinedAxis( |
| graph->FindOutputs(node->id)[0]->tensor.shape); |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| |
| private: |
| absl::Status GetAlignCornersValue(const TfLiteNode* tflite_node, |
| bool* align_corners) { |
| switch (sampling_type_) { |
| case SamplingType::BILINEAR: |
| return GetAlignCornersValueForType<TfLiteResizeBilinearParams>( |
| tflite_node, align_corners); |
| case SamplingType::NEAREST: |
| return GetAlignCornersValueForType<TfLiteResizeNearestNeighborParams>( |
| tflite_node, align_corners); |
| case SamplingType::UNKNOWN: |
| return absl::InternalError("Sampling type is not specified"); |
| } |
| return absl::OkStatus(); |
| } |
| |
| template <class T> |
| absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node, |
| bool* align_corners) { |
| const T* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| *align_corners = tf_options->align_corners; |
| return absl::OkStatus(); |
| } |
| |
| absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node, |
| bool* half_pixel_centers) { |
| if (sampling_type_ == SamplingType::BILINEAR) { |
| const TfLiteResizeBilinearParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| if (tf_options->align_corners && tf_options->half_pixel_centers) { |
| return absl::InternalError( |
| "If half_pixel_centers is True, align_corners must be False."); |
| } |
| *half_pixel_centers = tf_options->half_pixel_centers; |
| } else { |
| const TfLiteResizeNearestNeighborParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| *half_pixel_centers = tf_options->half_pixel_centers; |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status CheckOnlyUpsamplingIsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node) { |
| const auto* input = context->tensors + tflite_node->inputs->data[0]; |
| const auto* output = context->tensors + tflite_node->outputs->data[0]; |
| |
| if (!input->dims || input->dims->size != 4) { |
| return absl::InvalidArgumentError("input.dims.size != 4"); |
| } |
| if (!output->dims || output->dims->size != 4) { |
| return absl::InvalidArgumentError("output.dims.size != 4"); |
| } |
| if (output->dims->data[1] < input->dims->data[1] || |
| output->dims->data[2] < input->dims->data[2]) { |
| return absl::InvalidArgumentError(absl::StrCat( |
| "Only upsampling is supported, received output h,w = ", |
| output->dims->data[1], ",", output->dims->data[2], |
| " input h,w = ", input->dims->data[1], ",", input->dims->data[2])); |
| } |
| return absl::OkStatus(); |
| } |
| |
| SamplingType sampling_type_ = SamplingType::UNKNOWN; |
| }; |
| |
| class SliceOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| if (tflite_node->inputs->size < 3) { |
| return absl::UnimplementedError("SLICE requires 3 inputs."); |
| } |
| const TfLiteTensor* input = GetInput(context, tflite_node, 0); |
| if (input->dims->size != 3 && input->dims->size != 4) { |
| return absl::UnimplementedError( |
| "SLICE supports for 3 or 4 dimensional tensors only."); |
| } |
| |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::SLICE); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| Value* input; |
| RETURN_IF_ERROR(reader->ReadValue(0, &input)); |
| RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); |
| |
| const TfLiteTensor* tfl_input = reader->GetInputTensor(0); |
| const int input_dims = tfl_input->dims->size; |
| |
| SliceAttributes attr; |
| attr.strides = BHWC(1, 1, 1, 1); |
| Tensor<Linear, DataType::INT32> starts, sizes; |
| RETURN_IF_ERROR(reader->ReadTensor(1, &starts)); |
| RETURN_IF_ERROR(reader->ReadTensor(2, &sizes)); |
| if (starts.data.size() != sizes.data.size()) { |
| return absl::InvalidArgumentError("Starts amount != sizes amount."); |
| } |
| BHWC bhwc_starts(0, 0, 0, 0); |
| BHWC bhwc_sizes = input->tensor.shape; |
| if (input_dims == 4) { |
| // input in BHWC layout |
| if (starts.data.size() == 4) { |
| bhwc_starts.b = starts.data[0]; |
| bhwc_starts.h = starts.data[1]; |
| bhwc_starts.w = starts.data[2]; |
| bhwc_starts.c = starts.data[3]; |
| bhwc_sizes.b = sizes.data[0]; |
| bhwc_sizes.h = sizes.data[1]; |
| bhwc_sizes.w = sizes.data[2]; |
| bhwc_sizes.c = sizes.data[3]; |
| } else if (starts.data.size() == 3) { |
| // if input is 4D(BHWC) and args 3D, we assume that args in HWC layout |
| bhwc_starts.h = starts.data[0]; |
| bhwc_starts.w = starts.data[1]; |
| bhwc_starts.c = starts.data[2]; |
| bhwc_sizes.h = sizes.data[0]; |
| bhwc_sizes.w = sizes.data[1]; |
| bhwc_sizes.c = sizes.data[2]; |
| } else { |
| return absl::UnimplementedError( |
| "Slicing is supported for 3 or 4 dimensional tensors only."); |
| } |
| } else if (input_dims == 3) { |
| // input in BWC layout |
| if (starts.data.size() == 3) { |
| bhwc_starts.b = starts.data[0]; |
| bhwc_starts.w = starts.data[1]; |
| bhwc_starts.c = starts.data[2]; |
| bhwc_sizes.b = sizes.data[0]; |
| bhwc_sizes.w = sizes.data[1]; |
| bhwc_sizes.c = sizes.data[2]; |
| } else { |
| return absl::UnimplementedError( |
| "Slicing is supported for 3 or 4 dimensional tensors only."); |
| } |
| } else { |
| return absl::UnimplementedError( |
| "Slicing is supported for 3 or 4 dimensional tensors only."); |
| } |
| const auto& in_shape = input->tensor.shape; |
| if (bhwc_sizes.b == -1) { |
| bhwc_sizes.b = in_shape.b - bhwc_starts.b; |
| } |
| if (bhwc_sizes.h == -1) { |
| bhwc_sizes.h = in_shape.h - bhwc_starts.h; |
| } |
| if (bhwc_sizes.w == -1) { |
| bhwc_sizes.w = in_shape.w - bhwc_starts.w; |
| } |
| if (bhwc_sizes.c == -1) { |
| bhwc_sizes.c = in_shape.c - bhwc_starts.c; |
| } |
| attr.starts = bhwc_starts; |
| attr.ends = |
| BHWC(bhwc_starts.b + bhwc_sizes.b, bhwc_starts.h + bhwc_sizes.h, |
| bhwc_starts.w + bhwc_sizes.w, bhwc_starts.c + bhwc_sizes.c); |
| RETURN_IF_ERROR(UpdateIfNegative(in_shape, &attr)); |
| |
| auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape; |
| if ((attr.ends.b - attr.starts.b) != out_shape.b) { |
| return absl::UnimplementedError("Output batch don't match"); |
| } |
| if ((attr.ends.h - attr.starts.h) != out_shape.h) { |
| return absl::UnimplementedError("Output height doesn't match"); |
| } |
| if ((attr.ends.w - attr.starts.w) != out_shape.w) { |
| return absl::UnimplementedError("Output width doesn't match"); |
| } |
| if ((attr.ends.c - attr.starts.c) != out_shape.c) { |
| return absl::UnimplementedError("Output channels don't match"); |
| } |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| |
| private: |
| absl::Status UpdateIfNegative(const BHWC& input_shape, |
| SliceAttributes* attr) { |
| if (attr->ends.h < 0) { |
| attr->ends.h = input_shape.h + attr->ends.h; |
| } |
| if (attr->ends.w < 0) { |
| attr->ends.w = input_shape.w + attr->ends.w; |
| } |
| if (attr->ends.c < 0) { |
| attr->ends.c = input_shape.c + attr->ends.c; |
| } |
| if (attr->ends.b < 0) { |
| attr->ends.b = input_shape.b + attr->ends.b; |
| } |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class SoftmaxOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, /*outputs=*/1)); |
| const TfLiteSoftmaxParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| if (tf_options->beta != 1) { |
| // TODO(eignasheva): figure out, what's wrong with softmax. |
| return absl::UnimplementedError("Softmax.beta != 1 is not supported."); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::SOFTMAX); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| const TfLiteSoftmaxParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| if (tf_options->beta != 1) { |
| // there is multiply by scalar operation fused in softmax. Make a layer |
| // out of it before softmax. |
| return absl::UnimplementedError("Softmax.beta != 1 is not supported."); |
| // auto mul_node = reader->NewPassthroughNode(node); |
| // mul_node->operation.type = ToString(OperationType::MUL); |
| } |
| SoftmaxAttributes attr; |
| attr.axis = Axis::CHANNELS; // always by channels |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class SpaceToDepthOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, /*outputs=*/1)); |
| // TODO(impjdi): Dims check. |
| const TfLiteSpaceToDepthParams* s2d_params; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params)); |
| if (s2d_params->block_size == 1) { |
| return absl::InvalidArgumentError( |
| "SPACE_TO_DEPTH block_size = 1 is a no-op."); |
| } |
| if (s2d_params->block_size < 1) { |
| return absl::InvalidArgumentError( |
| "SPACE_TO_DEPTH block_size must be > 1."); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::SPACE_TO_DEPTH); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| const TfLiteSpaceToDepthParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| SpaceToDepthAttributes attr; |
| attr.block_size = tf_options->block_size; |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class StridedSliceOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| const TfLiteStridedSliceParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); |
| |
| if (tflite_node->inputs->size < 4) { |
| return absl::UnimplementedError("STRIDED_SLICE requires 4 inputs."); |
| } |
| const TfLiteTensor* input = GetInput(context, tflite_node, 0); |
| if (input->dims->size != 3 && input->dims->size != 4) { |
| return absl::UnimplementedError( |
| "STRIDED_SLICE supports for 3 or 4 dimensional tensors only."); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::SLICE); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| Value* input; |
| RETURN_IF_ERROR(reader->ReadValue(0, &input)); |
| RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); |
| |
| Tensor<Linear, DataType::INT32> tmp; |
| RETURN_IF_ERROR(reader->ReadTensor(1, &tmp)); |
| |
| bool read_without_batch = tmp.data.size() == 3; |
| bool read_with_batch = tmp.data.size() == 4; |
| if (!read_without_batch && !read_with_batch) { |
| // Error: Must be catched in IsSupported() |
| return absl::UnimplementedError( |
| "Slicing is supported for 3 or 4 dimensional tensors only."); |
| } |
| |
| const TfLiteStridedSliceParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); |
| |
| auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape; |
| |
| SliceAttributes attr; |
| if (read_without_batch) { |
| RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options, |
| input->tensor.shape, &attr)); |
| } |
| if (read_with_batch) { |
| RETURN_IF_ERROR( |
| ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr)); |
| } |
| if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 || |
| attr.strides.c == 0) { |
| return absl::InvalidArgumentError("stride values must be non-zero"); |
| } |
| if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 || |
| attr.strides.c < 0) { |
| return absl::UnimplementedError("Reverse slices are not supported."); |
| } |
| if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b != |
| out_shape.b) { |
| return absl::UnimplementedError("Output batch don't match"); |
| } |
| if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h != |
| out_shape.h) { |
| return absl::UnimplementedError("Output height doesn't match"); |
| } |
| if ((attr.ends.w - attr.starts.w + attr.strides.w - 1) / attr.strides.w != |
| out_shape.w) { |
| return absl::UnimplementedError("Output width doesn't match"); |
| } |
| if ((attr.ends.c - attr.starts.c + attr.strides.c - 1) / attr.strides.c != |
| out_shape.c) { |
| return absl::UnimplementedError("Output channels don't match"); |
| } |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| |
| private: |
| absl::Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options, |
| const BHWC& input_shape, int ignore_b, |
| int ignore_h, int ignore_w, int ignore_c, |
| SliceAttributes* attr) { |
| if (tf_options->begin_mask & ignore_h) { |
| attr->starts.h = 0; |
| } |
| if (tf_options->begin_mask & ignore_w) { |
| attr->starts.w = 0; |
| } |
| if (tf_options->begin_mask & ignore_c) { |
| attr->starts.c = 0; |
| } |
| if (tf_options->begin_mask & ignore_b) { |
| attr->starts.b = 0; |
| } |
| |
| if (tf_options->end_mask & ignore_h) { |
| attr->ends.h = input_shape.h; |
| } |
| if (tf_options->end_mask & ignore_w) { |
| attr->ends.w = input_shape.w; |
| } |
| if (tf_options->end_mask & ignore_c) { |
| attr->ends.c = input_shape.c; |
| } |
| if (tf_options->end_mask & ignore_b) { |
| attr->ends.b = input_shape.b; |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status UpdateIfNegative(const BHWC& input_shape, |
| SliceAttributes* attr) { |
| if (attr->ends.h < 0) { |
| attr->ends.h = input_shape.h + attr->ends.h; |
| } |
| if (attr->ends.w < 0) { |
| attr->ends.w = input_shape.w + attr->ends.w; |
| } |
| if (attr->ends.c < 0) { |
| attr->ends.c = input_shape.c + attr->ends.c; |
| } |
| if (attr->ends.b < 0) { |
| attr->ends.b = input_shape.b + attr->ends.b; |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status ReadAttribsWithBatch(const ObjectReader* reader, |
| const TfLiteStridedSliceParams* tf_options, |
| const BHWC& input_shape, |
| SliceAttributes* attr) { |
| auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status { |
| Tensor<Linear, DataType::INT32> t; |
| RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t)); |
| *bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]); |
| return absl::OkStatus(); |
| }; |
| |
| RETURN_IF_ERROR(read_bhwc(1, &attr->starts)); |
| RETURN_IF_ERROR(read_bhwc(2, &attr->ends)); |
| RETURN_IF_ERROR(read_bhwc(3, &attr->strides)); |
| RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr)); |
| RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status ReadAttribsWithoutBatch( |
| const ObjectReader* reader, const TfLiteStridedSliceParams* tf_options, |
| const BHWC& input_shape, SliceAttributes* attr) { |
| auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status { |
| Tensor<Linear, DataType::INT32> t; |
| RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t)); |
| *bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]); |
| return absl::OkStatus(); |
| }; |
| |
| RETURN_IF_ERROR(read_hwc(1, &attr->starts)); |
| RETURN_IF_ERROR(read_hwc(2, &attr->ends)); |
| RETURN_IF_ERROR(read_hwc(3, &attr->strides)); |
| RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr)); |
| RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr)); |
| attr->starts.b = 0; |
| attr->ends.b = input_shape.b; |
| attr->strides.b = 1; |
| return absl::OkStatus(); |
| } |
| absl::Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) { |
| if (tf_options->ellipsis_mask) { |
| return absl::UnimplementedError("Slice does not support ellipsis_mask."); |
| } |
| if (tf_options->new_axis_mask) { |
| return absl::UnimplementedError("Slice does not support new_axis_mask."); |
| } |
| if (tf_options->shrink_axis_mask) { |
| return absl::UnimplementedError( |
| "Slice does not support shrink_axis_mask parameter. "); |
| } |
| return absl::OkStatus(); |
| } |
| }; |
| |
| // Builtin op version of TRANSPOSE_CONV. |
| class TransposeConvBuiltinOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3)); |
| const int runtime_inputs = |
| GetNumberOfRuntimeInputsForNode(context, tflite_node); |
| if (runtime_inputs > 2) { |
| return absl::InternalError( |
| absl::StrCat("Expected 1 or 2 input tensor(s), but node has ", |
| runtime_inputs, " runtime inputs.")); |
| } |
| const int runtime_outputs = NumOutputs(tflite_node); |
| if (runtime_outputs != 1) { |
| return absl::InternalError( |
| absl::StrCat("Expected 1 output tensor(s), but node has ", |
| runtime_outputs, " runtime outputs.")); |
| } |
| if (runtime_inputs == 1) { |
| RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); |
| } |
| const TfLiteTransposeConvParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| RETURN_IF_ERROR( |
| CheckStrides(tf_options->stride_height, tf_options->stride_width)); |
| return absl::OkStatus(); |
| } |
| |
| // TFLite's TRANSPOSE_CONV expects 3-4 input tensors (output shape, weights, |
| // input, and an optional bias) and allows configurable padding & stride. |
| // TODO(impjdi): Translate output_shape to attr.adjacent. |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| auto* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); |
| Value* input; |
| RETURN_IF_ERROR(reader->ReadValue(2, &input)); |
| RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| const TfLiteTransposeConvParams* tf_options; |
| RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); |
| |
| ConvolutionTransposedAttributes attr; |
| attr.stride = tf_options |
| ? HW(tf_options->stride_height, tf_options->stride_width) |
| : HW(1, 1); |
| const int runtime_inputs = reader->GetNumberOfRuntimeInputs(); |
| if (runtime_inputs == 2) { |
| RETURN_IF_ERROR(reader->AddInput(node, 1)); |
| auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape; |
| attr.weights.shape = OHWI(weights_shape.b, weights_shape.h, |
| weights_shape.w, weights_shape.c); |
| } else { // runtime_inputs == 1; |
| RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); |
| } |
| reader->ReadTensor(3, &attr.bias).IgnoreError(); // bias is optional |
| |
| UpdatePadding(tf_options->padding, |
| graph->FindInputs(node->id)[0]->tensor.shape, &attr); |
| node->operation.attributes = std::move(attr); |
| return absl::OkStatus(); |
| } |
| }; |
| |
| // Custom op version of TRANSPOSE_CONV. |
| class TransposeConvCustomOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); |
| const TfLiteTransposeConvParams* tf_options; |
| RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); |
| RETURN_IF_ERROR( |
| CheckStrides(tf_options->stride_height, tf_options->stride_width)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| auto* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| const TfLiteTransposeConvParams* tf_options; |
| auto status = RetrieveCustomInitialData(tflite_node, &tf_options); |
| |
| ConvolutionTransposedAttributes attr; |
| attr.stride = status.ok() |
| ? HW(tf_options->stride_height, tf_options->stride_width) |
| : HW(1, 1); |
| RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); |
| reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional |
| |
| UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown, |
| graph->FindInputs(node->id)[0]->tensor.shape, &attr); |
| node->operation.attributes = std::move(attr); |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class TransposeOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, /*outputs=*/1)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::TRANSPOSE); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| TransposeAttributes attr; |
| Tensor<Linear, DataType::INT32> perm; |
| RETURN_IF_ERROR(reader->ReadTensor(1, &perm)); |
| std::map<Axis, int> axis_to_index = {{Axis::BATCH, 0}, |
| {Axis::HEIGHT, 1}, |
| {Axis::WIDTH, 2}, |
| {Axis::CHANNELS, 3}}; |
| if (perm.data.size() == 4) { |
| attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[2], perm.data[3]); |
| } else if (perm.data.size() == 3) { |
| std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::WIDTH, |
| Axis::CHANNELS}; |
| attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]]; |
| attr.perm.h = 1; |
| attr.perm.w = axis_to_index[index_to_axis[perm.data[1]]]; |
| attr.perm.c = axis_to_index[index_to_axis[perm.data[2]]]; |
| } else if (perm.data.size() == 2) { |
| std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::CHANNELS}; |
| attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]]; |
| attr.perm.h = 1; |
| attr.perm.w = 2; |
| attr.perm.c = axis_to_index[index_to_axis[perm.data[1]]]; |
| } else { |
| return absl::InvalidArgumentError( |
| "Permutation for transpose is invalid."); |
| } |
| |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class Unpooling2DOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/2, /*outputs=*/1)); |
| const TfLitePoolParams* tf_options; |
| RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); |
| RETURN_IF_ERROR(CheckKernelsAndStrides( |
| tf_options->filter_height, tf_options->filter_width, |
| tf_options->stride_height, tf_options->stride_width)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddInput(node, 1)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape; |
| MaxUnpooling2DAttributes attr; |
| |
| const TfLitePoolParams* tf_options; |
| RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); |
| |
| attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width); |
| attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); |
| UpdatePadding(tf_options->padding, input_shape, &attr); |
| |
| node->operation.attributes = attr; |
| |
| auto output_value = graph->FindOutputs(node->id)[0]; |
| output_value->tensor.shape = CalculateOutputShape(input_shape, attr); |
| return absl::OkStatus(); |
| } |
| }; |
| |
| // TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported. |
| class BatchToSpaceOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| auto* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::BATCH_TO_SPACE); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| BatchToSpaceAttributes bs_attr; |
| Tensor<Linear, DataType::INT32> block; |
| RETURN_IF_ERROR(reader->ReadTensor(1, &block)); |
| if (block.shape.v != 2) { |
| return absl::InternalError("Space has to be HxW."); |
| } |
| bs_attr.block.h = block.data[0]; |
| bs_attr.block.w = block.data[1]; |
| |
| Tensor<HW, DataType::INT32> crop; |
| RETURN_IF_ERROR(reader->ReadTensor(2, &crop)); |
| auto crop_shape = crop.shape; |
| if (crop_shape.h != 2 && crop_shape.w != 2) { |
| return absl::InternalError("Space has to be HxW."); |
| } |
| |
| bs_attr.crop.prepended.h = crop.data[0]; |
| bs_attr.crop.prepended.w = crop.data[2]; |
| |
| bs_attr.crop.appended.h = crop.data[1]; |
| bs_attr.crop.appended.w = crop.data[3]; |
| |
| node->operation.attributes = std::move(bs_attr); |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class SpaceToBatchOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| auto* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::SPACE_TO_BATCH); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| SpaceToBatchAttributes sb_attr; |
| Tensor<Linear, DataType::INT32> block; |
| RETURN_IF_ERROR(reader->ReadTensor(1, &block)); |
| if (block.shape.v != 2) { |
| return absl::InternalError("Space has to be HxW."); |
| } |
| sb_attr.block.h = block.data[0]; |
| sb_attr.block.w = block.data[1]; |
| |
| Tensor<HW, DataType::INT32> padding; |
| RETURN_IF_ERROR(reader->ReadTensor(2, &padding)); |
| auto padding_shape = padding.shape; |
| |
| if (padding_shape.h != 2 && padding_shape.w != 2) { |
| return absl::InternalError("Space has to be HxW."); |
| } |
| |
| sb_attr.padding.prepended.h = padding.data[0]; |
| sb_attr.padding.prepended.w = padding.data[2]; |
| |
| sb_attr.padding.appended.h = padding.data[1]; |
| sb_attr.padding.appended.w = padding.data[3]; |
| |
| node->operation.attributes = std::move(sb_attr); |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class RoIToTransformMatrixOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, /*outputs=*/1)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); // bbox |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| std::string op_name = "roi_to_transform_matrix"; |
| node->operation.type = op_name; |
| BHWC output_shape; |
| RETURN_IF_ERROR(ParseCustomAttributes( |
| op_name, registration->version, tflite_node->custom_initial_data, |
| tflite_node->custom_initial_data_size, &(node->operation.attributes), |
| &output_shape)); |
| |
| auto output_value = graph->FindOutputs(node->id)[0]; |
| output_value->tensor.shape = output_shape; |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class TransformTensorBilinearOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/2, /*outputs=*/1)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); // data |
| RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| std::string op_name = "transform_tensor_bilinear"; |
| node->operation.type = op_name; |
| BHWC output_shape; |
| RETURN_IF_ERROR(ParseCustomAttributes( |
| op_name, registration->version, tflite_node->custom_initial_data, |
| tflite_node->custom_initial_data_size, &(node->operation.attributes), |
| &output_shape)); |
| |
| auto output_value = graph->FindOutputs(node->id)[0]; |
| |
| output_value->tensor.shape = |
| BHWC(1, output_shape.h, output_shape.w, |
| graph->FindInputs(node->id)[0]->tensor.shape.c); |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class TransformLandmarksOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/2, /*outputs=*/1)); |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); // data |
| RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| std::string op_name = "transform_landmarks"; |
| node->operation.type = op_name; |
| BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape; |
| RETURN_IF_ERROR(ParseCustomAttributes( |
| op_name, registration->version, tflite_node->custom_initial_data, |
| tflite_node->custom_initial_data_size, &(node->operation.attributes), |
| &output_shape)); |
| |
| auto output_value = graph->FindOutputs(node->id)[0]; |
| |
| output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape; |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); |
| return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, |
| /*outputs=*/1); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks |
| RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix |
| |
| const std::string op_name = "landmarks_to_transform_matrix"; |
| node->operation.type = op_name; |
| BHWC output_shape; |
| RETURN_IF_ERROR(ParseCustomAttributes( |
| op_name, registration->version, tflite_node->custom_initial_data, |
| tflite_node->custom_initial_data_size, &(node->operation.attributes), |
| &output_shape)); |
| |
| auto output_value = graph->FindOutputs(node->id)[0]; |
| output_value->tensor.shape = output_shape; |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class AlignmentPointsToTransformMatrixOperationParser |
| : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, |
| /*outputs=*/1); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| Node* node = graph->NewNode(); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); // alignment points |
| RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix |
| |
| const std::string op_name = "alignment_points_to_transform_matrix"; |
| node->operation.type = op_name; |
| BHWC output_shape; |
| RETURN_IF_ERROR(ParseCustomAttributes( |
| op_name, registration->version, tflite_node->custom_initial_data, |
| tflite_node->custom_initial_data_size, &(node->operation.attributes), |
| &output_shape)); |
| |
| auto output_value = graph->FindOutputs(node->id)[0]; |
| output_value->tensor.shape = output_shape; |
| return absl::OkStatus(); |
| } |
| |
| private: |
| }; |
| |
| class MeanOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, |
| /*runtime_inputs=*/1, |
| /*outputs=*/1)); |
| |
| // Simple mechanism to check if MEAN is to be performed only on HW plane. |
| auto* axes = &context->tensors[tflite_node->inputs->data[1]]; |
| if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) { |
| return absl::UnimplementedError("Mean has unsupported tensor for axes"); |
| } |
| auto* axes_data = axes->data.i32; |
| const bool is_hw_mean = tflite::NumElements(axes) == 2 && |
| ((axes_data[0] == 1 && axes_data[1] == 2) || |
| (axes_data[0] == 2 && axes_data[1] == 1)); |
| if (!is_hw_mean) { |
| return absl::UnimplementedError("Mean operation supports only HW plane"); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| auto* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::MEAN); |
| RETURN_IF_ERROR(reader->AddInput(node, 0)); |
| RETURN_IF_ERROR(reader->AddOutputs(node)); |
| |
| MeanAttributes attr; |
| Tensor<Linear, DataType::INT32> channel; |
| RETURN_IF_ERROR(reader->ReadTensor(1, &channel)); |
| for (int i = 0; i < channel.data.size(); i++) { |
| std::string unsupported; |
| switch (channel.data[i]) { |
| case 1: |
| attr.dims.insert(Axis::HEIGHT); |
| break; |
| case 2: |
| attr.dims.insert(Axis::WIDTH); |
| break; |
| case 0: |
| unsupported = unsupported.empty() ? "batch" : unsupported; |
| ABSL_FALLTHROUGH_INTENDED; |
| case 3: |
| unsupported = unsupported.empty() ? "channels" : unsupported; |
| ABSL_FALLTHROUGH_INTENDED; |
| default: |
| return absl::UnimplementedError( |
| absl::StrCat("Unsupported mean dimension: ", unsupported)); |
| } |
| } |
| node->operation.attributes = attr; |
| return absl::OkStatus(); |
| } |
| }; |
| |
| class UnsupportedOperationParser : public TFLiteOperationParser { |
| public: |
| absl::Status IsSupported(const TfLiteContext* context, |
| const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration) final { |
| return absl::UnimplementedError("Operation is not supported."); |
| } |
| |
| absl::Status Parse(const TfLiteNode* tflite_node, |
| const TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader* reader) final { |
| return absl::UnimplementedError("Operation is not supported."); |
| } |
| }; |
| |
| std::unique_ptr<TFLiteOperationParser> NewOperationParser( |
| const TfLiteRegistration* registration, bool allow_quant_ops = false) { |
| const auto builtin_code = registration->builtin_code; |
| switch (builtin_code) { |
| case kTfLiteBuiltinAbs: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::ABS); |
| case kTfLiteBuiltinAdd: |
| return std::make_unique<AddOperationParser>(); |
| case kTfLiteBuiltinAveragePool2d: |
| return std::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE); |
| case kTfLiteBuiltinBatchMatmul: |
| return std::make_unique<BatchedMatMulOperationParser>(); |
| case kTfLiteBuiltinConcatenation: |
| return std::make_unique<ConcatenationOperationParser>(); |
| case kTfLiteBuiltinConv2d: |
| return std::make_unique<Conv2DOperationParser>(); |
| case kTfLiteBuiltinCos: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::COS); |
| case kTfLiteBuiltinDepthwiseConv2d: |
| return std::make_unique<DepthwiseConvolutionOperationParser>(); |
| case kTfLiteBuiltinDequantize: |
| if (allow_quant_ops) { |
| return std::make_unique<DequantizeOperationParser>(); |
| } |
| break; |
| case kTfLiteBuiltinDiv: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::DIV); |
| case kTfLiteBuiltinElu: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::ELU); |
| case kTfLiteBuiltinExp: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::EXP); |
| case kTfLiteBuiltinFullyConnected: |
| return std::make_unique<FullyConnectedOperationParser>(); |
| case kTfLiteBuiltinHardSwish: |
| return std::make_unique<HardSwishOperationParser>(); |
| case kTfLiteBuiltinLogistic: |
| return std::make_unique<ElementwiseOperationParser>( |
| OperationType::SIGMOID); |
| case kTfLiteBuiltinLog: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::LOG); |
| case kTfLiteBuiltinLstm: |
| return std::make_unique<LSTMOperationParser>(); |
| case kTfLiteBuiltinMaximum: |
| return std::make_unique<ElementwiseOperationParser>( |
| OperationType::MAXIMUM); |
| case kTfLiteBuiltinMaxPool2d: |
| return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX); |
| case kTfLiteBuiltinMean: |
| return std::make_unique<MeanOperationParser>(); |
| case kTfLiteBuiltinMinimum: |
| return std::make_unique<ElementwiseOperationParser>( |
| OperationType::MINIMUM); |
| case kTfLiteBuiltinMirrorPad: |
| return std::make_unique<PadOperationParser>(/*mirror_pad=*/true); |
| case kTfLiteBuiltinMul: |
| return std::make_unique<MulOperationParser>(); |
| case kTfLiteBuiltinNeg: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::NEG); |
| case kTfLiteBuiltinPack: |
| return std::make_unique<PackOperationParser>(); |
| case kTfLiteBuiltinPad: |
| return std::make_unique<PadOperationParser>(/*mirror_pad=*/false); |
| case kTfLiteBuiltinPow: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::POW); |
| case kTfLiteBuiltinReduceMax: |
| return std::make_unique<ReduceOperationParser>( |
| OperationType::REDUCE_MAXIMUM); |
| case kTfLiteBuiltinReduceMin: |
| return std::make_unique<ReduceOperationParser>( |
| OperationType::REDUCE_MINIMUM); |
| case kTfLiteBuiltinReduceProd: |
| return std::make_unique<ReduceOperationParser>( |
| OperationType::REDUCE_PRODUCT); |
| case kTfLiteBuiltinQuantize: |
| if (allow_quant_ops) { |
| return std::make_unique<QuantizeOperationParser>(); |
| } |
| break; |
| case kTfLiteBuiltinRelu: |
| return std::make_unique<ReLUOperationParser>(0); |
| case kTfLiteBuiltinRelu6: |
| return std::make_unique<ReLUOperationParser>(6); |
| case kTfLiteBuiltinLeakyRelu: |
| return std::make_unique<ReLUOperationParser>(0); |
| case kTfLiteBuiltinPrelu: |
| return std::make_unique<PReLUOperationParser>(); |
| case kTfLiteBuiltinReshape: |
| return std::make_unique<ReshapeOperationParser>(); |
| case kTfLiteBuiltinResizeBilinear: |
| return std::make_unique<Resize2DOperationParser>(SamplingType::BILINEAR); |
| case kTfLiteBuiltinResizeNearestNeighbor: |
| return std::make_unique<Resize2DOperationParser>(SamplingType::NEAREST); |
| case kTfLiteBuiltinRsqrt: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::RSQRT); |
| case kTfLiteBuiltinSin: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::SIN); |
| case kTfLiteBuiltinSlice: |
| return std::make_unique<SliceOperationParser>(); |
| case kTfLiteBuiltinSoftmax: |
| return std::make_unique<SoftmaxOperationParser>(); |
| case kTfLiteBuiltinSpaceToDepth: |
| return std::make_unique<SpaceToDepthOperationParser>(); |
| case kTfLiteBuiltinSqrt: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::SQRT); |
| case kTfLiteBuiltinSquare: |
| return std::make_unique<ElementwiseOperationParser>( |
| OperationType::SQUARE); |
| case kTfLiteBuiltinSquaredDifference: |
| return std::make_unique<ElementwiseOperationParser>( |
| OperationType::SQUARED_DIFF); |
| case kTfLiteBuiltinStridedSlice: |
| return std::make_unique<StridedSliceOperationParser>(); |
| case kTfLiteBuiltinSub: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::SUB); |
| case kTfLiteBuiltinSum: |
| return std::make_unique<ReduceOperationParser>(OperationType::REDUCE_SUM); |
| case kTfLiteBuiltinTanh: |
| return std::make_unique<ElementwiseOperationParser>(OperationType::TANH); |
| case kTfLiteBuiltinTranspose: |
| return std::make_unique<TransposeOperationParser>(); |
| case kTfLiteBuiltinTransposeConv: |
| return std::make_unique<TransposeConvBuiltinOperationParser>(); |
| |
| case kTfLiteBuiltinCustom: |
| const absl::string_view custom_name = registration->custom_name; |
| if (custom_name == "Convolution2DTransposeBias") { |
| return std::make_unique<TransposeConvCustomOperationParser>(); |
| } |
| if (custom_name == "MaxPoolingWithArgmax2D") { |
| return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX); |
| } |
| if (custom_name == "MaxUnpooling2D") { |
| return std::make_unique<Unpooling2DOperationParser>(); |
| } |
| if (custom_name == "RoIToTransformMatrix") { |
| return std::make_unique<RoIToTransformMatrixOperationParser>(); |
| } |
| if (custom_name == "TransformTensor" /*for version 1*/ || |
| custom_name == "TransformTensorBilinear" /*for version 2*/) { |
| return std::make_unique<TransformTensorBilinearOperationParser>(); |
| } |
| if (custom_name == "TransformLandmarks") { |
| return std::make_unique<TransformLandmarksOperationParser>(); |
| } |
| if (custom_name == "Landmarks2TransformMatrix" || |
| custom_name == "Landmarks2TransformMatrixV2") { |
| return std::make_unique<Landmarks2TransformMatrixOperationParser>(); |
| } |
| if (custom_name == "AlignmentPointsToTransformMatrix") { |
| return std::make_unique< |
| AlignmentPointsToTransformMatrixOperationParser>(); |
| } |
| break; |
| } |
| return std::make_unique<UnsupportedOperationParser>(); |
| } |
| |
| absl::Status IsSupported(const TfLiteContext* context, TfLiteNode* node, |
| const TfLiteRegistration* registration, |
| bool allow_quant_ops = false) { |
| return NewOperationParser(registration, allow_quant_ops) |
| ->IsSupported(context, node, registration); |
| } |
| |
| bool IsAllAllowedTensors(TfLiteContext* context, |
| const TfLiteIntArray* tensor_indices, |
| bool allow_quant_ops = false) { |
| for (int i = 0; i < tensor_indices->size; ++i) { |
| int tensor_idx = tensor_indices->data[i]; |
| if (tensor_idx == kTfLiteOptionalTensor) continue; |
| const TfLiteTensor* t = &context->tensors[tensor_idx]; |
| bool type_supported = |
| (t->type == kTfLiteFloat32 || t->type == kTfLiteFloat16); |
| if (allow_quant_ops) { |
| // Since we only check non-constant tensors, type cannot be Int32. |
| type_supported = |
| type_supported || t->type == kTfLiteInt8 || t->type == kTfLiteUInt8; |
| } |
| if (t->allocation_type == kTfLiteArenaRw && !type_supported) { |
| return false; |
| } |
| } |
| return true; |
| } |
| } // namespace |
| |
| // TODO(impjdi): Check number of input/output tensors and their dimensions. |
| // TODO(impjdi): Check ops' parameters. |
| TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops, |
| int max_delegated_partitions) { |
| delegates::IsNodeSupportedFn node_supported_fn = |
| [=](TfLiteContext* context, TfLiteNode* node, |
| TfLiteRegistration* registration, |
| std::string* unsupported_details) -> bool { |
| const auto status = |
| IsSupported(context, node, registration, allow_quant_ops); |
| if (!status.ok()) { |
| if (unsupported_details) { |
| *unsupported_details = std::string(status.message()); |
| } |
| return false; |
| } |
| |
| if (!IsAllAllowedTensors(context, node->inputs, allow_quant_ops) || |
| !IsAllAllowedTensors(context, node->outputs, allow_quant_ops)) { |
| if (unsupported_details) { |
| *unsupported_details = |
| "OP is supported, but tensor type isn't matched!"; |
| } |
| return false; |
| } |
| return true; |
| }; |
| |
| delegates::FP16GraphPartitionHelper partition_helper(context, |
| node_supported_fn); |
| std::set<std::string> unsupported_nodes_info; |
| if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) { |
| return TfLiteIntArrayCreate(0); |
| } |
| |
| // By default, we simply get 1st largest partition as 'max_delegate_partions' |
| // is set to 1 by default. |
| std::vector<int> ops_to_replace = |
| partition_helper.GetNodesOfFirstNLargestPartitions( |
| max_delegated_partitions); |
| |
| if (!unsupported_nodes_info.empty()) { |
| std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n"); |
| std::string error_message = absl::StrCat( |
| "Following operations are not supported by GPU delegate:\n", |
| unsupported, "\n"); |
| if (!ops_to_replace.empty()) { |
| absl::StrAppend( |
| &error_message, ops_to_replace.size(), |
| " operations will run on the GPU, and the remaining ", |
| partition_helper.num_total_nodes() - ops_to_replace.size()); |
| } else { |
| absl::StrAppend(&error_message, |
| "No operations will run on the GPU, and all ", |
| partition_helper.num_total_nodes()); |
| } |
| absl::StrAppend(&error_message, " operations will run on the CPU."); |
| TF_LITE_KERNEL_LOG(context, error_message.c_str()); |
| } |
| return ConvertVectorToTfLiteIntArray(ops_to_replace); |
| } |
| |
| // Creates inputs and outputs passed by io_tensors parameters in the resulting |
| // graph. We force it to make sure that delegated subgraph has same order of |
| // inputs and outputs with the original one. When delegated model is built from |
| // the tflite model representation tensors are created lazily, so there is no |
| // guarantee that the order will match the source model tensors order. |
| absl::Status PrecreateIOTensors( |
| TfLiteContext* context, GraphFloat32* graph, TfLiteIntArray* io_tensors, |
| absl::flat_hash_map<int, int>* quant_conversion_map, |
| absl::flat_hash_map<int, Value*>* tensor_to_value) { |
| for (int i = 0; i < io_tensors->size; ++i) { |
| const int tensor_index = io_tensors->data[i]; |
| const TfLiteTensor& tflite_tensor = context->tensors[tensor_index]; |
| if (tflite::IsConstantTensor(&tflite_tensor)) continue; |
| RETURN_IF_ERROR(ObjectReader::ReadNonConstantTensor( |
| context, tensor_to_value, quant_conversion_map, graph, tensor_index)); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status CopyVariableTensorOutputs( |
| TfLiteNode* tflite_node, TfLiteRegistration* registration, |
| GraphFloat32* graph, ObjectReader& reader, |
| const absl::flat_hash_map<int, ValueId>& new_variable_tensor_values) { |
| absl::flat_hash_map<int, ValueId> new_variable_tensor_values_copy( |
| new_variable_tensor_values); |
| // Retrieve the final value id for the variable input tensors. |
| for (int i = 0; i < tflite_node->inputs->size; i++) { |
| int tensor_idx = tflite_node->inputs->data[i]; |
| Value* value; |
| if (!reader.ReadValueByTensorIdx(tensor_idx, &value).ok()) continue; |
| if (value->tensor.is_variable_input) { |
| if (new_variable_tensor_values_copy.find(i) == |
| new_variable_tensor_values_copy.end()) { |
| return absl::InvalidArgumentError( |
| absl::StrCat(GetOpNameByRegistration(*registration), |
| " did not provide a new value for the variable input " |
| "tensor with index ", |
| tensor_idx)); |
| } else { |
| Node* node = graph->NewNode(); |
| node->operation.type = ToString(OperationType::COPY); |
| RETURN_IF_ERROR(graph->AddConsumer( |
| node->id, new_variable_tensor_values_copy.at(i))); |
| RETURN_IF_ERROR(reader.AddUpdate(node, i)); |
| new_variable_tensor_values_copy.erase( |
| new_variable_tensor_values_copy.find(i)); |
| } |
| } |
| } |
| if (!new_variable_tensor_values_copy.empty()) { |
| return absl::InvalidArgumentError( |
| "More input variable tensors asked to be copied than present on the " |
| "node"); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status BuildModel(TfLiteContext* context, |
| const TfLiteDelegateParams* delegate_params, |
| GraphFloat32* graph, |
| absl::flat_hash_map<int, int>* quant_conversion_map) { |
| std::vector<std::unique_ptr<TFLiteOperationParser>> operations; |
| std::vector<int> tflite_nodes; |
| for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) { |
| TfLiteNode* tflite_node = nullptr; |
| TfLiteRegistration* registration = nullptr; |
| RETURN_IF_ERROR(GetNodeAndRegistration( |
| context, delegate_params->nodes_to_replace->data[i], &tflite_node, |
| ®istration)); |
| if (registration->builtin_code == kTfLiteBuiltinDequantize && |
| context->tensors[tflite_node->inputs->data[0]].type == |
| TfLiteType::kTfLiteFloat16) { |
| // Ignore Fp16 Dequantize nodes. |
| continue; |
| } |
| auto op_parser = NewOperationParser( |
| registration, /*allow_quant_ops=*/quant_conversion_map != nullptr); |
| if (!op_parser) { |
| return absl::UnimplementedError( |
| absl::StrCat("Operation ", registration->builtin_code, "(", |
| registration->custom_name, |
| ") is not supported by TFLite GPU Delegate.")); |
| } |
| operations.push_back(std::move(op_parser)); |
| tflite_nodes.push_back(i); |
| } |
| absl::flat_hash_map<int, Value*> tensor_to_value; |
| std::vector<ValueId> variable_inputs_to_value_id; |
| RETURN_IF_ERROR(PrecreateIOTensors(context, graph, |
| delegate_params->input_tensors, |
| quant_conversion_map, &tensor_to_value)); |
| RETURN_IF_ERROR(PrecreateIOTensors(context, graph, |
| delegate_params->output_tensors, |
| quant_conversion_map, &tensor_to_value)); |
| for (int i = 0; i < operations.size(); ++i) { |
| TfLiteNode* tflite_node; |
| TfLiteRegistration* registration; |
| RETURN_IF_ERROR(GetNodeAndRegistration( |
| context, delegate_params->nodes_to_replace->data[tflite_nodes[i]], |
| &tflite_node, ®istration)); |
| ObjectReader reader(graph, context, tflite_node, &tensor_to_value, |
| quant_conversion_map); |
| const auto status = |
| operations[i]->Parse(tflite_node, registration, graph, &reader); |
| if (!status.ok()) { |
| return absl::InternalError(absl::StrCat( |
| GetOpNameByRegistration(*registration), ": ", status.message())); |
| } |
| |
| absl::flat_hash_map<int, ValueId> new_value_for_variable_input_tensors = |
| operations[i]->GetNewValueIdsForVariableInputNodes(); |
| |
| RETURN_IF_ERROR( |
| CopyVariableTensorOutputs(tflite_node, registration, graph, reader, |
| new_value_for_variable_input_tensors)); |
| } |
| |
| // Variable input tensors expect to be unchanged throughout model execution. |
| // They need to be an output of the graph in order to have them unchanged. |
| for (auto value_id : variable_inputs_to_value_id) { |
| if (!graph->IsGraphOutput(value_id)) { |
| return absl::InvalidArgumentError( |
| absl::StrCat("Variable input tensors must be a graph output. Value ", |
| value_id, " is not a graph output")); |
| } |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status BuildFinalModel( |
| TfLiteContext* context, const TfLiteDelegateParams* delegate_params, |
| GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) { |
| RETURN_IF_ERROR( |
| BuildModel(context, delegate_params, graph, quant_conversion_map)); |
| |
| // Apply general transformations on the graph. |
| NullTransformationReporter reporter; |
| ModelTransformer transformer(graph, &reporter); |
| if (!ApplyModelTransformations(&transformer)) { |
| return absl::InternalError("Graph transformations failed"); |
| } |
| return absl::OkStatus(); |
| } |
| |
| } // namespace gpu |
| } // namespace tflite |