Rebase gemmlowp to a227af1fdb47f250b5df07d6936366b0f8113b65 am: 70ba50cbca
am: 36f90a2b7a

Change-Id: I07a51a36e8beb9a632f43543cf700fb04c0da065
diff --git a/doc/kernel.md b/doc/kernel.md
index 261cb92..f3f2138 100644
--- a/doc/kernel.md
+++ b/doc/kernel.md
@@ -40,11 +40,15 @@
 
 The meaning of these terms is explained in the lengthy comment at the top of
 internal/kernel.h. Here, they mean that this kernel handles at each iteration
-(along the depth dimension): - 3 'cells' of size 4x2 each of the lhs, so a total
-lhs block of size 12x2 - 1 'cell' of size 2x4 of the rhs. In other words, this
-kernel handles 12 rows of the lhs and 4 columns of the rhs, and handles two
-levels of depth at once. The 'cells' and `CellFormat` detail the layout of these
-12x2 and 2x4 blocks.
+(along the depth dimension):
+
+- 3 'cells' of size 4x2 each of the lhs, so a total lhs block of size 12x2
+
+- 1 'cell' of size 2x4 of the rhs.
+
+In other words, this kernel handles 12 rows of the lhs and 4 columns of the
+rhs, and handles two levels of depth at once. The 'cells' and `CellFormat`
+detail the layout of these 12x2 and 2x4 blocks.
 
 This kernel then loads these 12x2 and 2x4 blocks and computes the corresponding
 12x4 GEMM; for ease of reference let us paste the critical comment and code
diff --git a/doc/public.md b/doc/public.md
index 935f6db..7739b85 100644
--- a/doc/public.md
+++ b/doc/public.md
@@ -14,7 +14,7 @@
 multiplication is explained in [low-precision.md](low-precision.md). The
 rationale for a specific quantization paradigm is given in
 [quantization.md](quantization.md). That specific quantization paradigm is
-implemented at two different stages of the computation: as pre-processing ont
+implemented at two different stages of the computation: as pre-processing on
 the operands and as post-processing on the result:
 
 *   Pre-processing on the LHS, RHS operands, in the form of adding constant
@@ -56,7 +56,7 @@
 
 *   `InputScalar`: The scalar type of the LHS and RHS operands. At the moment,
     this must be `std::uint8_t`.
-*   `OutputScalar`: The scalar type of the LHS and RHS operands. At the moment,
+*   `OutputScalar`: The scalar type of the result. At the moment,
     this must be `std::uint8_t`.
 *   `BitDepthParams`: Defines the bit format of the input and output matrices
     and the required accuracy of the computation. At the moment, the only
diff --git a/doc/quantization.md b/doc/quantization.md
index 3a8f72b..e5055e7 100644
--- a/doc/quantization.md
+++ b/doc/quantization.md
@@ -13,7 +13,7 @@
 perform, specifically, it affects how one goes from internal 32bit accumulator
 to final 8bit outputs.
 
-The part of gemmlowp transforming internal internal 32bit accumulator to final
+The part of gemmlowp transforming internal 32bit accumulator to final
 8bit outputs is the "output pipeline" described in [output.md](output.md).
 
 gemmlowp's `GemmWithOutputPipeline` entry point allows specifying an arbitrary
diff --git a/eight_bit_int_gemm/eight_bit_int_gemm.cc b/eight_bit_int_gemm/eight_bit_int_gemm.cc
index 512c483..a8d9b43 100644
--- a/eight_bit_int_gemm/eight_bit_int_gemm.cc
+++ b/eight_bit_int_gemm/eight_bit_int_gemm.cc
@@ -12,9 +12,6 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
-#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
-#endif
 #include "eight_bit_int_gemm.h"
 
 #include <memory>
diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h
index d39341b..58e8050 100644
--- a/fixedpoint/fixedpoint.h
+++ b/fixedpoint/fixedpoint.h
@@ -18,10 +18,13 @@
 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
 
+#include <algorithm>
 #include <cassert>
+#include <cmath>
+#include <cstdint>
 #include <limits>
 
-#include "../internal/common.h"
+#include "../internal/detect_platform.h"
 
 namespace gemmlowp {
 
@@ -47,13 +50,13 @@
 template <>
 struct FixedPointRawTypeTraits<std::int32_t> {
   typedef std::int32_t ScalarRawType;
-  static const int kLanes = 1;
+  static constexpr int kLanes = 1;
 };
 
 template <>
 struct FixedPointRawTypeTraits<std::int16_t> {
   typedef std::int16_t ScalarRawType;
-  static const int kLanes = 1;
+  static constexpr int kLanes = 1;
 };
 
 // Returns a SIMD value duplicating a scalar value across all lanes.
@@ -109,11 +112,25 @@
   return -a;
 }
 
-// Integer arithmetic left-shift, equivalent to multiplying with a
-// power of two. Not saturating. Overflow is undefined behavior.
-template <typename tIntegerType>
-tIntegerType ShiftLeft(tIntegerType a, int offset) {
-  return a << offset;
+// Integer arithmetic left-shift, equivalent to multiplying with a power of two.
+// Negative values are OK. In case of overflow, no Undefined
+// Behavior, but the results are implementation-defined (in practice,
+// they currently are saturated, but we make no commitment to that). The idea
+// is that the caller will want to implement the overflowing cases with
+// saturation with compare-and-mask, so we don't care about the results
+// in the overflow case, we just want to avoid undefined behavior.
+//
+// tIntegerType may be int32 or any narrower signed type.
+template <typename tIntegerType, typename OffsetType>
+tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) {
+  const std::int64_t wide_a = static_cast<std::int64_t>(a);
+  const std::int64_t wide_shifted = wide_a * (1 << offset);
+  const auto min = std::numeric_limits<tIntegerType>::min();
+  const auto max = std::numeric_limits<tIntegerType>::max();
+  return wide_shifted < min
+             ? min
+             : wide_shifted > max ? max
+                                  : static_cast<tIntegerType>(wide_shifted);
 }
 
 // Integer arithmetic right-shift. Not rounding.
@@ -137,7 +154,7 @@
 // input scalar is non-zero.
 template <typename tIntegerType>
 tIntegerType MaskIfNonZero(tIntegerType a) {
-  static const tIntegerType zero = 0;
+  static constexpr tIntegerType zero = 0;
   return a ? BitNot(zero) : zero;
 }
 
@@ -211,6 +228,7 @@
 template <typename IntegerType>
 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+  (void)b;
   return a;
 }
 
@@ -235,6 +253,7 @@
 template <typename IntegerType>
 IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+  (void)b;
   return a;
 }
 
@@ -244,7 +263,9 @@
   std::int32_t a32 = a;
   std::int32_t b32 = b;
   std::int32_t sum = a32 + b32;
-  return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum)));
+  return static_cast<std::int16_t>(
+      std::min(static_cast<std::int32_t>(32767),
+               std::max(static_cast<std::int32_t>(-32768), sum)));
 }
 
 // Returns a+b, saturating if the integers are 16bit or narrower,
@@ -298,6 +319,7 @@
 template <typename IntegerType>
 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+  (void)b;
   return a;
 }
 
@@ -331,8 +353,8 @@
 
 // Correctly-rounded-to-nearest division by a power-of-two.
 // Also known as a rounding arithmetic right shift.
-template <typename IntegerType>
-inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
+template <typename IntegerType, typename ExponentType>
+inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
   assert(exponent >= 0);
   assert(exponent <= 31);
   const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
@@ -432,9 +454,9 @@
   typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
   typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
 
-  static const int kTotalBits = 8 * sizeof(ScalarRawType);
-  static const int kIntegerBits = tIntegerBits;
-  static const int kFractionalBits = kTotalBits - 1 - kIntegerBits;
+  static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
+  static constexpr int kIntegerBits = tIntegerBits;
+  static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
   static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
                 "bad IntegerBits");
 
@@ -474,7 +496,7 @@
 
   template <int Exponent>
   static FixedPoint ConstantPOT() {
-    static const int kOffset = kFractionalBits + Exponent;
+    static constexpr int kOffset = kFractionalBits + Exponent;
     static_assert(
         kOffset < 31,
         "Constant not exactly representable in this fixed-point format");
@@ -645,7 +667,7 @@
 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
 FixedPoint<tRawType, tIntegerBitsDst> Rescale(
     FixedPoint<tRawType, tIntegerBitsSrc> x) {
-  static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
+  static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
   FixedPoint<tRawType, tIntegerBitsDst> result;
   result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
   return result;
@@ -725,9 +747,9 @@
     FixedPoint<tRawType, tIntegerBits> a) {
   typedef FixedPoint<tRawType, tIntegerBits> InputF;
   typedef FixedPoint<tRawType, 0> ResultF;
-  static const int kFractionalBits = InputF::kFractionalBits;
-  static const int kIntegerBits = InputF::kIntegerBits;
-  static const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
+  static constexpr int kFractionalBits = InputF::kFractionalBits;
+  static constexpr int kIntegerBits = InputF::kIntegerBits;
+  const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
   InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
   InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
   ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
@@ -755,10 +777,10 @@
 
 #undef GEMMLOWP_EXP_BARREL_SHIFTER
 
+  static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
   if (kIntegerBits > 5) {
-    static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
     const InputF clamp =
-        GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
+        GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0);
     result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
   }
 
@@ -867,6 +889,8 @@
 
 #ifdef GEMMLOWP_NEON
 #include "./fixedpoint_neon.h"
+#elif defined(GEMMLOWP_AVX2)
+#include "./fixedpoint_avx.h"
 #elif defined(GEMMLOWP_SSE4)
 #include "./fixedpoint_sse.h"
 #elif defined(GEMMLOWP_MSA)
diff --git a/fixedpoint/fixedpoint_avx.h b/fixedpoint/fixedpoint_avx.h
new file mode 100644
index 0000000..1816386
--- /dev/null
+++ b/fixedpoint/fixedpoint_avx.h
@@ -0,0 +1,218 @@
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// fixedpoint_avx.h: optimized avx specializations of the templates
+// in fixedpoint.h.
+
+#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
+#define GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
+
+#include <smmintrin.h>
+#include "fixedpoint.h"
+#include "fixedpoint_sse.h"
+
+namespace gemmlowp {
+
+template <>
+struct FixedPointRawTypeTraits<__m256i> {
+  typedef std::int32_t ScalarRawType;
+  static const int kLanes = 4;
+};
+
+template <>
+inline __m256i BitAnd(__m256i a, __m256i b) {
+  return _mm256_and_si256(a, b);
+}
+
+template <>
+inline __m256i BitOr(__m256i a, __m256i b) {
+  return _mm256_or_si256(a, b);
+}
+
+template <>
+inline __m256i BitXor(__m256i a, __m256i b) {
+  return _mm256_xor_si256(a, b);
+}
+
+template <>
+inline __m256i BitNot(__m256i a) {
+  return _mm256_andnot_si256(a, _mm256_set1_epi32(-1));
+}
+
+template <>
+inline __m256i Add(__m256i a, __m256i b) {
+  return _mm256_add_epi32(a, b);
+}
+
+template <>
+inline __m256i Mul(__m256i a, __m256i b) {
+  return _mm256_mullo_epi32(a, b);
+}
+
+template <>
+inline __m256i Sub(__m256i a, __m256i b) {
+  return _mm256_sub_epi32(a, b);
+}
+
+template <>
+inline __m256i Neg(__m256i a) {
+  return _mm256_sign_epi32(a, _mm256_set1_epi32(-1));
+}
+
+template <>
+inline __m256i ShiftLeft(__m256i a, int offset) {
+  return _mm256_slli_epi32(a, offset);
+}
+
+template <>
+inline __m256i ShiftRight(__m256i a, int offset) {
+  return _mm256_srai_epi32(a, offset);
+}
+
+template <>
+inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val,
+                               __m256i else_val) {
+  return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(else_val),
+                                              _mm256_castsi256_ps(then_val),
+                                              _mm256_castsi256_ps(if_mask)));
+}
+
+template <>
+inline __m256i MaskIfEqual(__m256i a, __m256i b) {
+  return _mm256_cmpeq_epi32(a, b);
+}
+
+template <>
+inline __m256i MaskIfNotEqual(__m256i a, __m256i b) {
+  return BitNot(MaskIfEqual(a, b));
+}
+
+template <>
+inline __m256i MaskIfZero(__m256i a) {
+  return MaskIfEqual(a, _mm256_set1_epi32(0));
+}
+
+template <>
+inline __m256i MaskIfNonZero(__m256i a) {
+  return MaskIfNotEqual(a, _mm256_set1_epi32(0));
+}
+
+template <>
+inline __m256i MaskIfGreaterThan(__m256i a, __m256i b) {
+  return _mm256_cmpgt_epi32(a, b);
+}
+
+template <>
+inline __m256i MaskIfLessThan(__m256i a, __m256i b) {
+  return _mm256_cmpgt_epi32(b, a);
+}
+
+template <>
+inline __m256i MaskIfGreaterThanOrEqual(__m256i a, __m256i b) {
+  return BitNot(MaskIfLessThan(a, b));
+}
+
+template <>
+inline __m256i MaskIfLessThanOrEqual(__m256i a, __m256i b) {
+  return BitNot(MaskIfGreaterThan(a, b));
+}
+
+/* Assumptions:
+   - All and Any are used on masks.
+   - masks are all_ones for true lanes, all_zeroes otherwise.
+Hence, All means all 128bits set, and Any means any bit set.
+*/
+
+template <>
+inline bool All(__m256i a) {
+  return _mm256_testc_si256(a, a);
+}
+
+template <>
+inline bool Any(__m256i a) {
+  return BitNot(_mm256_testz_si256(a, a));
+}
+
+template <>
+inline __m256i RoundingHalfSum(__m256i a, __m256i b) {
+  /* __m256i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */
+  /* We divide the inputs before the add to avoid the overflow and costly test
+   */
+  /* of checking if an overflow occured on signed add */
+  /* round_bit_mask = _mm_set1_epi32(1); */
+  /* a_over_2 = _mm_srai_epi32(a, 1); */
+  /* b_over_2 = _mm_srai_epi32(b, 1); */
+  /* sum = Add(a_over_2, b_over_2); */
+  /* round_bit = _mm_sign_epi32(BitAnd(BitOr(a,b), round_bit_mask), sum); */
+  /* return Add(sum, round_bit); */
+
+  /* Other possibility detecting overflow and xor the sign if an overflow
+   * happened*/
+  __m256i one, sign_bit_mask, sum, rounded_half_sum, overflow, result;
+  one = _mm256_set1_epi32(1);
+  sign_bit_mask = _mm256_set1_epi32(0x80000000);
+  sum = Add(a, b);
+  rounded_half_sum = _mm256_srai_epi32(Add(sum, one), 1);
+  overflow =
+      BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)),
+             sign_bit_mask);
+  result = BitXor(rounded_half_sum, overflow);
+  return result;
+}
+
+template <>
+inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) {
+  __m256i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3;
+  __m256i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded;
+  __m256i a0b0_a2b2_rounded_2x, a1b1_a3b3_rounded_2x, result;
+  __m256i nudge;
+
+  // saturation only happen if a == b == INT_MIN
+  min = _mm256_set1_epi32(std::numeric_limits<std::int32_t>::min());
+  saturation_mask = BitAnd(MaskIfEqual(a, b), MaskIfEqual(a, min));
+
+  // a = a0 | a1 | a2 | a3
+  // b = b0 | b1 | b2 | b3
+  a0_a2 = a;
+  a1_a3 = _mm256_srli_si256(a, 4);
+  b0_b2 = b;
+  b1_b3 = _mm256_srli_si256(b, 4);
+
+  a0b0_a2b2 = _mm256_mul_epi32(a0_a2, b0_b2);
+  a1b1_a3b3 = _mm256_mul_epi32(a1_a3, b1_b3);
+
+  // do the rounding and take into account that it will be doubled
+  nudge = _mm256_set1_epi64x(1 << 30);
+  a0b0_a2b2_rounded = _mm256_add_epi64(a0b0_a2b2, nudge);
+  a1b1_a3b3_rounded = _mm256_add_epi64(a1b1_a3b3, nudge);
+
+  // do the doubling
+  a0b0_a2b2_rounded_2x = _mm256_slli_epi64(a0b0_a2b2_rounded, 1);
+  a1b1_a3b3_rounded_2x = _mm256_slli_epi64(a1b1_a3b3_rounded, 1);
+
+  // get the high part of the products
+  result = _mm256_blend_epi16(_mm256_srli_si256(a0b0_a2b2_rounded_2x, 4),
+                              a1b1_a3b3_rounded_2x, 0xcc);
+
+  // saturate those which overflowed
+  return SelectUsingMask(saturation_mask, min, result);
+}
+
+template <>
+inline __m256i Dup<__m256i>(std::int32_t x) {
+  return _mm256_set1_epi32(x);
+}
+
+}  // end namespace gemmlowp
+
+#endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
diff --git a/fixedpoint/fixedpoint_msa.h b/fixedpoint/fixedpoint_msa.h
index c7a110c..b17f32a 100644
--- a/fixedpoint/fixedpoint_msa.h
+++ b/fixedpoint/fixedpoint_msa.h
@@ -25,13 +25,13 @@
 template <>
 struct FixedPointRawTypeTraits<v4i32> {
   typedef std::int32_t ScalarRawType;
-  static const int kLanes = 4;
+  static constexpr int kLanes = 4;
 };
 
 template <>
 struct FixedPointRawTypeTraits<v8i16> {
   typedef std::int16_t ScalarRawType;
-  static const int kLanes = 8;
+  static constexpr int kLanes = 8;
 };
 
 template <>
@@ -326,11 +326,71 @@
   }
 };
 
-// TODO: possibly implement:
-// template <> v4i32 RoundingDivideByPOT(v4i32, int)
-// template <> v8i16 RoundingDivideByPOT(v8i16, int)
-// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1>
-// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1>
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> {
+  static v4i32 eval(v4i32 x) {
+    static_assert(-31 <= Exponent && Exponent <= -1, "");
+    // Isolate the sign bits.
+    v4i32 sign = __builtin_msa_srli_w(x, 31);
+    // Decrement the negative elements by 1 (with saturation).
+    x = __builtin_msa_subs_s_w(x, sign);
+    // Arithmetic shift right with rounding.
+    // The srari instruction rounds all midpoint values towards +infinity.
+    // It will correctly round negative midpoint values as we just
+    // decremented the negative values by 1.
+    return __builtin_msa_srari_w(x, -Exponent);
+  }
+};
+
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> {
+  static v8i16 eval(v8i16 x) {
+    static_assert(-15 <= Exponent && Exponent <= -1, "");
+    // Isolate the sign bits.
+    v8i16 sign = __builtin_msa_srli_h(x, 15);
+    // Decrement the negative elements by 1 (with saturation).
+    x = __builtin_msa_subs_s_h(x, sign);
+    // Arithmetic shift right with rounding.
+    // The srari instruction rounds all midpoint values towards +infinity.
+    // It will correctly round negative midpoint values as we just
+    // decremented the negative values by 1.
+    return __builtin_msa_srari_h(x, -Exponent);
+  }
+};
+
+template <>
+inline v4i32 RoundingDivideByPOT(v4i32 x, int exponent) {
+  v4i32 e = __builtin_msa_fill_w(exponent);
+  // Isolate the sign bits.
+  v4i32 sign = __builtin_msa_srli_w(x, 31);
+  // Reset them to 0 if exponent is 0.
+  sign = __builtin_msa_min_s_w(sign, e);
+  // Decrement the negative elements by 1 (with saturation)
+  // if exponent is non-zero.
+  x = __builtin_msa_subs_s_w(x, sign);
+  // Arithmetic shift right with rounding.
+  // The srar instruction rounds all midpoint values towards +infinity.
+  // It will correctly round negative midpoint values as we just
+  // decremented the negative values by 1.
+  return __builtin_msa_srar_w(x, e);
+}
+
+template <>
+inline v8i16 RoundingDivideByPOT(v8i16 x, int exponent) {
+  v8i16 e = __builtin_msa_fill_h(exponent);
+  // Isolate the sign bits.
+  v8i16 sign = __builtin_msa_srli_h(x, 15);
+  // Reset them to 0 if exponent is 0.
+  sign = __builtin_msa_min_s_h(sign, e);
+  // Decrement the negative elements by 1 (with saturation)
+  // if exponent is non-zero.
+  x = __builtin_msa_subs_s_h(x, sign);
+  // Arithmetic shift right with rounding.
+  // The srar instruction rounds all midpoint values towards +infinity.
+  // It will correctly round negative midpoint values as we just
+  // decremented the negative values by 1.
+  return __builtin_msa_srar_h(x, e);
+}
 
 template <>
 inline v4i32 Dup<v4i32>(std::int32_t x) {
@@ -346,7 +406,6 @@
 template <>
 inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) {
   return __builtin_msa_adds_s_h(a, b);
-  return a;
 }
 
 }  // end namespace gemmlowp
diff --git a/fixedpoint/fixedpoint_neon.h b/fixedpoint/fixedpoint_neon.h
index 92b349b..4dab6c9 100644
--- a/fixedpoint/fixedpoint_neon.h
+++ b/fixedpoint/fixedpoint_neon.h
@@ -25,13 +25,13 @@
 template <>
 struct FixedPointRawTypeTraits<int32x4_t> {
   typedef std::int32_t ScalarRawType;
-  static const int kLanes = 4;
+  static constexpr int kLanes = 4;
 };
 
 template <>
 struct FixedPointRawTypeTraits<int16x8_t> {
   typedef std::int16_t ScalarRawType;
-  static const int kLanes = 8;
+  static constexpr int kLanes = 8;
 };
 
 template <>
@@ -115,6 +115,16 @@
 }
 
 template <>
+inline int32x4_t ShiftLeft(int32x4_t a, int32x4_t offset) {
+  return vshlq_s32(a, offset);
+}
+
+template <>
+inline int16x8_t ShiftLeft(int16x8_t a, int16x8_t offset) {
+  return vshlq_s16(a, offset);
+}
+
+template <>
 inline int32x4_t ShiftRight(int32x4_t a, int offset) {
   return vshlq_s32(a, vdupq_n_s32(-offset));
 }
@@ -282,6 +292,22 @@
   return vrshlq_s16(fixed_up_x, shift_vec);
 }
 
+template <>
+inline int32x4_t RoundingDivideByPOT(int32x4_t x, int32x4_t exponent) {
+  const int32x4_t shift_vec = vnegq_s32(exponent);
+  const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
+  const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
+  return vrshlq_s32(fixed_up_x, shift_vec);
+}
+
+template <>
+inline int16x8_t RoundingDivideByPOT(int16x8_t x, int16x8_t exponent) {
+  const int16x8_t shift_vec = vnegq_s16(exponent);
+  const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15);
+  const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
+  return vrshlq_s16(fixed_up_x, shift_vec);
+}
+
 template <int Exponent>
 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> {
   static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); }
diff --git a/fixedpoint/fixedpoint_sse.h b/fixedpoint/fixedpoint_sse.h
index ba990f0..a1fae32 100644
--- a/fixedpoint/fixedpoint_sse.h
+++ b/fixedpoint/fixedpoint_sse.h
@@ -42,13 +42,13 @@
 template <>
 struct FixedPointRawTypeTraits<__m128i> {
   typedef std::int32_t ScalarRawType;
-  static const int kLanes = 4;
+  static constexpr int kLanes = 4;
 };
 
 template <>
 struct FixedPointRawTypeTraits<int16x8_m128i> {
   typedef std::int16_t ScalarRawType;
-  static const int kLanes = 8;
+  static constexpr int kLanes = 8;
 };
 
 template <>
diff --git a/internal/common.h b/internal/common.h
index 26b6713..332ad07 100644
--- a/internal/common.h
+++ b/internal/common.h
@@ -26,144 +26,9 @@
 #include <cmath>
 #include <cstdlib>
 
+#include "../internal/detect_platform.h"
 #include "../profiling/instrumentation.h"
 
-// Our inline assembly path assume GCC/Clang syntax.
-// Native Client doesn't seem to support inline assembly(?).
-#if defined(__GNUC__) && !defined(__native_client__)
-#define GEMMLOWP_ALLOW_INLINE_ASM
-#endif
-
-// Define macro statement that avoids inlining for GCC.
-// For non-GCC, define as empty macro.
-#if defined(__GNUC__)
-#define GEMMLOWP_NOINLINE __attribute__((noinline))
-#else
-#define GEMMLOWP_NOINLINE
-#endif
-
-// Detect ARM, 32-bit or 64-bit
-#ifdef __arm__
-#define GEMMLOWP_ARM_32
-#endif
-
-#ifdef __aarch64__
-#define GEMMLOWP_ARM_64
-#endif
-
-#if defined(GEMMLOWP_ARM_32) || defined(GEMMLOWP_ARM_64)
-#define GEMMLOWP_ARM
-#endif
-
-// Detect MIPS, 32-bit or 64-bit
-#if defined(__mips) && !defined(__LP64__)
-#define GEMMLOWP_MIPS_32
-#endif
-
-#if defined(__mips) && defined(__LP64__)
-#define GEMMLOWP_MIPS_64
-#endif
-
-#if defined(GEMMLOWP_MIPS_32) || defined(GEMMLOWP_MIPS_64)
-#define GEMMLOWP_MIPS
-#endif
-
-// Detect x86, 32-bit or 64-bit
-#if defined(__i386__) || defined(_M_IX86) || defined(_X86_) || defined(__i386)
-#define GEMMLOWP_X86_32
-#endif
-
-#if defined(__x86_64__) || defined(_M_X64) || defined(__amd64)
-#define GEMMLOWP_X86_64
-#endif
-
-#if defined(GEMMLOWP_X86_32) || defined(GEMMLOWP_X86_64)
-#define GEMMLOWP_X86
-#endif
-
-// Some of our optimized paths use inline assembly and for
-// now we don't bother enabling some other optimized paths using intrinddics
-// where we can't use inline assembly paths.
-#ifdef GEMMLOWP_ALLOW_INLINE_ASM
-
-// Detect NEON. It's important to check for both tokens.
-#if (defined __ARM_NEON) || (defined __ARM_NEON__)
-#define GEMMLOWP_NEON
-#endif
-
-// Convenience NEON tokens for 32-bit or 64-bit
-#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_32)
-#define GEMMLOWP_NEON_32
-#endif
-
-#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_64)
-#define GEMMLOWP_NEON_64
-#endif
-
-// Detect MIPS MSA.
-// Limit MSA optimizations to little-endian CPUs for now.
-// TODO: Perhaps, eventually support MSA optimizations on big-endian CPUs?
-#if defined(GEMMLOWP_MIPS) && (__mips_isa_rev >= 5) && defined(__mips_msa) && \
-    defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
-#define GEMMLOWP_MSA
-#endif
-
-// Convenience MIPS MSA tokens for 32-bit or 64-bit.
-#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_32)
-#define GEMMLOWP_MSA_32
-#endif
-
-#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_64)
-#define GEMMLOWP_MSA_64
-#endif
-
-// Detect SSE.
-#ifdef __SSE4_1__
-#define GEMMLOWP_SSE4
-#endif
-
-#ifdef __SSE3__
-#define GEMMLOWP_SSE3
-#endif
-
-// Convenience SSE4 tokens for 32-bit or 64-bit
-#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32) && \
-   !defined(GEMMLOWP_DISABLE_SSE4)
-#define GEMMLOWP_SSE4_32
-#endif
-
-#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_32)
-#define GEMMLOWP_SSE3_32
-#endif
-
-#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64) && \
-   !defined(GEMMLOWP_DISABLE_SSE4)
-#define GEMMLOWP_SSE4_64
-#endif
-
-#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_64)
-#define GEMMLOWP_SSE3_64
-#endif
-
-#if defined(__has_feature)
-#if __has_feature(memory_sanitizer)
-#include <sanitizer/msan_interface.h>
-#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __msan_unpoison
-#elif __has_feature(address_sanitizer)
-#include <sanitizer/asan_interface.h>
-#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __asan_unpoison_memory_region
-#endif
-#endif
-
-#endif  // GEMMLOWP_ALLOW_INLINE_ASM
-
-// Detect Android. Don't conflate with ARM - we care about tuning
-// for non-ARM Android devices too. This can be used in conjunction
-// with x86 to tune differently for mobile x86 CPUs (Atom) vs. desktop x86 CPUs.
-#if defined(__ANDROID__) || defined(ANDROID)
-#define GEMMLOWP_ANDROID
-#endif
-
 namespace gemmlowp {
 
 // Standard cache line size. Useful to optimize alignment and
@@ -242,7 +107,12 @@
 // size, so any size would work there. Different platforms may set this
 // to different values but must ensure that their own optimized packing paths
 // are consistent with this value.
+
+#ifdef GEMMLOWP_AVX2
+const int kRegisterSize = 32;
+#else
 const int kRegisterSize = 16;
+#endif
 
 // Hints the CPU to prefetch the cache line containing ptr.
 inline void Prefetch(const void* ptr) {
diff --git a/internal/detect_platform.h b/internal/detect_platform.h
new file mode 100644
index 0000000..6f06d19
--- /dev/null
+++ b/internal/detect_platform.h
@@ -0,0 +1,166 @@
+// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// detect_platform.h: Sets up macros that control architecture-specific
+// features of gemmlowp's implementation.
+
+#ifndef GEMMLOWP_INTERNAL_DETECT_PLATFORM_H_
+#define GEMMLOWP_INTERNAL_DETECT_PLATFORM_H_
+
+// Our inline assembly path assume GCC/Clang syntax.
+// Native Client doesn't seem to support inline assembly(?).
+#if defined(__GNUC__) && !defined(__native_client__)
+#define GEMMLOWP_ALLOW_INLINE_ASM
+#endif
+
+// Define macro statement that avoids inlining for GCC.
+// For non-GCC, define as empty macro.
+#if defined(__GNUC__)
+#define GEMMLOWP_NOINLINE __attribute__((noinline))
+#else
+#define GEMMLOWP_NOINLINE
+#endif
+
+// Detect ARM, 32-bit or 64-bit
+#ifdef __arm__
+#define GEMMLOWP_ARM_32
+#endif
+
+#ifdef __aarch64__
+#define GEMMLOWP_ARM_64
+#endif
+
+#if defined(GEMMLOWP_ARM_32) || defined(GEMMLOWP_ARM_64)
+#define GEMMLOWP_ARM
+#endif
+
+// Detect MIPS, 32-bit or 64-bit
+#if defined(__mips) && !defined(__LP64__)
+#define GEMMLOWP_MIPS_32
+#endif
+
+#if defined(__mips) && defined(__LP64__)
+#define GEMMLOWP_MIPS_64
+#endif
+
+#if defined(GEMMLOWP_MIPS_32) || defined(GEMMLOWP_MIPS_64)
+#define GEMMLOWP_MIPS
+#endif
+
+// Detect x86, 32-bit or 64-bit
+#if defined(__i386__) || defined(_M_IX86) || defined(_X86_) || defined(__i386)
+#define GEMMLOWP_X86_32
+#endif
+
+#if defined(__x86_64__) || defined(_M_X64) || defined(__amd64)
+#define GEMMLOWP_X86_64
+#endif
+
+#if defined(GEMMLOWP_X86_32) || defined(GEMMLOWP_X86_64)
+#define GEMMLOWP_X86
+#endif
+
+// Some of our optimized paths use inline assembly and for
+// now we don't bother enabling some other optimized paths using intrinddics
+// where we can't use inline assembly paths.
+#ifdef GEMMLOWP_ALLOW_INLINE_ASM
+
+// Detect NEON. It's important to check for both tokens.
+#if (defined __ARM_NEON) || (defined __ARM_NEON__)
+#define GEMMLOWP_NEON
+#endif
+
+// Convenience NEON tokens for 32-bit or 64-bit
+#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_32)
+#define GEMMLOWP_NEON_32
+#endif
+
+#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_64)
+#define GEMMLOWP_NEON_64
+#endif
+
+// Detect MIPS MSA.
+// Limit MSA optimizations to little-endian CPUs for now.
+// TODO: Perhaps, eventually support MSA optimizations on big-endian CPUs?
+#if defined(GEMMLOWP_MIPS) && (__mips_isa_rev >= 5) && defined(__mips_msa) && \
+    defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+#define GEMMLOWP_MSA
+#endif
+
+// Convenience MIPS MSA tokens for 32-bit or 64-bit.
+#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_32)
+#define GEMMLOWP_MSA_32
+#endif
+
+#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_64)
+#define GEMMLOWP_MSA_64
+#endif
+
+// compiler define for AVX2 -D GEMMLOWP_ENABLE_AVX2
+// Detect AVX2
+#if defined(__AVX2__) && defined(GEMMLOWP_ENABLE_AVX2)
+#define GEMMLOWP_AVX2
+// Detect SSE4.
+// MSVC does not have __SSE4_1__ macro, but will enable SSE4
+// when AVX is turned on.
+#elif defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__))
+#define GEMMLOWP_SSE4
+// Detect SSE3.
+#elif defined(__SSE3__)
+#define GEMMLOWP_SSE3
+#endif
+
+// Convenience SSE4 tokens for 32-bit or 64-bit
+#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32) && \
+    !defined(GEMMLOWP_DISABLE_SSE4)
+#define GEMMLOWP_SSE4_32
+#endif
+
+#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_32)
+#define GEMMLOWP_SSE3_32
+#endif
+
+#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64) && \
+    !defined(GEMMLOWP_DISABLE_SSE4)
+#define GEMMLOWP_SSE4_64
+#endif
+
+#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_64)
+#define GEMMLOWP_SSE3_64
+#endif
+
+#if defined(GEMMLOWP_AVX2) && defined(GEMMLOWP_X86_64)
+#define GEMMLOWP_AVX2_64
+#endif
+
+#if defined(__has_feature)
+#if __has_feature(memory_sanitizer)
+#include <sanitizer/msan_interface.h>
+#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __msan_unpoison
+#elif __has_feature(address_sanitizer)
+#include <sanitizer/asan_interface.h>
+#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __asan_unpoison_memory_region
+#endif
+#endif
+
+#endif  // GEMMLOWP_ALLOW_INLINE_ASM
+
+// Detect Android. Don't conflate with ARM - we care about tuning
+// for non-ARM Android devices too. This can be used in conjunction
+// with x86 to tune differently for mobile x86 CPUs (Atom) vs. desktop x86 CPUs.
+#if defined(__ANDROID__) || defined(ANDROID)
+#define GEMMLOWP_ANDROID
+#endif
+
+#endif  // GEMMLOWP_INTERNAL_DETECT_PLATFORM_H_
diff --git a/internal/dispatch_gemm_shape.h b/internal/dispatch_gemm_shape.h
index 0be0bf3..ba4f341 100644
--- a/internal/dispatch_gemm_shape.h
+++ b/internal/dispatch_gemm_shape.h
@@ -85,6 +85,22 @@
   }
 };
 
+template <VectorShape Shape>
+struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> {
+  typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType;
+  static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
+  typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape>
+      DstType;
+  static DstType Run(const SrcType& src) {
+    DstType dst;
+    dst.result_fixedpoint_multiplier =
+        Transpose(src.result_fixedpoint_multiplier);
+    dst.result_exponent = Transpose(src.result_exponent);
+    dst.result_offset_after_shift = src.result_offset_after_shift;
+    return dst;
+  }
+};
+
 template <typename VectorMapType>
 struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
   typedef OutputStageBiasAddition<VectorMapType> SrcType;
diff --git a/internal/kernel.h b/internal/kernel.h
index 825a7f3..3120216 100644
--- a/internal/kernel.h
+++ b/internal/kernel.h
@@ -145,12 +145,24 @@
   static const int kCells = tCells;
   static const int kWidth = kCells * Cell::kWidth;
   static const int kDepth = Cell::kDepth;
-  typedef std::uint8_t Scalar;
+  typedef std::uint8_t Scalar;       // The scalar type of the Format.
+  typedef std::uint8_t InputScalar;  // The scalar type of the original input.
 };
 
+// KernelSideFormat for int8 fast kernel trick. The original input is uint8, but
+// packs converts it to int8.
 template <typename tCellFormat, int tCells>
 struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> {
   typedef std::int8_t Scalar;
+  typedef std::uint8_t InputScalar;
+};
+
+// KernelSideFormat for int8 inputs, enabling int8 fast kernel trick without
+// pack conversion.
+template <typename tCellFormat, int tCells>
+struct KernelSideFormatInt8Inputs : KernelSideFormat<tCellFormat, tCells> {
+  typedef std::int8_t Scalar;
+  typedef std::int8_t InputScalar;
 };
 
 // KernelFormat describes fully the input data layout that a kernel expects.
@@ -216,19 +228,24 @@
   virtual ~KernelBase() {}
 };
 
-template <typename KernelScalarType>
+template <typename InputKernelScalarType, typename KernelScalarType>
 struct ZeroPointInputValue {};
 
 template <>
-struct ZeroPointInputValue<std::uint8_t> {
+struct ZeroPointInputValue<std::uint8_t, std::uint8_t> {
   static constexpr std::uint8_t kValue = 0;
 };
 
 template <>
-struct ZeroPointInputValue<std::int8_t> {
+struct ZeroPointInputValue<std::uint8_t, std::int8_t> {
   static constexpr std::uint8_t kValue = 128;
 };
 
+template <>
+struct ZeroPointInputValue<std::int8_t, std::int8_t> {
+  static constexpr std::uint8_t kValue = 0;
+};
+
 }  // namespace gemmlowp
 
 #endif  // GEMMLOWP_INTERNAL_KERNEL_H_
diff --git a/internal/kernel_avx.h b/internal/kernel_avx.h
new file mode 100644
index 0000000..2fe1249
--- /dev/null
+++ b/internal/kernel_avx.h
@@ -0,0 +1,361 @@
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// kernel_SSE.h: a collection of Intel SSE optimized kernels.
+// Check in kernel_default.h which one(s) are actually used by default.
+// Others are mere experiments; they are still covered by tests
+// in case they might be useful some day.
+//
+
+#ifndef GEMMLOWP_INTERNAL_KERNEL_AVX_H_
+#define GEMMLOWP_INTERNAL_KERNEL_AVX_H_
+
+#include "kernel.h"
+
+#include <string.h>
+#include <cassert>
+
+namespace gemmlowp {
+
+#ifdef GEMMLOWP_AVX2_64
+struct AVX2_64_Kernel24x8Depth2 : KernelBase {
+  typedef KernelFormat<KernelSideFormat<CellFormat<8, 2, CellOrder::WidthMajor>, 3>,
+                       KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1>>
+      Format;
+
+  const char *Name() const override { return "AVX, 24x8, depth 2"; }
+
+  void Run(std::int32_t *dst_ptr, std::size_t dst_row_stride, std::size_t dst_col_stride,
+           const std::uint8_t *lhs_ptr, const std::uint8_t *rhs_ptr, std::size_t start_depth,
+           std::size_t run_depth) const override {
+    ScopedProfilingLabel label("optimized kernel");
+    assert(dst_row_stride == 1);
+    const std::int64_t run_depth_cells = run_depth / Format::kDepth;
+    const std::int64_t dst_col_stride_q = dst_col_stride;
+
+    /* Main loop */
+
+    // A 2x8 cell of Rhs is stored in 16bit in ymm1 .
+    // A 24x2 block of 3 8x2 cells Lhs is stored in 16bit in ymm0, replaced
+    // every Iteration.
+    // A 8x8 block of accumulators is stored in 32bit in xmm4--xmm15.
+    //
+    //                   +-------+-------+-------+-------+
+    //                   |ymm1[0]        |ymm2[2]        |
+    //              Rhs  +-------+---------------+-------+
+    //                   |ymm1[1]        |ymm1[4]        |
+    //                   +-------+-------+-------+-------+
+    //
+    //                   |       |       |       |       |
+    //
+    //    Lhs            |       |       |       |       |
+    //
+    //  +--+--+ - - - -  +-------+-------+-------+-------+
+    //  |ymm0 |          | ymm4  | ymm5  | ymm6  | ymm7  |
+    //  |ymm0 | (Iter1)  | ymm4  | ymm5  | ymm6  | ymm7  |
+    //  |ymm0 |          | ymm4  | ymm5  | ymm6  | ymm7  |
+    //  |ymm0 |          | ymm4  | ymm5  | ymm6  | ymm7  |
+    //  +--+--+ - - - -  +-------+-------+-------+-------+
+    //  |ymm0 |          | ymm8  | ymm9  | ymm10 | ymm11 |
+    //  |ymm0 | (Iter2)  | ymm8  | ymm9  | ymm10 | ymm11 |
+    //  |ymm0 |          | ymm8  | ymm9  | ymm10 | ymm11 |
+    //  |ymm0 |          | ymm8  | ymm9  | ymm10 | ymm11 |
+    //  +--+--+ - - - -  +-------+-------+-------+-------+
+    //  |ymm0 |          | ymm12 | ymm13 | ymm14 | ymm15 |
+    //  |ymm0 | (Iter3)  | ymm12 | ymm13 | ymm14 | ymm15 |
+    //  |ymm0 |          | ymm12 | ymm13 | ymm14 | ymm15 |
+    //  |ymm0 |          | ymm12 | ymm13 | ymm14 | ymm15 |
+    //  +--+--+ - - - -  +-------+-------+-------+-------+
+    //
+    //                              Accumulator
+
+    asm volatile(
+        // Set registers for destination
+        "movq  %[dst_col_stride_q], %%r12\n\t"  // stride is r12
+        "shlq $2, %%r12\n\t"                    // set stride dword
+        "leaq (%%r12,%%r12,0x2), %%r13\n\t"     // load stride aligned r13
+
+        // Set accumulators to zero.
+        "vpxor %%ymm4, %%ymm4, %%ymm4 \n\t"    // zero accumulators
+        "vpxor %%ymm5, %%ymm5, %%ymm5 \n\t"    // zero accumulators
+        "vpxor %%ymm6, %%ymm6, %%ymm6 \n\t"    // zero accumulators
+        "vpxor %%ymm7, %%ymm7, %%ymm7 \n\t"    // zero accumulators
+        "vpxor %%ymm8, %%ymm8, %%ymm8 \n\t"    // zero accumulators
+        "vpxor %%ymm9, %%ymm9, %%ymm9 \n\t"    // zero accumulators
+        "vpxor %%ymm10, %%ymm10, %%ymm10\n\t"  // zero accumulators
+        "vpxor %%ymm11, %%ymm11, %%ymm11\n\t"  // zero accumulators
+        "vpxor %%ymm12, %%ymm12, %%ymm12\n\t"  // zero accumulators
+        "vpxor %%ymm13, %%ymm13, %%ymm13\n\t"  // zero accumulators
+        "vpxor %%ymm14, %%ymm14, %%ymm14\n\t"  // zero accumulators
+        "vpxor %%ymm15, %%ymm15, %%ymm15\n\t"  // zero accumulators
+
+        "movq  %[run_depth_cells], %%r14 \n\t"  // load cell depth r14
+        "subq $2, %%r14 \n\t"                   // cell depth is 2
+        "js outerLoop1%= \n\t"                  // outerloop for matrix
+
+        // Loop for K unrolled by 4
+        "outerLoop2%=: \n\t"  // outer loop unroll
+
+        // K = 0,1,2,3
+        // RHS cell to ymm1
+
+        // lower half
+        "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t"  // mov rhs to ymm1
+        "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
+        // LHS cell elements 0 and 1
+        "vpmovzxbw 0x00(%[lhs_ptr]), %%ymm0\n\t"  // mov lhs to ymm0
+        "vpshufd $0x00,%%ymm1,%%ymm2     \n\t"    // move rhs 0 element to all ymm2
+        "vpshufd $0x55,%%ymm1,%%ymm3     \n\t"    // move rhs 1 element to all ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t"    // mul add lhs rhs0 into ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t"    // mul add lhs rhs1 into ymm3
+        "vpaddd %%ymm2, %%ymm4, %%ymm4   \n\t"    // add muladd lhs + rhs0 into ymm4
+        "vpaddd %%ymm3, %%ymm5, %%ymm5   \n\t"    // add muladd lhs + rhs1 into ymm5
+        // LHS cell elements 2 and 3
+        "vpshufd $0xaa, %%ymm1, %%ymm2   \n\t"  // move rhs 2 element to all ymm2
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t"  // mul add lhs rh3 into ymm2
+        "vpshufd $0xff,%%ymm1,%%ymm3     \n\t"  // mov rhs 3 element into all ymm3
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t"  // mul add lhs rh4 into ymm3
+        "vpaddd %%ymm2, %%ymm6, %%ymm6   \n\t"  // add muladd lhs + rhs2 into ymm6
+        "vpaddd %%ymm3, %%ymm7, %%ymm7   \n\t"  // add muladd lhs + rhs3 into ymm7
+
+        // cache prefect lhs //see if it works better?
+        //"prefetcht0 0x80(%[lhs_ptr]) \n\t" //prefetch cache lines
+        "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t"  // mov rhs to ymm1
+        "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
+
+        // K = 5,6,7,8
+        // next LHS cell elements 0 and 1
+        "vpmovzxbw 0x10(%[lhs_ptr]), %%ymm0 \n\t"  // mov lhs to ymm0
+        "vpshufd $0x00,%%ymm1,%%ymm2        \n\t"  // mov rhs 0 element to all ymm2
+        "vpshufd $0x55,%%ymm1,%%ymm3        \n\t"  // mov rhs 1 element to all ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2    \n\t"  // mul add lhs rhs0 into ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3    \n\t"  // mul add lhs rhs1 into ymm3
+        "vpaddd %%ymm2, %%ymm8, %%ymm8      \n\t"  // add muladd lhs + rhs0 into ymm8
+        "vpaddd %%ymm3, %%ymm9, %%ymm9      \n\t"  // add muladd lhs + rhs1 into ymm9
+        // next LHS cell elements 2 and 3
+        "vpshufd $0xaa,%%ymm1,%%ymm2        \n\t"  // mov rhs 2 element to all ymm2
+        "vpshufd $0xff,%%ymm1,%%ymm3        \n\t"  // mov rhs 3 element to all ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2    \n\t"  // mul add lhs rhs2 into ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3    \n\t"  // mul add lhs rhs3 into ymm3
+        "vpaddd %%ymm2, %%ymm10, %%ymm10    \n\t"  // add muladd lhs + rhs2 into ymm10
+        "vpaddd %%ymm3, %%ymm11, %%ymm11    \n\t"  // add muladd lhs + rhs3 into ymm11
+
+        // rhs lower half
+        "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t"  // mov rhs to ymm1
+        "vpermq $0x44,%%ymm1, %%ymm1 \n\t"     // duplcate lower 16
+
+        // next LHS cell elements 0 and 1
+        "vpmovzxbw 0x20(%[lhs_ptr]), %%ymm0 \n\t"    // mov lhs to ymm0
+        "vpshufd $0x00,%%ymm1,%%ymm2        \n\t"    // mov rhs 0 element to all ymm2
+        "vpshufd $0x55,%%ymm1,%%ymm3        \n\t"    // mov rhs 1 element to all ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2    \n\t"    // mul add lhs rhs0 into ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3    \n\t"    // mul add lhs rhs1 into ymm3
+        "vpaddd %%ymm2, %%ymm12, %%ymm12      \n\t"  // add muladd lhs + rhs0 into ymm8
+        "vpaddd %%ymm3, %%ymm13, %%ymm13      \n\t"  // add muladd lhs + rhs1 into ymm9
+
+        // cache prefetch rhs //see if it works better?
+        //"prefetcht0 0x80(%[rhs_ptr]) \n\t"
+
+        // next LHS cell elements 2 and 3
+        "vpshufd $0xaa,%%ymm1,%%ymm2        \n\t"  // mov rhs 2 element to all ymm2
+        "vpshufd $0xff,%%ymm1,%%ymm3        \n\t"  // mov rhs 3 element to all ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2    \n\t"  // mul add lhs rhs2 into ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3    \n\t"  // mul add lhs rhs3 into ymm3
+        "vpaddd %%ymm2, %%ymm14, %%ymm14    \n\t"  // add muladd lhs + rhs2 into ymm10
+        "vpaddd %%ymm3, %%ymm15, %%ymm15    \n\t"  // add muladd lhs + rhs3 into ymm11
+
+        // current result in ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10 ymm11 ymm12 ymm13 ymm14 ymm15
+
+        // rhs+10 lower half
+        "vpmovzxbw 0x08(%[rhs_ptr]), %%ymm1 \n\t"  // mov rhs to ymm1
+        "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
+        // next LHS cell elements 0 and 1
+        "vpmovzxbw 0x30(%[lhs_ptr]), %%ymm0 \n\t"  // mov lhs to ymm0
+        "vpshufd $0x00,%%ymm1,%%ymm2        \n\t"  // move rhs 0 element to ymm2
+        "vpshufd $0x55,%%ymm1,%%ymm3        \n\t"  // move rhs 1 element to ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2    \n\t"  // muladd lhs rhs0 into ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3    \n\t"  // muladd lhs rhs1 into ymm3
+        "vpaddd %%ymm2, %%ymm4, %%ymm4      \n\t"  // accumulate to ymm4
+        "vpaddd %%ymm3, %%ymm5, %%ymm5      \n\t"  // accumulate to ymm5
+        // next LHS cell elements 2 and 3
+        "vpshufd $0xaa,%%ymm1,%%ymm2        \n\t"  // mov rhs 2 element to ymm2
+        "vpshufd $0xff,%%ymm1,%%ymm3        \n\t"  // mov rhs 3 element to ymm2
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2    \n\t"  // mul add lhs rhs2 into ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3    \n\t"  // mull add lhs rhs3 into ymm3
+        "vpaddd %%ymm2, %%ymm6, %%ymm6      \n\t"  // add lhs rhs2 to ymm6
+        "vpaddd %%ymm3, %%ymm7, %%ymm7      \n\t"  // add lhs rhs3 to ymm7
+
+        // rhs+10 lower half
+        "vpmovzxbw 0x08(%[rhs_ptr]), %%ymm1 \n\t"  // mov rhs to ymm1
+        "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
+
+        // next LHS cell elements 4 and 5
+        "vpmovzxbw 0x40(%[lhs_ptr]), %%ymm0 \n\t"  // mov lhs to ymm0
+        "vpshufd $0x00,%%ymm1,%%ymm2        \n\t"  // move rhs 0 element to ymm2
+        "vpshufd $0x55,%%ymm1,%%ymm3        \n\t"  // move rhs 1 element to ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2    \n\t"  // muladd lhs rhs0 into ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3    \n\t"  // muladd lhs rhs1 into ymm3
+        "vpaddd %%ymm2, %%ymm8, %%ymm8      \n\t"  // accumulate to ymm8
+        "vpaddd %%ymm3, %%ymm9, %%ymm9      \n\t"  // accumulate to ymm9
+        // next LHS cell elements 6 and 7
+        "vpshufd $0xaa,%%ymm1,%%ymm2        \n\t"  // mov rhs 2 element to ymm2
+        "vpshufd $0xff,%%ymm1,%%ymm3        \n\t"  // mov rhs 3 element to ymm2
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2    \n\t"  // mul add lhs rhs2 into ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3    \n\t"  // mull add lhs rhs3 into ymm3
+        "vpaddd %%ymm2, %%ymm10, %%ymm10    \n\t"  // add lhs rhs2 to ymm10
+        "vpaddd %%ymm3, %%ymm11, %%ymm11    \n\t"  // add lhs rhs3 to ymm11
+
+        "vpmovzxbw 0x08(%[rhs_ptr]), %%ymm1 \n\t"  // mov rhs to ymm1
+        "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
+        // next LHS cell elements 9 and 10
+        "vpmovzxbw 0x50(%[lhs_ptr]), %%ymm0 \n\t"  // mov lhs to ymm0
+        "vpshufd $0x00,%%ymm1,%%ymm2        \n\t"  // move rhs 0 element to ymm2
+        "vpshufd $0x55,%%ymm1,%%ymm3        \n\t"  // move rhs 1 element to ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2    \n\t"  // muladd lhs rhs0 into ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3    \n\t"  // muladd lhs rhs1 into ymm3
+        "vpaddd %%ymm2, %%ymm12, %%ymm12    \n\t"  // accumulate to ymm12
+        "vpaddd %%ymm3, %%ymm13, %%ymm13    \n\t"  // accumulate to ymm13
+
+        // next LHS cell elements 11 and 12
+        "vpshufd $0xaa,%%ymm1,%%ymm2        \n\t"  // mov rhs 2 element to ymm2
+        "vpshufd $0xff,%%ymm1,%%ymm3        \n\t"  // mov rhs 3 element to ymm2
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2    \n\t"  // mul add lhs rhs2 into ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3    \n\t"  // mull add lhs rhs3 into ymm3
+        "vpaddd %%ymm2, %%ymm14, %%ymm14    \n\t"  // add lhs rhs2 to ymm14
+        "vpaddd %%ymm3, %%ymm15, %%ymm15    \n\t"  // add lhs rhs3 to ymm15
+
+        // completed rhs+10
+        "addq $0x60, %[lhs_ptr]             \n\t"  // increment stride lhs
+        "addq $0x10, %[rhs_ptr]             \n\t"  // increment stride rhs
+
+        "subq $2, %[run_depth_cells] \n\t"
+        "ja outerLoop2%= \n\t"
+
+        "movq %[run_depth_cells], %%r14 \n\t"
+        "decq %%r14 \n\t"
+        "js finish%= \n\t"
+
+        // Loop for K unrolled by 2
+        "outerLoop1%=: \n\t"
+
+        // rhs lower
+        "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t"  // get rhs into ymm1
+        "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
+
+        // LHS cell
+        "vpmovzxbw (%[lhs_ptr]), %%ymm0  \n\t"      // lhs in into ymm0
+        "vpshufd $0x00,%%ymm1,%%ymm2         \n\t"  // rhs element 0 into ymm2
+        "vpshufd $0x55,%%ymm1,%%ymm3         \n\t"  // rhs element 1 into ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2     \n\t"  // muladd lhs rhs element 0 ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3     \n\t"  // muladd lhs rhs element 1 ymm3
+        "vpaddd %%ymm2, %%ymm4, %%ymm4       \n\t"  // acc element 0 ymm4
+        "vpaddd %%ymm3, %%ymm5, %%ymm5       \n\t"  // acc element 1 ymm5
+        "vpshufd $0xaa,%%ymm1,%%ymm2         \n\t"  // rhs element 2 into ymm2
+        "vpshufd $0xff,%%ymm1,%%ymm3         \n\t"  // rhs element 3 into ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2     \n\t"  // muladd lhs rhs element 2 ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3     \n\t"  // muladd lhs rhs element 3 ymm3
+        "vpaddd %%ymm2, %%ymm6, %%ymm6       \n\t"  // acc element 2 into ymm6
+        "vpaddd %%ymm3, %%ymm7, %%ymm7       \n\t"  // acc element 3 into ymm7
+
+        // lhs+10
+        "vpmovzxbw 0x10(%[lhs_ptr]), %%ymm0  \n\t"  // lhs in into ymm0
+        "vpshufd $0x00, %%ymm1, %%ymm2       \n\t"  // rhs element 0 into ymm2
+        "vpshufd $0x55, %%ymm1, %%ymm3       \n\t"  // rhs element 1 into ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2     \n\t"  // muladd lhs rhs element 0 ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3     \n\t"  // muladd lhs rhs element 1 ymm3
+        "vpaddd %%ymm2, %%ymm8, %%ymm8       \n\t"  // acc element 0 ymm8
+        "vpaddd %%ymm3, %%ymm9, %%ymm9       \n\t"  // acc element 1 ymm9
+        "vpshufd $0xaa,%%ymm1,%%ymm2         \n\t"  // rhs element 2 into ymm2
+        "vpshufd $0xff,%%ymm1,%%ymm3         \n\t"  // rhs element 3 into ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2     \n\t"  // muladd lhs rhs element 2 ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3     \n\t"  // muladd lhs rhs element 3 ymm3
+        "vpaddd %%ymm2, %%ymm10, %%ymm10     \n\t"  // acc element 2 into ymm10
+        "vpaddd %%ymm3, %%ymm11, %%ymm11     \n\t"  // acc element 3 into ymm11
+
+        "vpmovzxbw 0x20(%[lhs_ptr]), %%ymm0  \n\t"
+        "vpshufd $0x00, %%ymm1, %%ymm2       \n\t"  // rhs element 0 into ymm2
+        "vpshufd $0x55, %%ymm1, %%ymm3       \n\t"  // rhs element 1 into ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2     \n\t"  // muladd lhs rhs element 0 ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3     \n\t"  // muladd lhs rhs element 1 ymm3
+        "vpaddd %%ymm2, %%ymm12, %%ymm12     \n\t"  // acc element 0 ymm12
+        "vpaddd %%ymm3, %%ymm13, %%ymm13     \n\t"  // acc element 1 ymm13
+        "vpshufd $0xaa,%%ymm1,%%ymm2         \n\t"  // rhs element 2 into ymm2
+        "vpshufd $0xff,%%ymm1,%%ymm3         \n\t"  // rhs element 3 into ymm3
+        "vpmaddwd %%ymm0, %%ymm2, %%ymm2     \n\t"  // muladd lhs rhs element 2 ymm2
+        "vpmaddwd %%ymm0, %%ymm3, %%ymm3     \n\t"  // muladd lhs rhs element 3 ymm3
+        "vpaddd %%ymm2, %%ymm14, %%ymm14     \n\t"  // acc element 2 into ymm14
+        "vpaddd %%ymm3, %%ymm15, %%ymm15     \n\t"  // acc element 3 into ymm15
+
+        // update matrix pointers
+        "addq $0x30, %[lhs_ptr]              \n\t"
+        "addq $0x08, %[rhs_ptr]              \n\t"
+
+        "decq %[run_depth_cells]             \n\t"
+        "jnz outerLoop1%=                    \n\t"
+
+        "finish%=:\n\t"
+
+        "test %[start_depth], %[start_depth] \n\t"
+        "jz storeDst%= \n\t"
+
+        "vpaddd 0x00(%[dst_ptr]), %%ymm4, %%ymm4 \n\t"    // rhs0
+        "vpaddd 0x20(%[dst_ptr]), %%ymm8, %%ymm8 \n\t"    // rhs0
+        "vpaddd 0x40(%[dst_ptr]), %%ymm12, %%ymm12 \n\t"  // rhs0
+
+        "vpaddd 0x00(%[dst_ptr], %%r12, 1) , %%ymm5, %%ymm5   \n\t"  // rhs1
+        "vpaddd 0x20(%[dst_ptr], %%r12, 1) , %%ymm9, %%ymm9   \n\t"  // rhs1
+        "vpaddd 0x40(%[dst_ptr], %%r12, 1) , %%ymm13, %%ymm13 \n\t"  // rhs1
+
+        "vpaddd 0x00(%[dst_ptr], %%r12, 2) , %%ymm6, %%ymm6   \n\t"  // rhs2
+        "vpaddd 0x20(%[dst_ptr], %%r12, 2) , %%ymm10, %%ymm10 \n\t"  // rhs2
+        "vpaddd 0x40(%[dst_ptr], %%r12, 2) , %%ymm14, %%ymm14 \n\t"  // rhs2
+
+        "vpaddd 0x00(%[dst_ptr], %%r13, 1) , %%ymm7, %%ymm7   \n\t"  // rhs3
+        "vpaddd 0x20(%[dst_ptr], %%r13, 1) , %%ymm11, %%ymm11 \n\t"  // rhs3
+        "vpaddd 0x40(%[dst_ptr], %%r13, 1) , %%ymm15, %%ymm15 \n\t"  // rhs3
+
+        "storeDst%=:\n\t"
+
+        "vmovdqu %%ymm4, 0x00(%[dst_ptr])            \n\t"  // rhs0
+        "vmovdqu %%ymm8, 0x20(%[dst_ptr])            \n\t"  // rhs0
+        "vmovdqu %%ymm12, 0x40(%[dst_ptr])           \n\t"  // rhs0
+
+        "vmovdqu %%ymm5, 0x00(%[dst_ptr], %%r12, 1)  \n\t"  // rhs1
+        "vmovdqu %%ymm9, 0x20(%[dst_ptr], %%r12, 1)  \n\t"  // rhs1
+        "vmovdqu %%ymm13, 0x40(%[dst_ptr], %%r12, 1) \n\t"  // rhs1
+
+        "vmovdqu %%ymm6, 0x00(%[dst_ptr], %%r12, 2)  \n\t"  // rhs2
+        "vmovdqu %%ymm10, 0x20(%[dst_ptr], %%r12, 2) \n\t"  // rhs2
+        "vmovdqu %%ymm14, 0x40(%[dst_ptr], %%r12, 2) \n\t"  // rhs2
+
+        "vmovdqu %%ymm7, 0x00(%[dst_ptr], %%r13, 1)  \n\t"  // rhs3
+        "vmovdqu %%ymm11, 0x20(%[dst_ptr], %%r13, 1) \n\t"  // rhs3
+        "vmovdqu %%ymm15, 0x40(%[dst_ptr], %%r13, 1) \n\t"  // rhs3
+
+        :  // outputs
+        [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+        [dst_ptr] "+r"(dst_ptr)
+        :  // inputs
+        [start_depth] "r"(start_depth), [dst_col_stride_q] "r"(dst_col_stride_q),
+        [run_depth_cells] "r"(run_depth_cells)
+        :  // clobbers
+        "cc", "memory", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7",
+        "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15", "%r12",
+        "%r13", "%r14");
+  }
+};
+#endif
+
+}  // namespace gemmlowp
+
+#endif  // GEMMLOWP_INTERNAL_KERNEL_AVX_H_
diff --git a/internal/kernel_default.h b/internal/kernel_default.h
index a919ffe..29b0991 100644
--- a/internal/kernel_default.h
+++ b/internal/kernel_default.h
@@ -20,66 +20,84 @@
 
 #include "../public/bit_depth.h"
 #include "common.h"
+#include "kernel.h"
 #include "kernel_reference.h"
 
 namespace gemmlowp {
 
-template <bool MaxProductIsLessThan4096, bool LhsAlwaysNonzero>
+template <bool MaxProductIsLessThan4096, bool IsUnsigned, bool LhsNonZero>
 struct DefaultKernelImpl {};
 
 // Partial specialization implementing the logic that if we want to use
-// a kernel for LhsAlwaysNonzero but do not have such a kernel, then we fall
-// back to a generic kernel not taking advantage of LhsAlwaysNonzero.
-template <bool LhsAlwaysNonzero>
-struct DefaultKernelImpl<true, LhsAlwaysNonzero>
-    : DefaultKernelImpl<false, LhsAlwaysNonzero> {};
-
-// Partial specialization implementing the logic that if we want to use
 // a kernel for MaxProductIsLessThan4096 but do not have such a kernel, then we
 // fall back to a generic kernel not taking advantage of
 // MaxProductIsLessThan4096.
+template <bool LhsNonZero>
+struct DefaultKernelImpl<true, true, LhsNonZero>
+    : DefaultKernelImpl<false, true, LhsNonZero> {};
+
+// Partial specialization implementing the logic that if we want to use
+// a kernel for LhsNonZero but do not have such a kernel, then we fall
+// back to a generic kernel not taking advantage of LhsNonZero.
 template <bool MaxProductIsLessThan4096>
-struct DefaultKernelImpl<MaxProductIsLessThan4096, true>
-    : DefaultKernelImpl<MaxProductIsLessThan4096, false> {};
+struct DefaultKernelImpl<MaxProductIsLessThan4096, true, true>
+    : DefaultKernelImpl<MaxProductIsLessThan4096, true, false> {};
 
 template <typename BitDepthParams>
 struct DefaultKernel
     : DefaultKernelImpl<(BitDepthParams::LhsRange::kMaxValue *
                              BitDepthParams::RhsRange::kMaxValue <
                          4096),
-                        (BitDepthParams::LhsRange::kMinValue > 0)> {};
+                        (BitDepthParams::LhsRange::kMinValue >= 0),
+                        (BitDepthParams::LhsRange::kMinValue > 0 ||
+                         (BitDepthParams::LhsRange::kMaxValue <= 127 &&
+                          BitDepthParams::LhsRange::kMinValue > -128))> {};
 
 }  // end namespace gemmlowp
 
-#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096,          \
-                                    LhsAlwaysNonzero, Kernel)          \
-  namespace gemmlowp {                                                 \
-  template <>                                                          \
-  struct DefaultKernelImpl<MaxProductIsLessThan4096, LhsAlwaysNonzero> \
-      : Kernel {};                                                     \
+#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, IsUnsigned, \
+                                    LhsAlwaysNonZero, Kernel)             \
+  namespace gemmlowp {                                                    \
+  template <>                                                             \
+  struct DefaultKernelImpl<MaxProductIsLessThan4096, IsUnsigned,          \
+                           LhsAlwaysNonZero> : Kernel {};                 \
   }
 
+// User-provided int8 inputs is only supported in the NEON path currently.
 #if defined GEMMLOWP_NEON_32
 #include "kernel_neon.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_32_Kernel12x4Depth2)
-GEMMLOWP_SET_DEFAULT_KERNEL(true, false,
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, NEON_32_Kernel12x4Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(true, true, false,
                             NEON_32_Kernel12x4Depth2Assuming12BitProducts)
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true,
                             NEON_32bit_GEMM_Int8Operands_LhsNonzero)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, true,
+                            NEON_32bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs)
 #elif defined GEMMLOWP_NEON_64
 #include "kernel_neon.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_64_Kernel12x8Depth2)
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
+#if defined GEMMLOWP_DOTPROD_KERNEL
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false,
+                            NEON_64_Kernel12x8Depth4_dotprod)
+#else
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, NEON_64_Kernel12x8Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true,
                             NEON_64bit_GEMM_Int8Operands_LhsNonzero)
+#endif
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, true,
+                            NEON_64bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs)
 #elif defined(GEMMLOWP_MSA)
 #include "kernel_msa.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, MSA_Kernel12x8Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, MSA_Kernel12x8Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true, MSA_GEMM_Int8Operands_LhsNonzero)
 #elif defined GEMMLOWP_SSE4_32
 #include "kernel_sse.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_32_Kernel4x4Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, SSE4_32_Kernel4x4Depth2)
 #elif defined GEMMLOWP_SSE4_64
 #include "kernel_sse.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_64_Kernel12x4Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, SSE4_64_Kernel12x4Depth2)
+#elif defined GEMMLOWP_AVX2_64
+#include "kernel_avx.h"
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, AVX2_64_Kernel24x8Depth2)
 #else
 #include "kernel_reference.h"
 namespace gemmlowp {
@@ -88,7 +106,7 @@
     KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > >
     DefaultReferenceKernel;
 }
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, DefaultReferenceKernel)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, DefaultReferenceKernel)
 #endif
 
 #endif  // GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_
diff --git a/internal/kernel_msa.h b/internal/kernel_msa.h
index 4985b73..a9205f6 100644
--- a/internal/kernel_msa.h
+++ b/internal/kernel_msa.h
@@ -42,8 +42,8 @@
 
 // Our main GEMM kernel.
 struct MSA_Kernel12x8Depth2 : KernelBase {
-  typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,
-                       KernelSideFormat<CellFormat<4, 2>, 2> >
+  typedef KernelFormat<KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 3>,
+                       KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2> >
       Format;
 
   const char* Name() const override { return "MSA, 12x8, depth 2"; }
@@ -62,9 +62,6 @@
 
     assert(dst_row_stride == 1);
     asm volatile(
-        // Set a temp to all zeroes.
-        "ldi.b  $w31, 0\n"
-
         // Multiply dst_col_stride by 4 == sizeof(int32) to use
         // it as a byte offset below.
         GEMMLOWP_MIPS_XSLL
@@ -75,32 +72,25 @@
         "beqz   %[start_depth], " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n"
 
         // Load accumulators (start_depth != 0).
-        GEMMLOWP_MIPS_XADDU
-        " $a0, %[dst_ptr], %[dst_col_stride]\n"
+        GEMMLOWP_MIPS_XADDU " $a0, %[dst_ptr], %[dst_col_stride]\n"
         "ld.w   $w0,  (0*16)(%[dst_ptr])\n"
         "ld.w   $w4,  (1*16)(%[dst_ptr])\n"
-        "ld.w   $w8,  (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU
-        " $a1, $a0, %[dst_col_stride]\n"
+        "ld.w   $w8,  (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
         "ld.w   $w1,  (0*16)($a0)\n"
         "ld.w   $w5,  (1*16)($a0)\n"
-        "ld.w   $w9,  (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
-        " $a0, $a1, %[dst_col_stride]\n"
+        "ld.w   $w9,  (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
         "ld.w   $w2,  (0*16)($a1)\n"
         "ld.w   $w6,  (1*16)($a1)\n"
-        "ld.w   $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
-        " $a1, $a0, %[dst_col_stride]\n"
+        "ld.w   $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
         "ld.w   $w3,  (0*16)($a0)\n"
         "ld.w   $w7,  (1*16)($a0)\n"
-        "ld.w   $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
-        " $a0, $a1, %[dst_col_stride]\n"
+        "ld.w   $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
         "ld.w   $w12, (0*16)($a1)\n"
         "ld.w   $w16, (1*16)($a1)\n"
-        "ld.w   $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
-        " $a1, $a0, %[dst_col_stride]\n"
+        "ld.w   $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
         "ld.w   $w13, (0*16)($a0)\n"
         "ld.w   $w17, (1*16)($a0)\n"
-        "ld.w   $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
-        " $a0, $a1, %[dst_col_stride]\n"
+        "ld.w   $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
         "ld.w   $w14, (0*16)($a1)\n"
         "ld.w   $w18, (1*16)($a1)\n"
         "ld.w   $w22, (2*16)($a1)\n"
@@ -109,8 +99,7 @@
         "ld.w   $w23, (2*16)($a0)\n"
         "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n"
 
-        GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
-        ":\n"
+        GEMMLOWP_LABEL_CLEAR_ACCUMULATORS ":\n"
         // Clear accumulators (start_depth == 0).
         "ldi.w  $w0,  0\n"
         "ldi.w  $w4,  0\n"
@@ -139,17 +128,16 @@
 
         GEMMLOWP_LABEL_BEFORE_LOOP ":\n"
 
-        GEMMLOWP_LABEL_LOOP
-        ":\n"
+        GEMMLOWP_LABEL_LOOP ":\n"
         // Overview of register layout:
         //
-        // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
+        // A half of the 2 2x4 cells of Rhs is stored in 16bit in w28-w31
         // (each register contains 4 replicas of a pair of elements).
         // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
         // A 12x8 block of accumulators is stored in 32bit in w0-w23.
         //
         //                    +------+------+------+------+
-        //               Rhs  |w27   |w28   |w29   |w30   |
+        //               Rhs  |w28   |w29   |w30   |w31   |
         //                    +------+------+------+------+
         //
         //                    |      |      |      |      |
@@ -179,128 +167,86 @@
         "ld.b   $w24, 0(%[lhs_ptr])\n"
         "ld.b   $w25, 8(%[lhs_ptr])\n"
 
-        // Load 4 bytes of rhs[] for the first half of depth 0.
-        "lbu    $a0, 0(%[rhs_ptr])\n"
-        "lbu    $a1, 1(%[rhs_ptr])\n"
-        "lbu    $a2, 2(%[rhs_ptr])\n"
-        "lbu    $a3, 3(%[rhs_ptr])\n"
-        // Load 4 bytes of rhs[] for the first half of depth 1.
-        "lbu    $v0, 4(%[rhs_ptr])\n"
-        "lbu    $v1, 5(%[rhs_ptr])\n"
-        "lbu    $t8, 6(%[rhs_ptr])\n"
-        "lbu    $t9, 7(%[rhs_ptr])\n"
+        // Load 2 x 8 bytes of rhs[].
+        "ld.b   $w27, 0(%[rhs_ptr])\n"
 
         // Zero-extend 8-bit elements of lhs[] to 16 bits.
+        "ldi.b  $w31, 0\n"
         "ilvr.b $w24, $w31, $w24\n"
         "ilvl.b $w26, $w31, $w25\n"
         "ilvr.b $w25, $w31, $w25\n"
-        // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
-        "ilvl.d $w27, $w31, $w24\n"
-        "ilvl.d $w28, $w31, $w25\n"
-        "ilvl.d $w29, $w31, $w26\n"
-        "ilvr.h $w24, $w27, $w24\n"
-        "ilvr.h $w25, $w28, $w25\n"
-        "ilvr.h $w26, $w29, $w26\n"
-
-        // Combine and interleave depth 0 and depth 1 elements of rhs[] for
-        // dpadd_u.w (for the first half).
-        "ins    $a0, $v0, 16, 8\n"
-        "ins    $a1, $v1, 16, 8\n"
-        "ins    $a2, $t8, 16, 8\n"
-        "ins    $a3, $t9, 16, 8\n"
-        // Make 4 replicas of every pair of rhs[] elements.
-        "fill.w $w27, $a0\n"
-        "fill.w $w28, $a1\n"
-        "fill.w $w29, $a2\n"
-        "fill.w $w30, $a3\n"
-
-        // Load 4 bytes of rhs[] for the second half of depth 0.
-        "lbu    $a0, 8(%[rhs_ptr])\n"
-        "lbu    $a1, 9(%[rhs_ptr])\n"
-        "lbu    $a2, 10(%[rhs_ptr])\n"
-        "lbu    $a3, 11(%[rhs_ptr])\n"
-        // Load 4 bytes of rhs[] for the second half of depth 1.
-        "lbu    $v0, 12(%[rhs_ptr])\n"
-        "lbu    $v1, 13(%[rhs_ptr])\n"
-        "lbu    $t8, 14(%[rhs_ptr])\n"
-        "lbu    $t9, 15(%[rhs_ptr])\n"
 
         // First half of depths 0 and 1.
-        // Dot-product-(and)-add doubles multiplicand width.
-        "dpadd_u.w $w0, $w24, $w27\n"
-        "dpadd_u.w $w4, $w25, $w27\n"
-        "dpadd_u.w $w8, $w26, $w27\n"
-        "dpadd_u.w $w1, $w24, $w28\n"
-        "dpadd_u.w $w5, $w25, $w28\n"
-        "dpadd_u.w $w9, $w26, $w28\n"
-        "dpadd_u.w $w2, $w24, $w29\n"
-        "dpadd_u.w $w6, $w25, $w29\n"
-        "dpadd_u.w $w10, $w26, $w29\n"
-        "dpadd_u.w $w3, $w24, $w30\n"
-        "dpadd_u.w $w7, $w25, $w30\n"
-        "dpadd_u.w $w11, $w26, $w30\n"
-
-        // Combine and interleave depth 0 and depth 1 elements of rhs[] for
-        // dpadd_u.w (for the second half).
-        "ins    $a0, $v0, 16, 8\n"
-        "ins    $a1, $v1, 16, 8\n"
-        "ins    $a2, $t8, 16, 8\n"
-        "ins    $a3, $t9, 16, 8\n"
+        // Zero-extend 8-bit elements of rhs[] to 16 bits.
+        "ilvr.b    $w31, $w31, $w27\n"
         // Make 4 replicas of every pair of rhs[] elements.
-        "fill.w $w27, $a0\n"
-        "fill.w $w28, $a1\n"
-        "fill.w $w29, $a2\n"
-        "fill.w $w30, $a3\n"
+        "splati.w  $w28, $w31[0]\n"
+        "splati.w  $w29, $w31[1]\n"
+        "splati.w  $w30, $w31[2]\n"
+        "splati.w  $w31, $w31[3]\n"
+        // Dot-product-(and)-add doubles multiplicand width.
+        "dpadd_u.w  $w0, $w24, $w28\n"
+        "dpadd_u.w  $w4, $w25, $w28\n"
+        "dpadd_u.w  $w8, $w26, $w28\n"
+        "dpadd_u.w  $w1, $w24, $w29\n"
+        "dpadd_u.w  $w5, $w25, $w29\n"
+        "dpadd_u.w  $w9, $w26, $w29\n"
+        "dpadd_u.w  $w2, $w24, $w30\n"
+        "dpadd_u.w  $w6, $w25, $w30\n"
+        "dpadd_u.w $w10, $w26, $w30\n"
+        "dpadd_u.w  $w3, $w24, $w31\n"
+        "dpadd_u.w  $w7, $w25, $w31\n"
+        "dpadd_u.w $w11, $w26, $w31\n"
 
         // Second half of depths 0 and 1.
+        // Zero-extend 8-bit elements of rhs[] to 16 bits.
+        "ldi.b     $w31, 0\n"
+        "ilvl.b    $w31, $w31, $w27\n"
+        // Make 4 replicas of every pair of rhs[] elements.
+        "splati.w  $w28, $w31[0]\n"
+        "splati.w  $w29, $w31[1]\n"
+        "splati.w  $w30, $w31[2]\n"
+        "splati.w  $w31, $w31[3]\n"
         // Dot-product-(and)-add doubles multiplicand width.
-        "dpadd_u.w $w12, $w24, $w27\n"
-        "dpadd_u.w $w16, $w25, $w27\n"
-        "dpadd_u.w $w20, $w26, $w27\n"
-        "dpadd_u.w $w13, $w24, $w28\n"
-        "dpadd_u.w $w17, $w25, $w28\n"
-        "dpadd_u.w $w21, $w26, $w28\n"
-        "dpadd_u.w $w14, $w24, $w29\n"
-        "dpadd_u.w $w18, $w25, $w29\n"
-        "dpadd_u.w $w22, $w26, $w29\n"
-        "dpadd_u.w $w15, $w24, $w30\n"
-        "dpadd_u.w $w19, $w25, $w30\n"
-        "dpadd_u.w $w23, $w26, $w30\n"
+        "dpadd_u.w $w12, $w24, $w28\n"
+        "dpadd_u.w $w16, $w25, $w28\n"
+        "dpadd_u.w $w20, $w26, $w28\n"
+        "dpadd_u.w $w13, $w24, $w29\n"
+        "dpadd_u.w $w17, $w25, $w29\n"
+        "dpadd_u.w $w21, $w26, $w29\n"
+        "dpadd_u.w $w14, $w24, $w30\n"
+        "dpadd_u.w $w18, $w25, $w30\n"
+        "dpadd_u.w $w22, $w26, $w30\n"
+        "dpadd_u.w $w15, $w24, $w31\n"
+        "dpadd_u.w $w19, $w25, $w31\n"
+        "dpadd_u.w $w23, $w26, $w31\n"
 
         GEMMLOWP_MIPS_XADDIU " %[run_depth], -2\n" GEMMLOWP_MIPS_XADDIU
-        " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU
-        " %[rhs_ptr], 16\n"
+        " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
         "bnez   %[run_depth]," GEMMLOWP_LABEL_LOOP "b\n"
 
         GEMMLOWP_LABEL_AFTER_LOOP ":\n"
 
         // Store accumulators.
-        GEMMLOWP_MIPS_XADDU
-        " $a0, %[dst_ptr], %[dst_col_stride]\n"
+        GEMMLOWP_MIPS_XADDU " $a0, %[dst_ptr], %[dst_col_stride]\n"
         "st.w   $w0,  (0*16)(%[dst_ptr])\n"
         "st.w   $w4,  (1*16)(%[dst_ptr])\n"
-        "st.w   $w8,  (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU
-        " $a1, $a0, %[dst_col_stride]\n"
+        "st.w   $w8,  (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
         "st.w   $w1,  (0*16)($a0)\n"
         "st.w   $w5,  (1*16)($a0)\n"
-        "st.w   $w9,  (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
-        " $a0, $a1, %[dst_col_stride]\n"
+        "st.w   $w9,  (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
         "st.w   $w2,  (0*16)($a1)\n"
         "st.w   $w6,  (1*16)($a1)\n"
-        "st.w   $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
-        " $a1, $a0, %[dst_col_stride]\n"
+        "st.w   $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
         "st.w   $w3,  (0*16)($a0)\n"
         "st.w   $w7,  (1*16)($a0)\n"
-        "st.w   $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
-        " $a0, $a1, %[dst_col_stride]\n"
+        "st.w   $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
         "st.w   $w12, (0*16)($a1)\n"
         "st.w   $w16, (1*16)($a1)\n"
-        "st.w   $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
-        " $a1, $a0, %[dst_col_stride]\n"
+        "st.w   $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
         "st.w   $w13, (0*16)($a0)\n"
         "st.w   $w17, (1*16)($a0)\n"
-        "st.w   $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
-        " $a0, $a1, %[dst_col_stride]\n"
+        "st.w   $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
         "st.w   $w14, (0*16)($a1)\n"
         "st.w   $w18, (1*16)($a1)\n"
         "st.w   $w22, (2*16)($a1)\n"
@@ -308,18 +254,15 @@
         "st.w   $w19, (1*16)($a0)\n"
         "st.w   $w23, (2*16)($a0)\n"
         :  // outputs
-        [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
-        [run_depth] "+r"(run_depth),
+        [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [run_depth] "+r"(run_depth),
         [dst_col_stride] "+r"(dst_col_stride)
         :  // inputs
         [dst_ptr] "r"(dst_ptr),
         [start_depth] "r"(start_depth)
         :  // clobbers
-        "memory", "v0", "v1", "a0", "a1", "a2", "a3", "t8", "t9", "$f0", "$f1",
-        "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9", "$f10", "$f11",
-        "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20",
-        "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29",
-        "$f30", "$f31");
+        "memory", "a0", "a1", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9",
+        "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20",
+        "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
 
 #undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
 #undef GEMMLOWP_LABEL_BEFORE_LOOP
@@ -328,6 +271,303 @@
   }
 };
 
+// Fast kernel operating on int8 operands.
+// It is assumed that one of the two int8 operands only takes values
+// in [-127, 127], while the other may freely range in [-128, 127].
+// The issue with both operands taking the value -128 is that:
+// -128*-128 + -128*-128 == -32768 overflows int16.
+// Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
+// range. That is the basic idea of this kernel.
+struct MSA_GEMM_Int8Operands_LhsNonzero : KernelBase {
+  typedef KernelFormat<
+      KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+      KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
+      Format;
+
+  const char* Name() const override {
+    return "MSA, 4x4, depth 16, accumulating two within signed int16";
+  }
+
+  // TODO(benoitjacob): reorder function arguments so dst comes last
+  void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
+           std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
+           const std::uint8_t* rhs_ptr, std::size_t start_depth,
+           std::size_t run_depth) const override {
+    (void)dst_row_stride;
+#define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1"
+#define GEMMLOWP_LABEL_LOOP "2"
+#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
+#define GEMMLOWP_LABEL_STORE "4"
+    asm volatile(
+        GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n"
+        // Load lhs[] and rhs[], zero out internal accumulators.
+        "ld.b       $w16, 0(%[lhs_ptr])\n"
+        "ldi.b      $w0, 0\n"
+        "ld.b       $w20, 0(%[rhs_ptr])\n"
+        "ldi.b      $w1, 0\n"
+        "ld.b       $w17, 16(%[lhs_ptr])\n"
+        "ldi.b      $w2, 0\n"
+        "ld.b       $w21, 16(%[rhs_ptr])\n"
+        "ldi.b      $w3, 0\n"
+        "ld.b       $w18, 32(%[lhs_ptr])\n"
+        "ldi.b      $w4, 0\n"
+        "ld.b       $w19, 48(%[lhs_ptr])\n"
+        "ldi.b      $w5, 0\n"
+        "ld.b       $w22, 32(%[rhs_ptr])\n"
+        "ldi.b      $w6, 0\n"
+        "ld.b       $w23, 48(%[rhs_ptr])\n"
+        "ldi.b      $w7, 0\n"
+        "ldi.b      $w8, 0\n"
+        "ldi.b      $w9, 0\n"
+        "ldi.b      $w10, 0\n"
+        "ldi.b      $w11, 0\n"
+        "ldi.b      $w12, 0\n"
+        "ldi.b      $w13, 0\n"
+        "ldi.b      $w14, 0\n"
+        "ldi.b      $w15, 0\n"
+        "ldi.h      $w31, 1\n"
+        // If the loop depth is only 16, then we can skip the general loop
+        // and go straight to the final part of the code.
+        "beqz %[run_depth], " GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "f\n"
+
+        GEMMLOWP_LABEL_LOOP ":\n"
+        // Overview of register layout:
+        //
+        // A 4x16 block of Rhs is stored in 8 bit in w16-w19.
+        // A 4x16 block of Lhs is stored in 8 bit in w20-w23.
+        //
+        // A 4x4 block of accumulators is stored in w0-w15 (as 4x32 bit
+        // components which need to be horizontally added at the end).
+        //
+        // Dot products of Lhs and Rhs are 16-bit values, which can't
+        // immediately be accumulated in 32-bit accumulators by that
+        // same instruction that calculates them.
+        // For example, "dotp_s.h $w25, $w16, $w20" produces 8 16-bit
+        // sums in w25 (note, the 16 sums have already been reduced to 8
+        // by the horizontal addition of the dotp instruction).
+        // They are then sign-extended to 32 bits, horizontally added
+        // (again) to form 4 32-bit sums and then they are finally added
+        // to the 32-bit accumulators, all by "dpadd_s.w $w0, $w25, $w31".
+        //
+        //                    +-----+-----+-----+-----+
+        //               Rhs  | w20 | w21 | w22 | w23 |
+        //                    +-----+-----+-----+-----+
+        //
+        //                    |     |     |     |     |
+        //
+        //       Lhs          |     |     |     |     |
+        //
+        //      +---+ - - - - +-----+-----+-----+-----+
+        //      |w16|         | w0  | w4  | w8  | w12 |
+        //      |w17|         | w1  | w5  | w9  | w13 |
+        //      |w18|         | w2  | w6  | w10 | w14 |
+        //      |w19|         | w3  | w7  | w11 | w15 |
+        //      +---+ - - - - +-----+-----+-----+-----+
+        //
+        //                           Accumulators
+
+        // Calculate the results for 16 depths and load
+        // lhs[] and rhs[] for the next iteration.
+        GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 64\n"
+        GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 64\n"
+        GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n"
+
+        // Dot product: multiply-add pairs of adjacent int8 elements.
+        // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+        "dotp_s.h   $w25, $w16, $w20\n"
+        "dotp_s.h   $w26, $w17, $w20\n"
+        "dotp_s.h   $w27, $w16, $w21\n"
+        "dotp_s.h   $w28, $w17, $w21\n"
+        "dotp_s.h   $w29, $w18, $w20\n"
+        // Horizontal add of pairs of adjacent int16 sums into internal int32
+        // accumulators.
+        "dpadd_s.w  $w0, $w25, $w31\n"
+        "dpadd_s.w  $w1, $w26, $w31\n"
+        "dpadd_s.w  $w4, $w27, $w31\n"
+        "dpadd_s.w  $w5, $w28, $w31\n"
+        "dpadd_s.w  $w2, $w29, $w31\n"
+
+        // Dot product: multiply-add pairs of adjacent int8 elements.
+        // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+        "dotp_s.h   $w24, $w16, $w22\n"
+        "dotp_s.h   $w25, $w19, $w20\n"
+        "dotp_s.h   $w26, $w16, $w23\n"
+        "dotp_s.h   $w27, $w17, $w22\n"
+        "ld.b       $w20, 0(%[rhs_ptr])\n"
+        "dotp_s.h   $w28, $w17, $w23\n"
+        "ld.b       $w16, 0(%[lhs_ptr])\n"
+        "dotp_s.h   $w29, $w18, $w21\n"
+        "ld.b       $w17, 16(%[lhs_ptr])\n"
+        // Horizontal add of pairs of adjacent int16 sums into internal int32
+        // accumulators.
+        "dpadd_s.w  $w8, $w24, $w31\n"
+        "dpadd_s.w  $w3, $w25, $w31\n"
+        "dpadd_s.w  $w12, $w26, $w31\n"
+        "dpadd_s.w  $w9, $w27, $w31\n"
+        "dpadd_s.w  $w13, $w28, $w31\n"
+        "dpadd_s.w  $w6, $w29, $w31\n"
+
+        // Dot product: multiply-add pairs of adjacent int8 elements.
+        // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+        "dotp_s.h   $w25, $w19, $w21\n"
+        "dotp_s.h   $w26, $w18, $w22\n"
+        "dotp_s.h   $w27, $w18, $w23\n"
+        "ld.b       $w21, 16(%[rhs_ptr])\n"
+        "dotp_s.h   $w28, $w19, $w22\n"
+        "ld.b       $w18, 32(%[lhs_ptr])\n"
+        "dotp_s.h   $w29, $w19, $w23\n"
+        "ld.b       $w22, 32(%[rhs_ptr])\n"
+        // Horizontal add of pairs of adjacent int16 sums into internal int32
+        // accumulators.
+        "dpadd_s.w  $w7, $w25, $w31\n"
+        "ld.b       $w19, 48(%[lhs_ptr])\n"
+        "dpadd_s.w  $w10, $w26, $w31\n"
+        "ld.b       $w23, 48(%[rhs_ptr])\n"
+        "dpadd_s.w  $w14, $w27, $w31\n"
+        "dpadd_s.w  $w11, $w28, $w31\n"
+        "dpadd_s.w  $w15, $w29, $w31\n"
+
+        "bnez %[run_depth], " GEMMLOWP_LABEL_LOOP "b\n"
+
+        GEMMLOWP_LABEL_AFTER_LOOP_LAST16 ":\n"
+        // Calculate the results for the last 16 depths.
+
+        // Dot product: multiply-add pairs of adjacent int8 elements.
+        // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+        "dotp_s.h   $w25, $w16, $w20\n"
+        "dotp_s.h   $w26, $w17, $w20\n"
+        "dotp_s.h   $w27, $w16, $w21\n"
+        "dotp_s.h   $w28, $w17, $w21\n"
+        "dotp_s.h   $w29, $w18, $w20\n"
+        // Horizontal add of pairs of adjacent int16 sums into internal int32
+        // accumulators.
+        "dpadd_s.w  $w0, $w25, $w31\n"
+        "dpadd_s.w  $w1, $w26, $w31\n"
+        "dpadd_s.w  $w4, $w27, $w31\n"
+        "dpadd_s.w  $w5, $w28, $w31\n"
+        "dpadd_s.w  $w2, $w29, $w31\n"
+
+        // Dot product: multiply-add pairs of adjacent int8 elements.
+        // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+        "dotp_s.h   $w24, $w16, $w22\n"
+        "dotp_s.h   $w25, $w19, $w20\n"
+        "dotp_s.h   $w26, $w16, $w23\n"
+        "dotp_s.h   $w27, $w17, $w22\n"
+        "dotp_s.h   $w28, $w17, $w23\n"
+        "dotp_s.h   $w29, $w18, $w21\n"
+        // Horizontal add of pairs of adjacent int16 sums into internal int32
+        // accumulators.
+        "dpadd_s.w  $w8, $w24, $w31\n"
+        "dpadd_s.w  $w3, $w25, $w31\n"
+        "dpadd_s.w  $w12, $w26, $w31\n"
+        "dpadd_s.w  $w9, $w27, $w31\n"
+        "dpadd_s.w  $w13, $w28, $w31\n"
+        "dpadd_s.w  $w6, $w29, $w31\n"
+
+        // Dot product: multiply-add pairs of adjacent int8 elements.
+        // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+        "dotp_s.h   $w25, $w19, $w21\n"
+        "dotp_s.h   $w26, $w18, $w22\n"
+        "dotp_s.h   $w27, $w18, $w23\n"
+        "dotp_s.h   $w28, $w19, $w22\n"
+        "dotp_s.h   $w29, $w19, $w23\n"
+        // Horizontal add of pairs of adjacent int16 sums into internal int32
+        // accumulators.
+        "dpadd_s.w  $w7, $w25, $w31\n"
+        "dpadd_s.w  $w10, $w26, $w31\n"
+        "dpadd_s.w  $w14, $w27, $w31\n"
+        "dpadd_s.w  $w11, $w28, $w31\n"
+        "dpadd_s.w  $w15, $w29, $w31\n"
+
+        // Horizontal-add internal accumulators.
+        "hadd_s.d   $w0, $w0, $w0\n"
+        "hadd_s.d   $w1, $w1, $w1\n"
+        "hadd_s.d   $w2, $w2, $w2\n"
+        "hadd_s.d   $w3, $w3, $w3\n"
+        "hadd_s.d   $w4, $w4, $w4\n"
+        "hadd_s.d   $w5, $w5, $w5\n"
+        "hadd_s.d   $w6, $w6, $w6\n"
+        "hadd_s.d   $w7, $w7, $w7\n"
+        "hadd_s.d   $w8, $w8, $w8\n"
+        "hadd_s.d   $w9, $w9, $w9\n"
+        "hadd_s.d   $w10, $w10, $w10\n"
+        "hadd_s.d   $w11, $w11, $w11\n"
+        "hadd_s.d   $w12, $w12, $w12\n"
+        "hadd_s.d   $w13, $w13, $w13\n"
+        "hadd_s.d   $w14, $w14, $w14\n"
+        "hadd_s.d   $w15, $w15, $w15\n"
+        "pckev.w    $w0, $w1, $w0\n"
+        "pckev.w    $w2, $w3, $w2\n"
+        "pckev.w    $w4, $w5, $w4\n"
+        "pckev.w    $w6, $w7, $w6\n"
+        "pckev.w    $w8, $w9, $w8\n"
+        "pckev.w    $w10, $w11, $w10\n"
+        "pckev.w    $w12, $w13, $w12\n"
+        "pckev.w    $w14, $w15, $w14\n"
+        "hadd_s.d   $w0, $w0, $w0\n"
+        "hadd_s.d   $w2, $w2, $w2\n"
+        "hadd_s.d   $w4, $w4, $w4\n"
+        "hadd_s.d   $w6, $w6, $w6\n"
+        "hadd_s.d   $w8, $w8, $w8\n"
+        "hadd_s.d   $w10, $w10, $w10\n"
+        "hadd_s.d   $w12, $w12, $w12\n"
+        "hadd_s.d   $w14, $w14, $w14\n"
+        // 4 more pckev instructions follow in both paths below.
+
+        // Check if start_depth==0 to decide whether we will load
+        // existing accumulators from memory.
+        "bnez %[start_depth], " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "f\n"
+
+        "pckev.w    $w0, $w2, $w0\n"
+        "pckev.w    $w1, $w6, $w4\n"
+        "pckev.w    $w2, $w10, $w8\n"
+        "pckev.w    $w3, $w14, $w12\n"
+
+        "b " GEMMLOWP_LABEL_STORE "f\n"
+
+        GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES ":\n"
+        // Load accumulators from memory.
+        "ld.w       $w16, 0(%[dst_ptr0])\n"
+        "pckev.w    $w0, $w2, $w0\n"
+        "ld.w       $w17, 0(%[dst_ptr1])\n"
+        "pckev.w    $w1, $w6, $w4\n"
+        "ld.w       $w18, 0(%[dst_ptr2])\n"
+        "pckev.w    $w2, $w10, $w8\n"
+        "ld.w       $w19, 0(%[dst_ptr3])\n"
+        "pckev.w    $w3, $w14, $w12\n"
+
+        // Add them to internal accumulators.
+        "addv.w     $w0, $w0, $w16\n"
+        "addv.w     $w1, $w1, $w17\n"
+        "addv.w     $w2, $w2, $w18\n"
+        "addv.w     $w3, $w3, $w19\n"
+
+        GEMMLOWP_LABEL_STORE ":\n"
+        // Store accumulators.
+        "st.w       $w0, 0(%[dst_ptr0])\n"
+        "st.w       $w1, 0(%[dst_ptr1])\n"
+        "st.w       $w2, 0(%[dst_ptr2])\n"
+        "st.w       $w3, 0(%[dst_ptr3])\n"
+        :  // outputs
+        [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+        [run_depth] "+r"(run_depth)
+        :  // inputs
+        [dst_ptr0] "r"(dst_ptr), [dst_ptr1] "r"(dst_ptr + dst_col_stride),
+        [dst_ptr2] "r"(dst_ptr + dst_col_stride * 2),
+        [dst_ptr3] "r"(dst_ptr + dst_col_stride * 3),
+        [start_depth] "r"(start_depth)
+        :  // clobbers
+        "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8",
+        "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17",
+        "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26",
+        "$f27", "$f28", "$f29", "$f30", "$f31");
+#undef GEMMLOWP_LABEL_LOOP
+#undef GEMMLOWP_LABEL_AFTER_LOOP_LAST16
+#undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
+#undef GEMMLOWP_LABEL_STORE
+  }
+};
+
 #undef GEMMLOWP_MIPS_XADDU
 #undef GEMMLOWP_MIPS_XADDIU
 #undef GEMMLOWP_MIPS_XSLL
diff --git a/internal/kernel_neon.h b/internal/kernel_neon.h
index 3cd48f4..9859637 100644
--- a/internal/kernel_neon.h
+++ b/internal/kernel_neon.h
@@ -55,6 +55,7 @@
 #define GEMMLOWP_LABEL_AFTER_LOOP "4"
 
     assert(dst_row_stride == 1);
+    (void)dst_row_stride;
     asm volatile(
         // Overview of register layout:
         //
@@ -308,6 +309,7 @@
     ScopedProfilingLabel label(
         "optimized kernel (NEON 12x4, assuming 12-bit products)");
     assert(dst_row_stride == 1);
+    (void)dst_row_stride;
 
 // See comments above for why we need local numerical labels in our asm.
 #define GEMMLOWP_LOOP_NEON_32_KERNEL_12X4_DEPTH2_ASSUMING_12BIT_PRODUCTS "1"
@@ -678,6 +680,7 @@
            std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
            const std::uint8_t* rhs_ptr, std::size_t start_depth,
            std::size_t run_depth) const override {
+    (void)dst_row_stride;
 #define GEMMLOWP_LABEL_AFTER_LOOP "1"
 #define GEMMLOWP_LABEL_LOOP "2"
 #define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
@@ -921,6 +924,17 @@
   }
 };
 
+// Same as NEON_32bit_GEMM_Int8Operands_LhsNonzero, but uses a side format that
+// requires that user inputs were originally int8. This avoids the uint8->int8
+// conversion in the pack step.
+struct NEON_32bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs
+    : NEON_32bit_GEMM_Int8Operands_LhsNonzero {
+  typedef KernelFormat<
+      KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+      KernelSideFormatInt8Inputs<CellFormat<2, 16, CellOrder::WidthMajor>, 1> >
+      Format;
+};
+
 #endif  // GEMMLOWP_NEON_32
 
 // The kernels here are specifically arm 64bit assembly, not arm 32bit.
@@ -940,6 +954,7 @@
            std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
            const std::uint8_t* rhs_ptr, std::size_t start_depth,
            std::size_t run_depth) const override {
+    (void)dst_row_stride;
 #define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1"
 #define GEMMLOWP_LABEL_LOOP "2"
 #define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
@@ -1261,6 +1276,17 @@
   }
 };
 
+// Same as NEON_32bit_GEMM_Int8Operands_LhsNonzero, but uses a side format that
+// requires that user inputs were originally int8. This avoids the uint8->int8
+// conversion in the pack step.
+struct NEON_64bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs
+    : NEON_64bit_GEMM_Int8Operands_LhsNonzero {
+  typedef KernelFormat<
+      KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+      KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
+      Format;
+};
+
 // Our main GEMM kernel.
 struct NEON_64_Kernel12x8Depth2 : KernelBase {
   typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,
@@ -1274,6 +1300,7 @@
            std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
            const std::uint8_t* rhs_ptr, std::size_t start_depth,
            std::size_t run_depth) const override {
+    (void)dst_row_stride;
     ScopedProfilingLabel label("optimized kernel (NEON 12x8)");
 // See comments above for why we need local numerical labels in our asm.
 #define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1"
@@ -1611,6 +1638,274 @@
   }
 };
 
+#ifdef GEMMLOWP_DOTPROD_KERNEL
+#ifndef __ARM_FEATURE_DOTPROD
+#error This kernel requires ARM dot-product instructions. Enable them by \
+  adding '+dotprod' to a compiler flag, e.g. -march=armv8.2-a+dotprod . \
+  Note that Clang up to version 7 fails to define the corresponding \
+  preprocessor token __ARM_FEATURE_DOTPROD, so you will still have to define \
+  it manually.
+#endif
+// Kernels utilizing the Armv8.2 Dot Product extension.
+//
+// The dot product instructions work by taking 4 consecutive 8-bit depth
+// values from each operand, multiplying the 4 pairs together and
+// accumulating all the results into the corresponding 32-bit accumulator
+// lane.  As such, the operation is identical to a 32-bit instruction (like
+// FMLA used in SGEMM), except that 4 depth values are processed at a time
+// instead of 1.
+
+// Thus, this first kernel is a carbon copy of
+// "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good
+// performance for most processors) below with the opcode (fmla -> udot) and
+// types (float32 -> uint8/uint32) changed.
+//
+// A signed version of this kernel could be produced by replacing "udot"
+// with "sdot" - performance should be identical to this udot kernel.
+struct NEON_64_Kernel12x8Depth4_dotprod : KernelBase {
+  typedef KernelFormat<KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
+                       KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
+      Format;
+
+  const char* Name() const override { return "NEON, 12x8, depth 4, dotprod"; }
+
+  void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, std::size_t dst_col_stride,
+           const std::uint8_t* lhs_ptr, const std::uint8_t* rhs_ptr, std::size_t start_depth,
+           std::size_t depth) const override {
+    (void)dst_row_stride;
+    ScopedProfilingLabel label("optimized kernel (NEON 12x8, depth 4, dotprod)");
+// See comments above for why we need local numerical labels in our asm.
+#define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1"
+#define GEMMLOWP_LABEL_BEFORE_LOOP "2"
+#define GEMMLOWP_LABEL_LOOP "3"
+#define GEMMLOWP_LABEL_AFTER_LOOP "4"
+
+    assert(dst_row_stride == 1);
+    asm volatile(
+        // Multiply dst_col_stride by 4 == sizeof(int32) to use
+        // it as a byte offset below.
+        "lsl %[dst_col_stride], %[dst_col_stride], #2\n"
+
+        "cmp %[start_depth], #0\n"
+        "beq " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n"
+
+        // Load accumulators
+        "mov x1, %[dst_ptr]\n"
+        "mov x0, x1\n"
+        "ld1 {v8.16b}, [x0], #16\n"
+        "ld1 {v16.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v24.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v9.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v17.16b}, [x0], #16\n"
+        "ld1 {v25.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v10.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v18.16b}, [x0], #16\n"
+        "ld1 {v26.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v11.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v19.16b}, [x0], #16\n"
+        "ld1 {v27.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v12.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v20.16b}, [x0], #16\n"
+        "ld1 {v28.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v13.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v21.16b}, [x0], #16\n"
+        "ld1 {v29.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v14.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v22.16b}, [x0], #16\n"
+        "ld1 {v30.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v15.16b}, [x0], #16\n"
+        "ld1 {v23.16b}, [x0], #16\n"
+        "ld1 {v31.16b}, [x0]\n"
+
+        "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n"
+
+        GEMMLOWP_LABEL_CLEAR_ACCUMULATORS ":\n"
+
+        // Clear accumulator registers (see layout below)
+        "dup v8.4s, wzr\n"
+        "dup v9.4s, wzr\n"
+        "dup v10.4s, wzr\n"
+        "dup v11.4s, wzr\n"
+        "dup v12.4s, wzr\n"
+        "dup v13.4s, wzr\n"
+        "dup v14.4s, wzr\n"
+        "dup v15.4s, wzr\n"
+        "dup v16.4s, wzr\n"
+        "dup v17.4s, wzr\n"
+        "dup v18.4s, wzr\n"
+        "dup v19.4s, wzr\n"
+        "dup v20.4s, wzr\n"
+        "dup v21.4s, wzr\n"
+        "dup v22.4s, wzr\n"
+        "dup v23.4s, wzr\n"
+        "dup v24.4s, wzr\n"
+        "dup v25.4s, wzr\n"
+        "dup v26.4s, wzr\n"
+        "dup v27.4s, wzr\n"
+        "dup v28.4s, wzr\n"
+        "dup v29.4s, wzr\n"
+        "dup v30.4s, wzr\n"
+        "dup v31.4s, wzr\n"
+
+        GEMMLOWP_LABEL_BEFORE_LOOP ":\n"
+
+        "subs %w[depth], %w[depth], #4\n"
+
+        // The start of the loop assumes first Rhs cell is already loaded, so
+        // do it here for first iteration.
+        "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
+
+        // And the same for the first Lhs cell.
+        "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+
+        "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n"
+
+        GEMMLOWP_LABEL_LOOP ":\n"
+
+        // Start the MACs at the head of the loop - 1st cell from each side
+        // already loaded.
+        ".word 0x6f80e048  // udot v8.4s, v2.16b, v0.4b[0]\n"
+        ".word 0x6fa0e049  // udot v9.4s, v2.16b, v0.4b[1]\n"
+        "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"  // Load second Rhs cell.
+        ".word 0x6f80e84a  // udot v10.4s, v2.16b, v0.4b[2]\n"
+        ".word 0x6fa0e84b  // udot v11.4s, v2.16b, v0.4b[3]\n"
+        "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"  // Load second Lhs cell.
+        ".word 0x6f81e04c  // udot v12.4s, v2.16b, v1.4b[0]\n"
+        ".word 0x6fa1e04d  // udot v13.4s, v2.16b, v1.4b[1]\n"
+        "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"  // Load third Lhs cell.
+        ".word 0x6f81e84e  // udot v14.4s, v2.16b, v1.4b[2]\n"
+        ".word 0x6fa1e84f  // udot v15.4s, v2.16b, v1.4b[3]\n"
+        "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"  // Done with first Lhs cell - load
+        // for the next iteration early.
+        ".word 0x6f80e070  // udot v16.4s, v3.16b, v0.4b[0]\n"
+        ".word 0x6fa0e071  // udot v17.4s, v3.16b, v0.4b[1]\n"
+        ".word 0x6f80e872  // udot v18.4s, v3.16b, v0.4b[2]\n"
+        ".word 0x6fa0e873  // udot v19.4s, v3.16b, v0.4b[3]\n"
+        ".word 0x6f81e074  // udot v20.4s, v3.16b, v1.4b[0]\n"
+        ".word 0x6fa1e075  // udot v21.4s, v3.16b, v1.4b[1]\n"
+        ".word 0x6f81e876  // udot v22.4s, v3.16b, v1.4b[2]\n"
+        ".word 0x6fa1e877  // udot v23.4s, v3.16b, v1.4b[3]\n"
+        ".word 0x6f80e098  // udot v24.4s, v4.16b, v0.4b[0]\n"
+        ".word 0x6fa0e099  // udot v25.4s, v4.16b, v0.4b[1]\n"
+        ".word 0x6f80e89a  // udot v26.4s, v4.16b, v0.4b[2]\n"
+        ".word 0x6fa0e89b  // udot v27.4s, v4.16b, v0.4b[3]\n"
+        "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"  // Done with the first Rhs cell -
+        // load for the next iteration early.
+        ".word 0x6f81e09c  // udot v28.4s, v4.16b, v1.4b[0]\n"
+        ".word 0x6fa1e09d  // udot v29.4s, v4.16b, v1.4b[1]\n"
+
+        // Loop.  Decrement loop index (depth) by 4 as udot processes 4
+        // depth values.
+        "subs %w[depth], %w[depth], #4\n"
+        ".word 0x6f81e89e  // udot v30.4s, v4.16b, v1.4b[2]\n"
+        ".word 0x6fa1e89f  // udot v31.4s, v4.16b, v1.4b[3]\n"
+
+        "bne " GEMMLOWP_LABEL_LOOP "b\n"
+
+        GEMMLOWP_LABEL_AFTER_LOOP ":\n"
+
+        // Final iteration. v0 and v2 were already loaded, don't load
+        // them again, don't read past the end of buffers.
+        ".word 0x6f80e048  // udot v8.4s, v2.16b, v0.4b[0]\n"
+        ".word 0x6fa0e049  // udot v9.4s, v2.16b, v0.4b[1]\n"
+        "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"  // Load second Rhs cell.
+        ".word 0x6f80e84a  // udot v10.4s, v2.16b, v0.4b[2]\n"
+        ".word 0x6fa0e84b  // udot v11.4s, v2.16b, v0.4b[3]\n"
+        "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"  // Load second Lhs cell.
+        ".word 0x6f81e04c  // udot v12.4s, v2.16b, v1.4b[0]\n"
+        ".word 0x6fa1e04d  // udot v13.4s, v2.16b, v1.4b[1]\n"
+        "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"  // Load third Lhs cell.
+        ".word 0x6f81e84e  // udot v14.4s, v2.16b, v1.4b[2]\n"
+        ".word 0x6fa1e84f  // udot v15.4s, v2.16b, v1.4b[3]\n"
+        ".word 0x6f80e070  // udot v16.4s, v3.16b, v0.4b[0]\n"
+        ".word 0x6fa0e071  // udot v17.4s, v3.16b, v0.4b[1]\n"
+        ".word 0x6f80e872  // udot v18.4s, v3.16b, v0.4b[2]\n"
+        ".word 0x6fa0e873  // udot v19.4s, v3.16b, v0.4b[3]\n"
+        ".word 0x6f81e074  // udot v20.4s, v3.16b, v1.4b[0]\n"
+        ".word 0x6fa1e075  // udot v21.4s, v3.16b, v1.4b[1]\n"
+        ".word 0x6f81e876  // udot v22.4s, v3.16b, v1.4b[2]\n"
+        ".word 0x6fa1e877  // udot v23.4s, v3.16b, v1.4b[3]\n"
+        ".word 0x6f80e098  // udot v24.4s, v4.16b, v0.4b[0]\n"
+        ".word 0x6fa0e099  // udot v25.4s, v4.16b, v0.4b[1]\n"
+        ".word 0x6f80e89a  // udot v26.4s, v4.16b, v0.4b[2]\n"
+        ".word 0x6fa0e89b  // udot v27.4s, v4.16b, v0.4b[3]\n"
+        ".word 0x6f81e09c  // udot v28.4s, v4.16b, v1.4b[0]\n"
+        ".word 0x6fa1e09d  // udot v29.4s, v4.16b, v1.4b[1]\n"
+
+        // Loop.  Decrement loop index (depth) by 4 as udot processes 4
+        // depth values.
+        "subs %w[depth], %w[depth], #4\n"
+        ".word 0x6f81e89e  // udot v30.4s, v4.16b, v1.4b[2]\n"
+        ".word 0x6fa1e89f  // udot v31.4s, v4.16b, v1.4b[3]\n"
+
+        // Store accumulators
+        "mov x1, %[dst_ptr]\n"
+        "mov x0, x1\n"
+        "st1 {v8.16b}, [x0], #16\n"
+        "st1 {v16.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v24.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v9.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v17.16b}, [x0], #16\n"
+        "st1 {v25.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v10.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v18.16b}, [x0], #16\n"
+        "st1 {v26.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v11.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v19.16b}, [x0], #16\n"
+        "st1 {v27.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v12.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v20.16b}, [x0], #16\n"
+        "st1 {v28.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v13.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v21.16b}, [x0], #16\n"
+        "st1 {v29.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v14.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v22.16b}, [x0], #16\n"
+        "st1 {v30.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v15.16b}, [x0], #16\n"
+        "st1 {v23.16b}, [x0], #16\n"
+        "st1 {v31.16b}, [x0]\n"
+        :  // outputs
+        [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+        [depth] "+r"(depth)
+        :  // inputs
+        [dst_ptr] "r"(dst_ptr), [dst_col_stride] "r"(dst_col_stride), [start_depth] "r"(start_depth)
+        :  // clobbers
+        "cc", "memory", "x0", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+        "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
+        "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+  }
+};
+#endif  // GEMMLOWP_DOTPROD_KERNEL
+
 #endif  // GEMMLOWP_NEON_64
 
 }  // namespace gemmlowp
diff --git a/internal/kernel_sse.h b/internal/kernel_sse.h
index b879fd7..ba7959b 100644
--- a/internal/kernel_sse.h
+++ b/internal/kernel_sse.h
@@ -43,6 +43,7 @@
            std::size_t run_depth) const override {
     ScopedProfilingLabel label("optimized kernel");
     assert(dst_row_stride == 1);
+    (void)dst_row_stride;
     std::int32_t run_depth_cells = run_depth / Format::kDepth;
     /* Main loop */
 
@@ -217,6 +218,7 @@
            std::size_t run_depth) const override {
     ScopedProfilingLabel label("optimized kernel");
     assert(dst_row_stride == 1);
+    (void)dst_row_stride;
     const std::int64_t run_depth_cells = run_depth / Format::kDepth;
     const std::int64_t dst_col_stride_q = dst_col_stride;
 
diff --git a/internal/multi_thread_gemm.h b/internal/multi_thread_gemm.h
index 791402f..97183e7 100644
--- a/internal/multi_thread_gemm.h
+++ b/internal/multi_thread_gemm.h
@@ -19,23 +19,43 @@
 #ifndef GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
 #define GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
 
+#include <atomic>  // NOLINT
+#include <chrono>  // NOLINT
+#include <thread>  // NOLINT
 #include <vector>
 
 #include "single_thread_gemm.h"
 
 namespace gemmlowp {
 
-// On X86 and ARM platforms we enable a busy-wait spinlock before waiting on a
-// pthread conditional variable. In order to implement that correctly we need
-// to put some explicit memory load/store barriers.
+// This value was empirically derived on an end-to-end application benchmark.
+// That this number of cycles means that we may be sleeping substantially longer
+// than a scheduler timeslice's duration is not necessarily surprising. The
+// idea is to pick up quickly new work after having finished the previous
+// workload. When it's new work within the same GEMM as the previous work, the
+// time interval that we might be busy-waiting is very small, so for that
+// purpose it would be more than enough to sleep for 1 million cycles.
+// That is all what we would observe on a GEMM benchmark. However, in a real
+// application, after having finished a GEMM, we might do unrelated work for
+// a little while, then start on a new GEMM. Think of a neural network
+// application performing inference, where many but not all layers are
+// implemented by a GEMM. In such cases, our worker threads might be idle for
+// longer periods of time before having work again. If we let them passively
+// wait, on a mobile device, the CPU scheduler might aggressively clock down
+// or even turn off the CPU cores that they were running on. That would result
+// in a long delay the next time these need to be turned back on for the next
+// GEMM. So we need to strike a balance that reflects typical time intervals
+// between consecutive GEMM invokations, not just intra-GEMM considerations.
+// Of course, we need to balance keeping CPUs spinning longer to resume work
+// faster, versus passively waiting to conserve power.
+const int kMaxBusyWaitNOPs = 4 * 1000 * 1000;
+
+// On X86 and ARM platforms we may use NOP instructions to know how long we
+// are busy-waiting.
 
 #if defined(GEMMLOWP_ALLOW_INLINE_ASM) && !defined(GEMMLOWP_NO_BUSYWAIT) && \
     (defined(GEMMLOWP_ARM) || defined(GEMMLOWP_X86))
 
-#define GEMMLOWP_USE_BUSYWAIT
-
-const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
-
 #define GEMMLOWP_NOP "nop\n"
 
 #define GEMMLOWP_STRING_CONCAT_4(X) X X X X
@@ -43,46 +63,26 @@
 #define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4)
 #define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16)
 
-inline int Do256NOPs() {
+inline int DoSomeNOPs() {
   asm volatile(GEMMLOWP_NOP64);
   return 64;
 }
 
 #undef GEMMLOWP_STRING_CONCAT_4
-#undef GEMMLOWP_NOP256
 #undef GEMMLOWP_NOP64
 #undef GEMMLOWP_NOP16
 #undef GEMMLOWP_NOP4
 #undef GEMMLOWP_NOP
 
-inline void WriteBarrier() {
-#if defined(_MSC_VER)
-  MemoryBarrier();
-#elif defined(GEMMLOWP_ARM_32)
-  asm volatile("" ::: "memory");
-#elif defined(GEMMLOWP_ARM_64)
-  asm volatile("dmb ishst" ::: "memory");
-#elif defined(GEMMLOWP_X86)
-  asm volatile("sfence" ::: "memory");
-#else
-#error "Unsupported architecture for WriteBarrier."
-#endif
-}
+#else  // May not use asm NOP.
 
-inline void ReadBarrier() {
-#if defined(_MSC_VER)
-  MemoryBarrier();
-#elif defined(GEMMLOWP_ARM_32)
-  asm volatile("" ::: "memory");
-#elif defined(GEMMLOWP_ARM_64)
-  asm volatile("dmb ishld" ::: "memory");
-#elif defined(GEMMLOWP_X86)
-  asm volatile("lfence" ::: "memory");
-#else
-#error "Unsupported architecture for ReadBarrier."
-#endif
+// If we can't use NOPs, let's use a non-inline function call as a basic
+// thing that has some vaguely known, nonzero cost.
+GEMMLOWP_NOINLINE
+inline int DoSomeNOPs() {
+  // Pretend that calling an empty function takes as long as 16 NOPs...
+  return 16;
 }
-
 #endif
 
 // Waits until *var != initial_value.
@@ -108,37 +108,29 @@
 // so as to avoid permanently spinning.
 //
 template <typename T>
-T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond,
-                        pthread_mutex_t* mutex) {
-#ifdef GEMMLOWP_USE_BUSYWAIT
-  // If we are on a platform that supports it, spin for some time.
-  {
-    int nops = 0;
-    // First, trivial case where the variable already changed value.
-    T new_value = *var;
+T WaitForVariableChange(std::atomic<T>* var, T initial_value,
+                        pthread_cond_t* cond, pthread_mutex_t* mutex) {
+  // First, trivial case where the variable already changed value.
+  T new_value = var->load(std::memory_order_acquire);
+  if (new_value != initial_value) {
+    return new_value;
+  }
+  // Then try busy-waiting.
+  int nops = 0;
+  while (nops < kMaxBusyWaitNOPs) {
+    nops += DoSomeNOPs();
+    new_value = var->load(std::memory_order_acquire);
     if (new_value != initial_value) {
-      ReadBarrier();
       return new_value;
     }
-    // Then try busy-waiting.
-    while (nops < kMaxBusyWaitNOPs) {
-      nops += Do256NOPs();
-      new_value = *var;
-      if (new_value != initial_value) {
-        ReadBarrier();
-        return new_value;
-      }
-    }
   }
-#endif
 
   // Finally, do real passive waiting.
   pthread_mutex_lock(mutex);
-  T new_value = *var;
-  if (new_value == initial_value) {
+  new_value = var->load(std::memory_order_acquire);
+  while (new_value == initial_value) {
     pthread_cond_wait(cond, mutex);
-    new_value = *var;
-    assert(new_value != initial_value);
+    new_value = var->load(std::memory_order_acquire);
   }
   pthread_mutex_unlock(mutex);
   return new_value;
@@ -147,73 +139,74 @@
 // A BlockingCounter lets one thread to wait for N events to occur.
 // This is how the master thread waits for all the worker threads
 // to have finished working.
+// The waiting is done using a naive spinlock waiting for the atomic
+// count_ to hit the value 0. This is acceptable because in our usage
+// pattern, BlockingCounter is used only to synchronize threads after
+// short-lived tasks (performing parts of the same GEMM). It is not used
+// for synchronizing longer waits (resuming work on the next GEMM).
 class BlockingCounter {
  public:
-  BlockingCounter() : count_(0), initial_count_(0) {
-    pthread_cond_init(&cond_, nullptr);
-    pthread_mutex_init(&mutex_, nullptr);
-  }
-
-  ~BlockingCounter() {
-    pthread_cond_destroy(&cond_);
-    pthread_mutex_destroy(&mutex_);
-  }
+  BlockingCounter() : count_(0) {}
 
   // Sets/resets the counter; initial_count is the number of
   // decrementing events that the Wait() call will be waiting for.
   void Reset(std::size_t initial_count) {
-    pthread_mutex_lock(&mutex_);
-    assert(count_ == 0);
-    initial_count_ = initial_count;
-    count_ = initial_count_;
-    pthread_mutex_unlock(&mutex_);
+    std::size_t old_count_value = count_.load(std::memory_order_relaxed);
+    assert(old_count_value == 0);
+    (void)old_count_value;
+    count_.store(initial_count, std::memory_order_release);
   }
 
   // Decrements the counter; if the counter hits zero, signals
-  // the thread that was waiting for that, and returns true.
+  // the threads that were waiting for that, and returns true.
   // Otherwise (if the decremented count is still nonzero),
   // returns false.
   bool DecrementCount() {
-    pthread_mutex_lock(&mutex_);
-    assert(count_ > 0);
-    count_--;
-#ifdef GEMMLOWP_USE_BUSYWAIT
-    WriteBarrier();
-#endif
-    if (count_ == 0) {
-      pthread_cond_signal(&cond_);
-    }
-    bool retval = count_ == 0;
-    pthread_mutex_unlock(&mutex_);
-    return retval;
+    std::size_t old_count_value =
+        count_.fetch_sub(1, std::memory_order_acq_rel);
+    assert(old_count_value > 0);
+    std::size_t count_value = old_count_value - 1;
+    return count_value == 0;
   }
 
   // Waits for the N other threads (N having been set by Reset())
   // to hit the BlockingCounter.
   void Wait() {
     ScopedProfilingLabel label("BlockingCounter::Wait");
-    while (count_) {
-#ifdef GEMMLOWP_USE_BUSYWAIT
-      ReadBarrier();
-#else
-      // This is likely unnecessary, but is kept to ensure regressions are not
-      // introduced.
-#ifndef _WIN32
-      asm volatile("" ::: "memory");
-#endif
-#endif
-      const std::size_t count_value = count_;
-      if (count_value) {
-        WaitForVariableChange(&count_, count_value, &cond_, &mutex_);
+    // Busy-wait until the count value is 0.
+    int nops = 0;
+    while (count_.load(std::memory_order_acquire)) {
+      nops += DoSomeNOPs();
+      if (nops > kMaxBusyWaitNOPs) {
+        nops = 0;
+        // If we are unlucky, the blocking thread (that calls DecrementCount)
+        // and the blocked thread (here, calling Wait) may be scheduled on
+        // the same CPU, so the busy-waiting of the present thread may prevent
+        // the blocking thread from resuming and unblocking.
+        // If we are even unluckier, the priorities of the present thread
+        // might be higher than that of the blocking thread, so just yielding
+        // wouldn't allow the blocking thread to resume. So we sleep for
+        // a substantial amount of time in that case. Notice that we only
+        // do so after having busy-waited for kMaxBusyWaitNOPs, which is
+        // typically several milliseconds, so sleeping 1 more millisecond
+        // isn't terrible at that point.
+        //
+        // How this is mitigated in practice:
+        // In practice, it is well known that the application should be
+        // conservative in choosing how many threads to tell gemmlowp to use,
+        // as it's hard to know how many CPU cores it will get to run on,
+        // on typical mobile devices.
+        // It seems impossible for gemmlowp to make this choice automatically,
+        // which is why gemmlowp's default is to use only 1 thread, and
+        // applications may override that if they know that they can count on
+        // using more than that.
+        std::this_thread::sleep_for(std::chrono::milliseconds(1));
       }
     }
   }
 
  private:
-  pthread_cond_t cond_;
-  pthread_mutex_t mutex_;
-  std::size_t count_;
-  std::size_t initial_count_;
+  std::atomic<std::size_t> count_;
 };
 
 // A workload for a worker.
@@ -253,11 +246,15 @@
   // Changes State; may be called from either the worker thread
   // or the master thread; however, not all state transitions are legal,
   // which is guarded by assertions.
-  void ChangeState(State new_state) {
+  //
+  // The Task argument is to be used only with new_state==HasWork.
+  // It specifies the Task being handed to this Worker.
+  void ChangeState(State new_state, Task* task = nullptr) {
     ScopedProfilingLabel label("Worker::ChangeState");
     pthread_mutex_lock(&state_mutex_);
-    assert(new_state != state_);
-    switch (state_) {
+    State old_state = state_.load(std::memory_order_relaxed);
+    assert(old_state != new_state);
+    switch (old_state) {
       case State::ThreadStartup:
         assert(new_state == State::Ready);
         break;
@@ -272,18 +269,33 @@
       default:
         abort();
     }
-    state_ = new_state;
-    pthread_cond_signal(&state_cond_);
-    if (state_ == State::Ready) {
+    switch (new_state) {
+      case State::Ready:
+        if (task_) {
+          // Doing work is part of reverting to 'ready' state.
+          task_->Run();
+          task_ = nullptr;
+        }
+        break;
+      case State::HasWork:
+        assert(!task_);
+        task->local_allocator = &local_allocator_;
+        task_ = task;
+        break;
+      default:
+        break;
+    }
+    state_.store(new_state, std::memory_order_relaxed);
+    pthread_cond_broadcast(&state_cond_);
+    pthread_mutex_unlock(&state_mutex_);
+    if (new_state == State::Ready) {
       counter_to_decrement_when_ready_->DecrementCount();
     }
-    pthread_mutex_unlock(&state_mutex_);
   }
 
   // Thread entry point.
   void ThreadFunc() {
     ScopedProfilingLabel label("Worker::ThreadFunc");
-    RegisterCurrentThreadForProfiling();
 
     ChangeState(State::Ready);
 
@@ -299,9 +311,6 @@
       switch (state_to_act_upon) {
         case State::HasWork:
           // Got work to do! So do it, and then revert to 'Ready' state.
-          assert(task_);
-          task_->Run();
-          task_ = nullptr;
           ChangeState(State::Ready);
           break;
         case State::ExitAsSoonAsPossible:
@@ -318,17 +327,7 @@
   }
 
   // Called by the master thead to give this worker work to do.
-  // It is only legal to call this if the worker
-  void StartWork(Task* task) {
-    assert(!task_);
-    task->local_allocator = &local_allocator_;
-    task_ = task;
-#ifdef GEMMLOWP_USE_BUSYWAIT
-    WriteBarrier();
-#endif
-    assert(state_ == State::Ready);
-    ChangeState(State::HasWork);
-  }
+  void StartWork(Task* task) { ChangeState(State::HasWork, task); }
 
  private:
   // The underlying thread.
@@ -342,7 +341,10 @@
   pthread_mutex_t state_mutex_;
 
   // The state enum tells if we're currently working, waiting for work, etc.
-  State state_;
+  // Its concurrent accesses by the worker and main threads are guarded by
+  // state_mutex_, and can thus use memory_order_relaxed. This still needs
+  // to be a std::atomic because we use WaitForVariableChange.
+  std::atomic<State> state_;
 
   // Each thread had a local allocator so they can allocate temporary
   // buffers without blocking each other.
@@ -359,9 +361,7 @@
 // waits for all of them to finish.
 //
 // See MultiThreadGemmContextBase for how other WorkersPool implementations can
-// be used. Note that in those implementations, StartWorker can be free to
-// ignore the <index> value; that is, the caller of WorkersPool does not rely on
-// <index> to order tasks with equal <index>.
+// be used.
 class WorkersPool {
  public:
   WorkersPool() {}
@@ -372,18 +372,41 @@
     }
   }
 
-  void Execute(const std::vector<Task*>& tasks) {
-    assert(tasks.size() >= 1);
+  // Just executes the tasks. Does not destroy them. Similar to
+  // ruy::ThreadPool::Execute.
+  template <typename TaskType>
+  void Execute(int tasks_count, TaskType* tasks) {
+    assert(tasks_count >= 1);
     // One of the tasks will be run on the current thread.
-    std::size_t workers_count = tasks.size() - 1;
+    std::size_t workers_count = tasks_count - 1;
     CreateWorkers(workers_count);
     assert(workers_count <= workers_.size());
     counter_to_decrement_when_ready_.Reset(workers_count);
-    int n = 0;
-    std::for_each(tasks.begin(), --tasks.end(),
-                  [this, &n](Task* task) { workers_[n++]->StartWork(task); });
+    for (std::size_t i = 0; i < tasks_count - 1; i++) {
+      workers_[i]->StartWork(&tasks[i]);
+    }
     // Execute the remaining workload immediately on the current thread.
-    Task* task = tasks.back();
+    Task* task = &tasks[tasks_count - 1];
+    task->local_allocator = &main_thread_task_allocator_;
+    task->Run();
+    // Wait for the workers submitted above to finish.
+    counter_to_decrement_when_ready_.Wait();
+  }
+
+  // Legacy: executes the tasks and destroys them
+  void LegacyExecuteAndDestroyTasks(const std::vector<Task*>& tasks) {
+    std::size_t tasks_count = tasks.size();
+    assert(tasks_count >= 1);
+    // One of the tasks will be run on the current thread.
+    std::size_t workers_count = tasks_count - 1;
+    CreateWorkers(workers_count);
+    assert(workers_count <= workers_.size());
+    counter_to_decrement_when_ready_.Reset(workers_count);
+    for (int i = 0; i < tasks_count - 1; i++) {
+      workers_[i]->StartWork(tasks[i]);
+    }
+    // Execute the remaining workload immediately on the current thread.
+    Task* task = tasks[tasks_count - 1];
     task->local_allocator = &main_thread_task_allocator_;
     task->Run();
     // Wait for the workers submitted above to finish.
@@ -393,6 +416,11 @@
     std::for_each(tasks.begin(), tasks.end(), [](Task* task) { delete task; });
   }
 
+  // Legacy old name of LegacyExecuteAndDestroyTasks
+  void Execute(const std::vector<Task*>& tasks) {
+    LegacyExecuteAndDestroyTasks(tasks);
+  }
+
  private:
   // Ensures that the pool has at least the given count of workers.
   // If any new worker has to be created, this function waits for it to
diff --git a/internal/output.h b/internal/output.h
index dcfe2b5..92bf7b9 100644
--- a/internal/output.h
+++ b/internal/output.h
@@ -22,6 +22,7 @@
 #include <cmath>
 #include <tuple>
 #include <type_traits>
+#include <typeinfo>
 
 #include "../fixedpoint/fixedpoint.h"
 #include "../public/output_stages.h"
@@ -179,7 +180,47 @@
   int right_shift;
 };
 
-// Implementation of OutputStageSaturatingCastToUint8 for scalar data
+template <int Rows, int Cols, VectorShape Shape>
+struct OutputStageEvalImpl<
+    OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>,
+    RegisterBlock<std::int32_t, Rows, Cols>> {
+  typedef RegisterBlock<std::int32_t, Rows, Cols> InputType;
+  typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType;
+
+  typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> OutputStage;
+
+  OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
+
+  OutputType Eval(InputType input, int row, int col) const {
+    OutputType output;
+    const int pos = Shape == VectorShape::Row ? col : row;
+    using RegisterType = typename InputType::RegisterType;
+    const RegisterType result_offset_after_shift =
+        Dup<RegisterType>(output_stage.result_offset_after_shift);
+    auto left_shift =
+        LoadForBroadcasting<InputType>(output_stage.result_exponent, pos);
+    auto right_shift =
+        LoadForBroadcasting<InputType>(output_stage.result_exponent, pos);
+    const auto result_fixedpoint_multiplier = LoadForBroadcasting<InputType>(
+        output_stage.result_fixedpoint_multiplier, pos);
+    for (int i = 0; i < decltype(left_shift)::kRegisterCount; i++) {
+      left_shift.buf.reg[i] = Max(left_shift.buf.reg[i], 0);
+      right_shift.buf.reg[i] = Max(-right_shift.buf.reg[i], 0);
+    }
+    const auto mulhigh_val = BroadcastSaturatingRoundingDoublingHighMul(
+        BroadcastShiftLeft(input, left_shift), result_fixedpoint_multiplier);
+    const auto rdpot_val =
+        BroadcastRoundingDivideByPOT(mulhigh_val, right_shift);
+    for (int i = 0; i < InputType::kRegisterCount; i++) {
+      output.buf.reg[i] = Add(rdpot_val.buf.reg[i], result_offset_after_shift);
+    }
+    return output;
+  }
+
+  const OutputStage& output_stage;
+};
+
+// Implementation of OutputStageSaturatingCastToUint8 for scalar data.
 template <int Size>
 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
                                  RegisterBuffer<std::int32_t, Size>> {
@@ -202,7 +243,30 @@
   }
 };
 
-// Implementation of OutputStageSaturatingCastToInt16 for scalar data
+// Implementation of OutputStageSaturatingCastToInt8 for scalar data.
+template <int Size>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
+                                 RegisterBuffer<std::int32_t, Size>> {
+  typedef RegisterBuffer<std::int32_t, Size> InputType;
+  typedef RegisterBuffer<std::int8_t, Size> OutputType;
+  static_assert(InputType::kRegisterLanes == 1,
+                "This path is only for scalar values");
+
+  typedef OutputStageSaturatingCastToInt8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    for (int i = 0; i < InputType::kRegisterCount; i++) {
+      std::int32_t data = input.reg[i];
+      output.reg[i] = data > 127 ? 127 : data < -128 ? -128 : data;
+    }
+    return output;
+  }
+};
+
+// Implementation of OutputStageSaturatingCastToInt16 for scalar data.
 template <int Size>
 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
                                  RegisterBuffer<std::int32_t, Size>> {
@@ -225,6 +289,28 @@
   }
 };
 
+// Implementation of OutputStageTruncatingCastToUint8 for scalar data
+template <int Size>
+struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
+                                 RegisterBuffer<std::int32_t, Size>> {
+  typedef RegisterBuffer<std::int32_t, Size> InputType;
+  typedef RegisterBuffer<std::uint8_t, Size> OutputType;
+  static_assert(InputType::kRegisterLanes == 1,
+                "This path is only for scalar values");
+
+  typedef OutputStageTruncatingCastToUint8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    for (int i = 0; i < InputType::kRegisterCount; i++) {
+      output.reg[i] = input.reg[i];
+    }
+    return output;
+  }
+};
+
 template <int Rows, int Cols, typename VectorType>
 struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>,
                            RegisterBlock<std::int32_t, Rows, Cols>> {
@@ -452,7 +538,7 @@
   OutputPipelineExecutor(const OutputPipelineType& output_pipeline)
       : output_pipeline_eval_impl_(output_pipeline) {}
 
-  // RunOutputPipeline is the entry point into the output pipeline evaluation
+  // Execute is the entry point into the output pipeline evaluation
   // code. It should be the only thing that unpack code calls. It takes the
   // result
   // of the unpack stage and stores it into the destination matrix.
diff --git a/internal/output_avx.h b/internal/output_avx.h
new file mode 100644
index 0000000..b8f94fb
--- /dev/null
+++ b/internal/output_avx.h
@@ -0,0 +1,19 @@
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// output_avx.h: optimized AVX 2 specializations of the templates in output.h.
+
+#ifndef GEMMLOWP_INTERNAL_OUTPUT_AVX_H_
+#define GEMMLOWP_INTERNAL_OUTPUT_AVX_H_
+
+#endif  // GEMMLOWP_INTERNAL_OUTPUT_AVX_H_
diff --git a/internal/output_msa.h b/internal/output_msa.h
index 4c8eb5d..0540bb3 100644
--- a/internal/output_msa.h
+++ b/internal/output_msa.h
@@ -38,18 +38,14 @@
     // Signed saturate each 32-bit element to 9 bits
     // (this takes full care of non-negative elements).
     v4i32 tmp = __builtin_msa_sat_s_w(input.reg[0], 8);
+    // Zero out negative elements.
+    tmp = __builtin_msa_maxi_s_w(tmp, 0);
     // Pack every 32-bit element into 16 bits.
     tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
         reinterpret_cast<v8i16>(tmp), reinterpret_cast<v8i16>(tmp)));
-    // Detect negative elements with arithmetic shift right (we
-    // get a 16-bit mask of all zeroes or all ones for every element).
-    v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp), 15);
-    // Zero out negative elements.
-    signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
-        reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp), 0));
     // Pack every element into 8 bits.
     tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
-        reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
+        reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp)));
     // Return 4 uint8_t elements as uint32_t.
     output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
     return output;
@@ -76,15 +72,12 @@
     // combining all 8 elements into one vector.
     tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
         reinterpret_cast<v8i16>(tmp_hi), reinterpret_cast<v8i16>(tmp_lo)));
-    // Detect negative elements with arithmetic shift right (we
-    // get a 16-bit mask of all zeroes or all ones for every element).
-    v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp_lo), 15);
     // Zero out negative elements.
-    signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
-        reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp_lo), 0));
+    tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h(
+        reinterpret_cast<v8i16>(tmp_lo), 0));
     // Pack every element into 8 bits.
     tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
-        reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
+        reinterpret_cast<v16i8>(tmp_lo), reinterpret_cast<v16i8>(tmp_lo)));
     // Return 8 uint8_t elements as 2 uint32_t's.
     output.reg[0] = __builtin_msa_copy_s_w(tmp_lo, 0);
     output.reg[1] = __builtin_msa_copy_s_w(tmp_lo, 1);
@@ -102,15 +95,13 @@
         reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)));      \
     tmp2 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(                    \
         reinterpret_cast<v8i16>(tmp3), reinterpret_cast<v8i16>(tmp2)));      \
-    v8i16 signs0 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp0), 15);  \
-    v8i16 signs1 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp2), 15);  \
-    signs0 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(                  \
-        reinterpret_cast<v16u8>(signs0), reinterpret_cast<v16u8>(tmp0), 0)); \
-    signs1 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(                  \
-        reinterpret_cast<v16u8>(signs1), reinterpret_cast<v16u8>(tmp2), 0)); \
-    signs0 = reinterpret_cast<v8i16>(__builtin_msa_pckev_b(                  \
-        reinterpret_cast<v16i8>(signs1), reinterpret_cast<v16i8>(signs0)));  \
-    out = reinterpret_cast<v16i8>(signs0);                                   \
+    tmp0 = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h(                   \
+        reinterpret_cast<v8i16>(tmp0), 0));                                  \
+    tmp2 = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h(                   \
+        reinterpret_cast<v8i16>(tmp2), 0));                                  \
+    tmp0 = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(                    \
+        reinterpret_cast<v16i8>(tmp2), reinterpret_cast<v16i8>(tmp0)));      \
+    out = reinterpret_cast<v16i8>(tmp0);                                     \
   }
 
 template <>
@@ -166,8 +157,8 @@
   OutputType Eval(InputType input) const {
     OutputType output;
     // Signed saturate each 32-bit element to 16 bits.
-    v8i16 tmp = reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(
-        input.reg[0], 15));
+    v8i16 tmp =
+        reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(input.reg[0], 15));
     output.reg[0] = __builtin_msa_copy_s_h(tmp, 0);
     output.reg[1] = __builtin_msa_copy_s_h(tmp, 2);
     output.reg[2] = __builtin_msa_copy_s_h(tmp, 4);
@@ -176,12 +167,12 @@
   }
 };
 
-#define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1)                         \
-  {                                                                    \
-    v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15);                       \
-    v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15);                       \
-    out = __builtin_msa_pckev_h(                                       \
-        reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)); \
+#define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1)                  \
+  {                                                             \
+    v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15);                \
+    v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15);                \
+    out = __builtin_msa_pckev_h(reinterpret_cast<v8i16>(tmp1),  \
+                                reinterpret_cast<v8i16>(tmp0)); \
   }
 
 template <>
@@ -241,6 +232,117 @@
 
 #undef GEMMLOWP_MIPS_SAT_I16_8
 
+template <>
+struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
+                                 RegBufferInt32<4>> {
+  typedef RegBufferInt32<4> InputType;
+  typedef RegBufferUint8<4> OutputType;
+
+  typedef OutputStageTruncatingCastToUint8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    // Pack every 32-bit element into 16 bits.
+    v4i32 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
+        reinterpret_cast<v8i16>(input.reg[0]),
+        reinterpret_cast<v8i16>(input.reg[0])));
+    // Pack every element into 8 bits.
+    tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
+        reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp)));
+    // Return 4 uint8_t elements as uint32_t.
+    output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
+    return output;
+  }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
+                                 RegBufferInt32<8>> {
+  typedef RegBufferInt32<8> InputType;
+  typedef RegBufferUint8<8> OutputType;
+
+  typedef OutputStageTruncatingCastToUint8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    // Pack every 32-bit element into 16 bits.
+    v4i32 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
+        reinterpret_cast<v8i16>(input.reg[1]),
+        reinterpret_cast<v8i16>(input.reg[0])));
+    // Pack every element into 8 bits.
+    tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
+        reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp)));
+    // Return 8 uint8_t elements as 2 uint32_t's.
+    output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
+    output.reg[1] = __builtin_msa_copy_s_w(tmp, 1);
+    return output;
+  }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
+                                 RegBufferInt32<16>> {
+  typedef RegBufferInt32<16> InputType;
+  typedef RegBufferUint8<16> OutputType;
+
+  typedef OutputStageTruncatingCastToUint8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    // Pack every 32-bit element into 16 bits.
+    v8i16 tmp0 = __builtin_msa_pckev_h(
+        reinterpret_cast<v8i16>(input.reg[1]),
+        reinterpret_cast<v8i16>(input.reg[0]));
+    v8i16 tmp1 = __builtin_msa_pckev_h(
+        reinterpret_cast<v8i16>(input.reg[3]),
+        reinterpret_cast<v8i16>(input.reg[2]));
+    // Pack every element into 8 bits.
+    output.reg[0] = __builtin_msa_pckev_b(
+        reinterpret_cast<v16i8>(tmp1), reinterpret_cast<v16i8>(tmp0));
+    return output;
+  }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
+                                 RegBufferInt32<32>> {
+  typedef RegBufferInt32<32> InputType;
+  typedef RegBufferUint8<32> OutputType;
+
+  typedef OutputStageTruncatingCastToUint8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    // Pack every 32-bit element into 16 bits.
+    v8i16 tmp0 = __builtin_msa_pckev_h(
+        reinterpret_cast<v8i16>(input.reg[1]),
+        reinterpret_cast<v8i16>(input.reg[0]));
+    v8i16 tmp1 = __builtin_msa_pckev_h(
+        reinterpret_cast<v8i16>(input.reg[3]),
+        reinterpret_cast<v8i16>(input.reg[2]));
+    v8i16 tmp2 = __builtin_msa_pckev_h(
+        reinterpret_cast<v8i16>(input.reg[5]),
+        reinterpret_cast<v8i16>(input.reg[4]));
+    v8i16 tmp3 = __builtin_msa_pckev_h(
+        reinterpret_cast<v8i16>(input.reg[7]),
+        reinterpret_cast<v8i16>(input.reg[6]));
+    // Pack every element into 8 bits.
+    output.reg[0] = __builtin_msa_pckev_b(
+        reinterpret_cast<v16i8>(tmp1), reinterpret_cast<v16i8>(tmp0));
+    output.reg[1] = __builtin_msa_pckev_b(
+        reinterpret_cast<v16i8>(tmp3), reinterpret_cast<v16i8>(tmp2));
+    return output;
+  }
+};
+
 template <typename DstType>
 struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
   static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
@@ -474,50 +576,50 @@
       }
     } else {
       // top-left 4x4
-      v4i32 t0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[1],
-          src.buf.reg[0]));
-      v4i32 t1 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[3],
-          src.buf.reg[2]));
+      v4i32 t0 = reinterpret_cast<v4i32>(
+          __builtin_msa_ilvr_h(src.buf.reg[1], src.buf.reg[0]));
+      v4i32 t1 = reinterpret_cast<v4i32>(
+          __builtin_msa_ilvr_h(src.buf.reg[3], src.buf.reg[2]));
       v2i64 u0 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t1, t0));
       v2i64 u1 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t1, t0));
       // top-right 4x4
-      v4i32 t2 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[5],
-          src.buf.reg[4]));
-      v4i32 t3 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[7],
-          src.buf.reg[6]));
+      v4i32 t2 = reinterpret_cast<v4i32>(
+          __builtin_msa_ilvr_h(src.buf.reg[5], src.buf.reg[4]));
+      v4i32 t3 = reinterpret_cast<v4i32>(
+          __builtin_msa_ilvr_h(src.buf.reg[7], src.buf.reg[6]));
       v2i64 u2 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t3, t2));
       v2i64 u3 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t3, t2));
       // bottom-left 4x4
-      v4i32 t4 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[1],
-          src.buf.reg[0]));
-      v4i32 t5 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[3],
-          src.buf.reg[2]));
+      v4i32 t4 = reinterpret_cast<v4i32>(
+          __builtin_msa_ilvl_h(src.buf.reg[1], src.buf.reg[0]));
+      v4i32 t5 = reinterpret_cast<v4i32>(
+          __builtin_msa_ilvl_h(src.buf.reg[3], src.buf.reg[2]));
       v2i64 u4 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t5, t4));
       v2i64 u5 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t5, t4));
       // bottom-right 4x4
-      v4i32 t6 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[5],
-          src.buf.reg[4]));
-      v4i32 t7 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[7],
-          src.buf.reg[6]));
+      v4i32 t6 = reinterpret_cast<v4i32>(
+          __builtin_msa_ilvl_h(src.buf.reg[5], src.buf.reg[4]));
+      v4i32 t7 = reinterpret_cast<v4i32>(
+          __builtin_msa_ilvl_h(src.buf.reg[7], src.buf.reg[6]));
       v2i64 u6 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t7, t6));
       v2i64 u7 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t7, t6));
 
-      StoreInt16x8(dst->data(row + 0, col), reinterpret_cast<v8i16>(
-          __builtin_msa_ilvr_d(u2, u0)));
-      StoreInt16x8(dst->data(row + 1, col), reinterpret_cast<v8i16>(
-          __builtin_msa_ilvl_d(u2, u0)));
-      StoreInt16x8(dst->data(row + 2, col), reinterpret_cast<v8i16>(
-          __builtin_msa_ilvr_d(u3, u1)));
-      StoreInt16x8(dst->data(row + 3, col), reinterpret_cast<v8i16>(
-          __builtin_msa_ilvl_d(u3, u1)));
-      StoreInt16x8(dst->data(row + 4, col), reinterpret_cast<v8i16>(
-          __builtin_msa_ilvr_d(u6, u4)));
-      StoreInt16x8(dst->data(row + 5, col), reinterpret_cast<v8i16>(
-          __builtin_msa_ilvl_d(u6, u4)));
-      StoreInt16x8(dst->data(row + 6, col), reinterpret_cast<v8i16>(
-          __builtin_msa_ilvr_d(u7, u5)));
-      StoreInt16x8(dst->data(row + 7, col), reinterpret_cast<v8i16>(
-          __builtin_msa_ilvl_d(u7, u5)));
+      StoreInt16x8(dst->data(row + 0, col),
+                   reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u2, u0)));
+      StoreInt16x8(dst->data(row + 1, col),
+                   reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u2, u0)));
+      StoreInt16x8(dst->data(row + 2, col),
+                   reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u3, u1)));
+      StoreInt16x8(dst->data(row + 3, col),
+                   reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u3, u1)));
+      StoreInt16x8(dst->data(row + 4, col),
+                   reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u6, u4)));
+      StoreInt16x8(dst->data(row + 5, col),
+                   reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u6, u4)));
+      StoreInt16x8(dst->data(row + 6, col),
+                   reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u7, u5)));
+      StoreInt16x8(dst->data(row + 7, col),
+                   reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u7, u5)));
     }
   }
 };
@@ -585,6 +687,391 @@
   }
 };
 
+// There's no way to express in C++ the desired machine code for
+// StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> and
+// StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType>.
+// Hence, if we can, we use inline assembly, which takes advantage
+// of little-endian byte order and specifics of different CPU revisions.
+// Note, clang currently can't derive MSA register names from floating-
+// point register names and vice versa in inline assembly.
+#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) && \
+    !defined(__clang__)
+
+// Instructions for pointer-sized operands.
+#ifdef GEMMLOWP_MIPS_64
+#define GEMMLOWP_MIPS_XADDU "daddu"
+#define GEMMLOWP_MIPS_XLSA "dlsa"
+#else
+#define GEMMLOWP_MIPS_XADDU "addu"
+#define GEMMLOWP_MIPS_XLSA "lsa"
+#endif
+
+// Stores 4 8-byte half-vectors with a stride.
+inline void MipsMsaStore4x8(const RegBlockUint8<8, 4>& src,
+                            std::uint8_t* dst_ptr, int stride) {
+#if (__mips_isa_rev >= 6)
+  // Assembly temporaries that will be handily referred to by their names.
+  std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3;
+  v16i8 vtmp0, vtmp1;
+  asm volatile(
+    GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
+    "ilvl.d               %w[vtmp0], %w[src0], %w[src0]\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
+    "ilvl.d               %w[vtmp1], %w[src1], %w[src1]\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
+    "sdc1                 %[src0], 0(%[dst_ptr0])\n"
+    "sdc1                 %[vtmp0], 0(%[dst_ptr1])\n"
+    "sdc1                 %[src1], 0(%[dst_ptr2])\n"
+    "sdc1                 %[vtmp1], 0(%[dst_ptr3])\n"
+    :
+    // Outputs.
+    [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
+    [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
+    [vtmp0] "=&f"(vtmp0), [vtmp1] "=&f"(vtmp1)
+    :
+    // Inputs.
+    [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
+    [stride] "r"(stride)
+    :
+    // Clobbers.
+    "memory");
+#else
+  // Assembly temporaries that will be handily referred to by their names.
+  std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3;
+  int tmp0, tmp1, tmp2, tmp3;
+  asm volatile(
+    GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
+    "copy_s.w             %[tmp0], %w[src0][0]\n"
+    "copy_s.w             %[tmp1], %w[src0][1]\n"
+    "copy_s.w             %[tmp2], %w[src0][2]\n"
+    "copy_s.w             %[tmp3], %w[src0][3]\n"
+    "swr                  %[tmp0], 0(%[dst_ptr0])\n"
+    "swl                  %[tmp0], 3(%[dst_ptr0])\n"
+    "swr                  %[tmp1], 4(%[dst_ptr0])\n"
+    "swl                  %[tmp1], 7(%[dst_ptr0])\n"
+    "swr                  %[tmp2], 0(%[dst_ptr1])\n"
+    "swl                  %[tmp2], 3(%[dst_ptr1])\n"
+    "swr                  %[tmp3], 4(%[dst_ptr1])\n"
+    "swl                  %[tmp3], 7(%[dst_ptr1])\n"
+    "copy_s.w             %[tmp0], %w[src1][0]\n"
+    "copy_s.w             %[tmp1], %w[src1][1]\n"
+    "copy_s.w             %[tmp2], %w[src1][2]\n"
+    "copy_s.w             %[tmp3], %w[src1][3]\n"
+    "swr                  %[tmp0], 0(%[dst_ptr2])\n"
+    "swl                  %[tmp0], 3(%[dst_ptr2])\n"
+    "swr                  %[tmp1], 4(%[dst_ptr2])\n"
+    "swl                  %[tmp1], 7(%[dst_ptr2])\n"
+    "swr                  %[tmp2], 0(%[dst_ptr3])\n"
+    "swl                  %[tmp2], 3(%[dst_ptr3])\n"
+    "swr                  %[tmp3], 4(%[dst_ptr3])\n"
+    "swl                  %[tmp3], 7(%[dst_ptr3])\n"
+    :
+    // Outputs.
+    [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
+    [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), [tmp0] "=&r"(tmp0),
+    [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), [tmp3] "=&r"(tmp3)
+    :
+    // Inputs.
+    [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
+    [stride] "r"(stride)
+    :
+    // Clobbers.
+    "memory");
+#endif
+}
+
+// Stores 8 4-byte quarter-vectors with a stride.
+inline void MipsMsaStore8x4(const RegBlockUint8<4, 8>& src,
+                            std::uint8_t* dst_ptr, int stride) {
+#if (__mips_isa_rev >= 6)
+  // Assembly temporaries that will be handily referred to by their names.
+  std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5,
+      *dst_ptr6, *dst_ptr7;
+  int tmp1, tmp2, tmp3;
+  asm volatile(
+    GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n"
+    "copy_s.w             %[tmp1], %w[src0][1]\n"
+    "copy_s.w             %[tmp2], %w[src0][2]\n"
+    "copy_s.w             %[tmp3], %w[src0][3]\n"
+    "swc1                 %[src0], 0(%[dst_ptr0])\n"
+    "sw                   %[tmp1], 0(%[dst_ptr1])\n"
+    "sw                   %[tmp2], 0(%[dst_ptr2])\n"
+    "sw                   %[tmp3], 0(%[dst_ptr3])\n"
+    "copy_s.w             %[tmp1], %w[src1][1]\n"
+    "copy_s.w             %[tmp2], %w[src1][2]\n"
+    "copy_s.w             %[tmp3], %w[src1][3]\n"
+    "swc1                 %[src1], 0(%[dst_ptr4])\n"
+    "sw                   %[tmp1], 0(%[dst_ptr5])\n"
+    "sw                   %[tmp2], 0(%[dst_ptr6])\n"
+    "sw                   %[tmp3], 0(%[dst_ptr7])\n"
+    :
+    // Outputs.
+    [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
+    [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
+    [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5),
+    [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7),
+    [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), [tmp3] "=&r"(tmp3)
+    :
+    // Inputs.
+    [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
+    [stride] "r"(stride)
+    :
+    // Clobbers.
+    "memory");
+#else
+  // Assembly temporaries that will be handily referred to by their names.
+  std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5,
+      *dst_ptr6, *dst_ptr7;
+  int tmp0, tmp1, tmp2, tmp3;
+  asm volatile(
+    GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n"
+    "copy_s.w             %[tmp0], %w[src0][0]\n"
+    "copy_s.w             %[tmp1], %w[src0][1]\n"
+    "copy_s.w             %[tmp2], %w[src0][2]\n"
+    "copy_s.w             %[tmp3], %w[src0][3]\n"
+    "swr                  %[tmp0], 0(%[dst_ptr0])\n"
+    "swl                  %[tmp0], 3(%[dst_ptr0])\n"
+    "swr                  %[tmp1], 0(%[dst_ptr1])\n"
+    "swl                  %[tmp1], 3(%[dst_ptr1])\n"
+    "swr                  %[tmp2], 0(%[dst_ptr2])\n"
+    "swl                  %[tmp2], 3(%[dst_ptr2])\n"
+    "swr                  %[tmp3], 0(%[dst_ptr3])\n"
+    "swl                  %[tmp3], 3(%[dst_ptr3])\n"
+    "copy_s.w             %[tmp0], %w[src1][0]\n"
+    "copy_s.w             %[tmp1], %w[src1][1]\n"
+    "copy_s.w             %[tmp2], %w[src1][2]\n"
+    "copy_s.w             %[tmp3], %w[src1][3]\n"
+    "swr                  %[tmp0], 0(%[dst_ptr4])\n"
+    "swl                  %[tmp0], 3(%[dst_ptr4])\n"
+    "swr                  %[tmp1], 0(%[dst_ptr5])\n"
+    "swl                  %[tmp1], 3(%[dst_ptr5])\n"
+    "swr                  %[tmp2], 0(%[dst_ptr6])\n"
+    "swl                  %[tmp2], 3(%[dst_ptr6])\n"
+    "swr                  %[tmp3], 0(%[dst_ptr7])\n"
+    "swl                  %[tmp3], 3(%[dst_ptr7])\n"
+    :
+    // Outputs.
+    [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
+    [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
+    [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5),
+    [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7),
+    [tmp0] "=&r"(tmp0), [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2),
+    [tmp3] "=&r"(tmp3)
+    :
+    // Inputs.
+    [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
+    [stride] "r"(stride)
+    :
+    // Clobbers.
+    "memory");
+#endif
+}
+
+// Stores 8 8-byte half-vectors with a stride.
+inline void MipsMsaStore8x8(const RegBlockUint8<8, 8>& src,
+                            std::uint8_t* dst_ptr, int stride) {
+#if (__mips_isa_rev >= 6)
+  // Assembly temporaries that will be handily referred to by their names.
+  std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5,
+      *dst_ptr6, *dst_ptr7;
+  v16i8 vtmp0, vtmp1, vtmp2, vtmp3;
+  asm volatile(
+    "ilvl.d               %w[vtmp0], %w[src0], %w[src0]\n"
+    GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
+    "ilvl.d               %w[vtmp1], %w[src1], %w[src1]\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
+    "ilvl.d               %w[vtmp2], %w[src2], %w[src2]\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n"
+    "ilvl.d               %w[vtmp3], %w[src3], %w[src3]\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n"
+    "sdc1                 %[src0], 0(%[dst_ptr0])\n"
+    "sdc1                 %[vtmp0], 0(%[dst_ptr1])\n"
+    "sdc1                 %[src1], 0(%[dst_ptr2])\n"
+    "sdc1                 %[vtmp1], 0(%[dst_ptr3])\n"
+    "sdc1                 %[src2], 0(%[dst_ptr4])\n"
+    "sdc1                 %[vtmp2], 0(%[dst_ptr5])\n"
+    "sdc1                 %[src3], 0(%[dst_ptr6])\n"
+    "sdc1                 %[vtmp3], 0(%[dst_ptr7])\n"
+    :
+    // Outputs.
+    [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
+    [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
+    [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5),
+    [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7),
+    [vtmp0] "=&f"(vtmp0), [vtmp1] "=&f"(vtmp1), [vtmp2] "=&f"(vtmp2),
+    [vtmp3] "=&f"(vtmp3)
+    :
+    // Inputs.
+    [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
+    [src2] "f"(src.buf.reg[2]), [src3] "f"(src.buf.reg[3]),
+    [stride] "r"(stride)
+    :
+    // Clobbers.
+    "memory");
+#else
+  // Assembly temporaries that will be handily referred to by their names.
+  std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5,
+      *dst_ptr6, *dst_ptr7;
+  int tmp0, tmp1, tmp2, tmp3;
+  asm volatile(
+    GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n"
+    GEMMLOWP_MIPS_XLSA  " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n"
+    "copy_s.w             %[tmp0], %w[src0][0]\n"
+    "copy_s.w             %[tmp1], %w[src0][1]\n"
+    "copy_s.w             %[tmp2], %w[src0][2]\n"
+    "copy_s.w             %[tmp3], %w[src0][3]\n"
+    "swr                  %[tmp0], 0(%[dst_ptr0])\n"
+    "swl                  %[tmp0], 3(%[dst_ptr0])\n"
+    "swr                  %[tmp1], 4(%[dst_ptr0])\n"
+    "swl                  %[tmp1], 7(%[dst_ptr0])\n"
+    "swr                  %[tmp2], 0(%[dst_ptr1])\n"
+    "swl                  %[tmp2], 3(%[dst_ptr1])\n"
+    "swr                  %[tmp3], 4(%[dst_ptr1])\n"
+    "swl                  %[tmp3], 7(%[dst_ptr1])\n"
+    "copy_s.w             %[tmp0], %w[src1][0]\n"
+    "copy_s.w             %[tmp1], %w[src1][1]\n"
+    "copy_s.w             %[tmp2], %w[src1][2]\n"
+    "copy_s.w             %[tmp3], %w[src1][3]\n"
+    "swr                  %[tmp0], 0(%[dst_ptr2])\n"
+    "swl                  %[tmp0], 3(%[dst_ptr2])\n"
+    "swr                  %[tmp1], 4(%[dst_ptr2])\n"
+    "swl                  %[tmp1], 7(%[dst_ptr2])\n"
+    "swr                  %[tmp2], 0(%[dst_ptr3])\n"
+    "swl                  %[tmp2], 3(%[dst_ptr3])\n"
+    "swr                  %[tmp3], 4(%[dst_ptr3])\n"
+    "swl                  %[tmp3], 7(%[dst_ptr3])\n"
+    "copy_s.w             %[tmp0], %w[src2][0]\n"
+    "copy_s.w             %[tmp1], %w[src2][1]\n"
+    "copy_s.w             %[tmp2], %w[src2][2]\n"
+    "copy_s.w             %[tmp3], %w[src2][3]\n"
+    "swr                  %[tmp0], 0(%[dst_ptr4])\n"
+    "swl                  %[tmp0], 3(%[dst_ptr4])\n"
+    "swr                  %[tmp1], 4(%[dst_ptr4])\n"
+    "swl                  %[tmp1], 7(%[dst_ptr4])\n"
+    "swr                  %[tmp2], 0(%[dst_ptr5])\n"
+    "swl                  %[tmp2], 3(%[dst_ptr5])\n"
+    "swr                  %[tmp3], 4(%[dst_ptr5])\n"
+    "swl                  %[tmp3], 7(%[dst_ptr5])\n"
+    "copy_s.w             %[tmp0], %w[src3][0]\n"
+    "copy_s.w             %[tmp1], %w[src3][1]\n"
+    "copy_s.w             %[tmp2], %w[src3][2]\n"
+    "copy_s.w             %[tmp3], %w[src3][3]\n"
+    "swr                  %[tmp0], 0(%[dst_ptr6])\n"
+    "swl                  %[tmp0], 3(%[dst_ptr6])\n"
+    "swr                  %[tmp1], 4(%[dst_ptr6])\n"
+    "swl                  %[tmp1], 7(%[dst_ptr6])\n"
+    "swr                  %[tmp2], 0(%[dst_ptr7])\n"
+    "swl                  %[tmp2], 3(%[dst_ptr7])\n"
+    "swr                  %[tmp3], 4(%[dst_ptr7])\n"
+    "swl                  %[tmp3], 7(%[dst_ptr7])\n"
+    :
+    // Outputs.
+    [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
+    [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
+    [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5),
+    [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7),
+    [tmp0] "=&r"(tmp0), [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2),
+    [tmp3] "=&r"(tmp3)
+    :
+    // Inputs.
+    [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
+    [src2] "f"(src.buf.reg[2]), [src3] "f"(src.buf.reg[3]),
+    [stride] "r"(stride)
+    :
+    // Clobbers.
+    "memory");
+#endif
+}
+
+#undef GEMMLOWP_MIPS_XADDU
+#undef GEMMLOWP_MIPS_XLSA
+
+// Transposes a column-major 8x4 block for storage into a row-major matrix.
+inline RegBlockUint8<4, 8> Transpose(const RegBlockUint8<8, 4>& src) {
+  v16i8 tmp0 = __builtin_msa_ilvr_b(src.buf.reg[1], src.buf.reg[0]);
+  v16i8 tmp1 = __builtin_msa_ilvl_b(src.buf.reg[1], src.buf.reg[0]);
+  RegBlockUint8<4, 8> result;
+  result.buf.reg[0] = __builtin_msa_ilvr_b(tmp1, tmp0);
+  result.buf.reg[1] = __builtin_msa_ilvl_b(tmp1, tmp0);
+  return result;
+}
+
+inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) {
+  v16i8 tmp0[4];
+  tmp0[0] = __builtin_msa_ilvr_b(src.buf.reg[1], src.buf.reg[0]);
+  tmp0[1] = __builtin_msa_ilvl_b(src.buf.reg[1], src.buf.reg[0]);
+  tmp0[2] = __builtin_msa_ilvr_b(src.buf.reg[3], src.buf.reg[2]);
+  tmp0[3] = __builtin_msa_ilvl_b(src.buf.reg[3], src.buf.reg[2]);
+  v16i8 tmp1[4];
+  tmp1[0] = __builtin_msa_ilvr_b(tmp0[1], tmp0[0]);
+  tmp1[1] = __builtin_msa_ilvl_b(tmp0[1], tmp0[0]);
+  tmp1[2] = __builtin_msa_ilvr_b(tmp0[3], tmp0[2]);
+  tmp1[3] = __builtin_msa_ilvl_b(tmp0[3], tmp0[2]);
+  RegBlockUint8<8, 8> result;
+  result.buf.reg[0] = reinterpret_cast<v16i8>(__builtin_msa_ilvr_w(
+      reinterpret_cast<v4i32>(tmp1[2]), reinterpret_cast<v4i32>(tmp1[0])));
+  result.buf.reg[1] = reinterpret_cast<v16i8>(__builtin_msa_ilvl_w(
+      reinterpret_cast<v4i32>(tmp1[2]), reinterpret_cast<v4i32>(tmp1[0])));
+  result.buf.reg[2] = reinterpret_cast<v16i8>(__builtin_msa_ilvr_w(
+      reinterpret_cast<v4i32>(tmp1[3]), reinterpret_cast<v4i32>(tmp1[1])));
+  result.buf.reg[3] = reinterpret_cast<v16i8>(__builtin_msa_ilvl_w(
+      reinterpret_cast<v4i32>(tmp1[3]), reinterpret_cast<v4i32>(tmp1[1])));
+  return result;
+}
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
+  static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
+                  int col) {
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      std::uint8_t* dst_ptr = dst->data(row, col);
+      int col_stride = dst->cols_stride();
+      MipsMsaStore4x8(src, dst_ptr, col_stride);
+    } else {
+      const auto& block = Transpose(src);
+      std::uint8_t* dst_ptr = dst->data(row, col);
+      int row_stride = dst->rows_stride();
+      MipsMsaStore8x4(block, dst_ptr, row_stride);
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
+  static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
+                  int col) {
+    const auto& block =
+        (DstType::kOrder == MapOrder::ColMajor) ? src : Transpose(src);
+    std::uint8_t* dst_ptr = dst->data(row, col);
+    int stride = dst->stride();
+    MipsMsaStore8x8(block, dst_ptr, stride);
+  }
+};
+
+#else
+
 template <typename DstType>
 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
   static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
@@ -617,6 +1104,8 @@
   }
 };
 
+#endif  // Endianness, compiler.
+
 }  // namespace gemmlowp
 
 #endif  // GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
diff --git a/internal/output_neon.h b/internal/output_neon.h
index 911fed0..52ea1bc 100644
--- a/internal/output_neon.h
+++ b/internal/output_neon.h
@@ -108,6 +108,90 @@
 };
 
 template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
+                                 RegBufferInt32<4>> {
+  typedef RegBufferInt32<4> InputType;
+  typedef RegBufferInt8<4> OutputType;
+
+  typedef OutputStageSaturatingCastToInt8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    int16x4_t res_16 = vqmovn_s32(input.reg[0]);
+    int8x8_t res_8 = vqmovn_s16(vcombine_s16(res_16, res_16));
+    output.reg[0] = vget_lane_s32(vreinterpret_s32_s8(res_8), 0);
+    return output;
+  }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
+                                 RegBufferInt32<8>> {
+  typedef RegBufferInt32<8> InputType;
+  typedef RegBufferInt8<8> OutputType;
+
+  typedef OutputStageSaturatingCastToInt8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    int16x8_t res_16 =
+        vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
+    output.reg[0] = vqmovn_s16(res_16);
+    return output;
+  }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
+                                 RegBufferInt32<16>> {
+  typedef RegBufferInt32<16> InputType;
+  typedef RegBufferInt8<16> OutputType;
+
+  typedef OutputStageSaturatingCastToInt8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    int16x8_t res_16_0 =
+        vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
+    int16x8_t res_16_1 =
+        vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
+    output.reg[0] = vqmovn_s16(res_16_0);
+    output.reg[1] = vqmovn_s16(res_16_1);
+    return output;
+  }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
+                                 RegBufferInt32<32>> {
+  typedef RegBufferInt32<32> InputType;
+  typedef RegBufferInt8<32> OutputType;
+
+  typedef OutputStageSaturatingCastToInt8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    int16x8_t res_16[4];
+    for (int i = 0; i < 4; i++) {
+      res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]),
+                               vqmovn_s32(input.reg[2 * i + 1]));
+    }
+    for (int i = 0; i < 4; i++) {
+      output.reg[i] = vqmovn_s16(res_16[i]);
+    }
+    return output;
+  }
+};
+
+template <>
 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
                                  RegBufferInt32<4>> {
   typedef RegBufferInt32<4> InputType;
@@ -556,8 +640,8 @@
         vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]);
       }
     } else {
+      int row_stride = dst->rows_stride();
       for (int i = 0; i < 4; i++) {
-        int row_stride = dst->rows_stride();
         std::uint8_t* col_ptr = dst_ptr + i;
         vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0);
         vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1);
@@ -623,6 +707,153 @@
 };
 
 template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt8<4, 1>, DstType> {
+  static void Run(const RegBlockInt8<4, 1>& src, DstType* dst, int row,
+                  int col) {
+    const std::int32_t src_reg = src.buf.reg[0];
+    for (int i = 0; i < 4; i++) {
+      *dst->data(row + i, col) = (src_reg >> (8 * i));
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt8<1, 4>, DstType> {
+  static void Run(const RegBlockInt8<1, 4>& src, DstType* dst, int row,
+                  int col) {
+    for (int i = 0; i < 4; i++) {
+      *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt8<8, 1>, DstType> {
+  static void Run(const RegBlockInt8<8, 1>& src, DstType* dst, int row,
+                  int col) {
+    std::int8_t* dst_ptr = dst->data(row, col);
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      vst1_s8(dst_ptr, src.buf.reg[0]);
+    } else {
+      const int row_stride = dst->rows_stride();
+      vst1_lane_s8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0);
+      vst1_lane_s8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1);
+      vst1_lane_s8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2);
+      vst1_lane_s8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3);
+      vst1_lane_s8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4);
+      vst1_lane_s8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5);
+      vst1_lane_s8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6);
+      vst1_lane_s8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7);
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt8<4, 4>, DstType> {
+  static void Run(const RegBlockInt8<4, 4>& src, DstType* dst, int row,
+                  int col) {
+    std::int8_t* dst_ptr = dst->data(row, col);
+    const int row_stride = dst->rows_stride();
+    const int col_stride = dst->cols_stride();
+    for (int i = 0; i < 2; i++) {
+      vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride,
+                   src.buf.reg[i], 0);
+      vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride,
+                   src.buf.reg[i], 1);
+      vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride,
+                   src.buf.reg[i], 2);
+      vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride,
+                   src.buf.reg[i], 3);
+      vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride,
+                   src.buf.reg[i], 4);
+      vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride,
+                   src.buf.reg[i], 5);
+      vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride,
+                   src.buf.reg[i], 6);
+      vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride,
+                   src.buf.reg[i], 7);
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt8<8, 4>, DstType> {
+  static void Run(const RegBlockInt8<8, 4>& src, DstType* dst, int row,
+                  int col) {
+    std::int8_t* dst_ptr = dst->data(row, col);
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      int col_stride = dst->cols_stride();
+      for (int i = 0; i < 4; i++) {
+        vst1_s8(dst_ptr + i * col_stride, src.buf.reg[i]);
+      }
+    } else {
+      int row_stride = dst->rows_stride();
+      for (int i = 0; i < 4; i++) {
+        std::int8_t* col_ptr = dst_ptr + i;
+        vst1_lane_s8(col_ptr + 0 * row_stride, src.buf.reg[i], 0);
+        vst1_lane_s8(col_ptr + 1 * row_stride, src.buf.reg[i], 1);
+        vst1_lane_s8(col_ptr + 2 * row_stride, src.buf.reg[i], 2);
+        vst1_lane_s8(col_ptr + 3 * row_stride, src.buf.reg[i], 3);
+        vst1_lane_s8(col_ptr + 4 * row_stride, src.buf.reg[i], 4);
+        vst1_lane_s8(col_ptr + 5 * row_stride, src.buf.reg[i], 5);
+        vst1_lane_s8(col_ptr + 6 * row_stride, src.buf.reg[i], 6);
+        vst1_lane_s8(col_ptr + 7 * row_stride, src.buf.reg[i], 7);
+      }
+    }
+  }
+};
+
+inline RegBlockInt8<8, 8> Transpose(const RegBlockInt8<8, 8>& src) {
+  int8x8x2_t a[4];
+  a[0] = vtrn_s8(src.buf.reg[0], src.buf.reg[1]);
+  a[1] = vtrn_s8(src.buf.reg[2], src.buf.reg[3]);
+  a[2] = vtrn_s8(src.buf.reg[4], src.buf.reg[5]);
+  a[3] = vtrn_s8(src.buf.reg[6], src.buf.reg[7]);
+  int16x4x2_t b[4];
+  b[0] = vtrn_s16(vreinterpret_s16_s8(a[0].val[0]),
+                  vreinterpret_s16_s8(a[1].val[0]));
+  b[1] = vtrn_s16(vreinterpret_s16_s8(a[0].val[1]),
+                  vreinterpret_s16_s8(a[1].val[1]));
+  b[2] = vtrn_s16(vreinterpret_s16_s8(a[2].val[0]),
+                  vreinterpret_s16_s8(a[3].val[0]));
+  b[3] = vtrn_s16(vreinterpret_s16_s8(a[2].val[1]),
+                  vreinterpret_s16_s8(a[3].val[1]));
+  int32x2x2_t c[4];
+  c[0] = vtrn_s32(vreinterpret_s32_s16(b[0].val[0]),
+                  vreinterpret_s32_s16(b[2].val[0]));
+  c[1] = vtrn_s32(vreinterpret_s32_s16(b[1].val[0]),
+                  vreinterpret_s32_s16(b[3].val[0]));
+  c[2] = vtrn_s32(vreinterpret_s32_s16(b[0].val[1]),
+                  vreinterpret_s32_s16(b[2].val[1]));
+  c[3] = vtrn_s32(vreinterpret_s32_s16(b[1].val[1]),
+                  vreinterpret_s32_s16(b[3].val[1]));
+  RegBlockInt8<8, 8> result;
+  result.buf.reg[0] = vreinterpret_s8_s32(c[0].val[0]);
+  result.buf.reg[1] = vreinterpret_s8_s32(c[1].val[0]);
+  result.buf.reg[2] = vreinterpret_s8_s32(c[2].val[0]);
+  result.buf.reg[3] = vreinterpret_s8_s32(c[3].val[0]);
+  result.buf.reg[4] = vreinterpret_s8_s32(c[0].val[1]);
+  result.buf.reg[5] = vreinterpret_s8_s32(c[1].val[1]);
+  result.buf.reg[6] = vreinterpret_s8_s32(c[2].val[1]);
+  result.buf.reg[7] = vreinterpret_s8_s32(c[3].val[1]);
+  return result;
+}
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt8<8, 8>, DstType> {
+  static void Run(const RegBlockInt8<8, 8>& src, DstType* dst, int row,
+                  int col) {
+    const auto& block =
+        DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src);
+    std::int8_t* dst_ptr = dst->data(row, col);
+    int stride = dst->stride();
+    for (int i = 0; i < 8; i++) {
+      vst1_s8(dst_ptr + i * stride, block.buf.reg[i]);
+    }
+  }
+};
+
+template <typename DstType>
 struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
   static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
                   int col) {
diff --git a/internal/pack.h b/internal/pack.h
index cb4b93a..7c43d6e 100644
--- a/internal/pack.h
+++ b/internal/pack.h
@@ -72,6 +72,10 @@
     pos_ += n * KernelSideFormat::Cell::kSize;
   }
 
+  // TODO(suharshs): The datatype can now be int8 as well. We could introduce a
+  // new int8 current_data impl as well. This change would propagate to all pack
+  // impls and the Kernel::Run API, which all assume uint8. For now we leave
+  // this as-is pending future refactor.
   const std::uint8_t* current_data() const {
     return allocator_->GetPointer<std::uint8_t>(data_handle_) + pos_;
   }
@@ -208,6 +212,7 @@
  public:
   typedef typename PackedSideBlock::KernelSideFormat KernelSideFormat;
   typedef typename KernelSideFormat::Cell CellFormat;
+  typedef typename KernelSideFormat::InputScalar KernelInputScalar;
   typedef typename KernelSideFormat::Scalar KernelScalar;
   static const int kCells = KernelSideFormat::kCells;
   static const int kCellWidth = CellFormat::kWidth;
@@ -216,7 +221,7 @@
   static const int kCellSize = CellFormat::kSize;
   static const SideMapOrder kSrcOrder = SrcMapType::kOrder;
   static const int kZeroPointInputValue =
-      ZeroPointInputValue<KernelScalar>::kValue;
+      ZeroPointInputValue<KernelInputScalar, KernelScalar>::kValue;
 
   PackingRegisterBlockBase() : complete_src_(nullptr, 0, 0, 0) {}
 
@@ -233,7 +238,7 @@
   std::uint8_t buf_[kKernelWidth * kRegisterSize];
 
  public:
-  // Selects a block if in-place source data that's already a complete block
+  // Selects a block if in-place source data that's already a complete block.
   void UseCompleteSrcInPlace(const SrcMapType& src) { complete_src_ = src; }
   // Copies an incomplete block of source data into a local temporary
   // complete block by zero-extending it.
@@ -249,7 +254,10 @@
         memcpy(buf_ + d * kKernelWidth, src.data(0, d), src.width());
       }
     }
-    complete_src_ = SrcMapType(buf_, kKernelWidth, kRegisterSize);
+
+    // Since the KernelInputScalar type may not be uint8, we need to cast buf_.
+    complete_src_ = SrcMapType(reinterpret_cast<KernelInputScalar*>(buf_),
+                               kKernelWidth, kRegisterSize);
   }
   // Packs a complete block into the destination. This is the most
   // critical part and the part that we most typically want to
@@ -340,7 +348,7 @@
     }
   }
 
-  // Prefetches the data that will be read by PackL1
+  // Prefetches the data that will be read by PackL1.
   void PrefetchL1(int start_width, int width, int start_depth, int depth) {
     if (SrcMapType::kOrder == SideMapOrder::WidthMajor) {
       for (int d = 0; d < depth; d += kDefaultCacheLineSize) {
@@ -394,7 +402,7 @@
   const SrcMapType& src_map_;
 };
 
-// Packs a block of the input LHS matrix, into a PackedSideBlock
+// Packs a block of the input LHS matrix, into a PackedSideBlock.
 template <typename PackedSideBlock, typename MatrixMapType>
 void PackLhs(PackedSideBlock* dst, const MatrixMapType& src) {
   ScopedProfilingLabel label("pack LHS");
@@ -409,7 +417,7 @@
   impl.PackL2();
 }
 
-// Packs a block of the input RHS matrix, into a PackedSideBlock
+// Packs a block of the input RHS matrix, into a PackedSideBlock.
 template <typename PackedSideBlock, typename MatrixMapType>
 void PackRhs(PackedSideBlock* dst, const MatrixMapType& src) {
   ScopedProfilingLabel label("pack RHS");
@@ -430,6 +438,8 @@
 #include "pack_neon.h"
 #elif defined(GEMMLOWP_SSE4)
 #include "pack_sse.h"
+#elif defined(GEMMLOWP_AVX2)
+#include "pack_avx.h"
 #elif defined(GEMMLOWP_MSA)
 #include "pack_msa.h"
 #endif
diff --git a/internal/pack_avx.h b/internal/pack_avx.h
new file mode 100644
index 0000000..1ef5ce1
--- /dev/null
+++ b/internal/pack_avx.h
@@ -0,0 +1,282 @@
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// pack_avx.h: optimized AVX specializations of the templates in pack.h.
+
+#ifndef GEMMLOWP_INTERNAL_PACK_AVX_H_
+#define GEMMLOWP_INTERNAL_PACK_AVX_H_
+
+#include <immintrin.h>
+#include "pack.h"
+
+namespace gemmlowp {
+
+// TODO: Add DepthMajorUint8SideMap
+
+typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
+    WidthMajorUint8SideMap;
+
+template <int Cells>
+using WidthMajorSideFormatNCells4x2 =
+    KernelSideFormat<CellFormat<8, 2, CellOrder::WidthMajor>, Cells>;
+
+template <int Cells>
+class PackingRegisterBlock<
+    WidthMajorUint8SideMap,
+    PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>>
+    : public PackingRegisterBlockBase<
+          WidthMajorUint8SideMap,
+          PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> {
+ public:
+  typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
+  typedef typename KernelSideFormat::Cell CellFormat;
+  static const int kCells = KernelSideFormat::kCells;
+  static const int kCellWidth = CellFormat::kWidth;
+  static const int kKernelWidth = CellFormat::kWidth * kCells;
+  static const int kCellDepth = CellFormat::kDepth;
+  static const int kCellSize = CellFormat::kSize;
+
+  void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) {
+    std::uint8_t *dst_ptr = dst->current_data();
+    const int width_stride = this->complete_src_.width_stride();
+    int depth_step = 16;
+
+    __m256i one = _mm256_set1_epi16(1);
+    for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
+         cell_start_depth += depth_step) {
+      for (int cell_start_width = 0; cell_start_width < kKernelWidth;
+           cell_start_width += kCellWidth) {
+        std::int32_t *cell_sums_of_each_slice_ptr =
+            dst->sums_of_each_slice() + start_width + cell_start_width;
+        const std::uint8_t *src_data =
+            this->complete_src_.data(cell_start_width, cell_start_depth);
+
+        __m128i xmm1 =
+            _mm_loadu_si128(reinterpret_cast<const __m128i *>(&src_data[0]));
+        __m128i xmm2 = _mm_loadu_si128(
+            reinterpret_cast<const __m128i *>(&src_data[1 * width_stride]));
+        __m128i xmm3 = _mm_loadu_si128(
+            reinterpret_cast<const __m128i *>(&src_data[2 * width_stride]));
+        __m128i xmm4 = _mm_loadu_si128(
+            reinterpret_cast<const __m128i *>(&src_data[3 * width_stride]));
+        __m128i xmm5 = _mm_loadu_si128(
+            reinterpret_cast<const __m128i *>(&src_data[4 * width_stride]));
+        __m128i xmm6 = _mm_loadu_si128(
+            reinterpret_cast<const __m128i *>(&src_data[5 * width_stride]));
+        __m128i xmm7 = _mm_loadu_si128(
+            reinterpret_cast<const __m128i *>(&src_data[6 * width_stride]));
+        __m128i xmm8 = _mm_loadu_si128(
+            reinterpret_cast<const __m128i *>(&src_data[7 * width_stride]));
+
+        __m256i ymm1 = _mm256_set_m128i(xmm5, xmm1);
+        __m256i ymm2 = _mm256_set_m128i(xmm6, xmm2);
+        __m256i ymm3 = _mm256_set_m128i(xmm7, xmm3);
+        __m256i ymm4 = _mm256_set_m128i(xmm8, xmm4);
+
+        __m256i ymm5 = _mm256_unpacklo_epi16(ymm1, ymm2);
+        __m256i ymm6 = _mm256_unpacklo_epi16(ymm3, ymm4);
+
+        __m256i ymm9 = _mm256_unpackhi_epi16(ymm1, ymm2);
+        __m256i ymm10 = _mm256_unpackhi_epi16(ymm3, ymm4);
+
+        __m256i ymm7 = _mm256_unpacklo_epi32(ymm5, ymm6);
+        __m256i ymm8 = _mm256_unpackhi_epi32(ymm5, ymm6);
+
+        __m256i ymm13 = _mm256_unpacklo_epi32(ymm9, ymm10);
+        __m256i ymm14 = _mm256_unpackhi_epi32(ymm9, ymm10);
+
+        __m256i ymm11 = _mm256_permute4x64_epi64(ymm7, 0xd8);
+        __m256i ymm12 = _mm256_permute4x64_epi64(ymm8, 0xd8);
+
+        __m256i ymm15 = _mm256_permute4x64_epi64(ymm13, 0xd8);
+        __m256i ymm16 = _mm256_permute4x64_epi64(ymm14, 0xd8);
+
+        __m128i xmm9 = _mm256_castsi256_si128(ymm11);
+        __m128i xmm10 = _mm256_castsi256_si128(ymm12);
+        __m128i xmm11 = _mm256_extracti128_si256(ymm11, 1);
+        __m128i xmm12 = _mm256_extracti128_si256(ymm12, 1);
+
+        xmm1 = _mm256_castsi256_si128(ymm15);
+        xmm2 = _mm256_castsi256_si128(ymm16);
+        xmm3 = _mm256_extracti128_si256(ymm15, 1);
+        xmm4 = _mm256_extracti128_si256(ymm16, 1);
+
+        _mm_storeu_si128(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9);
+        _mm_storeu_si128(
+            reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm11);
+        _mm_storeu_si128(
+            reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]),
+            xmm10);
+        _mm_storeu_si128(
+            reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]),
+            xmm12);
+        _mm_storeu_si128(
+            reinterpret_cast<__m128i *>(&dst_ptr[4 * kCellSize * kCells]),
+            xmm1);
+        _mm_storeu_si128(
+            reinterpret_cast<__m128i *>(&dst_ptr[5 * kCellSize * kCells]),
+            xmm3);
+
+        _mm_storeu_si128(
+            reinterpret_cast<__m128i *>(&dst_ptr[6 * kCellSize * kCells]),
+            xmm2);
+        _mm_storeu_si128(
+            reinterpret_cast<__m128i *>(&dst_ptr[7 * kCellSize * kCells]),
+            xmm4);
+
+        ymm6 = _mm256_cvtepu8_epi16(xmm9);
+        ymm7 = _mm256_madd_epi16(ymm6, one);
+        __m256i sums_of_each_slice_xmm = _mm256_loadu_si256(
+            reinterpret_cast<const __m256i *>(&cell_sums_of_each_slice_ptr[0]));
+        sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+        ymm6 = _mm256_cvtepu8_epi16(xmm11);
+        ymm7 = _mm256_madd_epi16(ymm6, one);
+        sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+        ymm6 = _mm256_cvtepu8_epi16(xmm10);
+        ymm7 = _mm256_madd_epi16(ymm6, one);
+        sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+        ymm6 = _mm256_cvtepu8_epi16(xmm12);
+        ymm7 = _mm256_madd_epi16(ymm6, one);
+        sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+        ymm6 = _mm256_cvtepu8_epi16(xmm1);
+        ymm7 = _mm256_madd_epi16(ymm6, one);
+        sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+        ymm6 = _mm256_cvtepu8_epi16(xmm3);
+        ymm7 = _mm256_madd_epi16(ymm6, one);
+        sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+        ymm6 = _mm256_cvtepu8_epi16(xmm2);
+        ymm7 = _mm256_madd_epi16(ymm6, one);
+        sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+        ymm6 = _mm256_cvtepu8_epi16(xmm4);
+        ymm7 = _mm256_madd_epi16(ymm6, one);
+        sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+        _mm256_storeu_si256(
+            reinterpret_cast<__m256i *>(&cell_sums_of_each_slice_ptr[0]),
+            sums_of_each_slice_xmm);
+        dst_ptr += kCellSize;
+      }
+      dst_ptr += 7 * kCellSize * kCells;
+    }
+    dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
+  }
+};
+
+// Pack format for 4x2 rhs format
+template <int Cells>
+using RhsWidthMajorSideFormatNCells4x2 =
+    KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
+
+template <int Cells>
+class PackingRegisterBlock<
+    WidthMajorUint8SideMap,
+    PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>>
+    : public PackingRegisterBlockBase<
+          WidthMajorUint8SideMap,
+          PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>> {
+ public:
+  typedef RhsWidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
+  typedef typename KernelSideFormat::Cell CellFormat;
+  static const int kCells = KernelSideFormat::kCells;
+  static const int kCellWidth = CellFormat::kWidth;
+  static const int kKernelWidth = CellFormat::kWidth * kCells;
+  static const int kCellDepth = CellFormat::kDepth;
+  static const int kCellSize = CellFormat::kSize;
+
+  void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) {
+    std::uint8_t *dst_ptr = dst->current_data();
+    const int width_stride = this->complete_src_.width_stride();
+    int depth_step = 8;
+
+    __m128i one = _mm_set1_epi16(1);
+    for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
+         cell_start_depth += depth_step) {
+      for (int cell_start_width = 0; cell_start_width < kKernelWidth;
+           cell_start_width += kCellWidth) {
+        std::int32_t *cell_sums_of_each_slice_ptr =
+            dst->sums_of_each_slice() + start_width + cell_start_width;
+        const std::uint8_t *src_data =
+            this->complete_src_.data(cell_start_width, cell_start_depth);
+
+        __m128i xmm1 =
+            _mm_loadl_epi64(reinterpret_cast<const __m128i *>(&src_data[0]));
+        __m128i xmm2 = _mm_loadl_epi64(
+            reinterpret_cast<const __m128i *>(&src_data[1 * width_stride]));
+        __m128i xmm3 = _mm_loadl_epi64(
+            reinterpret_cast<const __m128i *>(&src_data[2 * width_stride]));
+        __m128i xmm4 = _mm_loadl_epi64(
+            reinterpret_cast<const __m128i *>(&src_data[3 * width_stride]));
+
+        __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2);
+        __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31);
+
+        __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4);
+        __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80);
+
+        __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc);
+        __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc);
+
+        _mm_storel_epi64(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9);
+        _mm_storel_epi64(
+            reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm10);
+
+        __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee);
+        __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee);
+
+        _mm_storel_epi64(
+            reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]),
+            xmm11);
+        _mm_storel_epi64(
+            reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]),
+            xmm12);
+
+        xmm1 = _mm_cvtepu8_epi16(xmm9);
+        xmm2 = _mm_madd_epi16(xmm1, one);
+        __m128i sums_of_each_slice_xmm = _mm_loadu_si128(
+            reinterpret_cast<const __m128i *>(&cell_sums_of_each_slice_ptr[0]));
+        sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
+
+        xmm1 = _mm_cvtepu8_epi16(xmm10);
+        xmm2 = _mm_madd_epi16(xmm1, one);
+        sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
+
+        xmm1 = _mm_cvtepu8_epi16(xmm11);
+        xmm2 = _mm_madd_epi16(xmm1, one);
+        sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
+
+        xmm1 = _mm_cvtepu8_epi16(xmm12);
+        xmm2 = _mm_madd_epi16(xmm1, one);
+        sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
+
+        _mm_storeu_si128(
+            reinterpret_cast<__m128i *>(&cell_sums_of_each_slice_ptr[0]),
+            sums_of_each_slice_xmm);
+        dst_ptr += kCellSize;
+      }
+      dst_ptr += 3 * kCellSize * kCells;
+    }
+    dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
+  }
+};
+
+}  // namespace gemmlowp
+
+#endif  // GEMMLOWP_INTERNAL_PACK_AVX_H_
diff --git a/internal/pack_msa.h b/internal/pack_msa.h
index fba8a0f..4072229 100644
--- a/internal/pack_msa.h
+++ b/internal/pack_msa.h
@@ -348,6 +348,84 @@
   }
 };
 
+template <int Width>
+using Int8FastKernelFormat =
+    KernelSideFormatInt8<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>;
+
+template <int Width>
+class PackingRegisterBlock<WidthMajorUint8SideMap,
+                           PackedSideBlock<Int8FastKernelFormat<Width>>>
+    : public PackingRegisterBlockBase<
+          WidthMajorUint8SideMap,
+          PackedSideBlock<Int8FastKernelFormat<Width>>> {
+ public:
+  static_assert(Width == 2 || Width == 4, "");
+  typedef Int8FastKernelFormat<Width> KernelSideFormat;
+  typedef typename KernelSideFormat::Cell CellFormat;
+  static const int kCells = KernelSideFormat::kCells;
+  static const int kCellWidth = CellFormat::kWidth;
+  static const int kKernelWidth = CellFormat::kWidth * kCells;
+  static const int kCellDepth = CellFormat::kDepth;
+  static const int kCellSize = CellFormat::kSize;
+
+  void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
+    std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width;
+    std::uint8_t* dst_ptr = dst->current_data();
+    const std::uint8_t* const src_ptr = this->complete_src_.data();
+    const int stride = this->complete_src_.stride();
+    // Load source WidthMajor data.
+    v16i8 src_lines[Width];
+    for (int i = 0; i < Width; i++) {
+      src_lines[i] = __builtin_msa_ld_b(
+          const_cast<std::uint8_t*>(src_ptr + i * stride), 0);
+    }
+    for (int i = 0; i < Width; i++) {
+      // Subtract 128 by inverting bit 7.
+      src_lines[i] = reinterpret_cast<v16i8>(
+          __builtin_msa_bnegi_b(reinterpret_cast<v16u8>(src_lines[i]), 7));
+    }
+    for (int i = 0; i < Width; i++) {
+      __builtin_msa_st_b(src_lines[i], dst_ptr + 16 * i, 0);
+    }
+    v8i16 sums2[Width];
+    for (int i = 0; i < Width; i++) {
+      sums2[i] = __builtin_msa_hadd_s_h(src_lines[i], src_lines[i]);
+    }
+    v4i32 sums4_wide[Width];
+    for (int i = 0; i < Width; i++) {
+      sums4_wide[i] = __builtin_msa_hadd_s_w(sums2[i], sums2[i]);
+    }
+    v8i16 sums4[Width / 2];
+    for (int i = 0; i < Width / 2; i++) {
+      sums4[i] = __builtin_msa_pckev_h(
+          reinterpret_cast<v8i16>(sums4_wide[2 * i + 1]),
+          reinterpret_cast<v8i16>(sums4_wide[2 * i]));
+    }
+    v4i32 sums8_wide[Width / 2];
+    for (int i = 0; i < Width / 2; i++) {
+      sums8_wide[i] = __builtin_msa_hadd_s_w(sums4[i], sums4[i]);
+    }
+    if (Width == 4) {
+      v4i32 sum = __builtin_msa_ld_w(const_cast<std::int32_t*>(sums_ptr), 0);
+      v8i16 sums8 = __builtin_msa_pckev_h(
+          reinterpret_cast<v8i16>(sums8_wide[1]),
+          reinterpret_cast<v8i16>(sums8_wide[0]));
+      v4i32 sums16 = __builtin_msa_hadd_s_w(sums8, sums8);
+      sum = __builtin_msa_addv_w(sum, sums16);
+      __builtin_msa_st_w(sum, sums_ptr, 0);
+    } else {
+      assert(Width == 2);
+      std::int32_t sum[2] = { sums_ptr[0], sums_ptr[1] };
+      v2i64 sums16 = __builtin_msa_hadd_s_d(sums8_wide[0], sums8_wide[0]);
+      sum[0] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 0);
+      sum[1] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 2);
+      sums_ptr[0] = sum[0];
+      sums_ptr[1] = sum[1];
+    }
+    dst->seek_forward_n_cells(1);
+  }
+};
+
 }  // namespace gemmlowp
 
 #endif  // GEMMLOWP_INTERNAL_PACK_MSA_H_
diff --git a/internal/pack_neon.h b/internal/pack_neon.h
index 2b08464..f113d9e 100644
--- a/internal/pack_neon.h
+++ b/internal/pack_neon.h
@@ -26,6 +26,9 @@
 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
     WidthMajorUint8SideMap;
 
+typedef SideMap<const std::int8_t, SideMapOrder::WidthMajor>
+    WidthMajorInt8SideMap;
+
 template <int Cells>
 using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>;
 
@@ -315,6 +318,67 @@
   }
 };
 
+template <int Width>
+using Int8InputsFastKernelFormat =
+    KernelSideFormatInt8Inputs<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>;
+
+// Same as above, but for int8 inputs, avoiding the uint8 -> int8 conversion.
+template <int Width>
+class PackingRegisterBlock<WidthMajorInt8SideMap,
+                           PackedSideBlock<Int8InputsFastKernelFormat<Width>>>
+    : public PackingRegisterBlockBase<
+          WidthMajorInt8SideMap,
+          PackedSideBlock<Int8InputsFastKernelFormat<Width>>> {
+ public:
+  static_assert(Width == 2 || Width == 4, "");
+  typedef Int8InputsFastKernelFormat<Width> KernelSideFormat;
+  typedef typename KernelSideFormat::Cell CellFormat;
+  static const int kCells = KernelSideFormat::kCells;
+  static const int kCellWidth = CellFormat::kWidth;
+  static const int kKernelWidth = CellFormat::kWidth * kCells;
+  static const int kCellDepth = CellFormat::kDepth;
+  static const int kCellSize = CellFormat::kSize;
+
+  void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
+    std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width;
+    std::int8_t* dst_ptr = reinterpret_cast<std::int8_t*>(dst->current_data());
+    const std::int8_t* const src_ptr = this->complete_src_.data();
+    const int stride = this->complete_src_.stride();
+    // Load source WidthMajor data
+    int8x16_t src_lines[Width];
+    for (int i = 0; i < Width; i++) {
+      src_lines[i] = vld1q_s8(src_ptr + i * stride);
+    }
+    for (int i = 0; i < Width; i++) {
+      vst1q_s8(dst_ptr + 16 * i, src_lines[i]);
+    }
+    int16x8_t sums2[Width];
+    for (int i = 0; i < Width; i++) {
+      const int8x8_t lo = vget_low_s8(src_lines[i]);
+      const int8x8_t hi = vget_high_s8(src_lines[i]);
+      sums2[i] = vaddl_s8(lo, hi);
+    }
+    int16x8_t sums4[Width / 2];
+    for (int i = 0; i < Width / 2; i++) {
+      sums4[i] = vpaddq_s16(sums2[2 * i], sums2[2 * i + 1]);
+    }
+    if (Width == 4) {
+      int32x4_t sum = vld1q_s32(sums_ptr);
+      int16x8_t sums8 = vpaddq_s16(sums4[0], sums4[1]);
+      sum = vpadalq_s16(sum, sums8);
+      vst1q_s32(sums_ptr, sum);
+    } else {
+      assert(Width == 2);
+      int32x2_t sum = vld1_s32(sums_ptr);
+      int16x4_t sums8 =
+          vpadd_s16(vget_low_s16(sums4[0]), vget_high_s16(sums4[0]));
+      sum = vpadal_s16(sum, sums8);
+      vst1_s32(sums_ptr, sum);
+    }
+    dst->seek_forward_n_cells(1);
+  }
+};
+
 }  // namespace gemmlowp
 
 #endif  // GEMMLOWP_INTERNAL_PACK_NEON_H_
diff --git a/internal/platform.h b/internal/platform.h
index 1114767..ab71414 100644
--- a/internal/platform.h
+++ b/internal/platform.h
@@ -18,6 +18,7 @@
 #define GEMMLOWP_INTERNAL_PLATFORM_H_
 
 #ifdef _WIN32
+#include <malloc.h>
 #include <windows.h>
 #else
 #include <stdlib.h>
@@ -71,8 +72,8 @@
 inline double real_time_in_seconds() {
   __int64 wintime;
   GetSystemTimeAsFileTime((FILETIME *)&wintime);
-  wintime -= 116444736000000000i64;  // 1jan1601 to 1jan1970
-  return wintime / 10000000i64 + wintime % 10000000i64 * 100 * 1e-9;
+  wintime -= 116444736000000000LL;  // 1jan1601 to 1jan1970
+  return wintime / 10000000LL + wintime % 10000000LL * 100 * 1e-9;
 }
 
 #else
diff --git a/internal/simd_wrappers.h b/internal/simd_wrappers.h
index d9721c9..4e4cce8 100644
--- a/internal/simd_wrappers.h
+++ b/internal/simd_wrappers.h
@@ -105,10 +105,12 @@
   using FlippedRhsType = RhsType;
   static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
                                           const RhsType& rhs) {
+    (void)rhs;
     return lhs;
   }
   static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
                                           const RhsType& rhs) {
+    (void)lhs;
     return rhs;
   }
 };
@@ -119,10 +121,12 @@
   using FlippedRhsType = LhsType;
   static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
                                           const RhsType& rhs) {
+    (void)lhs;
     return rhs;
   }
   static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
                                           const RhsType& rhs) {
+    (void)rhs;
     return lhs;
   }
 };
@@ -192,6 +196,153 @@
 }
 
 template <typename Lhs, typename Rhs>
+struct BroadcastShiftLeftImpl {
+  using ResultBlockType =
+      typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
+  static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
+    ResultBlockType result;
+    static constexpr int Rows = ResultBlockType::kRows;
+    static constexpr int Cols = ResultBlockType::kCols;
+    static constexpr int LhsRows = Lhs::kRows;
+    static constexpr int LhsCols = Lhs::kCols;
+    static constexpr int RhsRows = Rhs::kRows;
+    static constexpr int RhsCols = Rhs::kCols;
+
+    static_assert(LhsRows == Rows || LhsRows == 1, "");
+    static_assert(RhsRows == Rows || RhsRows == 1, "");
+    static_assert(LhsCols == Cols || LhsCols == 1, "");
+    static_assert(RhsCols == Cols || RhsCols == 1, "");
+    static_assert(ResultBlockType::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+    static_assert(Lhs::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+    static_assert(Rhs::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+
+    for (int c = 0; c < Cols; c++) {
+      const int lhs_c = LhsCols == Cols ? c : 0;
+      const int rhs_c = RhsCols == Cols ? c : 0;
+      for (int r = 0; r < Rows; r++) {
+        const int lhs_r = LhsRows == Rows ? r : 0;
+        const int rhs_r = RhsRows == Rows ? r : 0;
+        result.buf.reg[r + c * Rows] =
+            ShiftLeft(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
+                      rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
+      }
+    }
+    return result;
+  }
+};
+
+template <typename Lhs, typename Rhs>
+typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastShiftLeft(
+    const Lhs& lhs, const Rhs& rhs) {
+  using Flip = FlipLhsRhs<Lhs, Rhs>;
+  return BroadcastShiftLeftImpl<
+      typename Flip::FlippedLhsType,
+      typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
+                                          Flip::FlippedRhs(lhs, rhs));
+}
+
+template <typename Lhs, typename Rhs>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl {
+  using ResultBlockType =
+      typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
+  static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
+    ResultBlockType result;
+    static constexpr int Rows = ResultBlockType::kRows;
+    static constexpr int Cols = ResultBlockType::kCols;
+    static constexpr int LhsRows = Lhs::kRows;
+    static constexpr int LhsCols = Lhs::kCols;
+    static constexpr int RhsRows = Rhs::kRows;
+    static constexpr int RhsCols = Rhs::kCols;
+
+    static_assert(LhsRows == Rows || LhsRows == 1, "");
+    static_assert(RhsRows == Rows || RhsRows == 1, "");
+    static_assert(LhsCols == Cols || LhsCols == 1, "");
+    static_assert(RhsCols == Cols || RhsCols == 1, "");
+    static_assert(ResultBlockType::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+    static_assert(Lhs::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+    static_assert(Rhs::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+
+    for (int c = 0; c < Cols; c++) {
+      const int lhs_c = LhsCols == Cols ? c : 0;
+      const int rhs_c = RhsCols == Cols ? c : 0;
+      for (int r = 0; r < Rows; r++) {
+        const int lhs_r = LhsRows == Rows ? r : 0;
+        const int rhs_r = RhsRows == Rows ? r : 0;
+        result.buf.reg[r + c * Rows] = SaturatingRoundingDoublingHighMul(
+            lhs.buf.reg[lhs_r + lhs_c * LhsRows],
+            rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
+      }
+    }
+    return result;
+  }
+};
+
+template <typename Lhs, typename Rhs>
+typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
+BroadcastSaturatingRoundingDoublingHighMul(const Lhs& lhs, const Rhs& rhs) {
+  using Flip = FlipLhsRhs<Lhs, Rhs>;
+  return BroadcastSaturatingRoundingDoublingHighMulImpl<
+      typename Flip::FlippedLhsType,
+      typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
+                                          Flip::FlippedRhs(lhs, rhs));
+}
+
+template <typename Lhs, typename Rhs>
+struct BroadcastRoundingDivideByPOTImpl {
+  using ResultBlockType =
+      typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
+  static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
+    ResultBlockType result;
+    static constexpr int Rows = ResultBlockType::kRows;
+    static constexpr int Cols = ResultBlockType::kCols;
+    static constexpr int LhsRows = Lhs::kRows;
+    static constexpr int LhsCols = Lhs::kCols;
+    static constexpr int RhsRows = Rhs::kRows;
+    static constexpr int RhsCols = Rhs::kCols;
+
+    static_assert(LhsRows == Rows || LhsRows == 1, "");
+    static_assert(RhsRows == Rows || RhsRows == 1, "");
+    static_assert(LhsCols == Cols || LhsCols == 1, "");
+    static_assert(RhsCols == Cols || RhsCols == 1, "");
+    static_assert(ResultBlockType::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+    static_assert(Lhs::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+    static_assert(Rhs::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+
+    for (int c = 0; c < Cols; c++) {
+      const int lhs_c = LhsCols == Cols ? c : 0;
+      const int rhs_c = RhsCols == Cols ? c : 0;
+      for (int r = 0; r < Rows; r++) {
+        const int lhs_r = LhsRows == Rows ? r : 0;
+        const int rhs_r = RhsRows == Rows ? r : 0;
+        result.buf.reg[r + c * Rows] =
+            RoundingDivideByPOT(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
+                                rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
+      }
+    }
+    return result;
+  }
+};
+
+template <typename Lhs, typename Rhs>
+typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
+BroadcastRoundingDivideByPOT(const Lhs& lhs, const Rhs& rhs) {
+  using Flip = FlipLhsRhs<Lhs, Rhs>;
+  return BroadcastRoundingDivideByPOTImpl<
+      typename Flip::FlippedLhsType,
+      typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
+                                          Flip::FlippedRhs(lhs, rhs));
+}
+
+template <typename Lhs, typename Rhs>
 struct BroadcastMulImpl {
   using ResultBlockType =
       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
@@ -494,12 +645,16 @@
 using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
 template <int N>
 using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
+template <int N>
+using RegBufferInt8 = RegisterBuffer<std::int8_t, N>;
 template <int R, int C>
 using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
 template <int R, int C>
 using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
 template <int R, int C>
 using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
+template <int R, int C>
+using RegBlockInt8 = RegisterBlock<std::int8_t, R, C>;
 
 }  // end namespace gemmlowp
 
diff --git a/internal/simd_wrappers_common_neon_sse.h b/internal/simd_wrappers_common_neon_sse.h
index 3830eb1..694bf99 100644
--- a/internal/simd_wrappers_common_neon_sse.h
+++ b/internal/simd_wrappers_common_neon_sse.h
@@ -350,6 +350,210 @@
   }
 };
 
+// 4x1 := 4x1 + 1x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>,
+                                                      RegBlockInt32<1, 1>> {
+  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<4, 1> result;
+    result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 1x4 := 1x4 + 1x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>,
+                                                      RegBlockInt32<1, 1>> {
+  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<1, 4> result;
+    result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 4x1 := 4x1 + 4x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>,
+                                                      RegBlockInt32<4, 1>> {
+  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+                                 const RegBlockInt32<4, 1>& rhs) {
+    RegBlockInt32<4, 1> result;
+    result.buf.reg[0] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 1x4 := 1x4 + 1x4
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>,
+                                                      RegBlockInt32<1, 4>> {
+  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<1, 4> result;
+    result.buf.reg[0] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 4x4 := 4x4 + 1x4
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>,
+                                                      RegBlockInt32<1, 4>> {
+  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<4, 4> result;
+    result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+    result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
+    result.buf.reg[2] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
+    result.buf.reg[3] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 4x4 := 4x4 + 4x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>,
+                                                      RegBlockInt32<4, 1>> {
+  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+                                 const RegBlockInt32<4, 1>& rhs) {
+    RegBlockInt32<4, 4> result;
+    result.buf.reg[0] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+    result.buf.reg[1] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[0]);
+    result.buf.reg[2] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]);
+    result.buf.reg[3] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 8x1 := 8x1 + 1x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>,
+                                                      RegBlockInt32<1, 1>> {
+  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<8, 1> result;
+    const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
+    for (int i = 0; i < 2; i++) {
+      result.buf.reg[i] = SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], p);
+    }
+    return result;
+  }
+};
+
+// 8x1 := 8x1 + 8x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>,
+                                                      RegBlockInt32<8, 1>> {
+  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+                                 const RegBlockInt32<8, 1>& rhs) {
+    RegBlockInt32<8, 1> result;
+    for (int i = 0; i < 2; i++) {
+      result.buf.reg[i] =
+          SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], rhs.buf.reg[i]);
+    }
+    return result;
+  }
+};
+
+// 8x4 := 8x4 + 1x4
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>,
+                                                      RegBlockInt32<1, 4>> {
+  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<8, 4> result;
+    result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+    result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
+    result.buf.reg[2] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
+    result.buf.reg[3] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
+    result.buf.reg[4] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
+    result.buf.reg[5] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
+    result.buf.reg[6] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
+    result.buf.reg[7] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 8x4 := 8x4 + 8x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>,
+                                                      RegBlockInt32<8, 1>> {
+  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+                                 const RegBlockInt32<8, 1>& rhs) {
+    RegBlockInt32<8, 4> result;
+    result.buf.reg[0] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+    result.buf.reg[1] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]);
+    result.buf.reg[2] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]);
+    result.buf.reg[3] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[1]);
+    result.buf.reg[4] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[4], rhs.buf.reg[0]);
+    result.buf.reg[5] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[5], rhs.buf.reg[1]);
+    result.buf.reg[6] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[6], rhs.buf.reg[0]);
+    result.buf.reg[7] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[7], rhs.buf.reg[1]);
+    return result;
+  }
+};
+
+// 1x8 := 1x8 + 1x8
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>,
+                                                      RegBlockInt32<1, 8>> {
+  static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+                                 const RegBlockInt32<1, 8>& rhs) {
+    RegBlockInt32<1, 8> result;
+    result.buf.reg[0] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+    result.buf.reg[1] =
+        SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]);
+    return result;
+  }
+};
+
+// 1x8 := 1x8 + 1x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>,
+                                                      RegBlockInt32<1, 1>> {
+  static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<1, 8> result;
+    result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
+        lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
 // 4x1 := 4x1 * 1x1
 template <>
 struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
diff --git a/internal/simd_wrappers_msa.h b/internal/simd_wrappers_msa.h
index cf5e8e9..7de01ff 100644
--- a/internal/simd_wrappers_msa.h
+++ b/internal/simd_wrappers_msa.h
@@ -33,8 +33,7 @@
 
 template <int ScalarCount>
 struct RegisterType<std::int16_t, ScalarCount> {
-  using Type =
-      typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type;
+  using Type = typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type;
 };
 
 template <int ScalarCount>
@@ -69,13 +68,9 @@
   return __builtin_msa_ld_h(const_cast<Int16x8*>(src), 0);
 }
 
-inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
-  __builtin_msa_st_h(value, dst, 0);
-}
+inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) { __builtin_msa_st_h(value, dst, 0); }
 
-inline void StoreInt16x8(Int16x8* dst, Int16x8 value) {
-  __builtin_msa_st_h(value, dst, 0);
-}
+inline void StoreInt16x8(Int16x8* dst, Int16x8 value) { __builtin_msa_st_h(value, dst, 0); }
 
 inline Uint8x16 LoadUint8x16(const std::uint8_t* src) {
   return __builtin_msa_ld_b(const_cast<std::uint8_t*>(src), 0);
diff --git a/internal/simd_wrappers_neon.h b/internal/simd_wrappers_neon.h
index 2949173..6871055 100644
--- a/internal/simd_wrappers_neon.h
+++ b/internal/simd_wrappers_neon.h
@@ -25,6 +25,7 @@
 using Int16x4 = int16x4_t;
 using Int16x8 = int16x8_t;
 using Uint8x8 = uint8x8_t;
+using Int8x8 = int8x8_t;
 
 template <int ScalarCount>
 struct RegisterType<std::int32_t, ScalarCount> {
@@ -48,6 +49,14 @@
                                 std::uint8_t>::type>::type;
 };
 
+template <int ScalarCount>
+struct RegisterType<std::int8_t, ScalarCount> {
+  using Type = typename std::conditional<
+      ScalarCount >= 8, Int8x8,
+      typename std::conditional<ScalarCount >= 4, std::int32_t,
+                                std::int8_t>::type>::type;
+};
+
 inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); }
 inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); }
 inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); }
@@ -92,6 +101,10 @@
 
 inline Int32x4 Max(Int32x4 a, Int32x4 b) { return vmaxq_s32(a, b); }
 
+inline Int32x4 Max(Int32x4 a, std::int32_t b) {
+  return vmaxq_s32(a, vdupq_n_s32(b));
+}
+
 inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) {
   return vqrdmulhq_n_s32(a, b);
 }
@@ -164,6 +177,17 @@
 };
 
 template <>
+struct LoadContiguousImpl<RegBlockInt8<8, 8>> {
+  static RegBlockInt8<8, 8> Run(const std::int8_t* src) {
+    RegBlockInt8<8, 8> result;
+    for (int i = 0; i < 8; i++) {
+      result.buf.reg[i] = vld1_s8(src + 8 * i);
+    }
+    return result;
+  }
+};
+
+template <>
 struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
   static RegBlockInt32<8, 8> Run(const std::int32_t* src) {
     RegBlockInt32<8, 8> result;
@@ -174,6 +198,352 @@
   }
 };
 
+// 4x1 := 4x1 + 1x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
+  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<4, 1> result;
+    result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 1x4 := 1x4 + 1x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
+  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<1, 4> result;
+    result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 4x1 := 4x1 + 4x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
+  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+                                 const RegBlockInt32<4, 1>& rhs) {
+    RegBlockInt32<4, 1> result;
+    result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 1x4 := 1x4 + 1x4
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
+  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<1, 4> result;
+    result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 4x4 := 4x4 + 1x4
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
+  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<4, 4> result;
+    result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+    result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
+    result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
+    result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 4x4 := 4x4 + 4x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
+  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+                                 const RegBlockInt32<4, 1>& rhs) {
+    RegBlockInt32<4, 4> result;
+    result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
+    result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[0]);
+    result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]);
+    result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 8x1 := 8x1 + 1x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
+  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<8, 1> result;
+    const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
+    for (int i = 0; i < 2; i++) {
+      result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], p);
+    }
+    return result;
+  }
+};
+
+// 8x1 := 8x1 + 8x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
+  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+                                 const RegBlockInt32<8, 1>& rhs) {
+    RegBlockInt32<8, 1> result;
+    for (int i = 0; i < 2; i++) {
+      result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], rhs.buf.reg[i]);
+    }
+    return result;
+  }
+};
+
+// 8x4 := 8x4 + 1x4
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
+  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<8, 4> result;
+    result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+    result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
+    result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
+    result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
+    result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
+    result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
+    result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
+    result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 8x4 := 8x4 + 8x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
+  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+                                 const RegBlockInt32<8, 1>& rhs) {
+    RegBlockInt32<8, 4> result;
+    result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
+    result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]);
+    result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]);
+    result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[1]);
+    result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], rhs.buf.reg[0]);
+    result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], rhs.buf.reg[1]);
+    result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], rhs.buf.reg[0]);
+    result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], rhs.buf.reg[1]);
+    return result;
+  }
+};
+
+// 1x8 := 1x8 + 1x8
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> {
+  static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+                                 const RegBlockInt32<1, 8>& rhs) {
+    RegBlockInt32<1, 8> result;
+    result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
+    result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]);
+    return result;
+  }
+};
+
+// 1x8 := 1x8 + 1x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> {
+  static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<1, 8> result;
+    result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 4x1 := 4x1 + 1x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 1>,
+                                        RegBlockInt32<1, 1>> {
+  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<4, 1> result;
+    result.buf.reg[0] =
+        RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 1x4 := 1x4 + 1x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 4>,
+                                        RegBlockInt32<1, 1>> {
+  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<1, 4> result;
+    result.buf.reg[0] =
+        RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 4x1 := 4x1 + 4x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 1>,
+                                        RegBlockInt32<4, 1>> {
+  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+                                 const RegBlockInt32<4, 1>& rhs) {
+    RegBlockInt32<4, 1> result;
+    result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 1x4 := 1x4 + 1x4
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 4>,
+                                        RegBlockInt32<1, 4>> {
+  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<1, 4> result;
+    result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 4x4 := 4x4 + 1x4
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 4>,
+                                        RegBlockInt32<1, 4>> {
+  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<4, 4> result;
+    result.buf.reg[0] =
+        RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+    result.buf.reg[1] =
+        RoundingDivideByPOT(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
+    result.buf.reg[2] =
+        RoundingDivideByPOT(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
+    result.buf.reg[3] =
+        RoundingDivideByPOT(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 4x4 := 4x4 + 4x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 4>,
+                                        RegBlockInt32<4, 1>> {
+  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+                                 const RegBlockInt32<4, 1>& rhs) {
+    RegBlockInt32<4, 4> result;
+    result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
+    result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[0]);
+    result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]);
+    result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 8x1 := 8x1 + 1x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 1>,
+                                        RegBlockInt32<1, 1>> {
+  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<8, 1> result;
+    const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
+    for (int i = 0; i < 2; i++) {
+      result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], p);
+    }
+    return result;
+  }
+};
+
+// 8x1 := 8x1 + 8x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 1>,
+                                        RegBlockInt32<8, 1>> {
+  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+                                 const RegBlockInt32<8, 1>& rhs) {
+    RegBlockInt32<8, 1> result;
+    for (int i = 0; i < 2; i++) {
+      result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], rhs.buf.reg[i]);
+    }
+    return result;
+  }
+};
+
+// 8x4 := 8x4 + 1x4
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 4>,
+                                        RegBlockInt32<1, 4>> {
+  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<8, 4> result;
+    result.buf.reg[0] =
+        RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+    result.buf.reg[1] =
+        RoundingDivideByPOT(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
+    result.buf.reg[2] =
+        RoundingDivideByPOT(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
+    result.buf.reg[3] =
+        RoundingDivideByPOT(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
+    result.buf.reg[4] =
+        RoundingDivideByPOT(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
+    result.buf.reg[5] =
+        RoundingDivideByPOT(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
+    result.buf.reg[6] =
+        RoundingDivideByPOT(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
+    result.buf.reg[7] =
+        RoundingDivideByPOT(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 8x4 := 8x4 + 8x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 4>,
+                                        RegBlockInt32<8, 1>> {
+  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+                                 const RegBlockInt32<8, 1>& rhs) {
+    RegBlockInt32<8, 4> result;
+    result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
+    result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]);
+    result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]);
+    result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[1]);
+    result.buf.reg[4] = RoundingDivideByPOT(lhs.buf.reg[4], rhs.buf.reg[0]);
+    result.buf.reg[5] = RoundingDivideByPOT(lhs.buf.reg[5], rhs.buf.reg[1]);
+    result.buf.reg[6] = RoundingDivideByPOT(lhs.buf.reg[6], rhs.buf.reg[0]);
+    result.buf.reg[7] = RoundingDivideByPOT(lhs.buf.reg[7], rhs.buf.reg[1]);
+    return result;
+  }
+};
+
+// 1x8 := 1x8 + 1x8
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 8>,
+                                        RegBlockInt32<1, 8>> {
+  static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+                                 const RegBlockInt32<1, 8>& rhs) {
+    RegBlockInt32<1, 8> result;
+    result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
+    result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]);
+    return result;
+  }
+};
+
+// 1x8 := 1x8 + 1x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 8>,
+                                        RegBlockInt32<1, 1>> {
+  static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<1, 8> result;
+    result.buf.reg[0] =
+        RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    result.buf.reg[1] =
+        RoundingDivideByPOT(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
 }  // end namespace gemmlowp
 
 #include "simd_wrappers_common_neon_sse.h"
diff --git a/internal/unpack.h b/internal/unpack.h
index 33aee13..021f4aa 100644
--- a/internal/unpack.h
+++ b/internal/unpack.h
@@ -98,12 +98,14 @@
                        const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
                        int depth, int src_row, int src_col, int src_global_row,
                        int src_global_col, int dst_row, int dst_col) {
+  using KernelLhsInputScalar = typename KernelFormat::Lhs::InputScalar;
   using KernelLhsScalar = typename KernelFormat::Lhs::Scalar;
+  using KernelRhsInputScalar = typename KernelFormat::Rhs::InputScalar;
   using KernelRhsScalar = typename KernelFormat::Rhs::Scalar;
   static constexpr int KernelLhsZeroPointInput =
-      ZeroPointInputValue<KernelLhsScalar>::kValue;
+      ZeroPointInputValue<KernelLhsInputScalar, KernelLhsScalar>::kValue;
   static constexpr int KernelRhsZeroPointInput =
-      ZeroPointInputValue<KernelRhsScalar>::kValue;
+      ZeroPointInputValue<KernelRhsInputScalar, KernelRhsScalar>::kValue;
   auto acc = Load<RegisterBlockType>(src, src_row, src_col);
   const auto& lhs_sums_of_each_slice_block =
       LoadForBroadcasting<RegisterBlockType>(lhs_sums_of_each_slice, src_row);
diff --git a/meta/multi_thread_common.h b/meta/multi_thread_common.h
index 0b35759..b39c3f2 100644
--- a/meta/multi_thread_common.h
+++ b/meta/multi_thread_common.h
@@ -22,9 +22,15 @@
 
 inline int ResolveMaxThreads(int max_threads) {
   if (max_threads == 0) {
+#ifdef _WIN32
+    SYSTEM_INFO sysinfo;
+    GetSystemInfo(&sysinfo);
+    return sysinfo.dwNumberOfProcessors;
+#else
     static const int hardware_threads_count =
         static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
     return hardware_threads_count;
+#endif
   }
   return max_threads;
 }
diff --git a/profiling/instrumentation.h b/profiling/instrumentation.h
index 437fe54..c1f852e 100644
--- a/profiling/instrumentation.h
+++ b/profiling/instrumentation.h
@@ -108,13 +108,14 @@
 // contains pointers to literal strings that were manually entered
 // in the instrumented code (see ScopedProfilingLabel).
 struct ProfilingStack {
-  static const std::size_t kMaxSize = 14;
+  static const std::size_t kMaxSize = 30;
   typedef const char* LabelsArrayType[kMaxSize];
   LabelsArrayType labels;
   std::size_t size;
   Mutex* lock;
 
   ProfilingStack() { memset(this, 0, sizeof(ProfilingStack)); }
+  ~ProfilingStack() { delete lock; }
 
   void Push(const char* label) {
     ScopedLock sl(lock);
@@ -171,8 +172,6 @@
     ScopedLock sl(GlobalMutexes::Profiler());
     ThreadInfo* self = static_cast<ThreadInfo*>(ptr);
     ThreadsUnderProfiling().erase(self);
-    pthread_key_delete(self->key);
-    delete self->stack.lock;
   }
 };
 
@@ -185,7 +184,11 @@
     }
   };
 
-  static int key_result = pthread_key_create(&key, DeleteThreadInfo);
+  // key_result is unused. The purpose of this 'static' local object is
+  // to have its initializer (the pthread_key_create call) performed exactly
+  // once, in a way that is guaranteed (since C++11) to be reentrant.
+  static const int key_result = pthread_key_create(&key, DeleteThreadInfo);
+  (void)key_result;
 
   ThreadInfo* threadInfo = static_cast<ThreadInfo*>(pthread_getspecific(key));
   if (!threadInfo) {
diff --git a/profiling/pthread_everywhere.h b/profiling/pthread_everywhere.h
index df17c6f..2569bbc 100644
--- a/profiling/pthread_everywhere.h
+++ b/profiling/pthread_everywhere.h
@@ -60,6 +60,9 @@
   *cond = new std::condition_variable;
 }
 inline void pthread_cond_signal(pthread_cond_t *cond) { (*cond)->notify_one(); }
+inline void pthread_cond_broadcast(pthread_cond_t *cond) {
+  (*cond)->notify_all();
+}
 inline void pthread_cond_wait(pthread_cond_t *cond, pthread_mutex_t *mutex) {
   std::unique_lock<std::mutex> lock(**mutex, std::adopt_lock);
   (*cond)->wait(lock);
diff --git a/public/bit_depth.h b/public/bit_depth.h
index 6cb4ecf..412944e 100644
--- a/public/bit_depth.h
+++ b/public/bit_depth.h
@@ -24,14 +24,15 @@
 struct OperandRange {
   static const int kMinValue = tMinValue;
   static const int kMaxValue = tMaxValue;
-  static_assert(0 <= kMinValue, "");
   static_assert(kMinValue < kMaxValue, "");
-  static_assert(kMaxValue <= 255, "");
 };
 
 using Uint8Range = OperandRange<0, 255>;
 using Uint8RangeExcludingZero = OperandRange<1, 255>;
 
+using Int8Range = OperandRange<-128, 127>;
+using Int8RangeExcludingLow = OperandRange<-127, 127>;
+
 template <typename tLhsRange, typename tRhsRange>
 struct BitDepthParams {
   using LhsRange = tLhsRange;
@@ -47,6 +48,11 @@
 using L8R8WithLhsNonzeroBitDepthParams =
     BitDepthParams<Uint8RangeExcludingZero, Uint8Range>;
 
+// Signed Variant: This allows using faster kernels using signed arithmetic, see
+// NEON_64bit_GEMM_Int8Operands_Int32Accumulators_AccumTwoWithin16Bits
+using SignedL8R8WithLhsNonzeroBitDepthParams =
+    BitDepthParams<Int8RangeExcludingLow, Int8Range>;
+
 // Deprecated: when gemmlowp used to allow requantizing 8bit
 // inputs to less-than-8-bit depths, the public setting allowing
 // that was DefaultL7R5BitDepthParams. That requantization
diff --git a/public/map.h b/public/map.h
index 3073e05..fe6bc5c 100644
--- a/public/map.h
+++ b/public/map.h
@@ -131,6 +131,7 @@
     assert(start >= 0);
     assert(start + len <= size_);
 
+    (void)start;
     return VectorDup(data_, len);
   }
 };
diff --git a/public/output_stages.h b/public/output_stages.h
index 1d5fca4..797b662 100644
--- a/public/output_stages.h
+++ b/public/output_stages.h
@@ -138,12 +138,44 @@
   std::int32_t result_offset_after_shift;
 };
 
+// Variant of OutputStageQuantizeDownInt32ByFixedPoint where the 'shift'
+// is not necessarily just a right shift, so we can represent multipliers
+// greater than 1. This takes an result_exponent parameter; when it's
+// <= 0, this is equivalent to OutputStageQuantizeDownInt32ByFixedPoint
+// with result_shift = -result_exponent.
+// In the general case, this consists in first left-shifting by
+// std::max(result_exponent, 0), before doing the same as
+// OutputStageQuantizeDownInt32ByFixedPoint with
+// result_shift = std::max(-result_exponent, 0).
+//
+// Difference from OutputStageScaleInt32ByFixedPointAndExponent here is that
+// each row or column of the output (depending on tShape) has its own
+// result_fixedpoint_multiplier and result_exponent numbers.
+template <VectorShape tShape>
+struct OutputStageScaleInt32ByFixedPointAndExponentPC {
+  VectorMap<const std::int32_t, tShape> result_fixedpoint_multiplier;
+  VectorMap<const std::int32_t, tShape> result_exponent;
+  std::int32_t result_offset_after_shift;
+};
+
 // This output stage takes int32 values that are expected to be already
 // on the final uint8 scale, but not necessarily in the [0..255] range.
 // It clamps them to the [0..255] range and returns them casted to uint8.
 struct OutputStageSaturatingCastToUint8 {};
 
 // This output stage takes int32 values that are expected to be already
+// on the final int8 scale, but not necessarily in the [-128..127] range.
+// It clamps them to the [-128..127] range and returns them casted to int8.
+struct OutputStageSaturatingCastToInt8 {};
+
+// This output stage takes int32 values that are expected to be already
+// in the [0..255] range and returns them casted to uint8.
+// This stage can save time if used instead of the
+// OutputStageSaturatingCastToUint8 stage immediately after the
+// OutputStageClamp stage.
+struct OutputStageTruncatingCastToUint8 {};
+
+// This output stage takes int32 values that are expected to be already
 // on the final int16 scale, but not necessarily in the [-32768..32767] range.
 // It clamps them to the [-32768..32767] range and returns them casted to int16.
 struct OutputStageSaturatingCastToInt16 {};