Use efficient ARM fp16 dot product for gemm_transa_ general case (#127451)

Summary: This doesn't change the overall gemm algorithm away from repeated dot products, just uses our efficient fp16 dot product developed for the gemv case. It seems to improve performance for every prompt length I tested.

Test Plan: Use https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py , edited to test only the trans_b (really gemm_transa_) case for the sizes outlined in the output.

Before:
```
Matrix-vector:
m=8, n=128, k=1
====================
trans_b  torch.float32    1.05 usec
trans_b  torch.float16    0.97 usec
trans_b torch.bfloat16    1.06 usec
m=128, n=8, k=1
====================
trans_b  torch.float32    0.80 usec
trans_b  torch.float16    0.97 usec
trans_b torch.bfloat16    1.00 usec
m=4096, n=4096, k=1
====================
trans_b  torch.float32 2160.75 usec
trans_b  torch.float16  659.77 usec
trans_b torch.bfloat16 3800.13 usec
m=11008, n=4096, k=1
====================
trans_b  torch.float32 6343.68 usec
trans_b  torch.float16 1789.42 usec
trans_b torch.bfloat16 10098.34 usec
m=4096, n=11008, k=1
====================
trans_b  torch.float32 6217.20 usec
trans_b  torch.float16 1874.47 usec
trans_b torch.bfloat16 10490.30 usec
m=32000, n=4096, k=1
====================
trans_b  torch.float32 17934.45 usec
trans_b  torch.float16 5323.81 usec
trans_b torch.bfloat16 29320.80 usec

Matrix-matrix (prompt len 4:
m=8, n=128, k=4
====================
trans_b  torch.float32    2.40 usec
trans_b  torch.float16    1.22 usec
trans_b torch.bfloat16    1.22 usec
m=128, n=8, k=4
====================
trans_b  torch.float32    1.52 usec
trans_b  torch.float16    1.33 usec
trans_b torch.bfloat16    1.77 usec
m=4096, n=4096, k=4
====================
trans_b  torch.float32 4317.09 usec
trans_b  torch.float16 15541.04 usec
trans_b torch.bfloat16 15032.29 usec
m=11008, n=4096, k=4
====================
trans_b  torch.float32 6191.19 usec
trans_b  torch.float16 40436.29 usec
trans_b torch.bfloat16 40626.93 usec
m=4096, n=11008, k=4
====================
trans_b  torch.float32 6049.22 usec
trans_b  torch.float16 42367.16 usec
trans_b torch.bfloat16 42482.43 usec
m=32000, n=4096, k=4
====================
trans_b  torch.float32 17611.36 usec
trans_b  torch.float16 117368.54 usec
trans_b torch.bfloat16 116958.85 usec

Matrix-matrix (prompt len 8:
m=8, n=128, k=8
====================
trans_b  torch.float32    1.04 usec
trans_b  torch.float16    1.71 usec
trans_b torch.bfloat16    1.74 usec
m=128, n=8, k=8
====================
trans_b  torch.float32    2.10 usec
trans_b  torch.float16    2.01 usec
trans_b torch.bfloat16    2.91 usec
m=4096, n=4096, k=8
====================
trans_b  torch.float32 2456.23 usec
trans_b  torch.float16 30112.76 usec
trans_b torch.bfloat16 29941.58 usec
m=11008, n=4096, k=8
====================
trans_b  torch.float32 6236.12 usec
trans_b  torch.float16 80361.22 usec
trans_b torch.bfloat16 80466.64 usec
m=4096, n=11008, k=8
====================
trans_b  torch.float32 6236.10 usec
trans_b  torch.float16 82990.74 usec
trans_b torch.bfloat16 83899.80 usec
m=32000, n=4096, k=8
====================
trans_b  torch.float32 17606.43 usec
trans_b  torch.float16 234397.38 usec
trans_b torch.bfloat16 237057.29 usec

Matrix-matrix (prompt len 16:
m=8, n=128, k=16
====================
trans_b  torch.float32    1.31 usec
trans_b  torch.float16    2.67 usec
trans_b torch.bfloat16    2.72 usec
m=128, n=8, k=16
====================
trans_b  torch.float32    1.66 usec
trans_b  torch.float16    3.36 usec
trans_b torch.bfloat16    5.18 usec
m=4096, n=4096, k=16
====================
trans_b  torch.float32 2504.24 usec
trans_b  torch.float16 60896.53 usec
trans_b torch.bfloat16 59852.49 usec
m=11008, n=4096, k=16
====================
trans_b  torch.float32 6407.11 usec
trans_b  torch.float16 163294.92 usec
trans_b torch.bfloat16 161199.10 usec
m=4096, n=11008, k=16
====================
trans_b  torch.float32 6132.30 usec
trans_b  torch.float16 167244.77 usec
trans_b torch.bfloat16 170064.35 usec
m=32000, n=4096, k=16
====================
trans_b  torch.float32 17635.56 usec
trans_b  torch.float16 475020.00 usec
trans_b torch.bfloat16 476332.29 usec

Matrix-matrix (prompt len 32:
m=8, n=128, k=32
====================
trans_b  torch.float32    1.40 usec
trans_b  torch.float16    4.67 usec
trans_b torch.bfloat16    4.80 usec
m=128, n=8, k=32
====================
trans_b  torch.float32    1.24 usec
trans_b  torch.float16    6.10 usec
trans_b torch.bfloat16   10.03 usec
m=4096, n=4096, k=32
====================
trans_b  torch.float32 2660.63 usec
trans_b  torch.float16 122436.04 usec
trans_b torch.bfloat16 121687.96 usec
m=11008, n=4096, k=32
====================
trans_b  torch.float32 6405.60 usec
trans_b  torch.float16 324708.42 usec
trans_b torch.bfloat16 324866.67 usec
m=4096, n=11008, k=32
====================
trans_b  torch.float32 6566.74 usec
trans_b  torch.float16 330801.04 usec
trans_b torch.bfloat16 332561.79 usec
m=32000, n=4096, k=32
====================
trans_b  torch.float32 18610.84 usec
trans_b  torch.float16 944578.75 usec
trans_b torch.bfloat16 940674.33 usec

Matrix-matrix (prompt len 128:
m=8, n=128, k=128
====================
trans_b  torch.float32    2.48 usec
trans_b  torch.float16   16.43 usec
trans_b torch.bfloat16   17.11 usec
m=128, n=8, k=128
====================
trans_b  torch.float32    1.83 usec
trans_b  torch.float16   22.31 usec
trans_b torch.bfloat16   37.00 usec
m=4096, n=4096, k=128
====================
trans_b  torch.float32 4806.59 usec
trans_b  torch.float16 485338.83 usec
trans_b torch.bfloat16 478835.08 usec
m=11008, n=4096, k=128
====================
trans_b  torch.float32 12109.51 usec
trans_b  torch.float16 1300928.58 usec
trans_b torch.bfloat16 1293181.63 usec
m=4096, n=11008, k=128
====================
trans_b  torch.float32 11223.70 usec
trans_b  torch.float16 1326119.92 usec
trans_b torch.bfloat16 1330395.12 usec
m=32000, n=4096, k=128
====================
trans_b  torch.float32 33485.34 usec
trans_b  torch.float16 3869227.17 usec
trans_b torch.bfloat16 3792905.00 usec
```

After:
```
Matrix-vector:
m=8, n=128, k=1
====================
trans_b  torch.float32    0.75 usec
trans_b  torch.float16    0.71 usec
trans_b torch.bfloat16    0.81 usec
m=128, n=8, k=1
====================
trans_b  torch.float32    0.75 usec
trans_b  torch.float16    0.93 usec
trans_b torch.bfloat16    0.98 usec
m=4096, n=4096, k=1
====================
trans_b  torch.float32 2194.31 usec
trans_b  torch.float16  661.27 usec
trans_b torch.bfloat16 3758.42 usec
m=11008, n=4096, k=1
====================
trans_b  torch.float32 5792.04 usec
trans_b  torch.float16 1789.98 usec
trans_b torch.bfloat16 10120.67 usec
m=4096, n=11008, k=1
====================
trans_b  torch.float32 6101.22 usec
trans_b  torch.float16 1927.34 usec
trans_b torch.bfloat16 10469.47 usec
m=32000, n=4096, k=1
====================
trans_b  torch.float32 18353.20 usec
trans_b  torch.float16 5161.06 usec
trans_b torch.bfloat16 29601.69 usec

Matrix-matrix (prompt len 4:
m=8, n=128, k=4
====================
trans_b  torch.float32    2.14 usec
trans_b  torch.float16    0.85 usec
trans_b torch.bfloat16    1.19 usec
m=128, n=8, k=4
====================
trans_b  torch.float32    1.47 usec
trans_b  torch.float16    1.85 usec
trans_b torch.bfloat16    1.75 usec
m=4096, n=4096, k=4
====================
trans_b  torch.float32 4416.40 usec
trans_b  torch.float16 2688.36 usec
trans_b torch.bfloat16 14987.33 usec
m=11008, n=4096, k=4
====================
trans_b  torch.float32 6140.24 usec
trans_b  torch.float16 7467.26 usec
trans_b torch.bfloat16 40295.52 usec
m=4096, n=11008, k=4
====================
trans_b  torch.float32 6143.10 usec
trans_b  torch.float16 7298.04 usec
trans_b torch.bfloat16 41393.43 usec
m=32000, n=4096, k=4
====================
trans_b  torch.float32 17650.72 usec
trans_b  torch.float16 21346.63 usec
trans_b torch.bfloat16 116849.98 usec

Matrix-matrix (prompt len 8:
m=8, n=128, k=8
====================
trans_b  torch.float32    1.05 usec
trans_b  torch.float16    1.03 usec
trans_b torch.bfloat16    1.69 usec
m=128, n=8, k=8
====================
trans_b  torch.float32    2.05 usec
trans_b  torch.float16    3.08 usec
trans_b torch.bfloat16    2.95 usec
m=4096, n=4096, k=8
====================
trans_b  torch.float32 2323.99 usec
trans_b  torch.float16 5265.45 usec
trans_b torch.bfloat16 29942.40 usec
m=11008, n=4096, k=8
====================
trans_b  torch.float32 6202.01 usec
trans_b  torch.float16 14677.90 usec
trans_b torch.bfloat16 80625.18 usec
m=4096, n=11008, k=8
====================
trans_b  torch.float32 6112.05 usec
trans_b  torch.float16 14340.52 usec
trans_b torch.bfloat16 82799.99 usec
m=32000, n=4096, k=8
====================
trans_b  torch.float32 17650.65 usec
trans_b  torch.float16 42551.43 usec
trans_b torch.bfloat16 236081.08 usec

Matrix-matrix (prompt len 16:
m=8, n=128, k=16
====================
trans_b  torch.float32    1.26 usec
trans_b  torch.float16    1.34 usec
trans_b torch.bfloat16    2.69 usec
m=128, n=8, k=16
====================
trans_b  torch.float32    1.60 usec
trans_b  torch.float16    5.81 usec
trans_b torch.bfloat16    5.34 usec
m=4096, n=4096, k=16
====================
trans_b  torch.float32 2328.05 usec
trans_b  torch.float16 10526.58 usec
trans_b torch.bfloat16 60028.28 usec
m=11008, n=4096, k=16
====================
trans_b  torch.float32 6243.35 usec
trans_b  torch.float16 28505.08 usec
trans_b torch.bfloat16 163670.15 usec
m=4096, n=11008, k=16
====================
trans_b  torch.float32 5870.11 usec
trans_b  torch.float16 28597.89 usec
trans_b torch.bfloat16 165404.88 usec
m=32000, n=4096, k=16
====================
trans_b  torch.float32 17746.27 usec
trans_b  torch.float16 83393.87 usec
trans_b torch.bfloat16 472313.13 usec

Matrix-matrix (prompt len 32:
m=8, n=128, k=32
====================
trans_b  torch.float32    1.35 usec
trans_b  torch.float16    2.01 usec
trans_b torch.bfloat16    4.68 usec
m=128, n=8, k=32
====================
trans_b  torch.float32    1.19 usec
trans_b  torch.float16   10.98 usec
trans_b torch.bfloat16   10.13 usec
m=4096, n=4096, k=32
====================
trans_b  torch.float32 2525.29 usec
trans_b  torch.float16 23106.71 usec
trans_b torch.bfloat16 122987.04 usec
m=11008, n=4096, k=32
====================
trans_b  torch.float32 6131.34 usec
trans_b  torch.float16 57537.41 usec
trans_b torch.bfloat16 327825.00 usec
m=4096, n=11008, k=32
====================
trans_b  torch.float32 6395.01 usec
trans_b  torch.float16 57456.33 usec
trans_b torch.bfloat16 331325.58 usec
m=32000, n=4096, k=32
====================
trans_b  torch.float32 19078.68 usec
trans_b  torch.float16 167735.08 usec
trans_b torch.bfloat16 975736.88 usec

Matrix-matrix (prompt len 128:
m=8, n=128, k=128
====================
trans_b  torch.float32    2.40 usec
trans_b  torch.float16    6.07 usec
trans_b torch.bfloat16   16.83 usec
m=128, n=8, k=128
====================
trans_b  torch.float32    1.78 usec
trans_b  torch.float16   40.35 usec
trans_b torch.bfloat16   37.21 usec
m=4096, n=4096, k=128
====================
trans_b  torch.float32 4827.60 usec
trans_b  torch.float16 84341.24 usec
trans_b torch.bfloat16 478917.75 usec
m=11008, n=4096, k=128
====================
trans_b  torch.float32 11879.96 usec
trans_b  torch.float16 226484.33 usec
trans_b torch.bfloat16 1289465.50 usec
m=4096, n=11008, k=128
====================
trans_b  torch.float32 10707.75 usec
trans_b  torch.float16 229200.58 usec
trans_b torch.bfloat16 1327416.67 usec
m=32000, n=4096, k=128
====================
trans_b  torch.float32 33306.32 usec
trans_b  torch.float16 662898.21 usec
trans_b torch.bfloat16 3815866.63 usec
```

torch.float16 performance seems to be improved for all except the
m=128, n=8, k=128 case, where it is roughly neutral. This case
motivated the addition of the "first-tier tail fixup" in the dot
kernel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127451
Approved by: https://github.com/malfet
ghstack dependencies: #127435
diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp
index df0512f..930b2ae 100644
--- a/aten/src/ATen/native/BlasKernel.cpp
+++ b/aten/src/ATen/native/BlasKernel.cpp
@@ -344,6 +344,10 @@
 #endif
 }
 
+static inline float32x4_t f32_fma_f16(float32x4_t a, float16x4_t b, float16x4_t c) {
+  return f32_fma_low_f16(a, vcombine_f16(b, vdup_n_f16(0)), vcombine_f16(c, vdup_n_f16(0)));
+}
+
 // The below reduce overload and fp16_dot_with_fp32_arith are adapted
 // from llama.cpp's ggml_vec_dot_f32 and surrounding utility
 // functions. See NOTE [ GGML Copyright Notice ] above for the
@@ -375,7 +379,7 @@
   return vaddvq_f32(x[0]);
 }
 
-static float fp16_dot_with_fp32_arith(const float16_t* x, const float16_t* a, int len) {
+float fp16_dot_with_fp32_arith(const float16_t* x, const float16_t* a, int64_t len) {
   float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
   const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
   for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) {
@@ -392,7 +396,21 @@
   }
   auto reducedSum = reduce(sum);
 
-  for (int j = len_aligned; j < len; ++j) {
+  // First-tier tail fixup: make sure we handle workloads that can
+  // benefit from vectorization, but don't fit into our fully unrolled
+  // loop above.
+  float32x4_t tailSum = vdupq_n_f32(0);
+  const auto len_aligned_4 = len & ~3;
+  for (int j = len_aligned; j < len_aligned_4; j += 4) {
+    const auto temp_x = vld1_f16(x + j);
+    const auto temp_a = vld1_f16(a + j);
+    tailSum = f32_fma_f16(tailSum, temp_x, temp_a);
+  }
+  auto reducedTail = vpaddq_f32(tailSum, tailSum);
+  reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0);
+
+  // Second-tier tail fixup: handle all workloads.
+  for (int j = len_aligned_4; j < len; ++j) {
     reducedSum += x[j] * a[j];
   }
   return reducedSum;
diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp
index 587809e..387d684 100644
--- a/aten/src/ATen/native/cpu/BlasKernel.cpp
+++ b/aten/src/ATen/native/cpu/BlasKernel.cpp
@@ -33,6 +33,11 @@
     const float beta,
     float16_t* y,
     const int incy);
+
+float fp16_dot_with_fp32_arith(
+  const float16_t* x,
+  const float16_t* a,
+  int64_t len);
 }
 #endif
 
@@ -308,18 +313,20 @@
 }
 
 
-inline float32x4_t load_as_float32x4(const Half* ptr) {
-    return vcvt_f32_f16(vld1_f16(reinterpret_cast<const float16_t *>(ptr)));
-}
-
 inline float32x4_t load_as_float32x4(const BFloat16* ptr) {
   int32x4_t shift = vdupq_n_s32(16);
   uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast<const uint16_t *>(ptr)));
   return vreinterpretq_f32_u32(vshlq_u32(as_int, shift));
 }
 
-template<typename T>
-static float compute_dot(const T* a, const T* b, int64_t l) {
+static float compute_dot(const at::Half* a, const at::Half* b, int64_t len) {
+  return at::native::blas_impl::fp16_dot_with_fp32_arith(
+    reinterpret_cast<const float16_t*>(a),
+    reinterpret_cast<const float16_t*>(b),
+    len);
+}
+
+static float compute_dot(const at::BFloat16* a, const at::BFloat16* b, int64_t l) {
     if ((l&3) != 0) {
       return sum(l, [&](int64_t i) -> float {
         return float(a[i]) * float(b[i]);