QS8 e2e benchmark for C2 neon microkernels
Change GEMM remainder code for MLA to use if instead of while.
PiperOrigin-RevId: 357759630
diff --git a/bench/qs8-gemm-e2e.cc b/bench/qs8-gemm-e2e.cc
index 6ccff5c..ff7f03d 100644
--- a/bench/qs8-gemm-e2e.cc
+++ b/bench/qs8-gemm-e2e.cc
@@ -172,6 +172,170 @@
}
#if XNN_ENABLE_FULL_BENCHMARKS
+ static void qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+ 1 /* mr */, 8 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+ 1 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+#endif // XNN_ENABLE_FULL_BENCHMARKS
+
+ static void qs8_gemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+ 2 /* mr */, 8 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+ 2 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+ 3 /* mr */, 8 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+ 3 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup,
+ 4 /* mr */, 8 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup,
+ 4 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+#if XNN_ENABLE_FULL_BENCHMARKS
+ static void qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+ 1 /* mr */, 8 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+ 1 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+#endif // XNN_ENABLE_FULL_BENCHMARKS
+
+ static void qs8_gemm_minmax_ukernel_2x8c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mull_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+ 2 /* mr */, 8 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_2x16c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mull_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+ 2 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_3x8c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mull_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+ 3 /* mr */, 8 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_3x16c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mull_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+ 3 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_4x8c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mull_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup,
+ 4 /* mr */, 8 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_4x16c2__neon_mull_padal_dup(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mull_padal_dup,
+ xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+ xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup,
+ 4 /* mr */, 16 /* nr */, 1 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+#if XNN_ENABLE_FULL_BENCHMARKS
static void qs8_gemm_minmax_ukernel_1x8c4__neondot(benchmark::State& state, models::ExecutionPlanFactory model) {
GEMMEnd2EndBenchmark(state, model,
xnn_qs8_gemm_minmax_ukernel_1x8c4__neondot,
@@ -370,6 +534,28 @@
BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
#if XNN_ENABLE_FULL_BENCHMARKS
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup);
+#endif // XNN_ENABLE_FULL_BENCHMARKS
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup);
+
+#if XNN_ENABLE_FULL_BENCHMARKS
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup);
+#endif // XNN_ENABLE_FULL_BENCHMARKS
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x8c2__neon_mull_padal_dup);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x16c2__neon_mull_padal_dup);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_3x8c2__neon_mull_padal_dup);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_3x16c2__neon_mull_padal_dup);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x8c2__neon_mull_padal_dup);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x16c2__neon_mull_padal_dup);
+
+#if XNN_ENABLE_FULL_BENCHMARKS
BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x8__neon_mlal_lane);
BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x16__neon_mlal_lane);
#endif // XNN_ENABLE_FULL_BENCHMARKS
diff --git a/src/qs8-gemm/c2-neon-mull-padal-dup.c.in b/src/qs8-gemm/c2-neon-mull-padal-dup.c.in
index 31e34c5..ba19f77 100644
--- a/src/qs8-gemm/c2-neon-mull-padal-dup.c.in
+++ b/src/qs8-gemm/c2-neon-mull-padal-dup.c.in
@@ -88,7 +88,7 @@
k -= 16 * sizeof(int8_t);
}
- while (k >= 8 * sizeof(int8_t)) {
+ ${"if" if MLA else "while"} (k >= 8 * sizeof(int8_t)) {
$for M in range(MR):
const int8x8_t va${M} = vld1_s8(a${M}); a${M} += 8;
diff --git a/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
index 2c0bf43..43eec29 100644
--- a/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
@@ -136,7 +136,7 @@
k -= 16 * sizeof(int8_t);
}
- while (k >= 8 * sizeof(int8_t)) {
+ if (k >= 8 * sizeof(int8_t)) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t vb0123c0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
diff --git a/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
index f442eef..01eea15 100644
--- a/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
@@ -94,7 +94,7 @@
k -= 16 * sizeof(int8_t);
}
- while (k >= 8 * sizeof(int8_t)) {
+ if (k >= 8 * sizeof(int8_t)) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t vb0123c0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
diff --git a/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
index a38902c..91f2a9d 100644
--- a/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
@@ -196,7 +196,7 @@
k -= 16 * sizeof(int8_t);
}
- while (k >= 8 * sizeof(int8_t)) {
+ if (k >= 8 * sizeof(int8_t)) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t va1 = vld1_s8(a1); a1 += 8;
diff --git a/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
index e08a473..e3d3570 100644
--- a/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
@@ -128,7 +128,7 @@
k -= 16 * sizeof(int8_t);
}
- while (k >= 8 * sizeof(int8_t)) {
+ if (k >= 8 * sizeof(int8_t)) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t va1 = vld1_s8(a1); a1 += 8;
diff --git a/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
index 3f58900..7750eb3 100644
--- a/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
@@ -256,7 +256,7 @@
k -= 16 * sizeof(int8_t);
}
- while (k >= 8 * sizeof(int8_t)) {
+ if (k >= 8 * sizeof(int8_t)) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t va1 = vld1_s8(a1); a1 += 8;
const int8x8_t va2 = vld1_s8(a2); a2 += 8;
diff --git a/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
index aa376ba..7f095a9 100644
--- a/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
@@ -162,7 +162,7 @@
k -= 16 * sizeof(int8_t);
}
- while (k >= 8 * sizeof(int8_t)) {
+ if (k >= 8 * sizeof(int8_t)) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t va1 = vld1_s8(a1); a1 += 8;
const int8x8_t va2 = vld1_s8(a2); a2 += 8;
diff --git a/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
index 41f4a21..19d1d86 100644
--- a/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
@@ -316,7 +316,7 @@
k -= 16 * sizeof(int8_t);
}
- while (k >= 8 * sizeof(int8_t)) {
+ if (k >= 8 * sizeof(int8_t)) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t va1 = vld1_s8(a1); a1 += 8;
const int8x8_t va2 = vld1_s8(a2); a2 += 8;
diff --git a/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
index 10afbe4..b7403d4 100644
--- a/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
@@ -196,7 +196,7 @@
k -= 16 * sizeof(int8_t);
}
- while (k >= 8 * sizeof(int8_t)) {
+ if (k >= 8 * sizeof(int8_t)) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t va1 = vld1_s8(a1); a1 += 8;
const int8x8_t va2 = vld1_s8(a2); a2 += 8;