Branch TransformTensor operation to TransformTensorV2.
PiperOrigin-RevId: 308880491
Change-Id: Ia7f2eb04f354882d3b1e75e3c55515715d15ba06
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index eca86fa..e8a899c 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -2316,6 +2316,43 @@
private:
};
+class TransformTensorV2OperationParser : public TFLiteOperationParser {
+ public:
+ absl::Status IsSupported(const TfLiteContext* context,
+ const TfLiteNode* tflite_node,
+ const TfLiteRegistration* registration) final {
+ RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
+ /*runtime_inputs=*/2, /*outputs=*/1));
+ return absl::OkStatus();
+ }
+
+ absl::Status Parse(const TfLiteNode* tflite_node,
+ const TfLiteRegistration* registration,
+ GraphFloat32* graph, ObjectReader* reader) final {
+ Node* node = graph->NewNode();
+ RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
+ RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
+ RETURN_IF_ERROR(reader->AddOutputs(node));
+
+ std::string op_name = "transform_tensor_v2";
+ node->operation.type = op_name;
+ BHWC output_shape;
+ RETURN_IF_ERROR(
+ ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
+ tflite_node->custom_initial_data_size,
+ &(node->operation.attributes), &output_shape));
+
+ auto output_value = graph->FindOutputs(node->id)[0];
+
+ output_value->tensor.shape =
+ BHWC(1, output_shape.h, output_shape.w,
+ graph->FindInputs(node->id)[0]->tensor.shape.c);
+ return absl::OkStatus();
+ }
+
+ private:
+};
+
class TransformLandmarksOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
@@ -2595,6 +2632,9 @@
if (custom_name == "TransformTensor") {
return std::make_unique<TransformTensorOperationParser>();
}
+ if (custom_name == "TransformTensorV2") {
+ return std::make_unique<TransformTensorV2OperationParser>();
+ }
if (custom_name == "TransformLandmarks") {
return std::make_unique<TransformLandmarksOperationParser>();
}