Add NNAPI delegate support for fused HardSwish
PiperOrigin-RevId: 313217072
Change-Id: I492a1d6c7b2b5968a29a24b0cef1c82e15898dad
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index b396780..fd6703b 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -660,8 +660,10 @@
// Lower hardswish according to the following equation:
// hard_swish[x] = x (ReLU6(x + 3)) / 6 == x * (Relu_N1_to_1(x/3) * 3 + 3) / 6
// = 0.5x * Relu_N1_to_1(x/3) + 0.5x
- TfLiteStatus AddHardSwish(int lite_input_index, int lite_output_index,
- bool need_int8_conversion, int lite_node_index) {
+ TfLiteStatus TransformHardSwishIntoSupportedOps(int lite_input_index,
+ int lite_output_index,
+ bool need_int8_conversion,
+ int lite_node_index) {
const TfLiteTensor& tensor = context_->tensors[lite_input_index];
float input_scale = tensor.params.scale;
int input_zero_point = tensor.params.zero_point;
@@ -2425,6 +2427,9 @@
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
*nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED;
} break;
+ case kTfLiteBuiltinHardSwish: {
+ *nn_op_type = ANEURALNETWORKS_HARD_SWISH;
+ } break;
case kTfLiteBuiltinSoftmax: {
auto builtin = reinterpret_cast<TfLiteSoftmaxParams*>(
mapping_args.node->builtin_data);
@@ -3635,10 +3640,14 @@
input_tensor_flags |= NN_TENSOR_FLAG_SCALAR_AS_TENSOR;
}
- // h_swish will be lowered into supported NNAPI operations.
- if (reg->builtin_code == kTfLiteBuiltinHardSwish) {
- builder.AddHardSwish(node->inputs->data[0], node->outputs->data[0],
- need_int8_conversion, node_index);
+ // On SDK level less than 30, h_swish will be lowered into supported NNAPI
+ // operations. Since SDK level 30, h_swish is supported as a single
+ // operation.
+ if (reg->builtin_code == kTfLiteBuiltinHardSwish &&
+ nnapi_->android_sdk_version < kMinSdkVersionForNNAPI13) {
+ builder.TransformHardSwishIntoSupportedOps(
+ node->inputs->data[0], node->outputs->data[0], need_int8_conversion,
+ node_index);
continue;
}
// Map inputs to NN API tensor indices.