QS8 e2e benchmark for C2 neon microkernels

Change GEMM remainder code for MLA to use if instead of while.

PiperOrigin-RevId: 357759630
diff --git a/bench/qs8-gemm-e2e.cc b/bench/qs8-gemm-e2e.cc
index 6ccff5c..ff7f03d 100644
--- a/bench/qs8-gemm-e2e.cc
+++ b/bench/qs8-gemm-e2e.cc
@@ -172,6 +172,170 @@
   }
 
 #if XNN_ENABLE_FULL_BENCHMARKS
+  static void qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+      1 /* mr */, 8  /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+  static void qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+      1 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+#endif  // XNN_ENABLE_FULL_BENCHMARKS
+
+  static void qs8_gemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+      2 /* mr */, 8  /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+  static void qs8_gemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+      2 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+  static void qs8_gemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+      3 /* mr */, 8  /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+  static void qs8_gemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+      3 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+  static void qs8_gemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+      4 /* mr */, 8  /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+  static void qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+      4 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+#if XNN_ENABLE_FULL_BENCHMARKS
+  static void qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+      1 /* mr */, 8  /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+  static void qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+      1 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+#endif  // XNN_ENABLE_FULL_BENCHMARKS
+
+  static void qs8_gemm_minmax_ukernel_2x8c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mull_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+      2 /* mr */, 8  /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+  static void qs8_gemm_minmax_ukernel_2x16c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mull_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+      2 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+  static void qs8_gemm_minmax_ukernel_3x8c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mull_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+      3 /* mr */, 8  /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+  static void qs8_gemm_minmax_ukernel_3x16c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mull_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+      3 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+  static void qs8_gemm_minmax_ukernel_4x8c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mull_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+      4 /* mr */, 8  /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+  static void qs8_gemm_minmax_ukernel_4x16c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+    GEMMEnd2EndBenchmark(state, model,
+      xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mull_padal_dup,
+      xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+      xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+      4 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+      benchmark::utils::CheckNEON);
+  }
+
+#if XNN_ENABLE_FULL_BENCHMARKS
   static void qs8_gemm_minmax_ukernel_1x8c4__neondot(benchmark::State& state, models::ExecutionPlanFactory model) {
     GEMMEnd2EndBenchmark(state, model,
       xnn_qs8_gemm_minmax_ukernel_1x8c4__neondot,
@@ -370,6 +534,28 @@
   BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
 
 #if XNN_ENABLE_FULL_BENCHMARKS
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup);
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup);
+#endif  // XNN_ENABLE_FULL_BENCHMARKS
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup);
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup);
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup);
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup);
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup);
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup);
+
+#if XNN_ENABLE_FULL_BENCHMARKS
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup);
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup);
+#endif  // XNN_ENABLE_FULL_BENCHMARKS
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x8c2__neon_mull_padal_dup);
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x16c2__neon_mull_padal_dup);
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_3x8c2__neon_mull_padal_dup);
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_3x16c2__neon_mull_padal_dup);
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x8c2__neon_mull_padal_dup);
+  BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x16c2__neon_mull_padal_dup);
+
+#if XNN_ENABLE_FULL_BENCHMARKS
   BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x8__neon_mlal_lane);
   BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x16__neon_mlal_lane);
 #endif  // XNN_ENABLE_FULL_BENCHMARKS
diff --git a/src/qs8-gemm/c2-neon-mull-padal-dup.c.in b/src/qs8-gemm/c2-neon-mull-padal-dup.c.in
index 31e34c5..ba19f77 100644
--- a/src/qs8-gemm/c2-neon-mull-padal-dup.c.in
+++ b/src/qs8-gemm/c2-neon-mull-padal-dup.c.in
@@ -88,7 +88,7 @@
         k -= 16 * sizeof(int8_t);
       }
 
-    while (k >= 8 * sizeof(int8_t)) {
+    ${"if" if MLA else "while"} (k >= 8 * sizeof(int8_t)) {
       $for M in range(MR):
         const int8x8_t va${M} = vld1_s8(a${M}); a${M} += 8;
 
diff --git a/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
index 2c0bf43..43eec29 100644
--- a/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
@@ -136,7 +136,7 @@
       k -= 16 * sizeof(int8_t);
     }
 
-    while (k >= 8 * sizeof(int8_t)) {
+    if (k >= 8 * sizeof(int8_t)) {
       const int8x8_t va0 = vld1_s8(a0); a0 += 8;
 
       const int8x8_t vb0123c0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
diff --git a/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
index f442eef..01eea15 100644
--- a/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
@@ -94,7 +94,7 @@
       k -= 16 * sizeof(int8_t);
     }
 
-    while (k >= 8 * sizeof(int8_t)) {
+    if (k >= 8 * sizeof(int8_t)) {
       const int8x8_t va0 = vld1_s8(a0); a0 += 8;
 
       const int8x8_t vb0123c0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
diff --git a/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
index a38902c..91f2a9d 100644
--- a/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
@@ -196,7 +196,7 @@
       k -= 16 * sizeof(int8_t);
     }
 
-    while (k >= 8 * sizeof(int8_t)) {
+    if (k >= 8 * sizeof(int8_t)) {
       const int8x8_t va0 = vld1_s8(a0); a0 += 8;
       const int8x8_t va1 = vld1_s8(a1); a1 += 8;
 
diff --git a/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
index e08a473..e3d3570 100644
--- a/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
@@ -128,7 +128,7 @@
       k -= 16 * sizeof(int8_t);
     }
 
-    while (k >= 8 * sizeof(int8_t)) {
+    if (k >= 8 * sizeof(int8_t)) {
       const int8x8_t va0 = vld1_s8(a0); a0 += 8;
       const int8x8_t va1 = vld1_s8(a1); a1 += 8;
 
diff --git a/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
index 3f58900..7750eb3 100644
--- a/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
@@ -256,7 +256,7 @@
       k -= 16 * sizeof(int8_t);
     }
 
-    while (k >= 8 * sizeof(int8_t)) {
+    if (k >= 8 * sizeof(int8_t)) {
       const int8x8_t va0 = vld1_s8(a0); a0 += 8;
       const int8x8_t va1 = vld1_s8(a1); a1 += 8;
       const int8x8_t va2 = vld1_s8(a2); a2 += 8;
diff --git a/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
index aa376ba..7f095a9 100644
--- a/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
@@ -162,7 +162,7 @@
       k -= 16 * sizeof(int8_t);
     }
 
-    while (k >= 8 * sizeof(int8_t)) {
+    if (k >= 8 * sizeof(int8_t)) {
       const int8x8_t va0 = vld1_s8(a0); a0 += 8;
       const int8x8_t va1 = vld1_s8(a1); a1 += 8;
       const int8x8_t va2 = vld1_s8(a2); a2 += 8;
diff --git a/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
index 41f4a21..19d1d86 100644
--- a/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
@@ -316,7 +316,7 @@
       k -= 16 * sizeof(int8_t);
     }
 
-    while (k >= 8 * sizeof(int8_t)) {
+    if (k >= 8 * sizeof(int8_t)) {
       const int8x8_t va0 = vld1_s8(a0); a0 += 8;
       const int8x8_t va1 = vld1_s8(a1); a1 += 8;
       const int8x8_t va2 = vld1_s8(a2); a2 += 8;
diff --git a/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
index 10afbe4..b7403d4 100644
--- a/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
@@ -196,7 +196,7 @@
       k -= 16 * sizeof(int8_t);
     }
 
-    while (k >= 8 * sizeof(int8_t)) {
+    if (k >= 8 * sizeof(int8_t)) {
       const int8x8_t va0 = vld1_s8(a0); a0 += 8;
       const int8x8_t va1 = vld1_s8(a1); a1 += 8;
       const int8x8_t va2 = vld1_s8(a2); a2 += 8;