Optimized packing code path for row-major 8bit inputs for the x86 paths.

PiperOrigin-RevId: 322925521
diff --git a/ruy/pack_avx2_fma.cc b/ruy/pack_avx2_fma.cc
index 601d464..55879a1 100644
--- a/ruy/pack_avx2_fma.cc
+++ b/ruy/pack_avx2_fma.cc
@@ -44,6 +44,11 @@
   RUY_DCHECK(false);
 }
 
+void Pack8bitRowMajorForAvx2(const std::uint8_t*, int, int, std::int8_t*, int,
+                             int, int, int, int, int, int, std::int32_t*) {
+  RUY_DCHECK(false);
+}
+
 #else  // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
 
 // The first int8_t template parameter is arbitrary: this routine is common to
@@ -811,6 +816,87 @@
   }
 }
 
+void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride,
+                             int src_zero_point, std::int8_t* packed_ptr,
+                             int packed_stride, int start_col, int end_col,
+                             int src_cols, int block_row, int src_rows,
+                             int input_xor, std::int32_t* sums) {
+  int col = start_col;
+  int src_end_col = std::min(end_col, src_cols);
+
+  for (; col <= src_end_col - 8; col += 8) {
+    std::int8_t* dst_ptr = packed_ptr;
+    __m128i val0, val1, val2, val3;
+    __m128i input_xor_dup = _mm_set1_epi8(input_xor);
+    // Load a 4x8 block.
+    if (block_row + 4 <= src_rows) {
+      val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
+      val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
+      val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
+      val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
+    } else {
+      val0 = _mm_set1_epi8(src_zero_point);
+      val1 = val0;
+      val2 = val0;
+      val3 = val0;
+      if (block_row + 0 < src_rows)
+        val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
+      if (block_row + 1 < src_rows)
+        val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
+      if (block_row + 2 < src_rows)
+        val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
+      if (block_row + 3 < src_rows)
+        val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
+    }
+    // Maybe xor the sign bit to convert from uint8 to int8.
+    val0 = _mm_xor_si128(val0, input_xor_dup);
+    val1 = _mm_xor_si128(val1, input_xor_dup);
+    val2 = _mm_xor_si128(val2, input_xor_dup);
+    val3 = _mm_xor_si128(val3, input_xor_dup);
+    // Update the sums.
+    __m128i val16_0 = _mm_cvtepi8_epi16(val0);
+    __m128i val16_1 = _mm_cvtepi8_epi16(val1);
+    __m128i val16_2 = _mm_cvtepi8_epi16(val2);
+    __m128i val16_3 = _mm_cvtepi8_epi16(val3);
+    __m128i new_sum16 = _mm_add_epi16(_mm_add_epi16(val16_0, val16_1),
+                                      _mm_add_epi16(val16_2, val16_3));
+    __m256i sum =
+        _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums + col));
+    sum = _mm256_add_epi32(sum, _mm256_cvtepi16_epi32(new_sum16));
+    _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums + col), sum);
+    // Perform the transposition of 4x4 blocks
+    __m128i t2_val0 = _mm_unpacklo_epi8(val0, val1);
+    __m128i t2_val1 = _mm_unpacklo_epi8(val2, val3);
+    __m128i t4_val0 = _mm_unpacklo_epi16(t2_val0, t2_val1);
+    __m128i t4_val1 = _mm_unpackhi_epi16(t2_val0, t2_val1);
+    _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), t4_val0);
+    _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), t4_val1);
+    src_ptr += 8;
+    packed_ptr += packed_stride * 8;
+  }
+  for (; col < src_end_col; col++) {
+    std::int32_t accum = 0;
+    for (int r = 0; r < 4; r++) {
+      std::int8_t packed_val;
+      if (block_row + r < src_rows) {
+        packed_val = input_xor ^ src_ptr[r * src_stride];
+      } else {
+        packed_val = input_xor ^ src_zero_point;
+      }
+      accum += packed_val;
+      *packed_ptr++ = packed_val;
+    }
+    if (sums) {
+      sums[col] += accum;
+    }
+    src_ptr++;
+  }
+  for (; col < end_col; col++) {
+    std::memset(packed_ptr, 0, 4);
+    packed_ptr += 4;
+  }
+}
+
 #endif  // RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS)
 
 }  // namespace ruy
diff --git a/ruy/pack_avx512.cc b/ruy/pack_avx512.cc
index bd04aca..b38bc01 100644
--- a/ruy/pack_avx512.cc
+++ b/ruy/pack_avx512.cc
@@ -44,6 +44,11 @@
   RUY_DCHECK(false);
 }
 
+void Pack8bitRowMajorForAvx512(const std::uint8_t*, int, int, std::int8_t*, int,
+                               int, int, int, int, int, int, std::int32_t*) {
+  RUY_DCHECK(false);
+}
+
 #else  // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
 
 // The first int8_t template parameter is arbitrary: this routine is common to
@@ -717,6 +722,107 @@
   }
 }
 
+void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
+                               int src_zero_point, std::int8_t* packed_ptr,
+                               int packed_stride, int start_col, int end_col,
+                               int src_cols, int block_row, int src_rows,
+                               int input_xor, std::int32_t* sums) {
+  int col = start_col;
+  int src_end_col = std::min(end_col, src_cols);
+
+  for (; col <= src_end_col - 16; col += 16) {
+    std::int8_t* dst_ptr = packed_ptr;
+    __m128i val0, val1, val2, val3;
+    __m128i input_xor_dup = _mm_set1_epi8(input_xor);
+    // Load a 4x16 block.
+    if (block_row + 4 <= src_rows) {
+      val0 = _mm_loadu_si128(
+          reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
+      val1 = _mm_loadu_si128(
+          reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
+      val2 = _mm_loadu_si128(
+          reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
+      val3 = _mm_loadu_si128(
+          reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
+    } else {
+      val0 = _mm_set1_epi8(src_zero_point);
+      val1 = val0;
+      val2 = val0;
+      val3 = val0;
+      if (block_row + 0 < src_rows)
+        val0 = _mm_loadu_si128(
+            reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
+      if (block_row + 1 < src_rows)
+        val1 = _mm_loadu_si128(
+            reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
+      if (block_row + 2 < src_rows)
+        val2 = _mm_loadu_si128(
+            reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
+      if (block_row + 3 < src_rows)
+        val3 = _mm_loadu_si128(
+            reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
+    }
+    // Maybe xor the sign bit to convert from uint8 to int8.
+    val0 = _mm_xor_si128(val0, input_xor_dup);
+    val1 = _mm_xor_si128(val1, input_xor_dup);
+    val2 = _mm_xor_si128(val2, input_xor_dup);
+    val3 = _mm_xor_si128(val3, input_xor_dup);
+    // Update the sums.
+    __m256i val16_0 = _mm256_cvtepi8_epi16(val0);
+    __m256i val16_1 = _mm256_cvtepi8_epi16(val1);
+    __m256i val16_2 = _mm256_cvtepi8_epi16(val2);
+    __m256i val16_3 = _mm256_cvtepi8_epi16(val3);
+    __m256i new_sum16 = _mm256_add_epi16(_mm256_add_epi16(val16_0, val16_1),
+                                         _mm256_add_epi16(val16_2, val16_3));
+    __m512i sum =
+        _mm512_loadu_si512(reinterpret_cast<const __m512i*>(sums + col));
+    sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(new_sum16));
+    _mm512_storeu_si512(reinterpret_cast<__m512i*>(sums + col), sum);
+    auto zip = [](__m128i x, __m128i y) {
+      auto perm_64_0_64_0 = [](__m128i x) {
+        return _mm256_permutexvar_epi64(_mm256_setr_epi64x(0, 2, 1, 3),
+                                        _mm256_castsi128_si256(x));
+      };
+      return _mm256_unpacklo_epi8(perm_64_0_64_0(x), perm_64_0_64_0(y));
+    };
+    __m256i t2_val0 = zip(val0, val1);
+    __m256i t2_val1 = zip(val2, val3);
+    __m256i t4_val0 = _mm256_unpacklo_epi16(t2_val0, t2_val1);
+    __m256i t4_val1 = _mm256_unpackhi_epi16(t2_val0, t2_val1);
+    _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr),
+                     _mm256_extractf128_si256(t4_val0, 0));
+    _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16),
+                     _mm256_extractf128_si256(t4_val1, 0));
+    _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 32),
+                     _mm256_extractf128_si256(t4_val0, 1));
+    _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 48),
+                     _mm256_extractf128_si256(t4_val1, 1));
+    src_ptr += 16;
+    packed_ptr += packed_stride * 16;
+  }
+  for (; col < src_end_col; col++) {
+    std::int32_t accum = 0;
+    for (int r = 0; r < 4; r++) {
+      std::int8_t packed_val;
+      if (block_row + r < src_rows) {
+        packed_val = input_xor ^ src_ptr[r * src_stride];
+      } else {
+        packed_val = input_xor ^ src_zero_point;
+      }
+      accum += packed_val;
+      *packed_ptr++ = packed_val;
+    }
+    if (sums) {
+      sums[col] += accum;
+    }
+    src_ptr++;
+  }
+  for (; col < end_col; col++) {
+    std::memset(packed_ptr, 0, 4);
+    packed_ptr += 4;
+  }
+}
+
 #endif  // RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
 
 }  // namespace ruy
diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h
index d78892a..bc84359 100644
--- a/ruy/pack_x86.h
+++ b/ruy/pack_x86.h
@@ -217,6 +217,77 @@
     }
   }
 };
+
+void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride,
+                             int src_zero_point, std::int8_t* packed_ptr,
+                             int packed_stride, int start_col, int end_col,
+                             int src_cols, int block_row, int src_rows,
+                             int input_xor, std::int32_t* sums);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
+                Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
+  static void Run(Tuning, const Mat<Scalar>& src_matrix,
+                  PMat<std::int8_t>* packed_matrix, int start_col,
+                  int end_col) {
+    profiler::ScopeLabel label("Pack (kAvx2Fma 8bit row-major)");
+    RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
+    static constexpr int kInputXor =
+        std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+    std::int32_t* sums = packed_matrix->sums;
+    std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
+    int block_row = 0;
+    for (; block_row < packed_matrix->layout.rows; block_row += 4) {
+      int src_stride = src_matrix.layout.stride;
+      int packed_stride = packed_matrix->layout.stride;
+      const Scalar* src_ptr =
+          src_matrix.data.get() + block_row * src_stride + start_col;
+      std::int8_t* packed_ptr =
+          packed_matrix->data + start_col * packed_stride + block_row * 8;
+      Pack8bitRowMajorForAvx2(reinterpret_cast<const std::uint8_t*>(src_ptr),
+                              src_stride, src_matrix.zero_point, packed_ptr,
+                              packed_stride, start_col, end_col,
+                              src_matrix.layout.cols, block_row,
+                              src_matrix.layout.rows, kInputXor, sums);
+    }
+  }
+};
+
+void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
+                               int src_zero_point, std::int8_t* packed_ptr,
+                               int packed_stride, int start_col, int end_col,
+                               int src_cols, int block_row, int src_rows,
+                               int input_xor, std::int32_t* sums);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+                Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
+  static void Run(Tuning, const Mat<Scalar>& src_matrix,
+                  PMat<std::int8_t>* packed_matrix, int start_col,
+                  int end_col) {
+    profiler::ScopeLabel label("Pack (kAvx512 8bit row-major)");
+    RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
+    static constexpr int kInputXor =
+        std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+    std::int32_t* sums = packed_matrix->sums;
+    std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
+    int block_row = 0;
+    for (; block_row < packed_matrix->layout.rows; block_row += 4) {
+      int src_stride = src_matrix.layout.stride;
+      int packed_stride = packed_matrix->layout.stride;
+      const Scalar* src_ptr =
+          src_matrix.data.get() + block_row * src_stride + start_col;
+      std::int8_t* packed_ptr =
+          packed_matrix->data + start_col * packed_stride + block_row * 16;
+      Pack8bitRowMajorForAvx512(reinterpret_cast<const std::uint8_t*>(src_ptr),
+                                src_stride, src_matrix.zero_point, packed_ptr,
+                                packed_stride, start_col, end_col,
+                                src_matrix.layout.cols, block_row,
+                                src_matrix.layout.rows, kInputXor, sums);
+    }
+  }
+};
+
 #endif  // RUY_PLATFORM_X86
 
 }  // namespace ruy