Optimized packing code path for row-major 8bit inputs for the kNeon path. Written in intrinsics to handle 3 cases at once:
ARM64, KernelCols==4
ARM32, KernelCols==4 (LHS)
ARM32, KernelCols==2 (RHS)

PiperOrigin-RevId: 322921106
diff --git a/ruy/pack_arm.cc b/ruy/pack_arm.cc
index 1dcac5c..f29b214 100644
--- a/ruy/pack_arm.cc
+++ b/ruy/pack_arm.cc
@@ -23,6 +23,10 @@
 #include "ruy/platform.h"
 #include "ruy/profiler/instrumentation.h"
 
+#if RUY_PLATFORM_NEON
+#include <arm_neon.h>
+#endif
+
 namespace ruy {
 
 #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
@@ -2216,4 +2220,263 @@
 }
 #endif  // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
 
+#if RUY_PLATFORM_NEON
+
+namespace {
+// transpose_*bit_vals are wrappers around ARM TRN1 instructions, allowing
+// to use these instructions like we would in assembly --- this is one instance
+// where assembly is more idiomatic than intrinsics.
+//
+// The way that TRN1 is exposed by vtrn_* intrinsics makes its usage very
+// cumbersome. The issue is that transposing grouped of values has been exposed
+// only as transposing values of a wider type, so this requires many
+// vreinterpret's, and to make it worse, vtrn_* return NEON array types like
+// int8x8x2_t for which vreinterpret's are not defined!
+void transpose_8bit_vals(int8x8_t& a, int8x8_t& b) {
+  int8x8x2_t t = vtrn_s8(a, b);
+  a = t.val[0];
+  b = t.val[1];
+}
+
+void transpose_16bit_vals(int8x8_t& a, int8x8_t& b) {
+  int16x4x2_t t = vtrn_s16(vreinterpret_s16_s8(a), vreinterpret_s16_s8(b));
+  a = vreinterpret_s8_s16(t.val[0]);
+  b = vreinterpret_s8_s16(t.val[1]);
+}
+
+void transpose_32bit_vals(int8x8_t& a, int8x8_t& b) {
+  int32x2x2_t t = vtrn_s32(vreinterpret_s32_s8(a), vreinterpret_s32_s8(b));
+  a = vreinterpret_s8_s32(t.val[0]);
+  b = vreinterpret_s8_s32(t.val[1]);
+}
+}  // namespace
+
+void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride,
+                             int src_rows, int src_cols, int block_row,
+                             int start_col, int end_col,
+                             std::int8_t* packed_ptr, int packed_stride,
+                             int packed_zero_point, std::int32_t* sums,
+                             int input_xor, int kernel_cols) {
+  profiler::ScopeLabel label("Pack (kNeon, from row-major)");
+
+  int src_end_col = std::min(end_col, src_cols);
+  int col = start_col;
+  for (; col <= src_end_col - 8; col += 8) {
+    // Each iteration of this loop handles 8 columns, and the kernel format
+    // has 16 rows, so each iteration handles a 16x8 block.
+    //
+    // Since the source is row-major, handling 8 columns at a time means
+    // loading only 8 bytes i.e. 64bit from each row. This may seem surprising
+    // on 128bit SIMD like NEON. While we could handle 16 columns at a time,
+    // we prefer to stick with 8 for the following reasons:
+    // 1. The arithmetic (computing sums and transposing data) done on these
+    //    values is such that even though we initially start from 64bit vectors,
+    //    most of our NEON instructions are full 128bit instructions. For the
+    //    sums computation, that is because summing 8bit values requires
+    //    expansion to 16bit anyway. For the matrix transposition code, that is
+    //    because the ARM ZIP instructions take 64bit of data from two input
+    //    registers and zip it into a 128bit output. If we had 128bit of data
+    //    in each input registers, we would need 2x more ARM NEON instructions
+    //    to zip it.
+    // 2. The main optimization target for this (ARM, 8bit, non-dotprod)
+    //    code path is in-order ARM cores such as the Cortex-A53, which prefer
+    //    64bit loads anyway.
+    // 3. Handling only 8 columns at a time limits the size of the final
+    //    leftover columns handled with slow scalar code.
+    //
+    // This code is not very optimized anyway, as evidenced from the facts that
+    // (1) it's written in intrinsics, (2) it's not using separate versions
+    // tuned for different types of CPU cores. At the level of optimization that
+    // it's working at, this seems like a fair compromise. If one wanted to
+    // maximize performance at the cost of more code complexity/size, one could
+    // have code handling 16 columns at a time (maybe limited to
+    // Tuning::kOutOfOrder), then 8, then 4 to minimize the amount of slow
+    // leftovers.
+    //
+    // Load 8 sums in sums0, sums1.
+    int32x4_t sums0 = vld1q_s32(sums + col);
+    int32x4_t sums1 = vld1q_s32(sums + col + 4);
+    // Load the 8x16 block from the source matrix.
+    // Each val* here is the data from one row.
+    int8x8_t val0, val1, val2, val3, val4, val5, val6, val7, val8, val9, val10,
+        val11, val12, val13, val14, val15;
+    // Even though this function takes a uint8_t* src_ptr, that's only a
+    // type-erased pointer (using uint8_t* so that pointer arithmetic is
+    // allowed). The actual type may be either uint8_t or int8_t. The only
+    // difference it makes is that if it's uint8_t then we need to flip the
+    // sign bit. This is specified by the input_xor value (which is 0x80 if the
+    // input data is uint8_t, and 0x0 otherwise).
+    auto load_and_convert = [=](const std::uint8_t* from) {
+      return vreinterpret_s8_u8(veor_u8(vdup_n_u8(input_xor), vld1_u8(from)));
+    };
+    if (block_row <= src_rows - 16) {
+      // Load data in the regular case: there are still 16 rows to be read from
+      // the source matrix.
+      val0 = load_and_convert(src_ptr + 0 * src_stride);
+      val1 = load_and_convert(src_ptr + 1 * src_stride);
+      val2 = load_and_convert(src_ptr + 2 * src_stride);
+      val3 = load_and_convert(src_ptr + 3 * src_stride);
+      val4 = load_and_convert(src_ptr + 4 * src_stride);
+      val5 = load_and_convert(src_ptr + 5 * src_stride);
+      val6 = load_and_convert(src_ptr + 6 * src_stride);
+      val7 = load_and_convert(src_ptr + 7 * src_stride);
+      val8 = load_and_convert(src_ptr + 8 * src_stride);
+      val9 = load_and_convert(src_ptr + 9 * src_stride);
+      val10 = load_and_convert(src_ptr + 10 * src_stride);
+      val11 = load_and_convert(src_ptr + 11 * src_stride);
+      val12 = load_and_convert(src_ptr + 12 * src_stride);
+      val13 = load_and_convert(src_ptr + 13 * src_stride);
+      val14 = load_and_convert(src_ptr + 14 * src_stride);
+      val15 = load_and_convert(src_ptr + 15 * src_stride);
+    } else {
+      // Boundary case: there are fewer than 16 rows to be read from the source
+      // matrix. We pad by the zero_point.
+      val0 = vdup_n_s8(packed_zero_point);
+      val1 = val0;
+      val2 = val0;
+      val3 = val0;
+      val4 = val0;
+      val5 = val0;
+      val6 = val0;
+      val7 = val0;
+      val8 = val0;
+      val9 = val0;
+      val10 = val0;
+      val11 = val0;
+      val12 = val0;
+      val13 = val0;
+      val14 = val0;
+      val15 = val0;
+      if (block_row + 0 < src_rows)
+        val0 = load_and_convert(src_ptr + 0 * src_stride);
+      if (block_row + 1 < src_rows)
+        val1 = load_and_convert(src_ptr + 1 * src_stride);
+      if (block_row + 2 < src_rows)
+        val2 = load_and_convert(src_ptr + 2 * src_stride);
+      if (block_row + 3 < src_rows)
+        val3 = load_and_convert(src_ptr + 3 * src_stride);
+      if (block_row + 4 < src_rows)
+        val4 = load_and_convert(src_ptr + 4 * src_stride);
+      if (block_row + 5 < src_rows)
+        val5 = load_and_convert(src_ptr + 5 * src_stride);
+      if (block_row + 6 < src_rows)
+        val6 = load_and_convert(src_ptr + 6 * src_stride);
+      if (block_row + 7 < src_rows)
+        val7 = load_and_convert(src_ptr + 7 * src_stride);
+      if (block_row + 8 < src_rows)
+        val8 = load_and_convert(src_ptr + 8 * src_stride);
+      if (block_row + 9 < src_rows)
+        val9 = load_and_convert(src_ptr + 9 * src_stride);
+      if (block_row + 10 < src_rows)
+        val10 = load_and_convert(src_ptr + 10 * src_stride);
+      if (block_row + 11 < src_rows)
+        val11 = load_and_convert(src_ptr + 11 * src_stride);
+      if (block_row + 12 < src_rows)
+        val12 = load_and_convert(src_ptr + 12 * src_stride);
+      if (block_row + 13 < src_rows)
+        val13 = load_and_convert(src_ptr + 13 * src_stride);
+      if (block_row + 14 < src_rows)
+        val14 = load_and_convert(src_ptr + 14 * src_stride);
+      if (block_row + 15 < src_rows)
+        val15 = load_and_convert(src_ptr + 15 * src_stride);
+    }
+    src_ptr += 8;
+    // Compute sums.
+    int16x8_t sums16_0 = vaddl_s8(val0, val1);
+    int16x8_t sums16_1 = vaddl_s8(val2, val3);
+    sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val4, val5));
+    sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val6, val7));
+    sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val8, val9));
+    sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val10, val11));
+    sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val12, val13));
+    sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val14, val15));
+    int16x8_t sums16 = vaddq_s16(sums16_0, sums16_1);
+    sums0 = vaddw_s16(sums0, vget_low_s16(sums16));
+    sums1 = vaddw_s16(sums1, vget_high_s16(sums16));
+    // Store sums.
+    vst1q_s32(sums + col, sums0);
+    vst1q_s32(sums + col + 4, sums1);
+
+    // Transpose the data, i.e. change the storage order of the
+    // 16x8 block, to convert from the row-major source to the
+    // column-major packed format.
+    //
+    // Before, for i in [0, 15], val<i> is the i-th row.
+    // After, for i in [0, 7], { val<i> val<i+8> } is the i-th column.
+    transpose_8bit_vals(val0, val1);
+    transpose_8bit_vals(val2, val3);
+    transpose_8bit_vals(val4, val5);
+    transpose_8bit_vals(val6, val7);
+    transpose_8bit_vals(val8, val9);
+    transpose_8bit_vals(val10, val11);
+    transpose_8bit_vals(val12, val13);
+    transpose_8bit_vals(val14, val15);
+    transpose_16bit_vals(val0, val2);
+    transpose_16bit_vals(val1, val3);
+    transpose_16bit_vals(val4, val6);
+    transpose_16bit_vals(val5, val7);
+    transpose_16bit_vals(val8, val10);
+    transpose_16bit_vals(val9, val11);
+    transpose_16bit_vals(val12, val14);
+    transpose_16bit_vals(val13, val15);
+    transpose_32bit_vals(val0, val4);
+    transpose_32bit_vals(val1, val5);
+    transpose_32bit_vals(val2, val6);
+    transpose_32bit_vals(val3, val7);
+    transpose_32bit_vals(val8, val12);
+    transpose_32bit_vals(val9, val13);
+    transpose_32bit_vals(val10, val14);
+    transpose_32bit_vals(val11, val15);
+    // Store to the packed_matrix.
+    std::int8_t* dst_ptr = packed_ptr;
+    vst1q_s8(dst_ptr, vcombine_s8(val0, val8));
+    vst1q_s8(dst_ptr + 16, vcombine_s8(val1, val9));
+    dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32;
+    vst1q_s8(dst_ptr, vcombine_s8(val2, val10));
+    vst1q_s8(dst_ptr + 16, vcombine_s8(val3, val11));
+    packed_ptr += 4 * packed_stride;
+    dst_ptr = packed_ptr;
+    vst1q_s8(dst_ptr, vcombine_s8(val4, val12));
+    vst1q_s8(dst_ptr + 16, vcombine_s8(val5, val13));
+    dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32;
+    vst1q_s8(dst_ptr, vcombine_s8(val6, val14));
+    vst1q_s8(dst_ptr + 16, vcombine_s8(val7, val15));
+    packed_ptr += 4 * packed_stride;
+  }
+  // Handle remaining columns, not fitting in a full block of 8 columns, but
+  // still true columns frome the source matrix (as opposed to the final columns
+  // below).
+  for (; col < src_end_col; col++) {
+    std::int32_t accum = 0;
+    std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16;
+    for (int r = 0; r < 16; r++) {
+      std::int8_t packed_val = (block_row + r < src_rows)
+                                   ? (src_ptr[r * src_stride] ^ input_xor)
+                                   : packed_zero_point;
+      accum += packed_val;
+      dst_ptr[r] = packed_val;
+    }
+    if (sums) {
+      sums[col] += accum;
+    }
+    src_ptr++;
+    if (((col + 1) & (kernel_cols - 1)) == 0) {
+      packed_ptr += kernel_cols * packed_stride;
+    }
+  }
+  // Handle the final columns of the packed matrix, beyond the last column of
+  // the source matrix. The values here don't matter, we just want to avoid
+  // leaving uninitialized data. Since the sums are already initialized above,
+  // we don't need to do anything about them here.
+  for (; col < end_col; col++) {
+    std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16;
+    std::memset(dst_ptr, 0, 16);
+    if (((col + 1) & (kernel_cols - 1)) == 0) {
+      packed_ptr += kernel_cols * packed_stride;
+    }
+  }
+}
+
+#endif
+
 }  // namespace ruy
diff --git a/ruy/pack_arm.h b/ruy/pack_arm.h
index 27498b9..12d5cab 100644
--- a/ruy/pack_arm.h
+++ b/ruy/pack_arm.h
@@ -51,6 +51,15 @@
 };
 #endif
 
+#if RUY_PLATFORM_NEON
+void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride,
+                             int src_rows, int src_cols, int block_row,
+                             int start_col, int end_col,
+                             std::int8_t* packed_ptr, int packed_stride,
+                             int packed_zero_point, std::int32_t* sums_ptr,
+                             int input_xor, int kernel_cols);
+#endif
+
 #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
 
 void Pack8bitColMajorForNeonOutOfOrder(
@@ -81,7 +90,6 @@
                                     int src_zero_point, std::int8_t* packed_ptr,
                                     int packed_stride, std::int32_t* sums_ptr,
                                     int input_xor);
-
 #elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
 
 struct PackParams8bit {
@@ -564,6 +572,42 @@
 
 #endif  // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
 
+#if RUY_PLATFORM_NEON
+
+template <typename Scalar, int KernelCols>
+struct PackImpl<Path::kNeon,
+                FixedKernelLayout<Order::kColMajor, 16, KernelCols>, Scalar,
+                std::int8_t, std::int32_t, Order::kRowMajor> {
+  static void Run(Tuning, const Mat<Scalar>& src_matrix,
+                  PMat<std::int8_t>* packed_matrix, int start_col,
+                  int end_col) {
+    profiler::ScopeLabel label("Pack (KNeon, from row-major source)");
+    static constexpr int kInputXor =
+        std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+    RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
+    RUY_DCHECK_EQ((end_col - start_col) % KernelCols, 0);
+    std::int32_t* sums = packed_matrix->sums;
+    std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
+    int block_row = 0;
+    for (; block_row < packed_matrix->layout.rows; block_row += 16) {
+      int src_stride = src_matrix.layout.stride;
+      int packed_stride = packed_matrix->layout.stride;
+      const Scalar* src_ptr =
+          src_matrix.data.get() + block_row * src_stride + start_col;
+      std::int8_t* packed_ptr = packed_matrix->data +
+                                start_col * packed_stride +
+                                block_row * KernelCols;
+
+      Pack8bitRowMajorForNeon(
+          reinterpret_cast<const std::uint8_t*>(src_ptr), src_stride,
+          src_matrix.layout.rows, src_matrix.layout.cols, block_row, start_col,
+          end_col, packed_ptr, packed_stride, packed_matrix->zero_point, sums,
+          kInputXor, KernelCols);
+    }
+  }
+};
+#endif
+
 }  // namespace ruy
 
 #endif  // RUY_RUY_PACK_ARM_H_