Fall back on xtensa reference when filter does not fit in input.
diff --git a/tensorflow/lite/micro/kernels/xtensa/conv.cc b/tensorflow/lite/micro/kernels/xtensa/conv.cc
index 59096c4..7d18411 100644
--- a/tensorflow/lite/micro/kernels/xtensa/conv.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/conv.cc
@@ -299,9 +299,13 @@
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias, TfLiteEvalTensor* output,
TfLiteEvalTensor* im2col) {
+ const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input);
+ const RuntimeShape& filter_shape = tflite::micro::GetTensorShape(filter);
/* Dilation is currently not supported on HiFi 4 NN Library */
if ((params.dilation_width_factor == 1) &&
- (params.dilation_height_factor == 1)) {
+ (params.dilation_height_factor == 1) &&
+ input_shape.Dims(1) >= filter_shape.Dims(1) &&
+ input_shape.Dims(2) >= filter_shape.Dims(2)) {
const int32_t input_offset = -data.reference_op_data.input_zero_point;
const int32_t output_offset = data.reference_op_data.output_zero_point;
const int stride_width = params.stride_width;
@@ -313,8 +317,6 @@
const int32_t output_activation_max =
data.reference_op_data.output_activation_max;
- const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input);
- const RuntimeShape& filter_shape = tflite::micro::GetTensorShape(filter);
const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_patch.patch b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_patch.patch
index 5b2c325..cad2381 100644
--- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_patch.patch
+++ b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_patch.patch
@@ -32,15 +32,3 @@
if(inp_data_format == 0)
{
-diff --git a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_sym8sxasym8s.c b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_sym8sxasym8s.c
-index b16b9fc..38e69d3 100644
---- a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_sym8sxasym8s.c
-+++ b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_sym8sxasym8s.c
-@@ -198,7 +198,6 @@ WORD32 xa_nn_conv2d_std_per_chan_sym8sxasym8s(
- XA_NNLIB_ARG_CHK_COND((input_channels <= 0), -1);
- XA_NNLIB_ARG_CHK_COND((kernel_height <= 0 || kernel_width <= 0), -1);
- XA_NNLIB_ARG_CHK_COND((kernel_height > input_height), -1);
-- XA_NNLIB_ARG_CHK_COND((kernel_width > input_width), -1);
- XA_NNLIB_ARG_CHK_COND((out_channels <= 0), -1);
- XA_NNLIB_ARG_CHK_COND((y_stride <= 0 || x_stride <= 0), -1);
- XA_NNLIB_ARG_CHK_COND((y_padding < 0 || x_padding < 0), -1);