Enable ROI v2 ops: TransformTensorBilinear, TransformLandmarks, ROI2TransformMatrix.
PiperOrigin-RevId: 313470232
Change-Id: Id581304ec9313070369c68bffe2fa12690e45c0b
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index daedc27..061c650 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -2453,15 +2453,15 @@
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "transform_landmarks_v2";
node->operation.type = op_name;
- BHWC output_shape;
+
+ auto output_value = graph->FindOutputs(node->id)[0];
+ output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
+ BHWC output_shape = output_value->tensor.shape;
RETURN_IF_ERROR(
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size,
&(node->operation.attributes), &output_shape));
- auto output_value = graph->FindOutputs(node->id)[0];
-
- output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
return absl::OkStatus();
}