blob: 4c21fb04efcd82018c37d6970a7e8e863540db80 [file] [log] [blame]
// Copyright 2015 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// single_thread_gemm.h: programatically generated GEMM library header.
#ifndef GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
#define GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
#ifdef GEMMLOWP_NEON_32
#include <cassert>
namespace gemmlowp {
namespace meta {
namespace internal {
void zip_1x8_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_1_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #1\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.8 {d0[0]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_2_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #2\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.16 {d0[0]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_3_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #3\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.16 {d0[0]}, [%[source]]!\n"
"vld1.8 {d0[2]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_4_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #4\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.32 {d0[0]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_5_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #5\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.8 {d0[4]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_6_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #6\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.16 {d0[2]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_7_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #7\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.16 {d0[2]}, [%[source]]!\n"
"vld1.8 {d0[6]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_2x8_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_1_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #1\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.8 {d0[0]}, [%[source]]\n"
"vld1.8 {d1[0]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_2_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #2\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.16 {d0[0]}, [%[source]]\n"
"vld1.16 {d1[0]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_3_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #3\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.16 {d0[0]}, [%[source]]!\n"
"vld1.16 {d1[0]}, [r0]!\n"
"vld1.8 {d0[2]}, [%[source]]\n"
"vld1.8 {d1[2]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_4_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #4\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.32 {d0[0]}, [%[source]]\n"
"vld1.32 {d1[0]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_5_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #5\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.32 {d1[0]}, [r0]!\n"
"vld1.8 {d0[4]}, [%[source]]\n"
"vld1.8 {d1[4]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_6_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #6\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.32 {d1[0]}, [r0]!\n"
"vld1.16 {d0[2]}, [%[source]]\n"
"vld1.16 {d1[2]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_7_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #7\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.32 {d1[0]}, [r0]!\n"
"vld1.16 {d0[2]}, [%[source]]!\n"
"vld1.16 {d1[2]}, [r0]!\n"
"vld1.8 {d0[6]}, [%[source]]\n"
"vld1.8 {d1[6]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_3x8_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vld1.8 {d2}, [r1:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_1_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #1\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vld1.8 {d2}, [r1:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.8 {d0[0]}, [%[source]]\n"
"vld1.8 {d1[0]}, [r0]\n"
"vld1.8 {d2[0]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_2_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #2\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vld1.8 {d2}, [r1:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.16 {d0[0]}, [%[source]]\n"
"vld1.16 {d1[0]}, [r0]\n"
"vld1.16 {d2[0]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_3_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #3\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vld1.8 {d2}, [r1:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.16 {d0[0]}, [%[source]]!\n"
"vld1.16 {d1[0]}, [r0]!\n"
"vld1.16 {d2[0]}, [r1]!\n"
"vld1.8 {d0[2]}, [%[source]]\n"
"vld1.8 {d1[2]}, [r0]\n"
"vld1.8 {d2[2]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_4_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #4\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vld1.8 {d2}, [r1:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.32 {d0[0]}, [%[source]]\n"
"vld1.32 {d1[0]}, [r0]\n"
"vld1.32 {d2[0]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_5_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #5\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vld1.8 {d2}, [r1:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.32 {d1[0]}, [r0]!\n"
"vld1.32 {d2[0]}, [r1]!\n"
"vld1.8 {d0[4]}, [%[source]]\n"
"vld1.8 {d1[4]}, [r0]\n"
"vld1.8 {d2[4]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_6_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #6\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vld1.8 {d2}, [r1:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.32 {d1[0]}, [r0]!\n"
"vld1.32 {d2[0]}, [r1]!\n"
"vld1.16 {d0[2]}, [%[source]]\n"
"vld1.16 {d1[2]}, [r0]\n"
"vld1.16 {d2[2]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_7_aligned(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #7\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]:64]!\n"
"vld1.8 {d1}, [r0:64]!\n"
"vld1.8 {d2}, [r1:64]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.32 {d1[0]}, [r0]!\n"
"vld1.32 {d2[0]}, [r1]!\n"
"vld1.16 {d0[2]}, [%[source]]!\n"
"vld1.16 {d1[2]}, [r0]!\n"
"vld1.16 {d2[2]}, [r1]!\n"
"vld1.8 {d0[6]}, [%[source]]\n"
"vld1.8 {d1[6]}, [r0]\n"
"vld1.8 {d2[6]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_1x8(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset, std::int32_t additive_offset) {
asm volatile(
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_1(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #1\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.8 {d0[0]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_2(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #2\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.16 {d0[0]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_3(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #3\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.16 {d0[0]}, [%[source]]!\n"
"vld1.8 {d0[2]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_4(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #4\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.32 {d0[0]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_5(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #5\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.8 {d0[4]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_6(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #6\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.16 {d0[2]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_1x8_7(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"sub %[count], %[count], #7\n"
"vmov.i16 q2, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.16 {d0[2]}, [%[source]]!\n"
"vld1.8 {d0[6]}, [%[source]]\n"
"vaddw.u8 q2, q2, d0\n"
"vst1.8 {d0}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d1[0], %[multiplicative_offset]\n"
"vdup.32 q1, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpadd.u32 d6, d4, d5\n"
"vpadd.u32 d8, d6, d6\n"
"vmul.i32 q4, q4, d1[0]\n"
"vadd.i32 q4, q4, q1\n"
"vst1.32 {d8[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
}
void zip_2x8(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset, std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_1(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #1\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.8 {d0[0]}, [%[source]]\n"
"vld1.8 {d1[0]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_2(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #2\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.16 {d0[0]}, [%[source]]\n"
"vld1.16 {d1[0]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_3(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #3\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.16 {d0[0]}, [%[source]]!\n"
"vld1.16 {d1[0]}, [r0]!\n"
"vld1.8 {d0[2]}, [%[source]]\n"
"vld1.8 {d1[2]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_4(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #4\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.32 {d0[0]}, [%[source]]\n"
"vld1.32 {d1[0]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_5(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #5\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.32 {d1[0]}, [r0]!\n"
"vld1.8 {d0[4]}, [%[source]]\n"
"vld1.8 {d1[4]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_6(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #6\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.32 {d1[0]}, [r0]!\n"
"vld1.16 {d0[2]}, [%[source]]\n"
"vld1.16 {d1[2]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_2x8_7(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"sub %[count], %[count], #7\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.32 {d1[0]}, [r0]!\n"
"vld1.16 {d0[2]}, [%[source]]!\n"
"vld1.16 {d1[2]}, [r0]!\n"
"vld1.8 {d0[6]}, [%[source]]\n"
"vld1.8 {d1[6]}, [r0]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d2[0], %[multiplicative_offset]\n"
"vdup.32 q4, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpadd.u32 d3, d4, d5\n"
"vpadd.u32 d10, d6, d7\n"
"vpadd.u32 d12, d3, d10\n"
"vmul.i32 q6, q6, d2[0]\n"
"vadd.i32 q6, q6, q4\n"
"vst1.32 {d12}, [%[destination]:64]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d12", "d13", "cc", "memory");
}
void zip_3x8(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset, std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vld1.8 {d2}, [r1]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_1(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #1\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vld1.8 {d2}, [r1]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.8 {d0[0]}, [%[source]]\n"
"vld1.8 {d1[0]}, [r0]\n"
"vld1.8 {d2[0]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_2(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #2\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vld1.8 {d2}, [r1]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.16 {d0[0]}, [%[source]]\n"
"vld1.16 {d1[0]}, [r0]\n"
"vld1.16 {d2[0]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_3(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #3\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vld1.8 {d2}, [r1]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.16 {d0[0]}, [%[source]]!\n"
"vld1.16 {d1[0]}, [r0]!\n"
"vld1.16 {d2[0]}, [r1]!\n"
"vld1.8 {d0[2]}, [%[source]]\n"
"vld1.8 {d1[2]}, [r0]\n"
"vld1.8 {d2[2]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_4(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #4\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vld1.8 {d2}, [r1]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.32 {d0[0]}, [%[source]]\n"
"vld1.32 {d1[0]}, [r0]\n"
"vld1.32 {d2[0]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_5(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #5\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vld1.8 {d2}, [r1]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.32 {d1[0]}, [r0]!\n"
"vld1.32 {d2[0]}, [r1]!\n"
"vld1.8 {d0[4]}, [%[source]]\n"
"vld1.8 {d1[4]}, [r0]\n"
"vld1.8 {d2[4]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_6(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #6\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vld1.8 {d2}, [r1]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.32 {d1[0]}, [r0]!\n"
"vld1.32 {d2[0]}, [r1]!\n"
"vld1.16 {d0[2]}, [%[source]]\n"
"vld1.16 {d1[2]}, [r0]\n"
"vld1.16 {d2[2]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
void zip_3x8_7(const std::uint8_t* source, std::int32_t count,
std::int32_t stride, std::uint8_t* destination,
std::int32_t multiplicative_offset,
std::int32_t additive_offset) {
asm volatile(
"add r0, %[source], %[stride]\n"
"add r1, r0, %[stride]\n"
"sub %[count], %[count], #7\n"
"vmov.i16 q2, #0\n"
"vmov.i16 q3, #0\n"
"vmov.i16 q4, #0\n"
"1:"
"subs %[count], %[count], #8\n"
// Load Aggregate Store.
"vld1.8 {d0}, [%[source]]!\n"
"vld1.8 {d1}, [r0]!\n"
"vld1.8 {d2}, [r1]!\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
"bne 1b\n"
// Leftover Load Aggregate Store.
"vmov.i8 d0, #0\n"
"vmov.i8 d1, #0\n"
"vmov.i8 d2, #0\n"
"vld1.32 {d0[0]}, [%[source]]!\n"
"vld1.32 {d1[0]}, [r0]!\n"
"vld1.32 {d2[0]}, [r1]!\n"
"vld1.16 {d0[2]}, [%[source]]!\n"
"vld1.16 {d1[2]}, [r0]!\n"
"vld1.16 {d2[2]}, [r1]!\n"
"vld1.8 {d0[6]}, [%[source]]\n"
"vld1.8 {d1[6]}, [r0]\n"
"vld1.8 {d2[6]}, [r1]\n"
"vaddw.u8 q2, q2, d0\n"
"vaddw.u8 q3, q3, d1\n"
"vaddw.u8 q4, q4, d2\n"
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
// Aggregator Reduction.
"vmov.32 d3[0], %[multiplicative_offset]\n"
"vdup.32 q5, %[additive_offset]\n"
"vpaddl.u16 q2, q2\n"
"vpaddl.u16 q3, q3\n"
"vpaddl.u16 q4, q4\n"
"vpadd.u32 d12, d4, d5\n"
"vpadd.u32 d13, d6, d7\n"
"vpadd.u32 d14, d8, d9\n"
"vpadd.u32 d16, d12, d13\n"
"vpadd.u32 d17, d14, d14\n"
"vmul.i32 q8, q8, d3[0]\n"
"vadd.i32 q8, q8, q5\n"
"vst1.32 {d16}, [%[destination]:64]!\n"
"vst1.32 {d17[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
[destination] "+r"(destination), [source] "+r"(source)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
}
inline void mul_1x8_1x8_int32_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d2}, [%[lhs]:64]!\n"
"vld1.8 {d3}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q2, d3, d2\n"
"vpadal.u16 q0, q2\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[rhs]:64]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
// Reduce rows.
"vpadd.u32 d0, d0, d0\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d8\n"
// Store reduced rows.
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d8", "cc", "memory");
}
inline void mul_1x8_2x8_int32_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d4}, [%[lhs]:64]!\n"
"vld1.8 {d5, d6}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q4, d5, d4\n"
"vmull.u8 q5, d6, d4\n"
"vpadal.u16 q0, q4\n"
"vpadal.u16 q1, q5\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[rhs]:64]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d8\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "d10", "d11",
"cc", "memory");
}
inline void mul_1x8_3x8_int32_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d6}, [%[lhs]:64]!\n"
"vld1.8 {d7, d8, d9}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q5, d7, d6\n"
"vmull.u8 q6, d8, d6\n"
"vmull.u8 q7, d9, d6\n"
"vpadal.u16 q0, q5\n"
"vpadal.u16 q1, q6\n"
"vpadal.u16 q2, q7\n"
// Loop break.
"bne 1b\n"
"vld1.32 {q4}, [%[rhs]:64]\n"
// Change stride because storing in two ops.
"sub %[result_stride], %[result_stride], #8\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d1, d4, d4\n"
// Add rhs offset to aggregated rows.
"vadd.s32 q0, q0, q4\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]]!\n"
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "cc", "memory");
}
inline void mul_2x8_1x8_int32_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d4, d5}, [%[lhs]:64]!\n"
"vld1.8 {d6}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q4, d6, d4\n"
"vmull.u8 q5, d6, d5\n"
"vpadal.u16 q0, q4\n"
"vpadal.u16 q1, q5\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[rhs]:64]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
// Reduce rows.
"vpadd.u32 d0, d0, d0\n"
"vpadd.u32 d2, d2, d2\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d8\n"
"vadd.s32 d2, d2, d8\n"
// Store reduced rows.
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d2[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "d10", "d11",
"cc", "memory");
}
inline void mul_2x8_2x8_int32_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"vmov.i32 q3, q0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d8, d9}, [%[lhs]:64]!\n"
"vld1.8 {d10, d11}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q6, d10, d8\n"
"vmull.u8 q7, d11, d8\n"
"vmull.u8 q8, d10, d9\n"
"vmull.u8 q9, d11, d9\n"
"vpadal.u16 q0, q6\n"
"vpadal.u16 q1, q7\n"
"vpadal.u16 q2, q8\n"
"vpadal.u16 q3, q9\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[rhs]:64]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
"vpadd.u32 d6, d6, d7\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d4, d4, d6\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d8\n"
"vadd.s32 d4, d4, d8\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
"vst1.32 {d4}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "cc",
"memory");
}
inline void mul_2x8_3x8_int32_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"vmov.i32 q3, q0\n"
"vmov.i32 q4, q1\n"
"vmov.i32 q5, q2\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d12, d13}, [%[lhs]:64]!\n"
"vld1.8 {d14, d15, d16}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q9, d14, d12\n"
"vmull.u8 q10, d15, d12\n"
"vmull.u8 q11, d16, d12\n"
"vmull.u8 q12, d14, d13\n"
"vmull.u8 q13, d15, d13\n"
"vmull.u8 q14, d16, d13\n"
"vpadal.u16 q0, q9\n"
"vpadal.u16 q1, q10\n"
"vpadal.u16 q2, q11\n"
"vpadal.u16 q3, q12\n"
"vpadal.u16 q4, q13\n"
"vpadal.u16 q5, q14\n"
// Loop break.
"bne 1b\n"
"vld1.32 {q6}, [%[rhs]:64]\n"
// Change stride because storing in two ops.
"sub %[result_stride], %[result_stride], #8\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
"vpadd.u32 d6, d6, d7\n"
"vpadd.u32 d8, d8, d9\n"
"vpadd.u32 d10, d10, d11\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d1, d4, d4\n"
"vpadd.u32 d6, d6, d8\n"
"vpadd.u32 d7, d10, d10\n"
// Add rhs offset to aggregated rows.
"vadd.s32 q0, q0, q6\n"
"vadd.s32 q3, q3, q6\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]]!\n"
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d6}, [%[result]]!\n"
"vst1.32 {d7[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d18", "d19", "d20", "d21",
"d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "cc", "memory");
}
inline void mul_3x8_1x8_int32_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d6, d7, d8}, [%[lhs]:64]!\n"
"vld1.8 {d9}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q5, d9, d6\n"
"vmull.u8 q6, d9, d7\n"
"vmull.u8 q7, d9, d8\n"
"vpadal.u16 q0, q5\n"
"vpadal.u16 q1, q6\n"
"vpadal.u16 q2, q7\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[rhs]:64]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
// Reduce rows.
"vpadd.u32 d0, d0, d0\n"
"vpadd.u32 d2, d2, d2\n"
"vpadd.u32 d4, d4, d4\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d8\n"
"vadd.s32 d2, d2, d8\n"
"vadd.s32 d4, d4, d8\n"
// Store reduced rows.
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d2[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d4[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "cc", "memory");
}
inline void mul_3x8_2x8_int32_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"vmov.i32 q3, q0\n"
"vmov.i32 q4, q1\n"
"vmov.i32 q5, q2\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d12, d13, d14}, [%[lhs]:64]!\n"
"vld1.8 {d15, d16}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q9, d15, d12\n"
"vmull.u8 q10, d16, d12\n"
"vmull.u8 q11, d15, d13\n"
"vmull.u8 q12, d16, d13\n"
"vmull.u8 q13, d15, d14\n"
"vmull.u8 q14, d16, d14\n"
"vpadal.u16 q0, q9\n"
"vpadal.u16 q1, q10\n"
"vpadal.u16 q2, q11\n"
"vpadal.u16 q3, q12\n"
"vpadal.u16 q4, q13\n"
"vpadal.u16 q5, q14\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d12}, [%[rhs]:64]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
"vpadd.u32 d6, d6, d7\n"
"vpadd.u32 d8, d8, d9\n"
"vpadd.u32 d10, d10, d11\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d4, d4, d6\n"
"vpadd.u32 d8, d8, d10\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d12\n"
"vadd.s32 d4, d4, d12\n"
"vadd.s32 d8, d8, d12\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
"vst1.32 {d4}, [%[result]], %[result_stride]\n"
"vst1.32 {d8}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d18", "d19", "d20", "d21",
"d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "cc", "memory");
}
inline void mul_3x8_3x8_int32_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"vmov.i32 q3, q0\n"
"vmov.i32 q4, q1\n"
"vmov.i32 q5, q2\n"
"vmov.i32 q6, q3\n"
"vmov.i32 q7, q4\n"
"vmov.i32 q8, q5\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// 3x3 lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d18, d19, d20}, [%[lhs]:64]!\n"
"vld1.8 {d21, d22, d23}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q12, d18, d21\n"
"vmull.u8 q13, d18, d22\n"
"vmull.u8 q14, d18, d23\n"
"vmull.u8 q15, d19, d21\n"
"vpadal.u16 q0, q12\n"
"vpadal.u16 q1, q13\n"
"vpadal.u16 q2, q14\n"
"vpadal.u16 q3, q15\n"
"vmull.u8 q12, d19, d22\n"
"vmull.u8 q13, d19, d23\n"
"vmull.u8 q14, d20, d21\n"
"vmull.u8 q15, d20, d22\n"
"vmull.u8 q9, d20, d23\n"
"vpadal.u16 q4, q12\n"
"vpadal.u16 q5, q13\n"
"vpadal.u16 q6, q14\n"
"vpadal.u16 q7, q15\n"
"vpadal.u16 q8, q9\n"
// Loop break.
"bne 1b\n"
"vld1.32 {q9}, [%[rhs]:64]\n"
// Change stride because storing in two ops.
"sub %[result_stride], %[result_stride], #8\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
"vpadd.u32 d6, d6, d7\n"
"vpadd.u32 d8, d8, d9\n"
"vpadd.u32 d10, d10, d11\n"
"vpadd.u32 d12, d12, d13\n"
"vpadd.u32 d14, d14, d15\n"
"vpadd.u32 d16, d16, d17\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d1, d4, d4\n"
"vpadd.u32 d6, d6, d8\n"
"vpadd.u32 d7, d10, d10\n"
"vpadd.u32 d12, d12, d14\n"
"vpadd.u32 d13, d16, d16\n"
// Add rhs offset to aggregated rows.
"vadd.s32 q0, q0, q9\n"
"vadd.s32 q3, q3, q9\n"
"vadd.s32 q6, q6, q9\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]]!\n"
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d6}, [%[result]]!\n"
"vst1.32 {d7[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d12}, [%[result]]!\n"
"vst1.32 {d13[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30",
"d31", "cc", "memory");
}
inline void mul_1x8_1x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count,
std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d2}, [%[lhs]:64]!\n"
"vld1.8 {d3}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q2, d3, d2\n"
"vpadal.u16 q0, q2\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[lhs]:64]\n"
"vdup.32 d4, d8[0]\n"
"vld1.32 {d9}, [%[rhs]:64]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
// Reduce rows.
"vpadd.u32 d0, d0, d0\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 d0, d0, d4\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d9\n"
// Store reduced rows.
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "cc", "memory");
}
inline void mul_1x8_2x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count,
std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d4}, [%[lhs]:64]!\n"
"vld1.8 {d5, d6}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q4, d5, d4\n"
"vmull.u8 q5, d6, d4\n"
"vpadal.u16 q0, q4\n"
"vpadal.u16 q1, q5\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[lhs]:64]\n"
"vdup.32 d4, d8[0]\n"
"vld1.32 {d9}, [%[rhs]:64]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 d0, d0, d4\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d9\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "d10", "d11",
"cc", "memory");
}
inline void mul_1x8_3x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count,
std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d6}, [%[lhs]:64]!\n"
"vld1.8 {d7, d8, d9}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q5, d7, d6\n"
"vmull.u8 q6, d8, d6\n"
"vmull.u8 q7, d9, d6\n"
"vpadal.u16 q0, q5\n"
"vpadal.u16 q1, q6\n"
"vpadal.u16 q2, q7\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[lhs]:64]\n"
"vdup.32 q5, d8[0]\n"
"vld1.32 {q6}, [%[rhs]:64]\n"
// Change stride because storing in two ops.
"sub %[result_stride], %[result_stride], #8\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d1, d4, d4\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 q0, q0, q5\n"
// Add rhs offset to aggregated rows.
"vadd.s32 q0, q0, q6\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]]!\n"
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "cc", "memory");
}
inline void mul_2x8_1x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count,
std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d4, d5}, [%[lhs]:64]!\n"
"vld1.8 {d6}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q4, d6, d4\n"
"vmull.u8 q5, d6, d5\n"
"vpadal.u16 q0, q4\n"
"vpadal.u16 q1, q5\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[lhs]:64]\n"
"vdup.32 d4, d8[0]\n"
"vdup.32 d5, d8[1]\n"
"vld1.32 {d9}, [%[rhs]:64]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
// Reduce rows.
"vpadd.u32 d0, d0, d0\n"
"vpadd.u32 d2, d2, d2\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 d0, d0, d4\n"
"vadd.s32 d2, d2, d5\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d9\n"
"vadd.s32 d2, d2, d9\n"
// Store reduced rows.
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d2[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "d10", "d11",
"cc", "memory");
}
inline void mul_2x8_2x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count,
std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"vmov.i32 q3, q0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d8, d9}, [%[lhs]:64]!\n"
"vld1.8 {d10, d11}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q6, d10, d8\n"
"vmull.u8 q7, d11, d8\n"
"vmull.u8 q8, d10, d9\n"
"vmull.u8 q9, d11, d9\n"
"vpadal.u16 q0, q6\n"
"vpadal.u16 q1, q7\n"
"vpadal.u16 q2, q8\n"
"vpadal.u16 q3, q9\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[lhs]:64]\n"
"vdup.32 d9, d8[0]\n"
"vdup.32 d10, d8[1]\n"
"vld1.32 {d11}, [%[rhs]:64]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
"vpadd.u32 d6, d6, d7\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d4, d4, d6\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 d0, d0, d9\n"
"vadd.s32 d4, d4, d10\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d11\n"
"vadd.s32 d4, d4, d11\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
"vst1.32 {d4}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "cc",
"memory");
}
inline void mul_2x8_3x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count,
std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"vmov.i32 q3, q0\n"
"vmov.i32 q4, q1\n"
"vmov.i32 q5, q2\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d12, d13}, [%[lhs]:64]!\n"
"vld1.8 {d14, d15, d16}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q9, d14, d12\n"
"vmull.u8 q10, d15, d12\n"
"vmull.u8 q11, d16, d12\n"
"vmull.u8 q12, d14, d13\n"
"vmull.u8 q13, d15, d13\n"
"vmull.u8 q14, d16, d13\n"
"vpadal.u16 q0, q9\n"
"vpadal.u16 q1, q10\n"
"vpadal.u16 q2, q11\n"
"vpadal.u16 q3, q12\n"
"vpadal.u16 q4, q13\n"
"vpadal.u16 q5, q14\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d12}, [%[lhs]:64]\n"
"vdup.32 q7, d12[0]\n"
"vdup.32 q8, d12[1]\n"
"vld1.32 {q9}, [%[rhs]:64]\n"
// Change stride because storing in two ops.
"sub %[result_stride], %[result_stride], #8\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
"vpadd.u32 d6, d6, d7\n"
"vpadd.u32 d8, d8, d9\n"
"vpadd.u32 d10, d10, d11\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d1, d4, d4\n"
"vpadd.u32 d6, d6, d8\n"
"vpadd.u32 d7, d10, d10\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 q0, q0, q7\n"
"vadd.s32 q3, q3, q8\n"
// Add rhs offset to aggregated rows.
"vadd.s32 q0, q0, q9\n"
"vadd.s32 q3, q3, q9\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]]!\n"
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d6}, [%[result]]!\n"
"vst1.32 {d7[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "cc",
"memory");
}
inline void mul_3x8_1x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count,
std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d6, d7, d8}, [%[lhs]:64]!\n"
"vld1.8 {d9}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q5, d9, d6\n"
"vmull.u8 q6, d9, d7\n"
"vmull.u8 q7, d9, d8\n"
"vpadal.u16 q0, q5\n"
"vpadal.u16 q1, q6\n"
"vpadal.u16 q2, q7\n"
// Loop break.
"bne 1b\n"
"vld1.32 {q4}, [%[lhs]:64]\n"
"vdup.32 d6, d8[0]\n"
"vdup.32 d7, d8[1]\n"
"vdup.32 d10, d9[0]\n"
"vld1.32 {d11}, [%[rhs]:64]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
// Reduce rows.
"vpadd.u32 d0, d0, d0\n"
"vpadd.u32 d2, d2, d2\n"
"vpadd.u32 d4, d4, d4\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 d0, d0, d6\n"
"vadd.s32 d2, d2, d7\n"
"vadd.s32 d4, d4, d10\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d11\n"
"vadd.s32 d2, d2, d11\n"
"vadd.s32 d4, d4, d11\n"
// Store reduced rows.
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d2[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d4[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "cc", "memory");
}
inline void mul_3x8_2x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count,
std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"vmov.i32 q3, q0\n"
"vmov.i32 q4, q1\n"
"vmov.i32 q5, q2\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d12, d13, d14}, [%[lhs]:64]!\n"
"vld1.8 {d15, d16}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q9, d15, d12\n"
"vmull.u8 q10, d16, d12\n"
"vmull.u8 q11, d15, d13\n"
"vmull.u8 q12, d16, d13\n"
"vmull.u8 q13, d15, d14\n"
"vmull.u8 q14, d16, d14\n"
"vpadal.u16 q0, q9\n"
"vpadal.u16 q1, q10\n"
"vpadal.u16 q2, q11\n"
"vpadal.u16 q3, q12\n"
"vpadal.u16 q4, q13\n"
"vpadal.u16 q5, q14\n"
// Loop break.
"bne 1b\n"
"vld1.32 {q6}, [%[lhs]:64]\n"
"vdup.32 d14, d12[0]\n"
"vdup.32 d15, d12[1]\n"
"vdup.32 d16, d13[0]\n"
"vld1.32 {d17}, [%[rhs]:64]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
"vpadd.u32 d6, d6, d7\n"
"vpadd.u32 d8, d8, d9\n"
"vpadd.u32 d10, d10, d11\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d4, d4, d6\n"
"vpadd.u32 d8, d8, d10\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 d0, d0, d14\n"
"vadd.s32 d4, d4, d15\n"
"vadd.s32 d8, d8, d16\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d17\n"
"vadd.s32 d4, d4, d17\n"
"vadd.s32 d8, d8, d17\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
"vst1.32 {d4}, [%[result]], %[result_stride]\n"
"vst1.32 {d8}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "cc",
"memory");
}
inline void mul_3x8_3x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count,
std::int32_t* result,
std::int32_t result_stride) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"vmov.i32 q3, q0\n"
"vmov.i32 q4, q1\n"
"vmov.i32 q5, q2\n"
"vmov.i32 q6, q3\n"
"vmov.i32 q7, q4\n"
"vmov.i32 q8, q5\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// 3x3 lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d18, d19, d20}, [%[lhs]:64]!\n"
"vld1.8 {d21, d22, d23}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q12, d18, d21\n"
"vmull.u8 q13, d18, d22\n"
"vmull.u8 q14, d18, d23\n"
"vmull.u8 q15, d19, d21\n"
"vpadal.u16 q0, q12\n"
"vpadal.u16 q1, q13\n"
"vpadal.u16 q2, q14\n"
"vpadal.u16 q3, q15\n"
"vmull.u8 q12, d19, d22\n"
"vmull.u8 q13, d19, d23\n"
"vmull.u8 q14, d20, d21\n"
"vmull.u8 q15, d20, d22\n"
"vmull.u8 q9, d20, d23\n"
"vpadal.u16 q4, q12\n"
"vpadal.u16 q5, q13\n"
"vpadal.u16 q6, q14\n"
"vpadal.u16 q7, q15\n"
"vpadal.u16 q8, q9\n"
// Loop break.
"bne 1b\n"
"vld1.32 {q9}, [%[lhs]:64]\n"
"vdup.32 q10, d18[0]\n"
"vdup.32 q11, d18[1]\n"
"vdup.32 q12, d19[0]\n"
"vld1.32 {q13}, [%[rhs]:64]\n"
// Change stride because storing in two ops.
"sub %[result_stride], %[result_stride], #8\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
"vpadd.u32 d6, d6, d7\n"
"vpadd.u32 d8, d8, d9\n"
"vpadd.u32 d10, d10, d11\n"
"vpadd.u32 d12, d12, d13\n"
"vpadd.u32 d14, d14, d15\n"
"vpadd.u32 d16, d16, d17\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d1, d4, d4\n"
"vpadd.u32 d6, d6, d8\n"
"vpadd.u32 d7, d10, d10\n"
"vpadd.u32 d12, d12, d14\n"
"vpadd.u32 d13, d16, d16\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 q0, q0, q10\n"
"vadd.s32 q3, q3, q11\n"
"vadd.s32 q6, q6, q12\n"
// Add rhs offset to aggregated rows.
"vadd.s32 q0, q0, q13\n"
"vadd.s32 q3, q3, q13\n"
"vadd.s32 q6, q6, q13\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]]!\n"
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d6}, [%[result]]!\n"
"vst1.32 {d7[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d12}, [%[result]]!\n"
"vst1.32 {d13[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_stride] "+r"(result_stride),
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30",
"d31", "cc", "memory");
}
inline void mul_1x8_1x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, float* result,
std::int32_t result_stride,
float result_scale) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d2}, [%[lhs]:64]!\n"
"vld1.8 {d3}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q2, d3, d2\n"
"vpadal.u16 q0, q2\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[lhs]:64]\n"
"vdup.32 d4, d8[0]\n"
"vld1.32 {d9}, [%[rhs]:64]\n"
"vdup.32 d5, %[result_scale]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
// Reduce rows.
"vpadd.u32 d0, d0, d0\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 d0, d0, d4\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d9\n"
// Convert to float. Multiply by result scale.
"vcvt.f32.s32 d0, d0\n"
"vmul.f32 d0, d0, d5\n"
// Store reduced rows.
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_scale] "+r"(result_scale),
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
[result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "cc", "memory");
}
inline void mul_1x8_2x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, float* result,
std::int32_t result_stride,
float result_scale) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d4}, [%[lhs]:64]!\n"
"vld1.8 {d5, d6}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q4, d5, d4\n"
"vmull.u8 q5, d6, d4\n"
"vpadal.u16 q0, q4\n"
"vpadal.u16 q1, q5\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[lhs]:64]\n"
"vdup.32 d4, d8[0]\n"
"vld1.32 {d9}, [%[rhs]:64]\n"
"vdup.32 d5, %[result_scale]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 d0, d0, d4\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d9\n"
// Convert to float. Multiply by result scale.
"vcvt.f32.s32 d0, d0\n"
"vmul.f32 d0, d0, d5\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_scale] "+r"(result_scale),
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
[result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "d10", "d11",
"cc", "memory");
}
inline void mul_1x8_3x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, float* result,
std::int32_t result_stride,
float result_scale) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d6}, [%[lhs]:64]!\n"
"vld1.8 {d7, d8, d9}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q5, d7, d6\n"
"vmull.u8 q6, d8, d6\n"
"vmull.u8 q7, d9, d6\n"
"vpadal.u16 q0, q5\n"
"vpadal.u16 q1, q6\n"
"vpadal.u16 q2, q7\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[lhs]:64]\n"
"vdup.32 q5, d8[0]\n"
"vld1.32 {q6}, [%[rhs]:64]\n"
"vdup.32 q7, %[result_scale]\n"
// Change stride because storing in two ops.
"sub %[result_stride], %[result_stride], #8\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d1, d4, d4\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 q0, q0, q5\n"
// Add rhs offset to aggregated rows.
"vadd.s32 q0, q0, q6\n"
// Convert to float. Multiply by result scale.
"vcvt.f32.s32 q0, q0\n"
"vmul.f32 q0, q0, q7\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]]!\n"
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_scale] "+r"(result_scale),
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
[result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "cc", "memory");
}
inline void mul_2x8_1x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, float* result,
std::int32_t result_stride,
float result_scale) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d4, d5}, [%[lhs]:64]!\n"
"vld1.8 {d6}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q4, d6, d4\n"
"vmull.u8 q5, d6, d5\n"
"vpadal.u16 q0, q4\n"
"vpadal.u16 q1, q5\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[lhs]:64]\n"
"vdup.32 d4, d8[0]\n"
"vdup.32 d5, d8[1]\n"
"vld1.32 {d9}, [%[rhs]:64]\n"
"vdup.32 d6, %[result_scale]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
// Reduce rows.
"vpadd.u32 d0, d0, d0\n"
"vpadd.u32 d2, d2, d2\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 d0, d0, d4\n"
"vadd.s32 d2, d2, d5\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d9\n"
"vadd.s32 d2, d2, d9\n"
// Convert to float. Multiply by result scale.
"vcvt.f32.s32 d0, d0\n"
"vcvt.f32.s32 d2, d2\n"
"vmul.f32 d0, d0, d6\n"
"vmul.f32 d2, d2, d6\n"
// Store reduced rows.
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d2[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_scale] "+r"(result_scale),
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
[result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "d10", "d11",
"cc", "memory");
}
inline void mul_2x8_2x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, float* result,
std::int32_t result_stride,
float result_scale) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"vmov.i32 q3, q0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d8, d9}, [%[lhs]:64]!\n"
"vld1.8 {d10, d11}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q6, d10, d8\n"
"vmull.u8 q7, d11, d8\n"
"vmull.u8 q8, d10, d9\n"
"vmull.u8 q9, d11, d9\n"
"vpadal.u16 q0, q6\n"
"vpadal.u16 q1, q7\n"
"vpadal.u16 q2, q8\n"
"vpadal.u16 q3, q9\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d8}, [%[lhs]:64]\n"
"vdup.32 d9, d8[0]\n"
"vdup.32 d10, d8[1]\n"
"vld1.32 {d11}, [%[rhs]:64]\n"
"vdup.32 d12, %[result_scale]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
"vpadd.u32 d6, d6, d7\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d4, d4, d6\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 d0, d0, d9\n"
"vadd.s32 d4, d4, d10\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d11\n"
"vadd.s32 d4, d4, d11\n"
// Convert to float. Multiply by result scale.
"vcvt.f32.s32 d0, d0\n"
"vcvt.f32.s32 d4, d4\n"
"vmul.f32 d0, d0, d12\n"
"vmul.f32 d4, d4, d12\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
"vst1.32 {d4}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_scale] "+r"(result_scale),
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
[result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "cc",
"memory");
}
inline void mul_2x8_3x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, float* result,
std::int32_t result_stride,
float result_scale) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"vmov.i32 q3, q0\n"
"vmov.i32 q4, q1\n"
"vmov.i32 q5, q2\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d12, d13}, [%[lhs]:64]!\n"
"vld1.8 {d14, d15, d16}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q9, d14, d12\n"
"vmull.u8 q10, d15, d12\n"
"vmull.u8 q11, d16, d12\n"
"vmull.u8 q12, d14, d13\n"
"vmull.u8 q13, d15, d13\n"
"vmull.u8 q14, d16, d13\n"
"vpadal.u16 q0, q9\n"
"vpadal.u16 q1, q10\n"
"vpadal.u16 q2, q11\n"
"vpadal.u16 q3, q12\n"
"vpadal.u16 q4, q13\n"
"vpadal.u16 q5, q14\n"
// Loop break.
"bne 1b\n"
"vld1.32 {d12}, [%[lhs]:64]\n"
"vdup.32 q7, d12[0]\n"
"vdup.32 q8, d12[1]\n"
"vld1.32 {q9}, [%[rhs]:64]\n"
"vdup.32 q10, %[result_scale]\n"
// Change stride because storing in two ops.
"sub %[result_stride], %[result_stride], #8\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
"vpadd.u32 d6, d6, d7\n"
"vpadd.u32 d8, d8, d9\n"
"vpadd.u32 d10, d10, d11\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d1, d4, d4\n"
"vpadd.u32 d6, d6, d8\n"
"vpadd.u32 d7, d10, d10\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 q0, q0, q7\n"
"vadd.s32 q3, q3, q8\n"
// Add rhs offset to aggregated rows.
"vadd.s32 q0, q0, q9\n"
"vadd.s32 q3, q3, q9\n"
// Convert to float. Multiply by result scale.
"vcvt.f32.s32 q0, q0\n"
"vcvt.f32.s32 q3, q3\n"
"vmul.f32 q0, q0, q10\n"
"vmul.f32 q3, q3, q10\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]]!\n"
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d6}, [%[result]]!\n"
"vst1.32 {d7[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_scale] "+r"(result_scale),
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
[result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "cc",
"memory");
}
inline void mul_3x8_1x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, float* result,
std::int32_t result_stride,
float result_scale) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d6, d7, d8}, [%[lhs]:64]!\n"
"vld1.8 {d9}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q5, d9, d6\n"
"vmull.u8 q6, d9, d7\n"
"vmull.u8 q7, d9, d8\n"
"vpadal.u16 q0, q5\n"
"vpadal.u16 q1, q6\n"
"vpadal.u16 q2, q7\n"
// Loop break.
"bne 1b\n"
"vld1.32 {q4}, [%[lhs]:64]\n"
"vdup.32 d6, d8[0]\n"
"vdup.32 d7, d8[1]\n"
"vdup.32 d10, d9[0]\n"
"vld1.32 {d11}, [%[rhs]:64]\n"
"vdup.32 d12, %[result_scale]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
// Reduce rows.
"vpadd.u32 d0, d0, d0\n"
"vpadd.u32 d2, d2, d2\n"
"vpadd.u32 d4, d4, d4\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 d0, d0, d6\n"
"vadd.s32 d2, d2, d7\n"
"vadd.s32 d4, d4, d10\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d11\n"
"vadd.s32 d2, d2, d11\n"
"vadd.s32 d4, d4, d11\n"
// Convert to float. Multiply by result scale.
"vcvt.f32.s32 d0, d0\n"
"vcvt.f32.s32 d2, d2\n"
"vcvt.f32.s32 d4, d4\n"
"vmul.f32 d0, d0, d12\n"
"vmul.f32 d2, d2, d12\n"
"vmul.f32 d4, d4, d12\n"
// Store reduced rows.
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d2[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d4[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_scale] "+r"(result_scale),
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
[result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "cc", "memory");
}
inline void mul_3x8_2x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, float* result,
std::int32_t result_stride,
float result_scale) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"vmov.i32 q3, q0\n"
"vmov.i32 q4, q1\n"
"vmov.i32 q5, q2\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// General NxM lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d12, d13, d14}, [%[lhs]:64]!\n"
"vld1.8 {d15, d16}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q9, d15, d12\n"
"vmull.u8 q10, d16, d12\n"
"vmull.u8 q11, d15, d13\n"
"vmull.u8 q12, d16, d13\n"
"vmull.u8 q13, d15, d14\n"
"vmull.u8 q14, d16, d14\n"
"vpadal.u16 q0, q9\n"
"vpadal.u16 q1, q10\n"
"vpadal.u16 q2, q11\n"
"vpadal.u16 q3, q12\n"
"vpadal.u16 q4, q13\n"
"vpadal.u16 q5, q14\n"
// Loop break.
"bne 1b\n"
"vld1.32 {q6}, [%[lhs]:64]\n"
"vdup.32 d14, d12[0]\n"
"vdup.32 d15, d12[1]\n"
"vdup.32 d16, d13[0]\n"
"vld1.32 {d17}, [%[rhs]:64]\n"
"vdup.32 d18, %[result_scale]\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
"vpadd.u32 d6, d6, d7\n"
"vpadd.u32 d8, d8, d9\n"
"vpadd.u32 d10, d10, d11\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d4, d4, d6\n"
"vpadd.u32 d8, d8, d10\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 d0, d0, d14\n"
"vadd.s32 d4, d4, d15\n"
"vadd.s32 d8, d8, d16\n"
// Add rhs offset to aggregated rows.
"vadd.s32 d0, d0, d17\n"
"vadd.s32 d4, d4, d17\n"
"vadd.s32 d8, d8, d17\n"
// Convert to float. Multiply by result scale.
"vcvt.f32.s32 d0, d0\n"
"vcvt.f32.s32 d4, d4\n"
"vcvt.f32.s32 d8, d8\n"
"vmul.f32 d0, d0, d18\n"
"vmul.f32 d4, d4, d18\n"
"vmul.f32 d8, d8, d18\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
"vst1.32 {d4}, [%[result]], %[result_stride]\n"
"vst1.32 {d8}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_scale] "+r"(result_scale),
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
[result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "cc",
"memory");
}
inline void mul_3x8_3x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
const std::uint8_t* rhs,
std::int32_t count, float* result,
std::int32_t result_stride,
float result_scale) {
asm volatile(
// Clear aggregators.
"vmov.i32 q0, #0\n"
"vmov.i32 q1, #0\n"
"vmov.i32 q2, #0\n"
"vmov.i32 q3, q0\n"
"vmov.i32 q4, q1\n"
"vmov.i32 q5, q2\n"
"vmov.i32 q6, q3\n"
"vmov.i32 q7, q4\n"
"vmov.i32 q8, q5\n"
"pld [%[lhs]]\n"
"pld [%[rhs]]\n"
// 3x3 lanes loop.
"1:"
// Subtract counter.
"subs %[count], %[count], #8\n"
"vld1.8 {d18, d19, d20}, [%[lhs]:64]!\n"
"vld1.8 {d21, d22, d23}, [%[rhs]:64]!\n"
"pld [%[lhs], #64]\n"
"pld [%[rhs], #64]\n"
"vmull.u8 q12, d18, d21\n"
"vmull.u8 q13, d18, d22\n"
"vmull.u8 q14, d18, d23\n"
"vmull.u8 q15, d19, d21\n"
"vpadal.u16 q0, q12\n"
"vpadal.u16 q1, q13\n"
"vpadal.u16 q2, q14\n"
"vpadal.u16 q3, q15\n"
"vmull.u8 q12, d19, d22\n"
"vmull.u8 q13, d19, d23\n"
"vmull.u8 q14, d20, d21\n"
"vmull.u8 q15, d20, d22\n"
"vmull.u8 q9, d20, d23\n"
"vpadal.u16 q4, q12\n"
"vpadal.u16 q5, q13\n"
"vpadal.u16 q6, q14\n"
"vpadal.u16 q7, q15\n"
"vpadal.u16 q8, q9\n"
// Loop break.
"bne 1b\n"
"vld1.32 {q9}, [%[lhs]:64]\n"
"vdup.32 q10, d18[0]\n"
"vdup.32 q11, d18[1]\n"
"vdup.32 q12, d19[0]\n"
"vld1.32 {q13}, [%[rhs]:64]\n"
"vdup.32 q14, %[result_scale]\n"
// Change stride because storing in two ops.
"sub %[result_stride], %[result_stride], #8\n"
// Horizontal reduce aggregators.
"vpadd.u32 d0, d0, d1\n"
"vpadd.u32 d2, d2, d3\n"
"vpadd.u32 d4, d4, d5\n"
"vpadd.u32 d6, d6, d7\n"
"vpadd.u32 d8, d8, d9\n"
"vpadd.u32 d10, d10, d11\n"
"vpadd.u32 d12, d12, d13\n"
"vpadd.u32 d14, d14, d15\n"
"vpadd.u32 d16, d16, d17\n"
// Reduce rows.
"vpadd.u32 d0, d0, d2\n"
"vpadd.u32 d1, d4, d4\n"
"vpadd.u32 d6, d6, d8\n"
"vpadd.u32 d7, d10, d10\n"
"vpadd.u32 d12, d12, d14\n"
"vpadd.u32 d13, d16, d16\n"
// Add lhs offsets to aggregated rows.
"vadd.s32 q0, q0, q10\n"
"vadd.s32 q3, q3, q11\n"
"vadd.s32 q6, q6, q12\n"
// Add rhs offset to aggregated rows.
"vadd.s32 q0, q0, q13\n"
"vadd.s32 q3, q3, q13\n"
"vadd.s32 q6, q6, q13\n"
// Convert to float. Multiply by result scale.
"vcvt.f32.s32 q0, q0\n"
"vcvt.f32.s32 q3, q3\n"
"vcvt.f32.s32 q6, q6\n"
"vmul.f32 q0, q0, q14\n"
"vmul.f32 q3, q3, q14\n"
"vmul.f32 q6, q6, q14\n"
// Store reduced rows.
"vst1.32 {d0}, [%[result]]!\n"
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d6}, [%[result]]!\n"
"vst1.32 {d7[0]}, [%[result]], %[result_stride]\n"
"vst1.32 {d12}, [%[result]]!\n"
"vst1.32 {d13[0]}, [%[result]], %[result_stride]\n"
: [count] "+r"(count), [result_scale] "+r"(result_scale),
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
[result] "+r"(result)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30",
"d31", "cc", "memory");
}
void qnt_1x8_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]:64]!\n"
"bne 1b\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_1_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #1\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8[0]}, [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vmul.i32 q4, q4, q0\n"
"vadd.i32 q4, q4, q1\n"
"vshl.s32 q4, q4, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_2_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #2\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8}, [%[source]:64]\n"
"vadd.i32 q4, q4, q3\n"
"vmul.i32 q4, q4, q0\n"
"vadd.i32 q4, q4, q1\n"
"vshl.s32 q4, q4, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovun.s16 d12, q6\n"
"vst1.16 {d12[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_3_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #3\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8}, [%[source]:64]!\n"
"vld1.32 {d9[0]}, [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vmul.i32 q4, q4, q0\n"
"vadd.i32 q4, q4, q1\n"
"vshl.s32 q4, q4, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovun.s16 d12, q6\n"
"vst1.16 {d12[0]}, [%[destination]]!\n"
"vst1.8 {d12[2]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_4_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #4\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8, d9}, [%[source]:64]\n"
"vadd.i32 q4, q4, q3\n"
"vmul.i32 q4, q4, q0\n"
"vadd.i32 q4, q4, q1\n"
"vshl.s32 q4, q4, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovun.s16 d12, q6\n"
"vst1.32 {d12[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_5_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #5\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8, d9}, [%[source]:64]!\n"
"vld1.32 {d10[0]}, [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.32 {d12[0]}, [%[destination]]!\n"
"vst1.8 {d12[4]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_6_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #6\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8, d9, d10}, [%[source]:64]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.32 {d12[0]}, [%[destination]]!\n"
"vst1.16 {d12[2]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_7_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #7\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8, d9, d10}, [%[source]:64]!\n"
"vld1.32 {d11[0]}, [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.32 {d12[0]}, [%[destination]]!\n"
"vst1.16 {d12[2]}, [%[destination]]!\n"
"vst1.8 {d12[6]}, [%[destination]]!\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_2x8_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]:64]!\n"
"vst1.8 {d20}, [r1:64]!\n"
"bne 1b\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_1_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #1\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]:64]!\n"
"vst1.8 {d20}, [r1:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10[0]}, [%[source]]\n"
"vld1.32 {d14[0]}, [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q7, q7, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q7, q7, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q7, q7, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q7, q7, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d20, q7\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18[0]}, [%[destination]]\n"
"vst1.8 {d20[0]}, [r1]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_2_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #2\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]:64]!\n"
"vst1.8 {d20}, [r1:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10}, [%[source]:64]\n"
"vld1.32 {d14}, [r0:64]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q7, q7, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q7, q7, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q7, q7, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q7, q7, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d20, q7\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.16 {d18[0]}, [%[destination]]\n"
"vst1.16 {d20[0]}, [r1]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_3_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #3\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]:64]!\n"
"vst1.8 {d20}, [r1:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10}, [%[source]:64]!\n"
"vld1.32 {d14}, [r0:64]!\n"
"vld1.32 {d11[0]}, [%[source]]\n"
"vld1.32 {d15[0]}, [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q7, q7, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q7, q7, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q7, q7, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q7, q7, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d20, q7\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.16 {d18[0]}, [%[destination]]!\n"
"vst1.16 {d20[0]}, [r1]!\n"
"vst1.8 {d18[2]}, [%[destination]]\n"
"vst1.8 {d20[2]}, [r1]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_4_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #4\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]:64]!\n"
"vst1.8 {d20}, [r1:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10, d11}, [%[source]:64]\n"
"vld1.32 {d14, d15}, [r0:64]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q7, q7, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q7, q7, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q7, q7, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q7, q7, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d20, q7\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.32 {d18[0]}, [%[destination]]\n"
"vst1.32 {d20[0]}, [r1]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_5_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #5\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]:64]!\n"
"vst1.8 {d20}, [r1:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10, d11}, [%[source]:64]!\n"
"vld1.32 {d14, d15}, [r0:64]!\n"
"vld1.32 {d12[0]}, [%[source]]\n"
"vld1.32 {d16[0]}, [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.32 {d18[0]}, [%[destination]]!\n"
"vst1.32 {d20[0]}, [r1]!\n"
"vst1.8 {d18[4]}, [%[destination]]\n"
"vst1.8 {d20[4]}, [r1]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_6_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #6\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]:64]!\n"
"vst1.8 {d20}, [r1:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10, d11, d12}, [%[source]:64]\n"
"vld1.32 {d14, d15, d16}, [r0:64]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.32 {d18[0]}, [%[destination]]!\n"
"vst1.32 {d20[0]}, [r1]!\n"
"vst1.16 {d18[2]}, [%[destination]]\n"
"vst1.16 {d20[2]}, [r1]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_7_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #7\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]:64]!\n"
"vst1.8 {d20}, [r1:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10, d11, d12}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16}, [r0:64]!\n"
"vld1.32 {d13[0]}, [%[source]]\n"
"vld1.32 {d17[0]}, [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.32 {d18[0]}, [%[destination]]!\n"
"vst1.32 {d20[0]}, [r1]!\n"
"vst1.16 {d18[2]}, [%[destination]]!\n"
"vst1.16 {d20[2]}, [r1]!\n"
"vst1.8 {d18[6]}, [%[destination]]!\n"
"vst1.8 {d20[6]}, [r1]!\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_3x8_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]:64]!\n"
"vst1.8 {d26}, [r1:64]!\n"
"vst1.8 {d28}, [r3:64]!\n"
"bne 1b\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_1_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #1\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]:64]!\n"
"vst1.8 {d26}, [r1:64]!\n"
"vst1.8 {d28}, [r3:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12[0]}, [%[source]]\n"
"vld1.32 {d16[0]}, [r0]\n"
"vld1.32 {d20[0]}, [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q10, q10, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q10, q10, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q10, q10, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q10, q10, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d28, q10\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24[0]}, [%[destination]]\n"
"vst1.8 {d26[0]}, [r1]\n"
"vst1.8 {d28[0]}, [r3]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_2_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #2\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]:64]!\n"
"vst1.8 {d26}, [r1:64]!\n"
"vst1.8 {d28}, [r3:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12}, [%[source]:64]\n"
"vld1.32 {d16}, [r0:64]\n"
"vld1.32 {d20}, [r2:64]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q10, q10, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q10, q10, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q10, q10, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q10, q10, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d28, q10\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.16 {d24[0]}, [%[destination]]\n"
"vst1.16 {d26[0]}, [r1]\n"
"vst1.16 {d28[0]}, [r3]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_3_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #3\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]:64]!\n"
"vst1.8 {d26}, [r1:64]!\n"
"vst1.8 {d28}, [r3:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12}, [%[source]:64]!\n"
"vld1.32 {d16}, [r0:64]!\n"
"vld1.32 {d20}, [r2:64]!\n"
"vld1.32 {d13[0]}, [%[source]]\n"
"vld1.32 {d17[0]}, [r0]\n"
"vld1.32 {d21[0]}, [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q10, q10, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q10, q10, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q10, q10, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q10, q10, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d28, q10\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.16 {d24[0]}, [%[destination]]!\n"
"vst1.16 {d26[0]}, [r1]!\n"
"vst1.16 {d28[0]}, [r3]!\n"
"vst1.8 {d24[2]}, [%[destination]]\n"
"vst1.8 {d26[2]}, [r1]\n"
"vst1.8 {d28[2]}, [r3]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_4_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #4\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]:64]!\n"
"vst1.8 {d26}, [r1:64]!\n"
"vst1.8 {d28}, [r3:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12, d13}, [%[source]:64]\n"
"vld1.32 {d16, d17}, [r0:64]\n"
"vld1.32 {d20, d21}, [r2:64]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q10, q10, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q10, q10, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q10, q10, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q10, q10, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d28, q10\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.32 {d24[0]}, [%[destination]]\n"
"vst1.32 {d26[0]}, [r1]\n"
"vst1.32 {d28[0]}, [r3]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_5_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #5\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]:64]!\n"
"vst1.8 {d26}, [r1:64]!\n"
"vst1.8 {d28}, [r3:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12, d13}, [%[source]:64]!\n"
"vld1.32 {d16, d17}, [r0:64]!\n"
"vld1.32 {d20, d21}, [r2:64]!\n"
"vld1.32 {d14[0]}, [%[source]]\n"
"vld1.32 {d18[0]}, [r0]\n"
"vld1.32 {d22[0]}, [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.32 {d24[0]}, [%[destination]]!\n"
"vst1.32 {d26[0]}, [r1]!\n"
"vst1.32 {d28[0]}, [r3]!\n"
"vst1.8 {d24[4]}, [%[destination]]\n"
"vst1.8 {d26[4]}, [r1]\n"
"vst1.8 {d28[4]}, [r3]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_6_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #6\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]:64]!\n"
"vst1.8 {d26}, [r1:64]!\n"
"vst1.8 {d28}, [r3:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12, d13, d14}, [%[source]:64]\n"
"vld1.32 {d16, d17, d18}, [r0:64]\n"
"vld1.32 {d20, d21, d22}, [r2:64]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.32 {d24[0]}, [%[destination]]!\n"
"vst1.32 {d26[0]}, [r1]!\n"
"vst1.32 {d28[0]}, [r3]!\n"
"vst1.16 {d24[2]}, [%[destination]]\n"
"vst1.16 {d26[2]}, [r1]\n"
"vst1.16 {d28[2]}, [r3]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_7_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #7\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]:64]!\n"
"vst1.8 {d26}, [r1:64]!\n"
"vst1.8 {d28}, [r3:64]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12, d13, d14}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18}, [r0:64]!\n"
"vld1.32 {d20, d21, d22}, [r2:64]!\n"
"vld1.32 {d15[0]}, [%[source]]\n"
"vld1.32 {d19[0]}, [r0]\n"
"vld1.32 {d23[0]}, [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.32 {d24[0]}, [%[destination]]!\n"
"vst1.32 {d26[0]}, [r1]!\n"
"vst1.32 {d28[0]}, [r3]!\n"
"vst1.16 {d24[2]}, [%[destination]]!\n"
"vst1.16 {d26[2]}, [r1]!\n"
"vst1.16 {d28[2]}, [r3]!\n"
"vst1.8 {d24[6]}, [%[destination]]!\n"
"vst1.8 {d26[6]}, [r1]!\n"
"vst1.8 {d28[6]}, [r3]!\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_1x8(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]]!\n"
"bne 1b\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_1(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #1\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8[0]}, [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vmul.i32 q4, q4, q0\n"
"vadd.i32 q4, q4, q1\n"
"vshl.s32 q4, q4, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_2(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #2\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8}, [%[source]:64]\n"
"vadd.i32 q4, q4, q3\n"
"vmul.i32 q4, q4, q0\n"
"vadd.i32 q4, q4, q1\n"
"vshl.s32 q4, q4, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovun.s16 d12, q6\n"
"vst1.16 {d12[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_3(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #3\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8}, [%[source]:64]!\n"
"vld1.32 {d9[0]}, [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vmul.i32 q4, q4, q0\n"
"vadd.i32 q4, q4, q1\n"
"vshl.s32 q4, q4, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovun.s16 d12, q6\n"
"vst1.16 {d12[0]}, [%[destination]]!\n"
"vst1.8 {d12[2]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_4(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #4\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8, d9}, [%[source]:64]\n"
"vadd.i32 q4, q4, q3\n"
"vmul.i32 q4, q4, q0\n"
"vadd.i32 q4, q4, q1\n"
"vshl.s32 q4, q4, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovun.s16 d12, q6\n"
"vst1.32 {d12[0]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_5(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #5\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8, d9}, [%[source]:64]!\n"
"vld1.32 {d10[0]}, [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.32 {d12[0]}, [%[destination]]!\n"
"vst1.8 {d12[4]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_6(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #6\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8, d9, d10}, [%[source]:64]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.32 {d12[0]}, [%[destination]]!\n"
"vst1.16 {d12[2]}, [%[destination]]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_1x8_7(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"subs %[count], %[count], #7\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
"pld [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.8 {d12}, [%[destination]]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d8, d9, d10}, [%[source]:64]!\n"
"vld1.32 {d11[0]}, [%[source]]\n"
"vadd.i32 q4, q4, q3\n"
"vadd.i32 q5, q5, q3\n"
"vmul.i32 q4, q4, q0\n"
"vmul.i32 q5, q5, q0\n"
"vadd.i32 q4, q4, q1\n"
"vadd.i32 q5, q5, q1\n"
"vshl.s32 q4, q4, q2\n"
"vshl.s32 q5, q5, q2\n"
"vqmovn.s32 d12, q4\n"
"vqmovn.s32 d13, q5\n"
"vqmovun.s16 d12, q6\n"
"vst1.32 {d12[0]}, [%[destination]]!\n"
"vst1.16 {d12[2]}, [%[destination]]!\n"
"vst1.8 {d12[6]}, [%[destination]]!\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "cc", "memory");
}
void qnt_2x8(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]]!\n"
"vst1.8 {d20}, [r1]!\n"
"bne 1b\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_1(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #1\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]]!\n"
"vst1.8 {d20}, [r1]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10[0]}, [%[source]]\n"
"vld1.32 {d14[0]}, [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q7, q7, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q7, q7, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q7, q7, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q7, q7, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d20, q7\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18[0]}, [%[destination]]\n"
"vst1.8 {d20[0]}, [r1]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_2(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #2\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]]!\n"
"vst1.8 {d20}, [r1]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10}, [%[source]:64]\n"
"vld1.32 {d14}, [r0:64]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q7, q7, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q7, q7, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q7, q7, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q7, q7, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d20, q7\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.16 {d18[0]}, [%[destination]]\n"
"vst1.16 {d20[0]}, [r1]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_3(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #3\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]]!\n"
"vst1.8 {d20}, [r1]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10}, [%[source]:64]!\n"
"vld1.32 {d14}, [r0:64]!\n"
"vld1.32 {d11[0]}, [%[source]]\n"
"vld1.32 {d15[0]}, [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q7, q7, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q7, q7, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q7, q7, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q7, q7, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d20, q7\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.16 {d18[0]}, [%[destination]]!\n"
"vst1.16 {d20[0]}, [r1]!\n"
"vst1.8 {d18[2]}, [%[destination]]\n"
"vst1.8 {d20[2]}, [r1]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_4(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #4\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]]!\n"
"vst1.8 {d20}, [r1]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10, d11}, [%[source]:64]\n"
"vld1.32 {d14, d15}, [r0:64]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q7, q7, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q7, q7, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q7, q7, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q7, q7, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d20, q7\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.32 {d18[0]}, [%[destination]]\n"
"vst1.32 {d20[0]}, [r1]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_5(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #5\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]]!\n"
"vst1.8 {d20}, [r1]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10, d11}, [%[source]:64]!\n"
"vld1.32 {d14, d15}, [r0:64]!\n"
"vld1.32 {d12[0]}, [%[source]]\n"
"vld1.32 {d16[0]}, [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.32 {d18[0]}, [%[destination]]!\n"
"vst1.32 {d20[0]}, [r1]!\n"
"vst1.8 {d18[4]}, [%[destination]]\n"
"vst1.8 {d20[4]}, [r1]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_6(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #6\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]]!\n"
"vst1.8 {d20}, [r1]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10, d11, d12}, [%[source]:64]\n"
"vld1.32 {d14, d15, d16}, [r0:64]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.32 {d18[0]}, [%[destination]]!\n"
"vst1.32 {d20[0]}, [r1]!\n"
"vst1.16 {d18[2]}, [%[destination]]\n"
"vst1.16 {d20[2]}, [r1]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_2x8_7(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"subs %[count], %[count], #7\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.8 {d18}, [%[destination]]!\n"
"vst1.8 {d20}, [r1]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d10, d11, d12}, [%[source]:64]!\n"
"vld1.32 {d14, d15, d16}, [r0:64]!\n"
"vld1.32 {d13[0]}, [%[source]]\n"
"vld1.32 {d17[0]}, [r0]\n"
"vadd.i32 q5, q5, q3\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q4\n"
"vadd.i32 q8, q8, q4\n"
"vmul.i32 q5, q5, q0\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vadd.i32 q5, q5, q1\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vshl.s32 q5, q5, q2\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vqmovn.s32 d18, q5\n"
"vqmovn.s32 d19, q6\n"
"vqmovn.s32 d20, q7\n"
"vqmovn.s32 d21, q8\n"
"vqmovun.s16 d18, q9\n"
"vqmovun.s16 d20, q10\n"
"vst1.32 {d18[0]}, [%[destination]]!\n"
"vst1.32 {d20[0]}, [r1]!\n"
"vst1.16 {d18[2]}, [%[destination]]!\n"
"vst1.16 {d20[2]}, [r1]!\n"
"vst1.8 {d18[6]}, [%[destination]]!\n"
"vst1.8 {d20[6]}, [r1]!\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "cc", "memory");
}
void qnt_3x8(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]]!\n"
"vst1.8 {d26}, [r1]!\n"
"vst1.8 {d28}, [r3]!\n"
"bne 1b\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_1(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #1\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]]!\n"
"vst1.8 {d26}, [r1]!\n"
"vst1.8 {d28}, [r3]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12[0]}, [%[source]]\n"
"vld1.32 {d16[0]}, [r0]\n"
"vld1.32 {d20[0]}, [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q10, q10, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q10, q10, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q10, q10, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q10, q10, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d28, q10\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24[0]}, [%[destination]]\n"
"vst1.8 {d26[0]}, [r1]\n"
"vst1.8 {d28[0]}, [r3]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_2(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #2\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]]!\n"
"vst1.8 {d26}, [r1]!\n"
"vst1.8 {d28}, [r3]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12}, [%[source]:64]\n"
"vld1.32 {d16}, [r0:64]\n"
"vld1.32 {d20}, [r2:64]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q10, q10, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q10, q10, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q10, q10, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q10, q10, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d28, q10\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.16 {d24[0]}, [%[destination]]\n"
"vst1.16 {d26[0]}, [r1]\n"
"vst1.16 {d28[0]}, [r3]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_3(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #3\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]]!\n"
"vst1.8 {d26}, [r1]!\n"
"vst1.8 {d28}, [r3]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12}, [%[source]:64]!\n"
"vld1.32 {d16}, [r0:64]!\n"
"vld1.32 {d20}, [r2:64]!\n"
"vld1.32 {d13[0]}, [%[source]]\n"
"vld1.32 {d17[0]}, [r0]\n"
"vld1.32 {d21[0]}, [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q10, q10, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q10, q10, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q10, q10, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q10, q10, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d28, q10\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.16 {d24[0]}, [%[destination]]!\n"
"vst1.16 {d26[0]}, [r1]!\n"
"vst1.16 {d28[0]}, [r3]!\n"
"vst1.8 {d24[2]}, [%[destination]]\n"
"vst1.8 {d26[2]}, [r1]\n"
"vst1.8 {d28[2]}, [r3]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_4(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #4\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]]!\n"
"vst1.8 {d26}, [r1]!\n"
"vst1.8 {d28}, [r3]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12, d13}, [%[source]:64]\n"
"vld1.32 {d16, d17}, [r0:64]\n"
"vld1.32 {d20, d21}, [r2:64]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q10, q10, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q10, q10, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q10, q10, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q10, q10, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d28, q10\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.32 {d24[0]}, [%[destination]]\n"
"vst1.32 {d26[0]}, [r1]\n"
"vst1.32 {d28[0]}, [r3]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_5(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #5\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]]!\n"
"vst1.8 {d26}, [r1]!\n"
"vst1.8 {d28}, [r3]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12, d13}, [%[source]:64]!\n"
"vld1.32 {d16, d17}, [r0:64]!\n"
"vld1.32 {d20, d21}, [r2:64]!\n"
"vld1.32 {d14[0]}, [%[source]]\n"
"vld1.32 {d18[0]}, [r0]\n"
"vld1.32 {d22[0]}, [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.32 {d24[0]}, [%[destination]]!\n"
"vst1.32 {d26[0]}, [r1]!\n"
"vst1.32 {d28[0]}, [r3]!\n"
"vst1.8 {d24[4]}, [%[destination]]\n"
"vst1.8 {d26[4]}, [r1]\n"
"vst1.8 {d28[4]}, [r3]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_6(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #6\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]]!\n"
"vst1.8 {d26}, [r1]!\n"
"vst1.8 {d28}, [r3]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12, d13, d14}, [%[source]:64]\n"
"vld1.32 {d16, d17, d18}, [r0:64]\n"
"vld1.32 {d20, d21, d22}, [r2:64]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.32 {d24[0]}, [%[destination]]!\n"
"vst1.32 {d26[0]}, [r1]!\n"
"vst1.32 {d28[0]}, [r3]!\n"
"vst1.16 {d24[2]}, [%[destination]]\n"
"vst1.16 {d26[2]}, [r1]\n"
"vst1.16 {d28[2]}, [r3]\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void qnt_3x8_7(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
std::int32_t shift) {
asm volatile(
"vdup.32 q0, %[multiplicative_offset]\n"
"vdup.32 q1, %[rounding_offset]\n"
"vdup.32 q2, %[shift]\n"
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
"add r0, %[source], %[stride]\n"
"add r1, %[destination], %[destination_stride]\n"
"add r2, r0, %[stride]\n"
"add r3, r1, %[destination_stride]\n"
"subs %[count], %[count], #7\n"
"beq 2f\n"
"1:"
"subs %[count], %[count], #8\n"
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
"pld [%[source]]\n"
"pld [r0]\n"
"pld [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.8 {d24}, [%[destination]]!\n"
"vst1.8 {d26}, [r1]!\n"
"vst1.8 {d28}, [r3]!\n"
"bne 1b\n"
"2:"
"vld1.32 {d12, d13, d14}, [%[source]:64]!\n"
"vld1.32 {d16, d17, d18}, [r0:64]!\n"
"vld1.32 {d20, d21, d22}, [r2:64]!\n"
"vld1.32 {d15[0]}, [%[source]]\n"
"vld1.32 {d19[0]}, [r0]\n"
"vld1.32 {d23[0]}, [r2]\n"
"vadd.i32 q6, q6, q3\n"
"vadd.i32 q7, q7, q3\n"
"vadd.i32 q8, q8, q4\n"
"vadd.i32 q9, q9, q4\n"
"vadd.i32 q10, q10, q5\n"
"vadd.i32 q11, q11, q5\n"
"vmul.i32 q6, q6, q0\n"
"vmul.i32 q7, q7, q0\n"
"vmul.i32 q8, q8, q0\n"
"vmul.i32 q9, q9, q0\n"
"vmul.i32 q10, q10, q0\n"
"vmul.i32 q11, q11, q0\n"
"vadd.i32 q6, q6, q1\n"
"vadd.i32 q7, q7, q1\n"
"vadd.i32 q8, q8, q1\n"
"vadd.i32 q9, q9, q1\n"
"vadd.i32 q10, q10, q1\n"
"vadd.i32 q11, q11, q1\n"
"vshl.s32 q6, q6, q2\n"
"vshl.s32 q7, q7, q2\n"
"vshl.s32 q8, q8, q2\n"
"vshl.s32 q9, q9, q2\n"
"vshl.s32 q10, q10, q2\n"
"vshl.s32 q11, q11, q2\n"
"vqmovn.s32 d24, q6\n"
"vqmovn.s32 d25, q7\n"
"vqmovn.s32 d26, q8\n"
"vqmovn.s32 d27, q9\n"
"vqmovn.s32 d28, q10\n"
"vqmovn.s32 d29, q11\n"
"vqmovun.s16 d24, q12\n"
"vqmovun.s16 d26, q13\n"
"vqmovun.s16 d28, q14\n"
"vst1.32 {d24[0]}, [%[destination]]!\n"
"vst1.32 {d26[0]}, [r1]!\n"
"vst1.32 {d28[0]}, [r3]!\n"
"vst1.16 {d24[2]}, [%[destination]]!\n"
"vst1.16 {d26[2]}, [r1]!\n"
"vst1.16 {d28[2]}, [r3]!\n"
"vst1.8 {d24[6]}, [%[destination]]!\n"
"vst1.8 {d26[6]}, [r1]!\n"
"vst1.8 {d28[6]}, [r3]!\n"
: [count] "+r"(count),
[multiplicative_offset] "+r"(multiplicative_offset),
[stride] "+r"(stride), [shift] "+r"(shift),
[destination] "+r"(destination), [offsets] "+r"(offsets),
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
[rounding_offset] "+r"(rounding_offset)
:
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
"d28", "d29", "cc", "memory");
}
void multi_qnt_1x8_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
switch (count % 8) {
case 0:
qnt_1x8_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 1:
qnt_1x8_1_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 2:
qnt_1x8_2_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 3:
qnt_1x8_3_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 4:
qnt_1x8_4_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 5:
qnt_1x8_5_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 6:
qnt_1x8_6_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 7:
qnt_1x8_7_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
}
}
void multi_qnt_2x8_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
switch (count % 8) {
case 0:
qnt_2x8_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 1:
qnt_2x8_1_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 2:
qnt_2x8_2_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 3:
qnt_2x8_3_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 4:
qnt_2x8_4_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 5:
qnt_2x8_5_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 6:
qnt_2x8_6_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 7:
qnt_2x8_7_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
}
}
void multi_qnt_3x8_aligned(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination,
std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
switch (count % 8) {
case 0:
qnt_3x8_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 1:
qnt_3x8_1_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 2:
qnt_3x8_2_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 3:
qnt_3x8_3_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 4:
qnt_3x8_4_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 5:
qnt_3x8_5_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 6:
qnt_3x8_6_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
case 7:
qnt_3x8_7_aligned(source, count, stride, offsets, destination,
destination_stride, multiplicative_offset,
rounding_offset, shift);
break;
}
}
void multi_qnt_1x8(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
switch (count % 8) {
case 0:
qnt_1x8(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 1:
qnt_1x8_1(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 2:
qnt_1x8_2(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 3:
qnt_1x8_3(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 4:
qnt_1x8_4(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 5:
qnt_1x8_5(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 6:
qnt_1x8_6(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 7:
qnt_1x8_7(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
}
}
void multi_qnt_2x8(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
switch (count % 8) {
case 0:
qnt_2x8(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 1:
qnt_2x8_1(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 2:
qnt_2x8_2(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 3:
qnt_2x8_3(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 4:
qnt_2x8_4(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 5:
qnt_2x8_5(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 6:
qnt_2x8_6(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 7:
qnt_2x8_7(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
}
}
void multi_qnt_3x8(const std::int32_t* source, std::int32_t count,
std::int32_t stride, const std::int32_t* offsets,
std::uint8_t* destination, std::int32_t destination_stride,
std::int32_t multiplicative_offset,
std::int32_t rounding_offset, std::int32_t shift) {
switch (count % 8) {
case 0:
qnt_3x8(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 1:
qnt_3x8_1(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 2:
qnt_3x8_2(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 3:
qnt_3x8_3(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 4:
qnt_3x8_4(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 5:
qnt_3x8_5(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 6:
qnt_3x8_6(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
case 7:
qnt_3x8_7(source, count, stride, offsets, destination, destination_stride,
multiplicative_offset, rounding_offset, shift);
break;
}
}
void gemm_q8_0_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_1_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset,
std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_0_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_0_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_q8_1_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_1_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_1_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_q8_2_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
const std::int32_t rounding_offset = (1 << (shift - 1));
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
scratch + zipped_chunk_size + zipped_rhs_size);
std::uint8_t* result_chunk = result;
std::int32_t* mul_result_chunk = temp_result;
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_3_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = temp_result;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk, mul_result_chunk_stride_bytes);
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
zipped_lhs_2_offsets, result_chunk, result_stride,
multiplicative_offset, rounding_offset, -shift);
}
void gemm_i32_0_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_1_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result, std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_0_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_0_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_i32_1_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_1_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_1_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_i32_2_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_i32_2_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
std::int32_t* result_chunk = result;
std::int32_t* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes);
}
void gemm_f_0_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_1_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m,
std::int32_t n, std::int32_t k,
std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_0_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_0_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
}
void gemm_f_1_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_1_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_1_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_1_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_1x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
}
void gemm_f_2_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_1x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
void gemm_f_2_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const std::int32_t row_chunks = m / 3;
const std::int32_t col_chunks = n / 3;
const std::int32_t padded_k = ((k + 7) / 8) * 8;
const std::int32_t chunk_size = k * 3;
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
const std::uint8_t* lhs_chunk = lhs;
const std::uint8_t* rhs_chunk = rhs;
std::uint8_t* zipped_lhs = scratch;
std::int32_t* zipped_lhs_3_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
std::int32_t* zipped_lhs_2_offsets =
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
const std::int32_t result_chunk_stride = result_stride * 3;
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
float* result_chunk = result;
float* mul_result_chunk = result;
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
for (int i = 0; i < col_chunks; ++i) {
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
rhs_chunk += chunk_size;
zipped_rhs_chunk += zipped_chunk_size;
}
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
for (int i = 0; i < row_chunks; ++i) {
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_3x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_3x8_2x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
lhs_chunk += chunk_size;
result_chunk += result_chunk_stride;
}
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
zipped_rhs_chunk = zipped_rhs;
mul_result_chunk = result_chunk;
for (int j = 0; j < col_chunks; ++j) {
mul_2x8_3x8_float_lhsadd_rhsadd(
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
zipped_rhs_chunk += zipped_chunk_size;
mul_result_chunk += 3;
}
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
mul_result_chunk,
mul_result_chunk_stride_bytes, result_scale);
}
} // namespace internal
void gemm_q8_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t result_offset,
std::int32_t multiplicative_offset, std::int32_t shift,
std::uint8_t* result, std::int32_t result_stride) {
const bool lhs_aligned = ((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0);
const bool rhs_aligned = ((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0);
const bool k_aligned = ((k % 8) == 0);
const bool result_aligned =
((reinterpret_cast<std::uintptr_t>(result) % 8) == 0);
const bool result_stride_aligned = ((result_stride % 8) == 0);
const bool aligned = lhs_aligned && rhs_aligned && result_aligned &&
k_aligned && result_stride_aligned;
if (aligned) {
switch (m % 3) {
case 0:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_q8_0_0_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_0_0_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_0_0_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_0_0_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_0_0_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_0_0_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_0_0_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_0_0_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_q8_0_1_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_0_1_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_0_1_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_0_1_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_0_1_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_0_1_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_0_1_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_0_1_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_q8_0_2_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_0_2_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_0_2_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_0_2_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_0_2_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_0_2_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_0_2_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_0_2_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
}
break;
}
break;
case 1:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_q8_1_0_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_1_0_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_1_0_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_1_0_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_1_0_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_1_0_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_1_0_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_1_0_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_q8_1_1_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_1_1_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_1_1_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_1_1_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_1_1_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_1_1_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_1_1_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_1_1_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_q8_1_2_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_1_2_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_1_2_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_1_2_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_1_2_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_1_2_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_1_2_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_1_2_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
}
break;
}
break;
case 2:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_q8_2_0_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_2_0_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_2_0_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_2_0_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_2_0_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_2_0_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_2_0_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_2_0_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_q8_2_1_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_2_1_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_2_1_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_2_1_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_2_1_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_2_1_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_2_1_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_2_1_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_q8_2_2_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_2_2_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_2_2_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_2_2_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_2_2_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_2_2_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_2_2_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_2_2_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result,
result_stride);
break;
}
break;
}
break;
}
} else {
switch (m % 3) {
case 0:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_q8_0_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_0_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_0_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_0_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_0_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_0_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_0_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_0_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_q8_0_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_0_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_0_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_0_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_0_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_0_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_0_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_0_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_q8_0_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_0_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_0_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_0_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_0_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_0_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_0_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_0_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
}
break;
}
break;
case 1:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_q8_1_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_1_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_1_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_1_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_1_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_1_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_1_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_1_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_q8_1_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_1_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_1_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_1_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_1_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_1_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_1_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_1_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_q8_1_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_1_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_1_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_1_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_1_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_1_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_1_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_1_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
}
break;
}
break;
case 2:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_q8_2_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_2_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_2_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_2_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_2_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_2_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_2_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_2_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_q8_2_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_2_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_2_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_2_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_2_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_2_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_2_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_2_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_q8_2_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 1:
internal::gemm_q8_2_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 2:
internal::gemm_q8_2_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 3:
internal::gemm_q8_2_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 4:
internal::gemm_q8_2_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 5:
internal::gemm_q8_2_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 6:
internal::gemm_q8_2_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
case 7:
internal::gemm_q8_2_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_offset,
multiplicative_offset, shift, result,
result_stride);
break;
}
break;
}
break;
}
}
}
void gemm_i32_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, std::int32_t* result,
std::int32_t result_stride) {
const bool lhs_aligned = ((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0);
const bool rhs_aligned = ((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0);
const bool k_aligned = ((k % 8) == 0);
const bool aligned = lhs_aligned && rhs_aligned && k_aligned;
if (aligned) {
switch (m % 3) {
case 0:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_i32_0_0_0_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 1:
internal::gemm_i32_0_0_1_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 2:
internal::gemm_i32_0_0_2_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 3:
internal::gemm_i32_0_0_3_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 4:
internal::gemm_i32_0_0_4_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 5:
internal::gemm_i32_0_0_5_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 6:
internal::gemm_i32_0_0_6_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 7:
internal::gemm_i32_0_0_7_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_i32_0_1_0_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 1:
internal::gemm_i32_0_1_1_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 2:
internal::gemm_i32_0_1_2_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 3:
internal::gemm_i32_0_1_3_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 4:
internal::gemm_i32_0_1_4_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 5:
internal::gemm_i32_0_1_5_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 6:
internal::gemm_i32_0_1_6_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 7:
internal::gemm_i32_0_1_7_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_i32_0_2_0_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 1:
internal::gemm_i32_0_2_1_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 2:
internal::gemm_i32_0_2_2_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 3:
internal::gemm_i32_0_2_3_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 4:
internal::gemm_i32_0_2_4_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 5:
internal::gemm_i32_0_2_5_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 6:
internal::gemm_i32_0_2_6_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 7:
internal::gemm_i32_0_2_7_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
}
break;
}
break;
case 1:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_i32_1_0_0_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 1:
internal::gemm_i32_1_0_1_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 2:
internal::gemm_i32_1_0_2_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 3:
internal::gemm_i32_1_0_3_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 4:
internal::gemm_i32_1_0_4_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 5:
internal::gemm_i32_1_0_5_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 6:
internal::gemm_i32_1_0_6_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 7:
internal::gemm_i32_1_0_7_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_i32_1_1_0_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 1:
internal::gemm_i32_1_1_1_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 2:
internal::gemm_i32_1_1_2_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 3:
internal::gemm_i32_1_1_3_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 4:
internal::gemm_i32_1_1_4_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 5:
internal::gemm_i32_1_1_5_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 6:
internal::gemm_i32_1_1_6_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 7:
internal::gemm_i32_1_1_7_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_i32_1_2_0_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 1:
internal::gemm_i32_1_2_1_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 2:
internal::gemm_i32_1_2_2_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 3:
internal::gemm_i32_1_2_3_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 4:
internal::gemm_i32_1_2_4_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 5:
internal::gemm_i32_1_2_5_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 6:
internal::gemm_i32_1_2_6_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 7:
internal::gemm_i32_1_2_7_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
}
break;
}
break;
case 2:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_i32_2_0_0_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 1:
internal::gemm_i32_2_0_1_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 2:
internal::gemm_i32_2_0_2_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 3:
internal::gemm_i32_2_0_3_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 4:
internal::gemm_i32_2_0_4_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 5:
internal::gemm_i32_2_0_5_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 6:
internal::gemm_i32_2_0_6_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 7:
internal::gemm_i32_2_0_7_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_i32_2_1_0_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 1:
internal::gemm_i32_2_1_1_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 2:
internal::gemm_i32_2_1_2_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 3:
internal::gemm_i32_2_1_3_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 4:
internal::gemm_i32_2_1_4_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 5:
internal::gemm_i32_2_1_5_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 6:
internal::gemm_i32_2_1_6_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 7:
internal::gemm_i32_2_1_7_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_i32_2_2_0_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 1:
internal::gemm_i32_2_2_1_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 2:
internal::gemm_i32_2_2_2_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 3:
internal::gemm_i32_2_2_3_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 4:
internal::gemm_i32_2_2_4_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 5:
internal::gemm_i32_2_2_5_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 6:
internal::gemm_i32_2_2_6_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
case 7:
internal::gemm_i32_2_2_7_aligned(scratch, lhs, rhs, m, n, k,
lhs_offset, rhs_offset, result,
result_stride);
break;
}
break;
}
break;
}
} else {
switch (m % 3) {
case 0:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_i32_0_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 1:
internal::gemm_i32_0_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 2:
internal::gemm_i32_0_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 3:
internal::gemm_i32_0_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 4:
internal::gemm_i32_0_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 5:
internal::gemm_i32_0_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 6:
internal::gemm_i32_0_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 7:
internal::gemm_i32_0_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_i32_0_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 1:
internal::gemm_i32_0_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 2:
internal::gemm_i32_0_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 3:
internal::gemm_i32_0_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 4:
internal::gemm_i32_0_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 5:
internal::gemm_i32_0_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 6:
internal::gemm_i32_0_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 7:
internal::gemm_i32_0_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_i32_0_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 1:
internal::gemm_i32_0_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 2:
internal::gemm_i32_0_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 3:
internal::gemm_i32_0_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 4:
internal::gemm_i32_0_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 5:
internal::gemm_i32_0_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 6:
internal::gemm_i32_0_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 7:
internal::gemm_i32_0_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
}
break;
}
break;
case 1:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_i32_1_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 1:
internal::gemm_i32_1_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 2:
internal::gemm_i32_1_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 3:
internal::gemm_i32_1_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 4:
internal::gemm_i32_1_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 5:
internal::gemm_i32_1_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 6:
internal::gemm_i32_1_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 7:
internal::gemm_i32_1_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_i32_1_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 1:
internal::gemm_i32_1_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 2:
internal::gemm_i32_1_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 3:
internal::gemm_i32_1_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 4:
internal::gemm_i32_1_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 5:
internal::gemm_i32_1_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 6:
internal::gemm_i32_1_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 7:
internal::gemm_i32_1_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_i32_1_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 1:
internal::gemm_i32_1_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 2:
internal::gemm_i32_1_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 3:
internal::gemm_i32_1_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 4:
internal::gemm_i32_1_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 5:
internal::gemm_i32_1_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 6:
internal::gemm_i32_1_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 7:
internal::gemm_i32_1_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
}
break;
}
break;
case 2:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_i32_2_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 1:
internal::gemm_i32_2_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 2:
internal::gemm_i32_2_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 3:
internal::gemm_i32_2_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 4:
internal::gemm_i32_2_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 5:
internal::gemm_i32_2_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 6:
internal::gemm_i32_2_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 7:
internal::gemm_i32_2_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_i32_2_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 1:
internal::gemm_i32_2_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 2:
internal::gemm_i32_2_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 3:
internal::gemm_i32_2_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 4:
internal::gemm_i32_2_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 5:
internal::gemm_i32_2_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 6:
internal::gemm_i32_2_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 7:
internal::gemm_i32_2_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_i32_2_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 1:
internal::gemm_i32_2_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 2:
internal::gemm_i32_2_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 3:
internal::gemm_i32_2_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 4:
internal::gemm_i32_2_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 5:
internal::gemm_i32_2_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 6:
internal::gemm_i32_2_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
case 7:
internal::gemm_i32_2_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result, result_stride);
break;
}
break;
}
break;
}
}
}
void gemm_f_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset,
std::int32_t rhs_offset, float result_scale, float* result,
std::int32_t result_stride) {
const bool lhs_aligned = ((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0);
const bool rhs_aligned = ((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0);
const bool k_aligned = ((k % 8) == 0);
const bool aligned = lhs_aligned && rhs_aligned && k_aligned;
if (aligned) {
switch (m % 3) {
case 0:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_f_0_0_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 1:
internal::gemm_f_0_0_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 2:
internal::gemm_f_0_0_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 3:
internal::gemm_f_0_0_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 4:
internal::gemm_f_0_0_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 5:
internal::gemm_f_0_0_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 6:
internal::gemm_f_0_0_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 7:
internal::gemm_f_0_0_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_f_0_1_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 1:
internal::gemm_f_0_1_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 2:
internal::gemm_f_0_1_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 3:
internal::gemm_f_0_1_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 4:
internal::gemm_f_0_1_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 5:
internal::gemm_f_0_1_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 6:
internal::gemm_f_0_1_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 7:
internal::gemm_f_0_1_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_f_0_2_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 1:
internal::gemm_f_0_2_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 2:
internal::gemm_f_0_2_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 3:
internal::gemm_f_0_2_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 4:
internal::gemm_f_0_2_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 5:
internal::gemm_f_0_2_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 6:
internal::gemm_f_0_2_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 7:
internal::gemm_f_0_2_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
}
break;
}
break;
case 1:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_f_1_0_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 1:
internal::gemm_f_1_0_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 2:
internal::gemm_f_1_0_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 3:
internal::gemm_f_1_0_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 4:
internal::gemm_f_1_0_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 5:
internal::gemm_f_1_0_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 6:
internal::gemm_f_1_0_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 7:
internal::gemm_f_1_0_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_f_1_1_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 1:
internal::gemm_f_1_1_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 2:
internal::gemm_f_1_1_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 3:
internal::gemm_f_1_1_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 4:
internal::gemm_f_1_1_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 5:
internal::gemm_f_1_1_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 6:
internal::gemm_f_1_1_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 7:
internal::gemm_f_1_1_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_f_1_2_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 1:
internal::gemm_f_1_2_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 2:
internal::gemm_f_1_2_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 3:
internal::gemm_f_1_2_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 4:
internal::gemm_f_1_2_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 5:
internal::gemm_f_1_2_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 6:
internal::gemm_f_1_2_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 7:
internal::gemm_f_1_2_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
}
break;
}
break;
case 2:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_f_2_0_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 1:
internal::gemm_f_2_0_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 2:
internal::gemm_f_2_0_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 3:
internal::gemm_f_2_0_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 4:
internal::gemm_f_2_0_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 5:
internal::gemm_f_2_0_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 6:
internal::gemm_f_2_0_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 7:
internal::gemm_f_2_0_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_f_2_1_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 1:
internal::gemm_f_2_1_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 2:
internal::gemm_f_2_1_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 3:
internal::gemm_f_2_1_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 4:
internal::gemm_f_2_1_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 5:
internal::gemm_f_2_1_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 6:
internal::gemm_f_2_1_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 7:
internal::gemm_f_2_1_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_f_2_2_0_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 1:
internal::gemm_f_2_2_1_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 2:
internal::gemm_f_2_2_2_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 3:
internal::gemm_f_2_2_3_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 4:
internal::gemm_f_2_2_4_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 5:
internal::gemm_f_2_2_5_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 6:
internal::gemm_f_2_2_6_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
case 7:
internal::gemm_f_2_2_7_aligned(
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, result_stride);
break;
}
break;
}
break;
}
} else {
switch (m % 3) {
case 0:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_f_0_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 1:
internal::gemm_f_0_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 2:
internal::gemm_f_0_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 3:
internal::gemm_f_0_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 4:
internal::gemm_f_0_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 5:
internal::gemm_f_0_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 6:
internal::gemm_f_0_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 7:
internal::gemm_f_0_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_f_0_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 1:
internal::gemm_f_0_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 2:
internal::gemm_f_0_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 3:
internal::gemm_f_0_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 4:
internal::gemm_f_0_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 5:
internal::gemm_f_0_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 6:
internal::gemm_f_0_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 7:
internal::gemm_f_0_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_f_0_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 1:
internal::gemm_f_0_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 2:
internal::gemm_f_0_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 3:
internal::gemm_f_0_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 4:
internal::gemm_f_0_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 5:
internal::gemm_f_0_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 6:
internal::gemm_f_0_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 7:
internal::gemm_f_0_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
}
break;
}
break;
case 1:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_f_1_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 1:
internal::gemm_f_1_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 2:
internal::gemm_f_1_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 3:
internal::gemm_f_1_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 4:
internal::gemm_f_1_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 5:
internal::gemm_f_1_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 6:
internal::gemm_f_1_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 7:
internal::gemm_f_1_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_f_1_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 1:
internal::gemm_f_1_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 2:
internal::gemm_f_1_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 3:
internal::gemm_f_1_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 4:
internal::gemm_f_1_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 5:
internal::gemm_f_1_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 6:
internal::gemm_f_1_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 7:
internal::gemm_f_1_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_f_1_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 1:
internal::gemm_f_1_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 2:
internal::gemm_f_1_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 3:
internal::gemm_f_1_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 4:
internal::gemm_f_1_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 5:
internal::gemm_f_1_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 6:
internal::gemm_f_1_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 7:
internal::gemm_f_1_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
}
break;
}
break;
case 2:
switch (n % 3) {
case 0:
switch (k % 8) {
case 0:
internal::gemm_f_2_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 1:
internal::gemm_f_2_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 2:
internal::gemm_f_2_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 3:
internal::gemm_f_2_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 4:
internal::gemm_f_2_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 5:
internal::gemm_f_2_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 6:
internal::gemm_f_2_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 7:
internal::gemm_f_2_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
}
break;
case 1:
switch (k % 8) {
case 0:
internal::gemm_f_2_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 1:
internal::gemm_f_2_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 2:
internal::gemm_f_2_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 3:
internal::gemm_f_2_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 4:
internal::gemm_f_2_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 5:
internal::gemm_f_2_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 6:
internal::gemm_f_2_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 7:
internal::gemm_f_2_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
}
break;
case 2:
switch (k % 8) {
case 0:
internal::gemm_f_2_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 1:
internal::gemm_f_2_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 2:
internal::gemm_f_2_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 3:
internal::gemm_f_2_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 4:
internal::gemm_f_2_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 5:
internal::gemm_f_2_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 6:
internal::gemm_f_2_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
case 7:
internal::gemm_f_2_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
rhs_offset, result_scale, result,
result_stride);
break;
}
break;
}
break;
}
}
}
void gemm_q8(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t result_offset, std::int32_t multiplicative_offset,
std::int32_t shift, std::uint8_t* result) {
gemm_q8_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_offset, multiplicative_offset, shift, result, n);
}
void gemm_i32(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset,
std::int32_t* result) {
gemm_i32_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset, result,
n);
}
void gemm_f(std::uint8_t* scratch, const std::uint8_t* lhs,
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset,
float result_scale, float* result) {
gemm_f_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
result_scale, result, n);
}
} // namespace meta
} // namespace gemmlowp
#else
#warning "Meta gemm fast-path requires GEMMLOWP_NEON_32!"
#endif
#endif // GEMMLOWP_META_SINGLE_THREAD_GEMM_H_