Move some custom ops back to model builder.
PiperOrigin-RevId: 365116594
Change-Id: Ieea3ad01cc6518c5ccd1cb971ef941864fc9db8b
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index af152a5..a9f9bbf 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -82,6 +82,16 @@
return absl::OkStatus();
}
+template <typename ParamsT>
+absl::Status RetrieveCustomInitialData(const TfLiteNode* tflite_node,
+ const ParamsT** tf_options) {
+ *tf_options = static_cast<const ParamsT*>(tflite_node->custom_initial_data);
+ if (!*tf_options) {
+ return absl::InternalError("Unable to retrieve custom_initial_data.");
+ }
+ return absl::OkStatus();
+}
+
absl::Status CheckDilation(int dilation_h, int dilation_w) {
if (dilation_h <= 0 || dilation_w <= 0) {
return absl::InvalidArgumentError(absl::StrCat(
@@ -1360,10 +1370,17 @@
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
const TfLitePoolParams* tf_options;
- RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
- RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
- /*runtime_inputs=*/1,
- /*outputs=*/1));
+ auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
+ if (status.ok()) { // custom case with indices as a second output
+ RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
+ /*runtime_inputs=*/1,
+ /*outputs=*/2));
+ } else { // common pooling with 1 output
+ RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
+ RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
+ /*runtime_inputs=*/1,
+ /*outputs=*/1));
+ }
RETURN_IF_ERROR(CheckKernelsAndStrides(
tf_options->filter_height, tf_options->filter_width,
tf_options->stride_height, tf_options->stride_width));
@@ -1386,12 +1403,28 @@
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
+ // Check whether there are custom options encoded. It happens if operation
+ // is MaxPoolingWithArgmax2D. There is no way to read
+ // tflite_node->builtin_code, so, simply check whether custom data is
+ // available.
const TfLitePoolParams* tf_options;
- RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
+ if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) {
+ RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
+ }
RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
+ // Second output is optional. It is not required, it but must be added after
+ // MaybeAddFusedActivation function is called
+ reader->AddOutput(node, 1).IgnoreError();
- attr.output_indices = false;
+ // First output is the result of pooling operation, while second output is
+ // indices used for pooling.
+ auto outputs = graph->FindOutputs(node->id);
+ attr.output_indices = outputs.size() == 2;
+ if (attr.output_indices) {
+ // Fix data type for output indices. In the model it is set as float32.
+ outputs[1]->tensor.type = DataType::INT32;
+ }
RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
node->operation.attributes = attr;
return absl::OkStatus();
@@ -2172,6 +2205,45 @@
}
};
+// Custom op version of TRANSPOSE_CONV.
+class TransposeConvCustomOperationParser : public TFLiteOperationParser {
+ public:
+ absl::Status IsSupported(const TfLiteContext* context,
+ const TfLiteNode* tflite_node,
+ const TfLiteRegistration* registration) final {
+ RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
+ const TfLiteTransposeConvParams* tf_options;
+ RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
+ RETURN_IF_ERROR(
+ CheckStrides(tf_options->stride_height, tf_options->stride_width));
+ return absl::OkStatus();
+ }
+
+ absl::Status Parse(const TfLiteNode* tflite_node,
+ const TfLiteRegistration* registration,
+ GraphFloat32* graph, ObjectReader* reader) final {
+ auto* node = graph->NewNode();
+ node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
+ RETURN_IF_ERROR(reader->AddInput(node, 0));
+ RETURN_IF_ERROR(reader->AddOutputs(node));
+
+ const TfLiteTransposeConvParams* tf_options;
+ auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
+
+ ConvolutionTransposedAttributes attr;
+ attr.stride = status.ok()
+ ? HW(tf_options->stride_height, tf_options->stride_width)
+ : HW(1, 1);
+ RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
+ reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
+
+ UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
+ graph->FindInputs(node->id)[0]->tensor.shape, &attr);
+ node->operation.attributes = std::move(attr);
+ return absl::OkStatus();
+ }
+};
+
class TransposeOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
@@ -2223,6 +2295,47 @@
}
};
+class Unpooling2DOperationParser : public TFLiteOperationParser {
+ public:
+ absl::Status IsSupported(const TfLiteContext* context,
+ const TfLiteNode* tflite_node,
+ const TfLiteRegistration* registration) final {
+ RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
+ /*runtime_inputs=*/2, /*outputs=*/1));
+ const TfLitePoolParams* tf_options;
+ RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
+ RETURN_IF_ERROR(CheckKernelsAndStrides(
+ tf_options->filter_height, tf_options->filter_width,
+ tf_options->stride_height, tf_options->stride_width));
+ return absl::OkStatus();
+ }
+
+ absl::Status Parse(const TfLiteNode* tflite_node,
+ const TfLiteRegistration* registration,
+ GraphFloat32* graph, ObjectReader* reader) final {
+ Node* node = graph->NewNode();
+ node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D);
+ RETURN_IF_ERROR(reader->AddInput(node, 0));
+ RETURN_IF_ERROR(reader->AddInput(node, 1));
+ RETURN_IF_ERROR(reader->AddOutputs(node));
+ auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
+ MaxUnpooling2DAttributes attr;
+
+ const TfLitePoolParams* tf_options;
+ RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
+
+ attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
+ attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
+ UpdatePadding(tf_options->padding, input_shape, &attr);
+
+ node->operation.attributes = attr;
+
+ auto output_value = graph->FindOutputs(node->id)[0];
+ output_value->tensor.shape = CalculateOutputShape(input_shape, attr);
+ return absl::OkStatus();
+ }
+};
+
// TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
class BatchToSpaceOperationParser : public TFLiteOperationParser {
public:
@@ -2496,8 +2609,19 @@
return std::make_unique<TransposeOperationParser>();
case kTfLiteBuiltinTransposeConv:
return std::make_unique<TransposeConvBuiltinOperationParser>();
- case kTfLiteBuiltinCustom:
+ case kTfLiteBuiltinCustom: {
+ const absl::string_view custom_name = registration->custom_name;
+ if (custom_name == "Convolution2DTransposeBias") {
+ return std::make_unique<TransposeConvCustomOperationParser>();
+ }
+ if (custom_name == "MaxPoolingWithArgmax2D") {
+ return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
+ }
+ if (custom_name == "MaxUnpooling2D") {
+ return std::make_unique<Unpooling2DOperationParser>();
+ }
return NewCustomOperationParser(registration->custom_name);
+ }
}
return std::make_unique<UnsupportedOperationParser>();
}