Snap for 7919077 from 0cdefb826c9c26f407aa89f380987a0eda162ff4 to mainline-permission-release

Change-Id: Ie3f6a0217aaa4135ac90eaed4bcf1c401f386ef2
diff --git a/Android.bp b/Android.bp
index c2445e6..f74fc34 100644
--- a/Android.bp
+++ b/Android.bp
@@ -50,6 +50,7 @@
         "ruy/context_get_ctx.cc",
         "ruy/cpuinfo.cc",
         "ruy/ctx.cc",
+        "ruy/denormal.cc",
         "ruy/frontend.cc",
         "ruy/have_built_path_for_avx.cc",
         "ruy/have_built_path_for_avx2_fma.cc",
diff --git a/cmake/bazel_to_cmake.py b/cmake/bazel_to_cmake.py
index ba1a38b..8f972ba 100755
--- a/cmake/bazel_to_cmake.py
+++ b/cmake/bazel_to_cmake.py
@@ -49,88 +49,92 @@
     ['selects.config_setting_group', 'config_setting_group'],
     ['@com_google_googletest//:gtest', 'gtest'],
     ['@com_google_googletest//:gtest_main', 'gtest_main'],
-    ['@cpuinfo//:cpuinfo_with_unstripped_include_path', 'cpuinfo'],
+    ['@cpuinfo', 'cpuinfo'],
 ]
 
 
 def preprocess_input_text(text):
-    result = text
-    for replacement in replacements:
-        result = result.replace(replacement[0], replacement[1])
-    return result
+  result = text
+  for replacement in replacements:
+    result = result.replace(replacement[0], replacement[1])
+  return result
 
 
 def set_cmake_list(list_name, values, indent):
-    semicolon_separated = ";".join(values)
-    print(f'{indent}set({list_name} "{semicolon_separated}")')
+  semicolon_separated = ';'.join(values)
+  print(f'{indent}set({list_name} "{semicolon_separated}")')
 
 
 def generate_cmake_select(select_name, dict):
-    new_if_branch_keyword = 'if'
-    default_value = []
-    for key in dict:
-        condition = ''
-        if key == '//conditions:default':
-            default_value = dict[key]
-            continue
-        elif re.search(r':windows$', key):
-            condition = 'CMAKE_SYSTEM_NAME STREQUAL Windows'
-        elif re.search(r':ppc$', key):
-            condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le'
-        elif re.search(r':s390x$', key):
-            condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR CMAKE_SYSTEM_PROCESSOR STREQUAL s390x'
-        elif re.search(r':fuchsia$', key):
-            condition = 'CMAKE_SYSTEM_NAME STREQUAL Fuchsia'
-        elif re.search(r':arm32_assuming_neon$', key):
-            condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL arm'
-        elif re.search(r':do_not_want_O3$', key):
-            # Ruy is a specialist library: we always want code to be compiled
-            # with -O3 unless the build type is Debug or the compiler does not
-            # support that flag syntax.
-            condition = '(CMAKE_BUILD_TYPE STREQUAL Debug) OR MSVC'
-        elif re.search(r':x86_64_and_not_msvc$', key):
-            condition = '(CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC'
-        elif re.search(r':windows_msvc$', key):
-            condition = 'MSVC'
-        elif re.search(r':ruy_profiler$', key):
-            condition = '${RUY_PROFILER}'
-        else:
-            raise ValueError(f'Unhandled key in select: {key}')
+  new_if_branch_keyword = 'if'
+  default_value = []
+  for key in dict:
+    condition = ''
+    if key == '//conditions:default':
+      default_value = dict[key]
+      continue
+    elif re.search(r':windows$', key):
+      condition = 'CMAKE_SYSTEM_NAME STREQUAL Windows'
+    elif re.search(r':ppc$', key):
+      condition = ('CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR '
+                   'CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le')
+    elif re.search(r':s390x$', key):
+      condition = ('CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR '
+                   'CMAKE_SYSTEM_PROCESSOR STREQUAL s390x')
+    elif re.search(r':fuchsia$', key):
+      condition = 'CMAKE_SYSTEM_NAME STREQUAL Fuchsia'
+    elif re.search(r':arm32_assuming_neon$', key):
+      condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL arm'
+    elif re.search(r':do_not_want_O3$', key):
+      # Ruy is a specialist library: we always want code to be compiled
+      # with -O3 unless the build type is Debug or the compiler does not
+      # support that flag syntax.
+      condition = '(CMAKE_BUILD_TYPE STREQUAL Debug) OR MSVC'
+    elif re.search(r':x86_64_and_not_msvc$', key):
+      condition = ('(CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR '
+                   'CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC')
+    elif re.search(r':windows_msvc$', key):
+      condition = 'MSVC'
+    elif re.search(r':ruy_profiler$', key):
+      condition = '${RUY_PROFILER}'
+    else:
+      raise ValueError(f'Unhandled key in select: {key}')
 
-        print(f'{new_if_branch_keyword}({condition})')
-        set_cmake_list(select_name, dict[key], '  ')
-        new_if_branch_keyword = 'elseif'
+    print(f'{new_if_branch_keyword}({condition})')
+    set_cmake_list(select_name, dict[key], '  ')
+    new_if_branch_keyword = 'elseif'
 
-    print('else()')
-    set_cmake_list(select_name, default_value, '  ')
+  print('else()')
+  set_cmake_list(select_name, default_value, '  ')
 
-    print('endif()\n')
+  print('endif()\n')
 
 
 def trim_multiple_ruy_prefixes(name):
-    return re.sub(r'(ruy_)+ruy', 'ruy', name)
+  return re.sub(r'(ruy_)+ruy', 'ruy', name)
+
 
 def get_cmake_local_target_name(name):
-    global package_prefix
-    return trim_multiple_ruy_prefixes(f'ruy_{package_prefix}_{name}')
+  global package_prefix
+  return trim_multiple_ruy_prefixes(f'ruy_{package_prefix}_{name}')
 
 
 def get_cmake_dep_target_name(name):
-    if name in external_targets:
-        return name
-    if name.startswith('$'):
-        # Happens for deps that are the result of expanding a select() that we
-        # have compiled to expanding a variable.
-        return name
-    if name.startswith('//'):
-        after_last_slash = name.split('/')[-1]
-        if not ':' in after_last_slash:
-            name = f'{name}:{after_last_slash}'
-        raw=name[2:].replace('/', '_').replace(':', '_')
-        return trim_multiple_ruy_prefixes(raw)
-    if name.startswith(':'):
-        name = name[1:]
-    return get_cmake_local_target_name(name)
+  if name in external_targets:
+    return name
+  if name.startswith('$'):
+    # Happens for deps that are the result of expanding a select() that we
+    # have compiled to expanding a variable.
+    return name
+  if name.startswith('//'):
+    after_last_slash = name.split('/')[-1]
+    if ':' not in after_last_slash:
+      name = f'{name}:{after_last_slash}'
+    raw = name[2:].replace('/', '_').replace(':', '_')
+    return trim_multiple_ruy_prefixes(raw)
+  if name.startswith(':'):
+    name = name[1:]
+  return get_cmake_local_target_name(name)
 
 
 #
@@ -139,45 +143,45 @@
 
 
 def package(**kwargs):
-    pass
+  pass
 
 
 def exports_files(*args):
-    pass
+  pass
 
 
 def load(filename, *args):
-    if filename.startswith('@'):
-        return
-    elif filename.startswith(':'):
-        filename = os.path.join(bazel_package_dir, filename[1:])
-    elif filename.startswith('//'):
-        split = filename[2:].split(':')
-        filename = os.path.join(bazel_workspace_dir, split[0], split[1])
+  if filename.startswith('@'):
+    return
+  elif filename.startswith(':'):
+    filename = os.path.join(bazel_package_dir, filename[1:])
+  elif filename.startswith('//'):
+    split = filename[2:].split(':')
+    filename = os.path.join(bazel_workspace_dir, split[0], split[1])
 
-    src_file_content = open(filename).read()
-    processed_file_content = preprocess_input_text(src_file_content)
-    exec(processed_file_content, globals(), globals())
+  src_file_content = open(filename).read()
+  processed_file_content = preprocess_input_text(src_file_content)
+  exec(processed_file_content, globals(), globals())
 
 
 def config_setting(**kwargs):
-    # Nothing to do since our implementation of select() is based on parsing
-    # the names of config_settings, not looking deep into their actual
-    # implementation.
-    pass
+  # Nothing to do since our implementation of select() is based on parsing
+  # the names of config_settings, not looking deep into their actual
+  # implementation.
+  pass
 
 
 def filegroup(**kwargs):
-    pass
+  pass
 
 
 def config_setting_group(**kwargs):
-    # See config_setting.
-    pass
+  # See config_setting.
+  pass
 
 
 def bzl_library(**kwargs):
-    pass
+  pass
 
 
 select_index = 0
@@ -185,95 +189,96 @@
 
 
 def select(select_dict):
-    global select_index
-    global select_cache
-    global package_prefix
-    key = pickle.dumps(sorted(select_dict.items()))
-    if key in select_cache:
-        select_name = select_cache[key]
-    else:
-        unique_values = sorted(set(itertools.chain.from_iterable(select_dict.values()))) # sorting ensures determinism, no spurious diffs
-        description = '_'.join(unique_values)
-        select_name = f'{package_prefix}_{select_index}_{description}'
-        select_name = select_name.replace('c++', 'cxx')
-        select_name = re.sub(r'[^a-zA-Z0-9]+', '_', select_name)
-        select_index = select_index + 1
-        select_cache[key] = select_name
-        generate_cmake_select(select_name, select_dict)
+  global select_index
+  global select_cache
+  global package_prefix
+  key = pickle.dumps(sorted(select_dict.items()))
+  if key in select_cache:
+    select_name = select_cache[key]
+  else:
+    unique_values = sorted(
+        set(itertools.chain.from_iterable(select_dict.values()))
+    )  # sorting ensures determinism, no spurious diffs
+    description = '_'.join(unique_values)
+    select_name = f'{package_prefix}_{select_index}_{description}'
+    select_name = select_name.replace('c++', 'cxx')
+    select_name = re.sub(r'[^a-zA-Z0-9]+', '_', select_name)
+    select_index = select_index + 1
+    select_cache[key] = select_name
+    generate_cmake_select(select_name, select_dict)
 
-    return [f'${{{select_name}}}']
+  return [f'${{{select_name}}}']
 
 
 def generic_rule(rule_name, **kwargs):
-    print(f'{rule_name}(')
-    for key in kwargs.keys():
-        values = kwargs[key]
-        if type(values) is bool:
-            if values:
-                print(f'  {key.upper()}')
-                continue
-            else:
-                raise ValueError(
-                    'Cannot specify FALSE boolean args in CMake')
-        if key == 'visibility':
-            if values == ['//visibility:public']:
-                print(f'  PUBLIC')
-            continue
-        if key == 'tags':
-            values = list(filter(lambda x : not x.startswith('req_dep'), values))
-        if not values:
-            continue
+  print(f'{rule_name}(')
+  for key in kwargs.keys():
+    values = kwargs[key]
+    if type(values) is bool:
+      if values:
         print(f'  {key.upper()}')
-        if type(values) is list:
-            for value in values:
-                if key == 'deps':
-                    target_name = get_cmake_dep_target_name(value)
-                    print(f'    {target_name}')
-                else:
-                    print(f'    {value}')
+        continue
+      else:
+        raise ValueError('Cannot specify FALSE boolean args in CMake')
+    if key == 'visibility':
+      if values == ['//visibility:public']:
+        print(f'  PUBLIC')
+      continue
+    if key == 'tags':
+      values = list(filter(lambda x: not x.startswith('req_dep'), values))
+    if not values:
+      continue
+    print(f'  {key.upper()}')
+    if type(values) is list:
+      for value in values:
+        if key == 'deps':
+          target_name = get_cmake_dep_target_name(value)
+          print(f'    {target_name}')
         else:
-            if key == 'name':
-                target_name = get_cmake_local_target_name(values)
-                print(f'    {target_name}')
-            else:
-                print(f'    {values}')
-    print(')\n')
+          print(f'    {value}')
+    else:
+      if key == 'name':
+        target_name = get_cmake_local_target_name(values)
+        print(f'    {target_name}')
+      else:
+        print(f'    {values}')
+  print(')\n')
 
 
 def cc_library(**kwargs):
-    generic_rule('ruy_cc_library', **kwargs)
+  generic_rule('ruy_cc_library', **kwargs)
 
 
 def cc_test(**kwargs):
-    generic_rule('ruy_cc_test', **kwargs)
+  generic_rule('ruy_cc_test', **kwargs)
 
 
 def cc_binary(**kwargs):
-    generic_rule('ruy_cc_binary', **kwargs)
+  generic_rule('ruy_cc_binary', **kwargs)
 
 
 #
 # Program entry point.
 #
 if __name__ == "__main__":
-    if len(sys.argv) != 3:
-        print("Usage: bazel_to_cmake.py bazel_workspace_dir bazel_package_dir")
-        sys.exit(1)
+  if len(sys.argv) != 3:
+    print('Usage: bazel_to_cmake.py bazel_workspace_dir bazel_package_dir')
+    sys.exit(1)
 
-    bazel_workspace_dir = sys.argv[1]
-    bazel_package_dir = sys.argv[2]
-    bazel_package_relative_dir = os.path.relpath(
-        bazel_package_dir, bazel_workspace_dir)
-    package_prefix = bazel_package_relative_dir.replace(os.path.sep, '_')
+  bazel_workspace_dir = sys.argv[1]
+  bazel_package_dir = sys.argv[2]
+  bazel_package_relative_dir = os.path.relpath(bazel_package_dir,
+                                               bazel_workspace_dir)
+  package_prefix = bazel_package_relative_dir.replace(os.path.sep, '_')
 
-    print("""# This file is generated (whence no license header). Do not edit!
+  print("""# This file is generated (whence no license header). Do not edit!
 # To regenerate, run:
 #   cmake/bazel_to_cmake.sh
 """)
 
-    src_build_file = os.path.join(bazel_package_dir, "BUILD")
-    src_build_content = open(src_build_file).read()
-    processed_build_content = preprocess_input_text(src_build_content)
-    exec(processed_build_content)
+  src_build_file = os.path.join(bazel_package_dir, 'BUILD')
+  src_build_content = open(src_build_file).read()
+  processed_build_content = preprocess_input_text(src_build_content)
+  exec(processed_build_content)
 
-    print("ruy_add_all_subdirs()")
+  print('ruy_add_all_subdirs()')
diff --git a/ruy/BUILD b/ruy/BUILD
index 37e89ab..d04a45d 100644
--- a/ruy/BUILD
+++ b/ruy/BUILD
@@ -357,6 +357,7 @@
     deps = [
         ":blocking_counter",
         ":check_macros",
+        ":denormal",
         ":time",
         ":trace",
         ":wait",
@@ -420,6 +421,14 @@
 )
 
 cc_library(
+    name = "denormal",
+    srcs = ["denormal.cc"],
+    hdrs = ["denormal.h"],
+    copts = ruy_copts(),
+    visibility = ["//visibility:public"],
+)
+
+cc_library(
     name = "performance_advisory",
     hdrs = ["performance_advisory.h"],
     copts = ruy_copts(),
@@ -956,6 +965,7 @@
         ":cpu_cache_params",
         ":cpuinfo",
         ":ctx",
+        ":denormal",
         ":mat",
         ":matrix",
         ":mul_params",
@@ -1195,6 +1205,22 @@
     ],
 )
 
+cc_test(
+    name = "test_overflow_dst_zero_point",
+    srcs = [
+        "test_overflow_dst_zero_point.cc",
+    ],
+    copts = ruy_copts(),
+    deps = [
+        ":gtest_wrapper",
+        ":matrix",
+        ":path",
+        ":ruy",
+        ":test_lib",
+        ":tune",
+    ],
+)
+
 bzl_library(
     name = "ruy_test_ext.oss_bzl",
     srcs = ["ruy_test_ext.oss.bzl"],
diff --git a/ruy/CMakeLists.txt b/ruy/CMakeLists.txt
index 4c3e394..502ad8a 100644
--- a/ruy/CMakeLists.txt
+++ b/ruy/CMakeLists.txt
@@ -376,6 +376,7 @@
   DEPS
     ruy_blocking_counter
     ruy_check_macros
+    ruy_denormal
     ruy_time
     ruy_trace
     ruy_wait
@@ -455,6 +456,20 @@
 
 ruy_cc_library(
   NAME
+    ruy_denormal
+  SRCS
+    denormal.cc
+  HDRS
+    denormal.h
+  COPTS
+    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_1_mfpu_neon}
+    ${ruy_2_O3}
+  PUBLIC
+)
+
+ruy_cc_library(
+  NAME
     ruy_performance_advisory
   HDRS
     performance_advisory.h
@@ -1102,6 +1117,7 @@
     ruy_cpu_cache_params
     ruy_cpuinfo
     ruy_ctx
+    ruy_denormal
     ruy_mat
     ruy_matrix
     ruy_mul_params
@@ -1693,4 +1709,22 @@
     slow
 )
 
+ruy_cc_test(
+  NAME
+    ruy_test_overflow_dst_zero_point
+  SRCS
+    test_overflow_dst_zero_point.cc
+  COPTS
+    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_1_mfpu_neon}
+    ${ruy_2_O3}
+  DEPS
+    ruy_gtest_wrapper
+    ruy_matrix
+    ruy
+    ruy_path
+    ruy_test_lib
+    ruy_tune
+)
+
 ruy_add_all_subdirs()
diff --git a/ruy/apply_multiplier.cc b/ruy/apply_multiplier.cc
index 19bfd88..b28c3b0 100644
--- a/ruy/apply_multiplier.cc
+++ b/ruy/apply_multiplier.cc
@@ -49,7 +49,6 @@
                                            std::int32_t quantized_multiplier,
                                            int shift) {
   RUY_CHECK_GE(shift, -31);
-  RUY_CHECK_LE(shift, 7);
 
   int total_shift = 31 - shift;
 
diff --git a/ruy/apply_multiplier_test.cc b/ruy/apply_multiplier_test.cc
index 2df80d7..ff4cb2c 100644
--- a/ruy/apply_multiplier_test.cc
+++ b/ruy/apply_multiplier_test.cc
@@ -104,14 +104,9 @@
 
 TEST(ApplyMultiplierTest, ApplyMultiplierUniform) {
   MulParams<std::int32_t, std::int8_t> mul_params;
-  // Test that default values give a multiplication by 1.
-  TestApplyMultiplier(mul_params, 0, 1000, 1000);
   mul_params.set_multiplier_fixedpoint(1 << 30);
   mul_params.set_multiplier_exponent(-1);
   TestApplyMultiplier(mul_params, 0, 1000, 250);
-  mul_params.set_multiplier_fixedpoint(1 << 25);
-  mul_params.set_multiplier_exponent(3);
-  TestApplyMultiplier(mul_params, 0, 1000, 125);
 }
 
 TEST(ApplyMultiplierTest, ApplyMultiplierPerChannel) {
diff --git a/ruy/block_map.cc b/ruy/block_map.cc
index 8240de2..e04e7af 100644
--- a/ruy/block_map.cc
+++ b/ruy/block_map.cc
@@ -17,6 +17,7 @@
 
 #include <algorithm>
 #include <cstdint>
+#include <limits>
 
 #ifdef RUY_MAKEBLOCKMAP_DEBUG
 #include <cstdio>
@@ -330,7 +331,7 @@
   // as that requires knowing the kernel block layout. Since we just want
   // a coarse estimate with only the guarantee that if we return `true` then
   // linear traversal will be used, it is OK here to over-estimate `rows` and
-  // `cols`, by omitting to divide them by the rectangularness factors.ß
+  // `cols`, by omitting to divide them by the rectangularness factors.
   return GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size,
                            cpu_cache_params) == BlockMapTraversalOrder::kLinear;
 }
diff --git a/ruy/context.cc b/ruy/context.cc
index 4661738..342ce52 100644
--- a/ruy/context.cc
+++ b/ruy/context.cc
@@ -55,4 +55,9 @@
   mutable_ctx()->SetRuntimeEnabledPaths(paths);
 }
 
+Path Context::get_runtime_enabled_paths() {
+  // The `& kAllPaths` hides internal test-only paths.
+  return mutable_ctx()->GetRuntimeEnabledPaths() & ruy::kAllPaths;
+}
+
 }  // namespace ruy
diff --git a/ruy/context.h b/ruy/context.h
index 79a4b5c..f148f0f 100644
--- a/ruy/context.h
+++ b/ruy/context.h
@@ -90,6 +90,9 @@
   // Paths in kNonArchPaths are always implicitly supported.
   void set_runtime_enabled_paths(Path paths);
 
+  // Returns the set of Path's that are available.
+  Path get_runtime_enabled_paths();
+
  private:
   CtxImpl* const impl_;
 
diff --git a/ruy/cpuinfo.cc b/ruy/cpuinfo.cc
index b1f54bc..a3e75d7 100644
--- a/ruy/cpuinfo.cc
+++ b/ruy/cpuinfo.cc
@@ -133,6 +133,17 @@
   }
 }
 
+bool CpuInfo::CurrentCpuIsX1() {
+  if (!EnsureInitialized()) {
+    return false;
+  }
+  if (cpuinfo_get_uarch(cpuinfo_get_current_uarch_index())->uarch ==
+      cpuinfo_uarch_cortex_x1) {
+    return true;
+  }
+  return false;
+}
+
 #else  // not defined RUY_HAVE_CPUINFO
 
 CpuInfo::~CpuInfo() {}
@@ -151,6 +162,7 @@
 bool CpuInfo::Avx512() { return false; }
 bool CpuInfo::AvxVnni() { return false; }
 bool CpuInfo::CurrentCpuIsA55ish() { return false; }
+bool CpuInfo::CurrentCpuIsX1() { return false; }
 
 #endif
 
diff --git a/ruy/cpuinfo.h b/ruy/cpuinfo.h
index e45fa51..2c7bc6a 100644
--- a/ruy/cpuinfo.h
+++ b/ruy/cpuinfo.h
@@ -39,6 +39,7 @@
   // Common features
   const CpuCacheParams& CacheParams();
   bool CurrentCpuIsA55ish();
+  bool CurrentCpuIsX1();
 
  private:
   enum class InitStatus {
diff --git a/ruy/denormal.cc b/ruy/denormal.cc
new file mode 100644
index 0000000..35bb739
--- /dev/null
+++ b/ruy/denormal.cc
@@ -0,0 +1,121 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/denormal.h"
+
+// NOTE: this is simply a copy of pthreadpool/src/threadpool-utils.h that's not
+// exposed by the pthreadpool library
+// (https://github.com/Maratyszcza/pthreadpool), but with an additional C++
+// helper class to suppress floating-point denormal values.
+
+/* SSE-specific headers */
+#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
+    (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
+#include <xmmintrin.h>
+#endif
+
+/* MSVC-specific headers */
+#if defined(_MSC_VER)
+#include <intrin.h>
+#endif
+
+namespace ruy {
+namespace {
+inline struct fpu_state get_fpu_state() {
+  struct fpu_state state = {};
+#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
+    (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
+  state.mxcsr = static_cast<std::uint32_t>(_mm_getcsr());
+#elif defined(_MSC_VER) && defined(_M_ARM)
+  state.fpscr =
+      static_cast<std::uint32_t>(_MoveFromCoprocessor(10, 7, 1, 0, 0));
+#elif defined(_MSC_VER) && defined(_M_ARM64)
+  state.fpcr = static_cast<std::uint64_t>(_ReadStatusReg(0x5A20));
+#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \
+    (__ARM_FP != 0)
+  __asm__ __volatile__("VMRS %[fpscr], fpscr" : [fpscr] "=r"(state.fpscr));
+#elif defined(__GNUC__) && defined(__aarch64__)
+  __asm__ __volatile__("MRS %[fpcr], fpcr" : [fpcr] "=r"(state.fpcr));
+#endif
+  return state;
+}
+
+inline void set_fpu_state(const struct fpu_state state) {
+#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
+    (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
+  _mm_setcsr(static_cast<unsigned int>(state.mxcsr));
+#elif defined(_MSC_VER) && defined(_M_ARM)
+  _MoveToCoprocessor(static_cast<int>(state.fpscr), 10, 7, 1, 0, 0);
+#elif defined(_MSC_VER) && defined(_M_ARM64)
+  _WriteStatusReg(0x5A20, static_cast<__int64>(state.fpcr));
+#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \
+    (__ARM_FP != 0)
+  __asm__ __volatile__("VMSR fpscr, %[fpscr]" : : [fpscr] "r"(state.fpscr));
+#elif defined(__GNUC__) && defined(__aarch64__)
+  __asm__ __volatile__("MSR fpcr, %[fpcr]" : : [fpcr] "r"(state.fpcr));
+#else
+  (void)state;
+#endif
+}
+
+inline void disable_fpu_denormals() {
+#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
+    (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
+  _mm_setcsr(_mm_getcsr() | 0x8040);
+#elif defined(_MSC_VER) && defined(_M_ARM)
+  int fpscr = _MoveFromCoprocessor(10, 7, 1, 0, 0);
+  fpscr |= 0x1000000;
+  _MoveToCoprocessor(fpscr, 10, 7, 1, 0, 0);
+#elif defined(_MSC_VER) && defined(_M_ARM64)
+  __int64 fpcr = _ReadStatusReg(0x5A20);
+  fpcr |= 0x1080000;
+  _WriteStatusReg(0x5A20, fpcr);
+#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \
+    (__ARM_FP != 0)
+  std::uint32_t fpscr;
+#if defined(__thumb__) && !defined(__thumb2__)
+  __asm__ __volatile__(
+      "VMRS %[fpscr], fpscr\n"
+      "ORRS %[fpscr], %[bitmask]\n"
+      "VMSR fpscr, %[fpscr]\n"
+      : [fpscr] "=l"(fpscr)
+      : [bitmask] "l"(0x1000000)
+      : "cc");
+#else
+  __asm__ __volatile__(
+      "VMRS %[fpscr], fpscr\n"
+      "ORR %[fpscr], #0x1000000\n"
+      "VMSR fpscr, %[fpscr]\n"
+      : [fpscr] "=r"(fpscr));
+#endif
+#elif defined(__GNUC__) && defined(__aarch64__)
+  std::uint64_t fpcr;
+  __asm__ __volatile__(
+      "MRS %[fpcr], fpcr\n"
+      "ORR %w[fpcr], %w[fpcr], 0x1000000\n"
+      "ORR %w[fpcr], %w[fpcr], 0x80000\n"
+      "MSR fpcr, %[fpcr]\n"
+      : [fpcr] "=r"(fpcr));
+#endif
+}
+}  // namespace
+
+ScopedSuppressDenormals::ScopedSuppressDenormals() {
+  restore_ = get_fpu_state();
+  disable_fpu_denormals();
+}
+
+ScopedSuppressDenormals::~ScopedSuppressDenormals() { set_fpu_state(restore_); }
+}  // namespace ruy
diff --git a/ruy/denormal.h b/ruy/denormal.h
new file mode 100644
index 0000000..e5b836c
--- /dev/null
+++ b/ruy/denormal.h
@@ -0,0 +1,53 @@
+/* Copyright 2021 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef RUY_RUY_DENORMAL_H_
+#define RUY_RUY_DENORMAL_H_
+
+#include <cstdint>
+
+namespace ruy {
+// NOTE: the following 'fpu_state' struct is copied from
+// pthreadpool/src/threadpool-utils.h that's not exposed by the pthreadpool
+// library (https://github.com/Maratyszcza/pthreadpool).
+struct fpu_state {
+#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
+    (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
+  std::uint32_t mxcsr;
+#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \
+        (__ARM_FP != 0) ||                                          \
+    defined(_MSC_VER) && defined(_M_ARM)
+  std::uint32_t fpscr;
+#elif defined(__GNUC__) && defined(__aarch64__) || \
+    defined(_MSC_VER) && defined(_M_ARM64)
+  std::uint64_t fpcr;
+#endif
+};
+
+// While this class is active, denormal floating point numbers are suppressed.
+// The destructor restores the original flags.
+class ScopedSuppressDenormals {
+ public:
+  ScopedSuppressDenormals();
+  ~ScopedSuppressDenormals();
+
+ private:
+  fpu_state restore_;
+
+  ScopedSuppressDenormals(const ScopedSuppressDenormals&) = delete;
+  void operator=(const ScopedSuppressDenormals&) = delete;
+};
+}  // namespace ruy
+
+#endif  // RUY_RUY_DENORMAL_H_
diff --git a/ruy/kernel_arm.h b/ruy/kernel_arm.h
index 76cfc82..15a5a89 100644
--- a/ruy/kernel_arm.h
+++ b/ruy/kernel_arm.h
@@ -49,6 +49,7 @@
 void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params);
 void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params);
 void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params);
+void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params);
 
 #if RUY_PLATFORM_NEON_64
 template <typename DstScalar>
@@ -104,7 +105,8 @@
 
 #if RUY_PLATFORM_NEON_64
 template <typename DstScalar>
-struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
+struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t,
+              DstScalar> {
   static constexpr Path kPath = Path::kNeonDotprod;
   Tuning tuning = Tuning::kAuto;
   using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
@@ -121,6 +123,8 @@
       Kernel8bitNeonDotprod1Col(params);
     } else if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
       Kernel8bitNeonDotprodA55ish(params);
+    } else if (tuning == Tuning::kX1) {
+      Kernel8bitNeonDotprodX1(params);
     } else {
       Kernel8bitNeonDotprod(params);
     }
@@ -129,6 +133,7 @@
 #endif
 
 void KernelFloatNeon(const KernelParamsFloat<8, 8>& params);
+void KernelFloatNeonX1(const KernelParamsFloat<8, 8>& params);
 void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params);
 void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params);
 void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params);
@@ -150,6 +155,8 @@
                           end_col, dst, &params);
     if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
       KernelFloatNeonA55ish(params);
+    } else if (tuning == Tuning::kX1) {
+      KernelFloatNeonX1(params);
     } else {
       KernelFloatNeon(params);
     }
@@ -188,8 +195,7 @@
   Tuning tuning = Tuning::kAuto;
   using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
   using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
-  using Base =
-      Kernel<Path::kNeon, float, float, float, float>;
+  using Base = Kernel<Path::kNeon, float, float, float, float>;
   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
   void Run(const PMat<float>& lhs, const PMat<float>& rhs,
            const MulParams<float, float>& mul_params, int start_row,
@@ -199,6 +205,8 @@
                           end_col, dst, &params);
     if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
       KernelFloatNeonDotprodA55ish(params);
+    } else if (tuning == Tuning::kX1) {
+      KernelFloatNeonX1(params);
     } else {
       KernelFloatNeon(params);
     }
diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc
index b20f668..8782dce 100644
--- a/ruy/kernel_arm32.cc
+++ b/ruy/kernel_arm32.cc
@@ -1102,7 +1102,7 @@
         "vdup.16 q13, r4\n" // dst_zero_point
 
         // Add the destination zero point
-        "vadd.i16 q14, q14, q13\n"
+        "vqadd.s16 q14, q14, q13\n"
 
         // Cast-and-saturate from int16 to uint8
         // Now all 8 1-byte values are in d30.
@@ -1226,7 +1226,7 @@
         "vdup.16 q13, r4\n" // dst_zero_point
 
         // Add the destination zero point
-        "vadd.i16 q14, q14, q13\n"
+        "vqadd.s16 q14, q14, q13\n"
 
         // Cast-and-saturate from int16 to int8
         // Now all 8 1-byte values are in d30.
@@ -2014,7 +2014,7 @@
         "vdup.16 q13, r4\n" // dst_zero_point
 
         // Add the destination zero point
-        "vadd.i16 q14, q14, q13\n"
+        "vqadd.s16 q14, q14, q13\n"
 
         // Cast-and-saturate from int16 to uint8
         "vqmovun.s16 d30, q14\n"
@@ -2126,7 +2126,7 @@
         "vdup.16 q13, r4\n" // dst_zero_point
 
         // Add the destination zero point
-        "vadd.i16 q14, q14, q13\n"
+        "vqadd.s16 q14, q14, q13\n"
 
         // Cast-and-saturate from int16 to int8
         "vqmovn.s16 d30, q14\n"
diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc
index fe65d9c..5424107 100644
--- a/ruy/kernel_arm64.cc
+++ b/ruy/kernel_arm64.cc
@@ -623,8 +623,8 @@
 
         // Add the destination zero point
         "dup v14.8h, v13.h[4]\n"
-        "add v16.8h, v16.8h, v14.8h\n"
-        "add v17.8h, v17.8h, v14.8h\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
+        "sqadd v17.8h, v17.8h, v14.8h\n"
 
         // Cast-and-saturate from int16 to uint8
         "sqxtun v16.8b, v16.8h\n"
@@ -750,8 +750,8 @@
 
         // Add the destination zero point
         "dup v14.8h, v13.h[4]\n"
-        "add v16.8h, v16.8h, v14.8h\n"
-        "add v17.8h, v17.8h, v14.8h\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
+        "sqadd v17.8h, v17.8h, v14.8h\n"
 
         // Cast-and-saturate from int16 to int8
         "sqxtn v16.8b, v16.8h\n"
@@ -1472,7 +1472,7 @@
 
         // Add the destination zero point
         "dup v14.8h, v13.h[4]\n"
-        "add v16.8h, v16.8h, v14.8h\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
 
         // Cast-and-saturate from int16 to uint8
         // Now all data is in the first 32-bits of v16
@@ -1553,7 +1553,7 @@
 
         // Add the destination zero point
         "dup v14.8h, v13.h[4]\n"
-        "add v16.8h, v16.8h, v14.8h\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
 
         // Cast-and-saturate from int16 to int8
         "sqxtn v16.8b, v16.8h\n"
@@ -2394,9 +2394,9 @@
         "dup v14.8h, v13.h[4]\n"
         RUY_MAKE_ZERO(v20)
         "add %[rhs_ptr], %[rhs_ptr], #64\n"
-        "add v16.8h, v16.8h, v14.8h\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
         RUY_MAKE_ZERO(v21)
-        "add v17.8h, v17.8h, v14.8h\n"
+        "sqadd v17.8h, v17.8h, v14.8h\n"
         RUY_MAKE_ZERO(v22)
 
         // Cast-and-saturate from int16 to uint8
@@ -2526,9 +2526,9 @@
         "dup v14.8h, v13.h[4]\n"
         RUY_MAKE_ZERO(v20)
         "add %[rhs_ptr], %[rhs_ptr], #64\n"
-        "add v16.8h, v16.8h, v14.8h\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
         RUY_MAKE_ZERO(v21)
-        "add v17.8h, v17.8h, v14.8h\n"
+        "sqadd v17.8h, v17.8h, v14.8h\n"
         RUY_MAKE_ZERO(v22)
 
         // Cast-and-saturate from int16 to uint8
@@ -3713,14 +3713,14 @@
 
         // Add the destination zero point
         "dup v14.8h, v13.h[4]\n"
-        "add v16.8h, v16.8h, v14.8h\n"
-        "add v17.8h, v17.8h, v14.8h\n"
-        "add v18.8h, v18.8h, v14.8h\n"
-        "add v19.8h, v19.8h, v14.8h\n"
-        "add v20.8h, v20.8h, v14.8h\n"
-        "add v21.8h, v21.8h, v14.8h\n"
-        "add v22.8h, v22.8h, v14.8h\n"
-        "add v23.8h, v23.8h, v14.8h\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
+        "sqadd v17.8h, v17.8h, v14.8h\n"
+        "sqadd v18.8h, v18.8h, v14.8h\n"
+        "sqadd v19.8h, v19.8h, v14.8h\n"
+        "sqadd v20.8h, v20.8h, v14.8h\n"
+        "sqadd v21.8h, v21.8h, v14.8h\n"
+        "sqadd v22.8h, v22.8h, v14.8h\n"
+        "sqadd v23.8h, v23.8h, v14.8h\n"
 
         // Cast-and-saturate from int16 to uint8
         "sqxtun v16.8b, v16.8h\n"
@@ -3888,14 +3888,14 @@
 
         // Add the destination zero point
         "dup v14.8h, v13.h[4]\n"
-        "add v16.8h, v16.8h, v14.8h\n"
-        "add v17.8h, v17.8h, v14.8h\n"
-        "add v18.8h, v18.8h, v14.8h\n"
-        "add v19.8h, v19.8h, v14.8h\n"
-        "add v20.8h, v20.8h, v14.8h\n"
-        "add v21.8h, v21.8h, v14.8h\n"
-        "add v22.8h, v22.8h, v14.8h\n"
-        "add v23.8h, v23.8h, v14.8h\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
+        "sqadd v17.8h, v17.8h, v14.8h\n"
+        "sqadd v18.8h, v18.8h, v14.8h\n"
+        "sqadd v19.8h, v19.8h, v14.8h\n"
+        "sqadd v20.8h, v20.8h, v14.8h\n"
+        "sqadd v21.8h, v21.8h, v14.8h\n"
+        "sqadd v22.8h, v22.8h, v14.8h\n"
+        "sqadd v23.8h, v23.8h, v14.8h\n"
 
         // Cast-and-saturate from int16 to uint8
         "sqxtn v16.8b, v16.8h\n"
@@ -4402,6 +4402,1261 @@
           "v26", "v27", "v28", "v29", "v30", "v31");
 }
 
+// A fork of the above 8bitNeonDotprod kernel but removes the max streaming
+// manual unrolling. Manually unrolling the inner loops benefits some GEMM
+// shapes on the Cortex-A76 but destroys performance on the X1 by increasing
+// backend stalls. Therefore, we remove the MAX_STREAMING option in this
+// kernel. The target CPU for this kernel is currently only the Cortex-X1.
+void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params) {
+  profiler::ScopeLabel label("Kernel (kNeonDotprod)");
+
+  CheckOffsetsInKernelParams8bit(params);
+
+  const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const std::int8_t* lhs_ptr = lhs_col_ptr;
+  const std::int8_t* rhs_ptr = rhs_col_ptr;
+  void* dst_col_ptr = params.dst_base_ptr;
+  void* dst_ptr = dst_col_ptr;
+  int row = params.start_row;
+  int col = params.start_col;
+
+  // The asm kernel below has the following NEON register allocation:
+  //
+  // v16 -- v31 are int32 accumulators.
+  // During accumulation, v0 -- v15 are used to load int8 data from LHS and
+  // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
+  // v3 are used to load a 4x8 block of RHS, like this:
+  //
+  //                                      int8 RHS 4x8 block
+  //                           /-----------------------------------------|
+  //                           |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
+  //                           |  ...                              ...   |
+  //                           |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
+  //                           \-----------------------------------------/
+  //    int8 LHS 8x4 block
+  //  /---------------------\  /-----------------------------------------|
+  //  |v0.b[0]  ... v0.b[3] |  |v16.s[0]           ...           v30.s[0]|
+  //  |  ...          ...   |  |  ...                              ...   |
+  //  |v0.b[12] ... v0.b[15]|  |v16.s[3]           ...           v30.s[3]|
+  //  |v1.b[0]  ... v1.b[3] |  |v17.s[0]           ...           v31.s[0]|
+  //  |  ...         ...    |  |  ...                              ...   |
+  //  |v1.b[12] ... v1.b[15]|  |v17.s[3]           ...           v31.s[3]|
+  //  \---------------------/  \-----------------------------------------/
+  //                                  int32 accumulators 8x8 block
+  //
+  // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
+  // is repeated 4 times, using 4x more registers for LHS and RHS, so that
+  // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
+  //
+  // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
+  // unused, and v8 -- v15 are used for loading parameters used for the
+  // post-accumulation part of the kernel.
+  asm volatile(
+#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
+
+        // clang-format off
+
+        // Load some parameters into registers.
+        "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+        "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+        "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+        "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+        "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+        "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+        "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+        "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+        // Load the first 32 bytes of LHS and RHS data.
+        "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+        "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+        "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+        "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+
+        // Clear accumulators.
+        RUY_MAKE_ZERO(v16)
+        RUY_MAKE_ZERO(v17)
+        RUY_MAKE_ZERO(v18)
+        RUY_MAKE_ZERO(v19)
+        RUY_MAKE_ZERO(v20)
+        RUY_MAKE_ZERO(v21)
+        RUY_MAKE_ZERO(v22)
+        RUY_MAKE_ZERO(v23)
+        RUY_MAKE_ZERO(v24)
+        RUY_MAKE_ZERO(v25)
+        RUY_MAKE_ZERO(v26)
+        RUY_MAKE_ZERO(v27)
+        RUY_MAKE_ZERO(v28)
+        RUY_MAKE_ZERO(v29)
+        RUY_MAKE_ZERO(v30)
+        RUY_MAKE_ZERO(v31)
+
+        // w1 is the number of levels of depth that we have already loaded
+        // LHS and RHS data for. Corresponding to the initial ld1 instructions
+        // above, this is currently 4.
+        "mov w1, #4\n"
+
+        // Perform the first few multiply-adds on the data that we have already
+        // loaded.
+        ".word 0x4f82e010  // sdot v16.4s, v0.16b, v2.4b[0]\n"
+        ".word 0x4fa2e012  // sdot v18.4s, v0.16b, v2.4b[1]\n"
+        ".word 0x4f82e814  // sdot v20.4s, v0.16b, v2.4b[2]\n"
+        ".word 0x4fa2e816  // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+        // Main loop of the whole GEMM, over rows and columns of the
+        // destination matrix.
+        "1:\n"
+
+        // Kernel inner loop (over depth).
+        // Reminder - w1 is how many levels of depth we have already loaded
+        // data for, w12 is the total depth.
+        "cmp w1, w12\n"
+        "beq 79f\n"
+
+        "2:\n"
+
+        // Because of the data that we have already loaded, we can start the
+        // loop body right away with some multiply-adds.
+        ".word 0x4f83e018  // sdot v24.4s, v0.16b, v3.4b[0]\n"
+        ".word 0x4fa3e01a  // sdot v26.4s, v0.16b, v3.4b[1]\n"
+        // Each iteration of this loop advances by 4 levels of depth.
+        "add w1, w1, #4\n"
+        ".word 0x4f83e81c  // sdot v28.4s, v0.16b, v3.4b[2]\n"
+        ".word 0x4fa3e81e  // sdot v30.4s, v0.16b, v3.4b[3]\n"
+        "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+        ".word 0x4f82e031  // sdot v17.4s, v1.16b, v2.4b[0]\n"
+        ".word 0x4fa2e033  // sdot v19.4s, v1.16b, v2.4b[1]\n"
+        // Loop termination condition.
+        "cmp w1, w12\n"
+        ".word 0x4f82e835  // sdot v21.4s, v1.16b, v2.4b[2]\n"
+        ".word 0x4fa2e837  // sdot v23.4s, v1.16b, v2.4b[3]\n"
+        "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+        ".word 0x4f83e039  // sdot v25.4s, v1.16b, v3.4b[0]\n"
+        ".word 0x4fa3e03b  // sdot v27.4s, v1.16b, v3.4b[1]\n"
+        ".word 0x4f83e83d  // sdot v29.4s, v1.16b, v3.4b[2]\n"
+        ".word 0x4fa3e83f  // sdot v31.4s, v1.16b, v3.4b[3]\n"
+        "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+        ".word 0x4f82e010  // sdot v16.4s, v0.16b, v2.4b[0]\n"
+        ".word 0x4fa2e012  // sdot v18.4s, v0.16b, v2.4b[1]\n"
+        ".word 0x4f82e814  // sdot v20.4s, v0.16b, v2.4b[2]\n"
+        ".word 0x4fa2e816  // sdot v22.4s, v0.16b, v2.4b[3]\n"
+        "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+
+        "blt 2b\n"
+
+        "79:\n"
+        // End of the inner loop on depth. Now perform the remaining
+        // multiply-adds of the last 4 levels of depth, for which the LHS
+        // and RHS data is already loaded.
+
+        ".word 0x4f83e018  // sdot v24.4s, v0.16b, v3.4b[0]\n"
+        ".word 0x4fa3e01a  // sdot v26.4s, v0.16b, v3.4b[1]\n"
+        ".word 0x4f83e81c  // sdot v28.4s, v0.16b, v3.4b[2]\n"
+        ".word 0x4fa3e81e  // sdot v30.4s, v0.16b, v3.4b[3]\n"
+        ".word 0x4f82e031  // sdot v17.4s, v1.16b, v2.4b[0]\n"
+        ".word 0x4fa2e033  // sdot v19.4s, v1.16b, v2.4b[1]\n"
+        ".word 0x4f82e835  // sdot v21.4s, v1.16b, v2.4b[2]\n"
+        ".word 0x4fa2e837  // sdot v23.4s, v1.16b, v2.4b[3]\n"
+        ".word 0x4f83e039  // sdot v25.4s, v1.16b, v3.4b[0]\n"
+        ".word 0x4fa3e03b  // sdot v27.4s, v1.16b, v3.4b[1]\n"
+        ".word 0x4f83e83d  // sdot v29.4s, v1.16b, v3.4b[2]\n"
+        ".word 0x4fa3e83f  // sdot v31.4s, v1.16b, v3.4b[3]\n"
+
+        // End of accumulation. The registers v16 -- v31 contain the final
+        // int32 accumulator values of the current 8x8 destination block.
+        // We now have to compute the final 8-bit values from these int32
+        // accumulators, and advance to the next 8x8 block. We intertwine
+        // these two aspects whenever possible for optimal pipelining, both
+        // at the data flow level (prefetch data for next block as early as
+        // possible) and instruction pipelining level (some of the next-block
+        // work can dual-issue with some of the final work on the current
+        // block).
+
+        // Logic to advance to the next block in preparation for the next
+        // iteration of the main loop. For now, we only want to compute
+        // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+        // not yet ready to update the values of row and col, as we still need
+        // the current values for the rest of the work on the current block.
+
+        "cmp %w[row], w7\n"  // Have we finished the last row?
+        "bge 4f\n"           // If finished last row, go to 4
+        // Not finished last row: then advance to next row.
+        "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+        "b 5f\n"
+        "4:\n"  // Finished last row...
+        "mov %[lhs_col_ptr], x5\n"  // Go back to first row
+        // Now we need to advance to the next column. If we already
+        // finished the last column, then in principle we are done, however
+        // we can't just return here, as we need to allow the end work of the
+        // current block to complete. The good news is that at this point it
+        // doesn't matter what data we load for the next column, since
+        // we will exit from the main loop below before actually storing
+        // anything computed from that data.
+        "cmp %w[col], w8\n"  // Have we finished the last column?
+        "bge 5f\n" // If yes, just carry on without updating the column pointer.
+        // Not finished last column: then advance to next column.
+        "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+        "5:\n"
+
+        // Set the LHS and RHS data pointers to the start of the columns just
+        // computed.
+        "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+        "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+        // Load some parameters needed for the end work on current block.
+        "mvni v8.4s, #0\n"
+        "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+        "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+        "dup v9.4s, w3\n"   // create prod_zp_depth_vec
+
+        "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+        // Determine the channel index.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "csel w3, %w[row], %w[col], eq\n"
+
+        // Offset the bias pointer as needed given the current row, col.
+        "add x5, x1, x3, lsl #2\n"
+
+        // If there is no bias, use no offset, just address the passed zero
+        // data.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+        "csel x1, x1, x5, eq\n"
+
+        // Load 8 bias values.
+        "ld1 {v14.4s}, [x1], #16\n"
+        "ld1 {v15.4s}, [x1]\n"
+
+        // Now that we know what LHS and RHS data the next iteration of the
+        // main loop will need to load, we start loading the first 32 bytes of
+        // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+        // in the rest of the work on the current block.
+        "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+        "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+        "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+        "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+
+        // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+        // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+        "add v14.4s, v14.4s, v9.4s\n"
+        "add v15.4s, v15.4s, v9.4s\n"
+
+        // Perform the bias-addition (per the above, we have just folded into
+        // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+        // Jump based on channel dimension.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "bne 6f\n"
+        // Case where channels are rows
+        "add v16.4s, v16.4s, v14.4s\n"
+        "add v17.4s, v17.4s, v15.4s\n"
+        "add v18.4s, v18.4s, v14.4s\n"
+        "add v19.4s, v19.4s, v15.4s\n"
+        "add v20.4s, v20.4s, v14.4s\n"
+        "add v21.4s, v21.4s, v15.4s\n"
+        "add v22.4s, v22.4s, v14.4s\n"
+        "add v23.4s, v23.4s, v15.4s\n"
+        "add v24.4s, v24.4s, v14.4s\n"
+        "add v25.4s, v25.4s, v15.4s\n"
+        "add v26.4s, v26.4s, v14.4s\n"
+        "add v27.4s, v27.4s, v15.4s\n"
+        "add v28.4s, v28.4s, v14.4s\n"
+        "add v29.4s, v29.4s, v15.4s\n"
+        "add v30.4s, v30.4s, v14.4s\n"
+        "add v31.4s, v31.4s, v15.4s\n"
+        "b 7f\n"
+
+        "6:\n"
+        // Case where channels are columns
+        "dup v10.4s, v14.s[0]\n"
+        "dup v11.4s, v14.s[1]\n"
+        "dup v12.4s, v14.s[2]\n"
+        "dup v13.4s, v14.s[3]\n"
+        "add v16.4s, v16.4s, v10.4s\n"
+        "add v17.4s, v17.4s, v10.4s\n"
+        "add v18.4s, v18.4s, v11.4s\n"
+        "add v19.4s, v19.4s, v11.4s\n"
+        "add v20.4s, v20.4s, v12.4s\n"
+        "add v21.4s, v21.4s, v12.4s\n"
+        "add v22.4s, v22.4s, v13.4s\n"
+        "add v23.4s, v23.4s, v13.4s\n"
+        "dup v10.4s, v15.s[0]\n"
+        "dup v11.4s, v15.s[1]\n"
+        "dup v12.4s, v15.s[2]\n"
+        "dup v13.4s, v15.s[3]\n"
+        "add v24.4s, v24.4s, v10.4s\n"
+        "add v25.4s, v25.4s, v10.4s\n"
+        "add v26.4s, v26.4s, v11.4s\n"
+        "add v27.4s, v27.4s, v11.4s\n"
+        "add v28.4s, v28.4s, v12.4s\n"
+        "add v29.4s, v29.4s, v12.4s\n"
+        "add v30.4s, v30.4s, v13.4s\n"
+        "add v31.4s, v31.4s, v13.4s\n"
+        "7:\n"
+
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+        "beq 401f\n"
+        "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+        "add x3, x3, %x[col], lsl #2\n"
+        "ld1 {v14.4s}, [x3], #16\n"
+        "ld1 {v15.4s}, [x3]\n"
+        "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+        "dup v10.4s, w5\n"  // create lhs_zero_point_vec
+        // Subtract rhs_sums * lhs_zero_point, per
+        // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+        "mls v16.4s, v10.4s, v14.s[0]\n"
+        "mls v17.4s, v10.4s, v14.s[0]\n"
+        "mls v18.4s, v10.4s, v14.s[1]\n"
+        "mls v19.4s, v10.4s, v14.s[1]\n"
+        "mls v20.4s, v10.4s, v14.s[2]\n"
+        "mls v21.4s, v10.4s, v14.s[2]\n"
+        "mls v22.4s, v10.4s, v14.s[3]\n"
+        "mls v23.4s, v10.4s, v14.s[3]\n"
+        "mls v24.4s, v10.4s, v15.s[0]\n"
+        "mls v25.4s, v10.4s, v15.s[0]\n"
+        "mls v26.4s, v10.4s, v15.s[1]\n"
+        "mls v27.4s, v10.4s, v15.s[1]\n"
+        "mls v28.4s, v10.4s, v15.s[2]\n"
+        "mls v29.4s, v10.4s, v15.s[2]\n"
+        "mls v30.4s, v10.4s, v15.s[3]\n"
+        "mls v31.4s, v10.4s, v15.s[3]\n"
+        "401:\n"
+
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+        "beq 402f\n"
+        "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+        "add x2, x2, %x[row], lsl #2\n"
+        "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+        // Load 4 lhs_sums values.
+        "ld1 {v11.4s}, [x2], #16\n"
+        "ld1 {v12.4s}, [x2]\n"
+        "ins v13.s[1], w5\n" // rhs_zero_point
+        // Compute lhs_sums * rhs_zero_point.
+        "mul v11.4s, v11.4s, v13.s[1]\n"
+        "mul v12.4s, v12.4s, v13.s[1]\n"
+        // Subtract lhs_sums * rhs_zero_point, per
+        // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+        "sub v16.4s, v16.4s, v11.4s\n"
+        "sub v17.4s, v17.4s, v12.4s\n"
+        "sub v18.4s, v18.4s, v11.4s\n"
+        "sub v19.4s, v19.4s, v12.4s\n"
+        "sub v20.4s, v20.4s, v11.4s\n"
+        "sub v21.4s, v21.4s, v12.4s\n"
+        "sub v22.4s, v22.4s, v11.4s\n"
+        "sub v23.4s, v23.4s, v12.4s\n"
+        "sub v24.4s, v24.4s, v11.4s\n"
+        "sub v25.4s, v25.4s, v12.4s\n"
+        "sub v26.4s, v26.4s, v11.4s\n"
+        "sub v27.4s, v27.4s, v12.4s\n"
+        "sub v28.4s, v28.4s, v11.4s\n"
+        "sub v29.4s, v29.4s, v12.4s\n"
+        "sub v30.4s, v30.4s, v11.4s\n"
+        "sub v31.4s, v31.4s, v12.4s\n"
+
+        "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+        "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+        "402:\n"
+
+        // At this point we have computed the final int32 values. Now we
+        // start down-quantizing them to obtain the final 8bit values from them.
+
+        // As part of this down-quantization, our int32 values will be
+        // multiplied by a multiplier that has a fixed-point component and an
+        // exponent component.
+
+        //Load the exponent part of the multiplier.
+        "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+        // Determine the channel index.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "csel w3, %w[row], %w[col], eq\n"
+        // Compute the multiplier_exponent pointer
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+        "add x5, x1, x3, lsl #2\n"
+        "csel x1, x1, x5, eq\n"
+        // Load multiplier_exponent
+        "ldr q9, [x1]\n"
+        "ldr q10, [x1, #16]\n"
+        // Separate positive and negative exponents
+        "smin v11.4s, v8.4s, v9.4s\n"
+        "smin v12.4s, v8.4s, v10.4s\n"
+        "sub v9.4s, v9.4s, v11.4s\n"
+        "sub v10.4s, v10.4s, v12.4s\n"
+
+        // Compute the multiplier_fixedpoint pointer
+        "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+        "add x5, x4, x3, lsl #2\n"
+        "csel x4, x4, x5, eq\n"
+        // Load multiplier_fixedpoint
+        "ldr q14, [x4]\n"
+        "ldr q15, [x4, #16]\n"
+
+        // Jump based on channel dimension.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "bne 8f\n"
+        // Case where channels are rows
+
+        // Apply the positive exponent part of the multiplier.
+        "sshl v16.4s, v16.4s, v9.4s\n"
+        "sshl v17.4s, v17.4s, v10.4s\n"
+        "sshl v18.4s, v18.4s, v9.4s\n"
+        "sshl v19.4s, v19.4s, v10.4s\n"
+        "sshl v20.4s, v20.4s, v9.4s\n"
+        "sshl v21.4s, v21.4s, v10.4s\n"
+        "sshl v22.4s, v22.4s, v9.4s\n"
+        "sshl v23.4s, v23.4s, v10.4s\n"
+        "sshl v24.4s, v24.4s, v9.4s\n"
+        "sshl v25.4s, v25.4s, v10.4s\n"
+        "sshl v26.4s, v26.4s, v9.4s\n"
+        "sshl v27.4s, v27.4s, v10.4s\n"
+        "sshl v28.4s, v28.4s, v9.4s\n"
+        "sshl v29.4s, v29.4s, v10.4s\n"
+        "sshl v30.4s, v30.4s, v9.4s\n"
+        "sshl v31.4s, v31.4s, v10.4s\n"
+        "10:\n"
+
+        // Apply the fixed-point part of the multiplier.
+        "sqdmulh v16.4s, v16.4s, v14.4s\n"
+        "sqdmulh v17.4s, v17.4s, v15.4s\n"
+        "sqdmulh v18.4s, v18.4s, v14.4s\n"
+        "sqdmulh v19.4s, v19.4s, v15.4s\n"
+        "sqdmulh v20.4s, v20.4s, v14.4s\n"
+        "sqdmulh v21.4s, v21.4s, v15.4s\n"
+        "sqdmulh v22.4s, v22.4s, v14.4s\n"
+        "sqdmulh v23.4s, v23.4s, v15.4s\n"
+        "sqdmulh v24.4s, v24.4s, v14.4s\n"
+        "sqdmulh v25.4s, v25.4s, v15.4s\n"
+        "sqdmulh v26.4s, v26.4s, v14.4s\n"
+        "sqdmulh v27.4s, v27.4s, v15.4s\n"
+        "sqdmulh v28.4s, v28.4s, v14.4s\n"
+        "sqdmulh v29.4s, v29.4s, v15.4s\n"
+        "sqdmulh v30.4s, v30.4s, v14.4s\n"
+        "sqdmulh v31.4s, v31.4s, v15.4s\n"
+
+        // Apply the negative exponent part of the multiplier.
+        "srshl v16.4s, v16.4s, v11.4s\n"
+        "srshl v17.4s, v17.4s, v12.4s\n"
+        "srshl v18.4s, v18.4s, v11.4s\n"
+        "srshl v19.4s, v19.4s, v12.4s\n"
+        "srshl v20.4s, v20.4s, v11.4s\n"
+        "srshl v21.4s, v21.4s, v12.4s\n"
+        "srshl v22.4s, v22.4s, v11.4s\n"
+        "srshl v23.4s, v23.4s, v12.4s\n"
+        "srshl v24.4s, v24.4s, v11.4s\n"
+        "srshl v25.4s, v25.4s, v12.4s\n"
+        "srshl v26.4s, v26.4s, v11.4s\n"
+        "srshl v27.4s, v27.4s, v12.4s\n"
+        "srshl v28.4s, v28.4s, v11.4s\n"
+        "srshl v29.4s, v29.4s, v12.4s\n"
+        "srshl v30.4s, v30.4s, v11.4s\n"
+        "srshl v31.4s, v31.4s, v12.4s\n"
+        "b 9f\n"
+
+        "8:\n"
+        // Case where channels are columns
+
+        // Apply the positive exponent part of the multiplier.
+        "dup v4.4s, v9.s[0]\n"
+        "dup v5.4s, v9.s[1]\n"
+        "dup v6.4s, v9.s[2]\n"
+        "dup v7.4s, v9.s[3]\n"
+        "sshl v16.4s, v16.4s, v4.4s\n"
+        "sshl v17.4s, v17.4s, v4.4s\n"
+        "sshl v18.4s, v18.4s, v5.4s\n"
+        "sshl v19.4s, v19.4s, v5.4s\n"
+        "sshl v20.4s, v20.4s, v6.4s\n"
+        "sshl v21.4s, v21.4s, v6.4s\n"
+        "sshl v22.4s, v22.4s, v7.4s\n"
+        "sshl v23.4s, v23.4s, v7.4s\n"
+        "dup v4.4s, v10.s[0]\n"
+        "dup v5.4s, v10.s[1]\n"
+        "dup v6.4s, v10.s[2]\n"
+        "dup v7.4s, v10.s[3]\n"
+        "sshl v24.4s, v24.4s, v4.4s\n"
+        "sshl v25.4s, v25.4s, v4.4s\n"
+        "sshl v26.4s, v26.4s, v5.4s\n"
+        "sshl v27.4s, v27.4s, v5.4s\n"
+        "sshl v28.4s, v28.4s, v6.4s\n"
+        "sshl v29.4s, v29.4s, v6.4s\n"
+        "sshl v30.4s, v30.4s, v7.4s\n"
+        "sshl v31.4s, v31.4s, v7.4s\n"
+        "11:\n"
+
+        // Apply the fixed-point part of the multiplier.
+        "sqdmulh v16.4s, v16.4s, v14.s[0]\n"
+        "sqdmulh v17.4s, v17.4s, v14.s[0]\n"
+        "sqdmulh v18.4s, v18.4s, v14.s[1]\n"
+        "sqdmulh v19.4s, v19.4s, v14.s[1]\n"
+        "sqdmulh v20.4s, v20.4s, v14.s[2]\n"
+        "sqdmulh v21.4s, v21.4s, v14.s[2]\n"
+        "sqdmulh v22.4s, v22.4s, v14.s[3]\n"
+        "sqdmulh v23.4s, v23.4s, v14.s[3]\n"
+        "sqdmulh v24.4s, v24.4s, v15.s[0]\n"
+        "sqdmulh v25.4s, v25.4s, v15.s[0]\n"
+        "sqdmulh v26.4s, v26.4s, v15.s[1]\n"
+        "sqdmulh v27.4s, v27.4s, v15.s[1]\n"
+        "sqdmulh v28.4s, v28.4s, v15.s[2]\n"
+        "sqdmulh v29.4s, v29.4s, v15.s[2]\n"
+        "sqdmulh v30.4s, v30.4s, v15.s[3]\n"
+        "sqdmulh v31.4s, v31.4s, v15.s[3]\n"
+
+        // Apply the negative exponent part of the multiplier.
+        "dup v4.4s, v11.s[0]\n"
+        "dup v5.4s, v11.s[1]\n"
+        "dup v6.4s, v11.s[2]\n"
+        "dup v7.4s, v11.s[3]\n"
+        "srshl v16.4s, v16.4s, v4.4s\n"
+        "srshl v17.4s, v17.4s, v4.4s\n"
+        "srshl v18.4s, v18.4s, v5.4s\n"
+        "srshl v19.4s, v19.4s, v5.4s\n"
+        "srshl v20.4s, v20.4s, v6.4s\n"
+        "srshl v21.4s, v21.4s, v6.4s\n"
+        "srshl v22.4s, v22.4s, v7.4s\n"
+        "srshl v23.4s, v23.4s, v7.4s\n"
+        "dup v4.4s, v12.s[0]\n"
+        "dup v5.4s, v12.s[1]\n"
+        "dup v6.4s, v12.s[2]\n"
+        "dup v7.4s, v12.s[3]\n"
+        "srshl v24.4s, v24.4s, v4.4s\n"
+        "srshl v25.4s, v25.4s, v4.4s\n"
+        "srshl v26.4s, v26.4s, v5.4s\n"
+        "srshl v27.4s, v27.4s, v5.4s\n"
+        "srshl v28.4s, v28.4s, v6.4s\n"
+        "srshl v29.4s, v29.4s, v6.4s\n"
+        "srshl v30.4s, v30.4s, v7.4s\n"
+        "srshl v31.4s, v31.4s, v7.4s\n"
+        "9:\n"
+
+        "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+        "ins v13.h[4], w4\n" // dst_zero_point
+
+        "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+        "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+        "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+        "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+        RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+        // Cast-and-saturate from int32 to int16
+        "sqxtn v16.4h, v16.4s\n"
+        "sqxtn2 v16.8h, v17.4s\n"
+        "sqxtn v17.4h, v18.4s\n"
+        "sqxtn2 v17.8h, v19.4s\n"
+        "sqxtn v18.4h, v20.4s\n"
+        "sqxtn2 v18.8h, v21.4s\n"
+        "sqxtn v19.4h, v22.4s\n"
+        "sqxtn2 v19.8h, v23.4s\n"
+        "sqxtn v20.4h, v24.4s\n"
+        "sqxtn2 v20.8h, v25.4s\n"
+        "sqxtn v21.4h, v26.4s\n"
+        "sqxtn2 v21.8h, v27.4s\n"
+        "sqxtn v22.4h, v28.4s\n"
+        "sqxtn2 v22.8h, v29.4s\n"
+        "sqxtn v23.4h, v30.4s\n"
+        "sqxtn2 v23.8h, v31.4s\n"
+
+        // At this point, v24 -- v31 aren't used anymore for the current block,
+        // so we can start clearing these accumulators for the next block
+        // (next iteration of the main loop).
+        RUY_MAKE_ZERO(v24)
+        RUY_MAKE_ZERO(v25)
+        RUY_MAKE_ZERO(v26)
+        RUY_MAKE_ZERO(v27)
+        RUY_MAKE_ZERO(v28)
+        RUY_MAKE_ZERO(v29)
+        RUY_MAKE_ZERO(v30)
+        RUY_MAKE_ZERO(v31)
+
+        // Add the destination zero point
+        "dup v14.8h, v13.h[4]\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
+        "sqadd v17.8h, v17.8h, v14.8h\n"
+        "sqadd v18.8h, v18.8h, v14.8h\n"
+        "sqadd v19.8h, v19.8h, v14.8h\n"
+        "sqadd v20.8h, v20.8h, v14.8h\n"
+        "sqadd v21.8h, v21.8h, v14.8h\n"
+        "sqadd v22.8h, v22.8h, v14.8h\n"
+        "sqadd v23.8h, v23.8h, v14.8h\n"
+
+        // Cast-and-saturate from int16 to uint8
+        "sqxtun v16.8b, v16.8h\n"
+        "sqxtun2 v16.16b, v17.8h\n"
+        "sqxtun v17.8b, v18.8h\n"
+        "sqxtun2 v17.16b, v19.8h\n"
+        "sqxtun v18.8b, v20.8h\n"
+        "sqxtun2 v18.16b, v21.8h\n"
+        "sqxtun v19.8b, v22.8h\n"
+        "sqxtun2 v19.16b, v23.8h\n"
+
+        // Load the clamp_min, clamp_max bounds
+        "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+        "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+        "dup v14.16b, w2\n"  // clamp_min
+        "dup v15.16b, w3\n"  // clamp_max
+
+        // Apply the clamp_min bound
+        "umax v16.16b, v16.16b, v14.16b\n"
+        "umax v17.16b, v17.16b, v14.16b\n"
+        "umax v18.16b, v18.16b, v14.16b\n"
+        "umax v19.16b, v19.16b, v14.16b\n"
+
+        // Apply the clamp_max bound
+        "umin v16.16b, v16.16b, v15.16b\n"
+        "umin v17.16b, v17.16b, v15.16b\n"
+        "umin v18.16b, v18.16b, v15.16b\n"
+        "umin v19.16b, v19.16b, v15.16b\n"
+
+        // Make it so that all of the final 8bit values are stored in the
+        // first 64bits of 128bit NEON registers, so they can be stored
+        // by 64bit st1 store instructions with byte alignment.
+        "dup d20, v16.d[1]\n"
+        "dup d21, v17.d[1]\n"
+        "dup d22, v18.d[1]\n"
+        "dup d23, v19.d[1]\n"
+
+        // Compute how much of the 8x8 block of destination 8bit values that
+        // we have computed, fit in the destination matrix. Typically, all of
+        // it fits, but when the destination matrix shape is not a multiple
+        // of 8x8, there are some 8x8 blocks along the boundaries that do
+        // not fit entirely.
+        "sub w1, %w[dst_rows], %w[row]\n"
+        "sub w2, %w[dst_cols], %w[col]\n"
+        "mov w3, #8\n"
+        "cmp w1, #8\n"
+        // Compute w1 = how many rows of the 8x8 block fit
+        "csel w1, w1, w3, le\n"
+        "cmp w2, #8\n"
+        // Compute w2 = how many cols of the 8x8 block fit
+        "csel w2, w2, w3, le\n"
+
+        // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+        "cmp w1, w3\n"
+        "ccmp w2, w3, 0, eq\n"
+        // Yes, all of the 8x8 block fits, go to fast path.
+        "beq 30f\n"
+        // Not all of the 8x8 block fits.
+        // Set (x3 address, x4 stride) to write to dst_tmp_buf
+        "mov x3, %[dst_tmp_buf]\n"
+        "mov x4, #8\n"
+        "b 31f\n"
+        "30:\n"
+        // Yes, all of the 8x8 block fits.
+        // Set (x3 address, x4 stride) to write directly to destination matrix.
+        "mov x3, %[dst_ptr]\n"
+        "mov x4, x11\n"
+        "31:\n"
+
+        // Write our 8bit values to the destination described by
+        // (x3 address, x4 stride).
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v16.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v16)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v20.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v20)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v17.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v17)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v21.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v21)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v18.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v18)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v22.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v22)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v19.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v19)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v23.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v23)
+
+        // For the next block: perform the first few multiply-adds on the data
+        // that we have already loaded.
+        ".word 0x4f82e010  // sdot v16.4s, v0.16b, v2.4b[0]\n"
+        ".word 0x4fa2e012  // sdot v18.4s, v0.16b, v2.4b[1]\n"
+        ".word 0x4f82e814  // sdot v20.4s, v0.16b, v2.4b[2]\n"
+        ".word 0x4fa2e816  // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+        // If all of the 8x8 block fits, we just finished writing it to the
+        // destination, so we skip the next part.
+        "beq 41f\n"
+        // Not all of the 8x8 block fits in the destination matrix.  We just
+        // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+        // it to copy into the destination matrix the part that fits.
+        "mov x3, %[dst_tmp_buf]\n"
+        "mov x4, %[dst_ptr]\n"
+        "mov w6, #0\n"
+        "50:\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov w5, #0\n"
+        "51:\n"
+        "ldrb w7, [x3, w5, uxtw]\n"
+        "strb w7, [x4, w5, uxtw]\n"
+        "add w5, w5, #1\n"
+        "cmp w5, w1\n"
+        "blt 51b\n"
+        "add w6, w6, #1\n"
+        "add x3, x3, #8\n"
+        "add x4, x4, x11\n"
+        "cmp w6, w2\n"
+        "blt 50b\n"
+        "41:\n"
+        "add %[dst_ptr], %[dst_ptr], #8\n"
+        // At this point we have completely finished writing values to the
+        // destination matrix for the current block.
+
+        "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+        RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+        // Cast-and-saturate from int32 to int16
+        "sqxtn v16.4h, v16.4s\n"
+        "sqxtn2 v16.8h, v17.4s\n"
+        "sqxtn v17.4h, v18.4s\n"
+        "sqxtn2 v17.8h, v19.4s\n"
+        "sqxtn v18.4h, v20.4s\n"
+        "sqxtn2 v18.8h, v21.4s\n"
+        "sqxtn v19.4h, v22.4s\n"
+        "sqxtn2 v19.8h, v23.4s\n"
+        "sqxtn v20.4h, v24.4s\n"
+        "sqxtn2 v20.8h, v25.4s\n"
+        "sqxtn v21.4h, v26.4s\n"
+        "sqxtn2 v21.8h, v27.4s\n"
+        "sqxtn v22.4h, v28.4s\n"
+        "sqxtn2 v22.8h, v29.4s\n"
+        "sqxtn v23.4h, v30.4s\n"
+        "sqxtn2 v23.8h, v31.4s\n"
+
+        // At this point, v24 -- v31 aren't used anymore for the current block,
+        // so we can start clearing these accumulators for the next block
+        // (next iteration of the main loop).
+        RUY_MAKE_ZERO(v24)
+        RUY_MAKE_ZERO(v25)
+        RUY_MAKE_ZERO(v26)
+        RUY_MAKE_ZERO(v27)
+        RUY_MAKE_ZERO(v28)
+        RUY_MAKE_ZERO(v29)
+        RUY_MAKE_ZERO(v30)
+        RUY_MAKE_ZERO(v31)
+
+        // Add the destination zero point
+        "dup v14.8h, v13.h[4]\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
+        "sqadd v17.8h, v17.8h, v14.8h\n"
+        "sqadd v18.8h, v18.8h, v14.8h\n"
+        "sqadd v19.8h, v19.8h, v14.8h\n"
+        "sqadd v20.8h, v20.8h, v14.8h\n"
+        "sqadd v21.8h, v21.8h, v14.8h\n"
+        "sqadd v22.8h, v22.8h, v14.8h\n"
+        "sqadd v23.8h, v23.8h, v14.8h\n"
+
+        // Cast-and-saturate from int16 to uint8
+        "sqxtn v16.8b, v16.8h\n"
+        "sqxtn2 v16.16b, v17.8h\n"
+        "sqxtn v17.8b, v18.8h\n"
+        "sqxtn2 v17.16b, v19.8h\n"
+        "sqxtn v18.8b, v20.8h\n"
+        "sqxtn2 v18.16b, v21.8h\n"
+        "sqxtn v19.8b, v22.8h\n"
+        "sqxtn2 v19.16b, v23.8h\n"
+
+        // Load the clamp_min, clamp_max bounds
+        "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+        "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+        "dup v14.16b, w2\n"  // clamp_min
+        "dup v15.16b, w3\n"  // clamp_max
+
+        // Apply the clamp_min bound
+        "smax v16.16b, v16.16b, v14.16b\n"
+        "smax v17.16b, v17.16b, v14.16b\n"
+        "smax v18.16b, v18.16b, v14.16b\n"
+        "smax v19.16b, v19.16b, v14.16b\n"
+
+        // Apply the clamp_max bound
+        "smin v16.16b, v16.16b, v15.16b\n"
+        "smin v17.16b, v17.16b, v15.16b\n"
+        "smin v18.16b, v18.16b, v15.16b\n"
+        "smin v19.16b, v19.16b, v15.16b\n"
+
+        // Make it so that all of the final 8bit values are stored in the
+        // first 64bits of 128bit NEON registers, so they can be stored
+        // by 64bit st1 store instructions with byte alignment.
+        "dup d20, v16.d[1]\n"
+        "dup d21, v17.d[1]\n"
+        "dup d22, v18.d[1]\n"
+        "dup d23, v19.d[1]\n"
+
+        // Compute how much of the 8x8 block of destination 8bit values that
+        // we have computed, fit in the destination matrix. Typically, all of
+        // it fits, but when the destination matrix shape is not a multiple
+        // of 8x8, there are some 8x8 blocks along the boundaries that do
+        // not fit entirely.
+        "sub w1, %w[dst_rows], %w[row]\n"
+        "sub w2, %w[dst_cols], %w[col]\n"
+        "mov w3, #8\n"
+        "cmp w1, #8\n"
+        // Compute w1 = how many rows of the 8x8 block fit
+        "csel w1, w1, w3, le\n"
+        "cmp w2, #8\n"
+        // Compute w2 = how many cols of the 8x8 block fit
+        "csel w2, w2, w3, le\n"
+
+        // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+        "cmp w1, w3\n"
+        "ccmp w2, w3, 0, eq\n"
+        // Yes, all of the 8x8 block fits, go to fast path.
+        "beq 130f\n"
+        // Not all of the 8x8 block fits.
+        // Set (x3 address, x4 stride) to write to dst_tmp_buf
+        "mov x3, %[dst_tmp_buf]\n"
+        "mov x4, #8\n"
+        "b 131f\n"
+        "130:\n"
+        // Yes, all of the 8x8 block fits.
+        // Set (x3 address, x4 stride) to write directly to destination matrix.
+        "mov x3, %[dst_ptr]\n"
+        "mov x4, x11\n"
+        "131:\n"
+
+        // Write our 8bit values to the destination described by
+        // (x3 address, x4 stride).
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v16.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v16)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v20.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v20)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v17.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v17)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v21.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v21)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v18.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v18)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v22.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v22)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v19.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v19)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v23.8b}, [x3], x4\n"
+        RUY_MAKE_ZERO(v23)
+
+        // For the next block: perform the first few multiply-adds on the data
+        // that we have already loaded.
+        ".word 0x4f82e010  // sdot v16.4s, v0.16b, v2.4b[0]\n"
+        ".word 0x4fa2e012  // sdot v18.4s, v0.16b, v2.4b[1]\n"
+        ".word 0x4f82e814  // sdot v20.4s, v0.16b, v2.4b[2]\n"
+        ".word 0x4fa2e816  // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+        // If all of the 8x8 block fits, we just finished writing it to the
+        // destination, so we skip the next part.
+        "beq 141f\n"
+        // Not all of the 8x8 block fits in the destination matrix.  We just
+        // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+        // it to copy into the destination matrix the part that fits.
+        "mov x3, %[dst_tmp_buf]\n"
+        "mov x4, %[dst_ptr]\n"
+        "mov w6, #0\n"
+        "150:\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov w5, #0\n"
+        "151:\n"
+        "ldrb w7, [x3, w5, uxtw]\n"
+        "strb w7, [x4, w5, uxtw]\n"
+        "add w5, w5, #1\n"
+        "cmp w5, w1\n"
+        "blt 151b\n"
+        "add w6, w6, #1\n"
+        "add x3, x3, #8\n"
+        "add x4, x4, x11\n"
+        "cmp w6, w2\n"
+        "blt 150b\n"
+        "141:\n"
+        "add %[dst_ptr], %[dst_ptr], #8\n"
+        // At this point we have completely finished writing values to the
+        // destination matrix for the current block.
+
+        "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+        RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+        // Add the destination zero point
+        "dup v14.8h, v13.h[4]\n"
+        "saddw v16.4s, v16.4s, v14.4h\n"
+        "saddw v17.4s, v17.4s, v14.4h\n"
+        "saddw v18.4s, v18.4s, v14.4h\n"
+        "saddw v19.4s, v19.4s, v14.4h\n"
+        "saddw v20.4s, v20.4s, v14.4h\n"
+        "saddw v21.4s, v21.4s, v14.4h\n"
+        "saddw v22.4s, v22.4s, v14.4h\n"
+        "saddw v23.4s, v23.4s, v14.4h\n"
+        "saddw v24.4s, v24.4s, v14.4h\n"
+        "saddw v25.4s, v25.4s, v14.4h\n"
+        "saddw v26.4s, v26.4s, v14.4h\n"
+        "saddw v27.4s, v27.4s, v14.4h\n"
+        "saddw v28.4s, v28.4s, v14.4h\n"
+        "saddw v29.4s, v29.4s, v14.4h\n"
+        "saddw v30.4s, v30.4s, v14.4h\n"
+        "saddw v31.4s, v31.4s, v14.4h\n"
+
+        // Cast-and-saturate from int32 to int16
+        "sqxtn v16.4h, v16.4s\n"
+        "sqxtn2 v16.8h, v17.4s\n"
+        "sqxtn v17.4h, v18.4s\n"
+        "sqxtn2 v17.8h, v19.4s\n"
+        "sqxtn v18.4h, v20.4s\n"
+        "sqxtn2 v18.8h, v21.4s\n"
+        "sqxtn v19.4h, v22.4s\n"
+        "sqxtn2 v19.8h, v23.4s\n"
+        "sqxtn v20.4h, v24.4s\n"
+        "sqxtn2 v20.8h, v25.4s\n"
+        "sqxtn v21.4h, v26.4s\n"
+        "sqxtn2 v21.8h, v27.4s\n"
+        "sqxtn v22.4h, v28.4s\n"
+        "sqxtn2 v22.8h, v29.4s\n"
+        "sqxtn v23.4h, v30.4s\n"
+        "sqxtn2 v23.8h, v31.4s\n"
+
+        // At this point, v24 -- v31 aren't used anymore for the current block,
+        // so we can start clearing these accumulators for the next block
+        // (next iteration of the main loop).
+        RUY_MAKE_ZERO(v24)
+        RUY_MAKE_ZERO(v25)
+        RUY_MAKE_ZERO(v26)
+        RUY_MAKE_ZERO(v27)
+        RUY_MAKE_ZERO(v28)
+        RUY_MAKE_ZERO(v29)
+        RUY_MAKE_ZERO(v30)
+        RUY_MAKE_ZERO(v31)
+
+        // Load the clamp_min, clamp_max bounds
+        "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+        "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+        "dup v14.8h, w2\n"  // clamp_min
+        "dup v15.8h, w3\n"  // clamp_max
+
+        // Apply the clamp_min bound
+        "smax v16.8h, v16.8h, v14.8h\n"
+        "smax v17.8h, v17.8h, v14.8h\n"
+        "smax v18.8h, v18.8h, v14.8h\n"
+        "smax v19.8h, v19.8h, v14.8h\n"
+        "smax v20.8h, v20.8h, v14.8h\n"
+        "smax v21.8h, v21.8h, v14.8h\n"
+        "smax v22.8h, v22.8h, v14.8h\n"
+        "smax v23.8h, v23.8h, v14.8h\n"
+        // Apply the clamp_max bound
+        "smin v16.8h, v16.8h, v15.8h\n"
+        "smin v17.8h, v17.8h, v15.8h\n"
+        "smin v18.8h, v18.8h, v15.8h\n"
+        "smin v19.8h, v19.8h, v15.8h\n"
+        "smin v20.8h, v20.8h, v15.8h\n"
+        "smin v21.8h, v21.8h, v15.8h\n"
+        "smin v22.8h, v22.8h, v15.8h\n"
+        "smin v23.8h, v23.8h, v15.8h\n"
+
+        // Compute how much of the 8x8 block of destination 16bit values that
+        // we have computed, fit in the destination matrix. Typically, all of
+        // it fits, but when the destination matrix shape is not a multiple
+        // of 8x8, there are some 8x8 blocks along the boundaries that do
+        // not fit entirely.
+        "sub w1, %w[dst_rows], %w[row]\n"
+        "sub w2, %w[dst_cols], %w[col]\n"
+        "mov w3, #8\n"
+        "cmp w1, #8\n"
+        // Compute w1 = how many rows of the 8x8 block fit
+        "csel w1, w1, w3, le\n"
+        "cmp w2, #8\n"
+        // Compute w1 = how many rows of the 8x8 block fit
+        "csel w2, w2, w3, le\n"
+
+        // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+        "cmp w1, w3\n"
+        "ccmp w2, w3, 0, eq\n"
+        // Yes, all of the 8x8 block fits, go to fast path.
+        "beq 230f\n"
+        // Not all of the 8x8 block fits.
+        // Set (x3 address, x4 stride) to write to dst_tmp_buf
+        "mov x3, %[dst_tmp_buf]\n"
+        "mov x4, #16\n"
+        "b 231f\n"
+        "230:\n"
+        // Yes, all of the 8x8 block fits.
+        // Set (x3 address, x4 stride) to write directly to destination matrix.
+        "mov x3, %[dst_ptr]\n"
+        "mov x4, x11\n"
+        "231:\n"
+
+        // Write our 16bit values to the destination described by
+        // (x3 address, x4 stride).
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v16.8h}, [x3], x4\n"
+        RUY_MAKE_ZERO(v16)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v17.8h}, [x3], x4\n"
+        RUY_MAKE_ZERO(v17)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v18.8h}, [x3], x4\n"
+        RUY_MAKE_ZERO(v18)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v19.8h}, [x3], x4\n"
+        RUY_MAKE_ZERO(v19)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v20.8h}, [x3], x4\n"
+        RUY_MAKE_ZERO(v20)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v21.8h}, [x3], x4\n"
+        RUY_MAKE_ZERO(v21)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v22.8h}, [x3], x4\n"
+        RUY_MAKE_ZERO(v22)
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "st1 {v23.8h}, [x3], x4\n"
+        RUY_MAKE_ZERO(v23)
+
+        // For the next block: perform the first few multiply-adds on the data
+        // that we have already loaded.
+        ".word 0x4f82e010  // sdot v16.4s, v0.16b, v2.4b[0]\n"
+        ".word 0x4fa2e012  // sdot v18.4s, v0.16b, v2.4b[1]\n"
+        ".word 0x4f82e814  // sdot v20.4s, v0.16b, v2.4b[2]\n"
+        ".word 0x4fa2e816  // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+        // If all of the 8x8 block fits, we just finished writing it to the
+        // destination, so we skip the next part.
+        "beq 241f\n"
+        // Not all of the 8x8 block fits in the destination matrix.  We just
+        // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+        // it to copy into the destination matrix the part that fits.
+        "mov x3, %[dst_tmp_buf]\n"
+        "mov x4, %[dst_ptr]\n"
+        "mov w6, #0\n"
+        "250:\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov w5, #0\n"
+        "251:\n"
+        "ldrsh w7, [x3, x5, lsl #1]\n"
+        "strh w7, [x4, x5, lsl #1]\n"
+        "add w5, w5, #1\n"
+        "cmp w5, w1\n"
+        "blt 251b\n"
+        "add w6, w6, #1\n"
+        "add x3, x3, #16\n"
+        "add x4, x4, x11\n"
+        "cmp w6, w2\n"
+        "blt 250b\n"
+        "241:\n"
+        "add %[dst_ptr], %[dst_ptr], #16\n"
+        // At this point we have completely finished writing values to the
+        // destination matrix for the current block.
+
+        "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+        RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+        // Since the store type is the same as the accum type, no need for
+        // downcast. There's also no need for clamp by min/max.
+
+        // Compute how much of the 8x8 block of destination 32it values that
+        // we have computed, fit in the destination matrix. Typically, all of
+        // it fits, but when the destination matrix shape is not a multiple
+        // of 8x8, there are some 8x8 blocks along the boundaries that do
+        // not fit entirely.
+        "sub w1, %w[dst_rows], %w[row]\n"
+        "sub w2, %w[dst_cols], %w[col]\n"
+        "mov w3, #8\n"
+        "cmp w1, #8\n"
+        // Compute w1 = how many rows of the 8x8 block fit
+        "csel w1, w1, w3, le\n"
+        "cmp w2, #8\n"
+        // Compute w1 = how many rows of the 8x8 block fit
+        "csel w2, w2, w3, le\n"
+
+        // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+        "cmp w1, w3\n"
+        "ccmp w2, w3, 0, eq\n"
+        // Yes, all of the 8x8 block fits, go to fast path.
+        "beq 330f\n"
+        // Not all of the 8x8 block fits.
+        // Write to dst_tmp_buf
+        "mov x3, %[dst_tmp_buf]\n"
+        "st1 {v16.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v16)
+        "st1 {v17.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v17)
+        "st1 {v18.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v18)
+        "st1 {v19.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v19)
+        "st1 {v20.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v20)
+        "st1 {v21.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v21)
+        "st1 {v22.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v22)
+        "st1 {v23.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v23)
+        "st1 {v24.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v24)
+        "st1 {v25.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v25)
+        "st1 {v26.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v26)
+        "st1 {v27.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v27)
+        "st1 {v28.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v28)
+        "st1 {v29.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v29)
+        "st1 {v30.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v30)
+        "st1 {v31.4s}, [x3], #16\n"
+        RUY_MAKE_ZERO(v31)
+
+        "b 331f\n"
+
+        "330:\n"
+        // Yes, all of the 8x8 block fits.
+        "mov x4, %[dst_ptr]\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov x3, x4\n"
+        "st1 {v16.4s, v17.4s}, [x3], #32\n"
+        RUY_MAKE_ZERO(v16)
+        RUY_MAKE_ZERO(v17)
+        "add x4, x4, x11\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov x3, x4\n"
+        "st1 {v18.4s, v19.4s}, [x3], #32\n"
+        RUY_MAKE_ZERO(v18)
+        RUY_MAKE_ZERO(v19)
+        "add x4, x4, x11\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov x3, x4\n"
+        "st1 {v20.4s, v21.4s}, [x3], #32\n"
+        RUY_MAKE_ZERO(v20)
+        RUY_MAKE_ZERO(v21)
+        "add x4, x4, x11\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov x3, x4\n"
+        "st1 {v22.4s, v23.4s}, [x3], #32\n"
+        RUY_MAKE_ZERO(v22)
+        RUY_MAKE_ZERO(v23)
+        "add x4, x4, x11\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov x3, x4\n"
+        "st1 {v24.4s, v25.4s}, [x3], #32\n"
+        RUY_MAKE_ZERO(v24)
+        RUY_MAKE_ZERO(v25)
+        "add x4, x4, x11\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov x3, x4\n"
+        "st1 {v26.4s, v27.4s}, [x3], #32\n"
+        RUY_MAKE_ZERO(v26)
+        RUY_MAKE_ZERO(v27)
+        "add x4, x4, x11\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov x3, x4\n"
+        "st1 {v28.4s, v29.4s}, [x3], #32\n"
+        RUY_MAKE_ZERO(v28)
+        RUY_MAKE_ZERO(v29)
+        "add x4, x4, x11\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov x3, x4\n"
+        "st1 {v30.4s, v31.4s}, [x3], #32\n"
+        RUY_MAKE_ZERO(v30)
+        RUY_MAKE_ZERO(v31)
+
+        "331:\n"
+
+        // For the next block: perform the first few multiply-adds on the data
+        // that we have already loaded.
+        ".word 0x4f82e010  // sdot v16.4s, v0.16b, v2.4b[0]\n"
+        ".word 0x4fa2e012  // sdot v18.4s, v0.16b, v2.4b[1]\n"
+        ".word 0x4f82e814  // sdot v20.4s, v0.16b, v2.4b[2]\n"
+        ".word 0x4fa2e816  // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+        // If all of the 8x8 block fits, we just finished writing it to the
+        // destination, so we skip the next part.
+        "beq 341f\n"
+
+        // Not all of the 8x8 block fits in the destination matrix.  We just
+        // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+        // it to copy into the destination matrix the part that fits.
+        "mov x3, %[dst_tmp_buf]\n"
+        "mov x4, %[dst_ptr]\n"
+        "mov w6, #0\n"
+        "350:\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov w5, #0\n"
+        "351:\n"
+        "ldr w7, [x3, x5, lsl #2]\n"
+        "str w7, [x4, x5, lsl #2]\n"
+        "add w5, w5, #1\n"
+        "cmp w5, w1\n"
+        "blt 351b\n"
+        "add w6, w6, #1\n"
+        "add x3, x3, #32\n"
+        "add x4, x4, x11\n"
+        "cmp w6, w2\n"
+        "blt 350b\n"
+        "341:\n"
+        "add %[dst_ptr], %[dst_ptr], #32\n"
+        // At this point we have completely finished writing values to the
+        // destination matrix for the current block.
+
+        RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+        // Reload some params --- we had used x5 -- x7 for a few other things
+        // since the last time we had loaded them.
+        "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+        "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+        "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+        // Move to the next block of the destination matrix, for the next iter
+        // of the main loop.  Notice that lhs_col_ptr, rhs_col_ptr have already
+        // been updated earlier.
+        // Have we reached the end row?
+        "cmp %w[row], w7\n"
+        "beq 20f\n"  // yes, end row.
+        // Not end row. Move to the next row.
+        "add %w[row], %w[row], #8\n"
+        "b 21f\n"
+        "20:\n"
+        // Was already at end row.
+        "mov %w[row], w6\n"  // Move back to first row.
+        "add %w[col], %w[col], #8\n"  // Move to the next column.
+        "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+        "mov %[dst_ptr], %[dst_col_ptr]\n"
+        "21:\n"
+
+        // Main loop exit condition: have we hit the end column?
+        "cmp %w[col], w8\n"
+
+        // w1 is the number of levels of depth that we have already loaded
+        // LHS and RHS data for. Corresponding to the initial ld1 instructions
+        // above, this is currently 4.
+        "mov w1, #4\n"
+
+        "ble 1b\n"
+
+        // clang-format on
+
+        : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+          [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+          [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+        : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+          [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+          [dst_type_id] "r"(params.dst_type_id)
+        : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+          "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+          "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+          "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+
 // Similar to the above 8-bit dotprod kernel, but specialized for the case of
 // RHS cols == 1.
 // Relevant target CPUs for this kernel include ARM Cortex-A76,
@@ -4692,7 +5947,7 @@
 
         // Add the destination zero point
         "dup v14.8h, v13.h[4]\n"
-        "add v16.8h, v16.8h, v14.8h\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
 
         // Cast-and-saturate from int16 to uint8, leaving all data in the
         // lower half of v16.
@@ -4788,7 +6043,7 @@
 
         // Add the destination zero point
         "dup v14.8h, v13.h[4]\n"
-        "add v16.8h, v16.8h, v14.8h\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
 
         // Cast-and-saturate from int16 to uint8
         "sqxtn v16.8b, v16.8h\n"
@@ -5691,14 +6946,14 @@
         RUY_MAKE_ZERO(v31)
 
         // Add the destination zero point
-        "add v16.8h, v16.8h, v14.8h\n"
-        "add v17.8h, v17.8h, v14.8h\n"
-        "add v18.8h, v18.8h, v14.8h\n"
-        "add v19.8h, v19.8h, v14.8h\n"
-        "add v20.8h, v20.8h, v14.8h\n"
-        "add v21.8h, v21.8h, v14.8h\n"
-        "add v22.8h, v22.8h, v14.8h\n"
-        "add v23.8h, v23.8h, v14.8h\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
+        "sqadd v17.8h, v17.8h, v14.8h\n"
+        "sqadd v18.8h, v18.8h, v14.8h\n"
+        "sqadd v19.8h, v19.8h, v14.8h\n"
+        "sqadd v20.8h, v20.8h, v14.8h\n"
+        "sqadd v21.8h, v21.8h, v14.8h\n"
+        "sqadd v22.8h, v22.8h, v14.8h\n"
+        "sqadd v23.8h, v23.8h, v14.8h\n"
 
         // Load the clamp_min, clamp_max bounds
         "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
@@ -5865,14 +7120,14 @@
         RUY_MAKE_ZERO(v31)
 
         // Add the destination zero point
-        "add v16.8h, v16.8h, v14.8h\n"
-        "add v17.8h, v17.8h, v14.8h\n"
-        "add v18.8h, v18.8h, v14.8h\n"
-        "add v19.8h, v19.8h, v14.8h\n"
-        "add v20.8h, v20.8h, v14.8h\n"
-        "add v21.8h, v21.8h, v14.8h\n"
-        "add v22.8h, v22.8h, v14.8h\n"
-        "add v23.8h, v23.8h, v14.8h\n"
+        "sqadd v16.8h, v16.8h, v14.8h\n"
+        "sqadd v17.8h, v17.8h, v14.8h\n"
+        "sqadd v18.8h, v18.8h, v14.8h\n"
+        "sqadd v19.8h, v19.8h, v14.8h\n"
+        "sqadd v20.8h, v20.8h, v14.8h\n"
+        "sqadd v21.8h, v21.8h, v14.8h\n"
+        "sqadd v22.8h, v22.8h, v14.8h\n"
+        "sqadd v23.8h, v23.8h, v14.8h\n"
 
         // Load the clamp_min, clamp_max bounds
         "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
@@ -7069,6 +8324,472 @@
           "v26", "v27", "v28", "v29", "v30", "v31");
 }
 
+// A fork of the standard float kernel where we omit the manual loop unrolling
+// to recover performance on the X1. For now, the X1 core is the only CPU that
+// uses this kernel.
+void KernelFloatNeonX1(const KernelParamsFloat<8, 8>& params) {
+  CheckOffsetsInKernelParamsFloat(params);
+  profiler::ScopeLabel label("Kernel (kNeon) X1");
+
+  const float* lhs_col_ptr = params.lhs_base_ptr;
+  const float* rhs_col_ptr = params.rhs_base_ptr;
+  const float* lhs_ptr = lhs_col_ptr;
+  const float* rhs_ptr = rhs_col_ptr;
+  float* dst_col_ptr = params.dst_base_ptr;
+  float* dst_ptr = dst_col_ptr;
+  int row = params.start_row;
+  int col = params.start_col;
+
+  // The asm kernel below has the following NEON register allocation:
+  //
+  // v16 -- v31 are accumulators.
+  // During accumulation, v0 -- v15 are used to load data from LHS and RHS.
+  // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and
+  // v3 are used to load a 1x8 block of RHS, like this:
+  //
+  //                                          RHS 1x8 block
+  //                           /-----------------------------------------|
+  //                           |v2.s[0] ... v2.s[3]   v3.s[0] ... v3.s[3]|
+  //                           \-----------------------------------------/
+  //        LHS 8x1 block
+  //  /---------------------\  /-----------------------------------------|
+  //  |        v0.s[0]      |  |v16.s[0]           ...           v30.s[0]|
+  //  |         ...         |  |  ...                              ...   |
+  //  |        v0.s[3]      |  |v16.s[3]           ...           v30.s[3]|
+  //  |        v1.s[0]      |  |v17.s[0]           ...           v31.s[0]|
+  //  |         ...         |  |  ...                              ...   |
+  //  |        v1.s[3]      |  |v17.s[3]           ...           v31.s[3]|
+  //  \---------------------/  \-----------------------------------------/
+  //                                      accumulators 8x8 block
+  //
+  // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
+  // is repeated 4 times, using 4x more registers for LHS and RHS, so that
+  // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
+  //
+  // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
+  // unused, and v8 -- v15 are used for floading parameters used for the
+  // post-accumulation part of the kernel.
+  asm volatile(
+#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
+
+        // clang-format off
+
+        // Load some parameters into registers.
+        "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+        "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+        "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+        "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+        "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+        "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+        "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+        "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+        // Load the first 32 bytes of LHS and RHS data.
+        "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+        "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+        "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+        "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+
+        // Clear accumulators.
+        RUY_MAKE_ZERO(v16)
+        RUY_MAKE_ZERO(v17)
+        RUY_MAKE_ZERO(v18)
+        RUY_MAKE_ZERO(v19)
+        RUY_MAKE_ZERO(v20)
+        RUY_MAKE_ZERO(v21)
+        RUY_MAKE_ZERO(v22)
+        RUY_MAKE_ZERO(v23)
+        RUY_MAKE_ZERO(v24)
+        RUY_MAKE_ZERO(v25)
+        RUY_MAKE_ZERO(v26)
+        RUY_MAKE_ZERO(v27)
+        RUY_MAKE_ZERO(v28)
+        RUY_MAKE_ZERO(v29)
+        RUY_MAKE_ZERO(v30)
+        RUY_MAKE_ZERO(v31)
+
+        // w1 is the number of levels of depth that we have already loaded
+        // LHS and RHS data for. Corresponding to the initial ld1 instructions
+        // above, this is currently 1.
+        "mov w1, #1\n"
+
+        // Main loop of the whole GEMM, over rows and columns of the
+        // destination matrix.
+        "1:\n"
+
+        "fmla v16.4s, v0.4s, v2.s[0]\n"
+        "fmla v18.4s, v0.4s, v2.s[1]\n"
+        "fmla v20.4s, v0.4s, v2.s[2]\n"
+        "fmla v22.4s, v0.4s, v2.s[3]\n"
+
+        // Accumulation loop
+        "cmp w1, w12\n"
+        "beq 79f\n"
+
+        "2:\n"
+        "fmla v24.4s, v0.4s, v3.s[0]\n"
+        "fmla v26.4s, v0.4s, v3.s[1]\n"
+        "ld1 {v4.4s}, [%[rhs_ptr]], #16\n"
+        "fmla v28.4s, v0.4s, v3.s[2]\n"
+        "fmla v30.4s, v0.4s, v3.s[3]\n"
+        "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+        "fmla v25.4s, v1.4s, v3.s[0]\n"
+        "fmla v27.4s, v1.4s, v3.s[1]\n"
+        "add w1, w1, #1\n"
+        "fmla v29.4s, v1.4s, v3.s[2]\n"
+        "fmla v31.4s, v1.4s, v3.s[3]\n"
+        "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+        "fmla v17.4s, v1.4s, v2.s[0]\n"
+        "fmla v19.4s, v1.4s, v2.s[1]\n"
+        "cmp w1, w12\n"
+        "fmla v21.4s, v1.4s, v2.s[2]\n"
+        "fmla v23.4s, v1.4s, v2.s[3]\n"
+        "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+        "fmla v16.4s, v0.4s, v4.s[0]\n"
+        "fmla v18.4s, v0.4s, v4.s[1]\n"
+        "mov v2.16b, v4.16b\n"
+        "fmla v20.4s, v0.4s, v4.s[2]\n"
+        "fmla v22.4s, v0.4s, v4.s[3]\n"
+        "blt 2b\n"
+
+        "79:\n"
+
+        // End of the inner loop on depth. Now perform the remaining
+        // multiply-adds of the last level of depth, for which the LHS
+        // and RHS data is already loaded.
+
+        "fmla v24.4s, v0.4s, v3.s[0]\n"
+        "fmla v26.4s, v0.4s, v3.s[1]\n"
+        "fmla v28.4s, v0.4s, v3.s[2]\n"
+        "fmla v30.4s, v0.4s, v3.s[3]\n"
+        "fmla v25.4s, v1.4s, v3.s[0]\n"
+        "fmla v27.4s, v1.4s, v3.s[1]\n"
+        "fmla v29.4s, v1.4s, v3.s[2]\n"
+        "fmla v31.4s, v1.4s, v3.s[3]\n"
+        "fmla v17.4s, v1.4s, v2.s[0]\n"
+        "fmla v19.4s, v1.4s, v2.s[1]\n"
+        "fmla v21.4s, v1.4s, v2.s[2]\n"
+        "fmla v23.4s, v1.4s, v2.s[3]\n"
+
+        // End of accumulation. The registers v16 -- v31 contain the final
+        // int32 accumulator values of the current 8x8 destination block.
+        // We now have to compute the final 8-bit values from these int32
+        // accumulators, and advance to the next 8x8 block. We intertwine
+        // these two aspects whenever possible for optimal pipelining, both
+        // at the data flow level (prefetch data for next block as early as
+        // possible) and instruction pipelining level (some of the next-block
+        // work can dual-issue with some of the final work on the current
+        // block).
+
+        // Logic to advance to the next block in preparation for the next
+        // iteration of the main loop. For now, we only want to compute
+        // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+        // not yet ready to update the values of row and col, as we still need
+        // the current values for the rest of the work on the current block.
+
+        "cmp %w[row], w7\n"  // Have we finished the last row?
+        "bge 4f\n"           // If finished last row, go to 4
+        // Not finished last row: then advance to next row.
+        "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+        "b 5f\n"
+        "4:\n"  // Finished last row...
+        "mov %[lhs_col_ptr], x5\n"  // Go back to first row
+        // Now we need to advance to the next column. If we already
+        // finished the last column, then in principle we are done, however
+        // we can't just return here, as we need to allow the end work of the
+        // current block to complete. The good news is that at this point it
+        // doesn't matter what data we load for the next column, since
+        // we will exit from the main loop below before actually storing
+        // anything computed from that data.
+        "cmp %w[col], w8\n"  // Have we finished the last column?
+        "bge 5f\n" // If yes, just carry on without updating the column pointer.
+        // Not finished last column: then advance to next column.
+        "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+        "5:\n"
+
+        // Set the LHS and RHS data pointers to the start of the columns just
+        // computed.
+        "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+        "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+        // Load some parameters needed for the end work on current block.
+        "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+        "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+        // Determine the channel index.
+        "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "csel w3, %w[row], %w[col], eq\n"
+
+        // Offset the bias pointer as needed given the current row, col.
+        "add x5, x1, x3, lsl #2\n"
+
+        // If there is no bias, use no offset, just address the passed zero
+        // data.
+        "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+        "csel x1, x1, x5, eq\n"
+
+        // Load 8 bias values.
+        "ld1 {v14.4s}, [x1], #16\n"
+        "ld1 {v15.4s}, [x1]\n"
+
+        // Now that we know what LHS and RHS data the next iteration of the
+        // main loop will need to load, we start loading the first 32 bytes of
+        // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+        // in the rest of the work on the current block.
+        "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+        "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+        "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+        "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+
+        // Perform the bias-addition.
+        // Jump based on channel dimension.
+        "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "bne 6f\n"
+        // Case where channels are rows
+        "fadd v16.4s, v16.4s, v14.4s\n"
+        "fadd v17.4s, v17.4s, v15.4s\n"
+        "fadd v18.4s, v18.4s, v14.4s\n"
+        "fadd v19.4s, v19.4s, v15.4s\n"
+        "fadd v20.4s, v20.4s, v14.4s\n"
+        "fadd v21.4s, v21.4s, v15.4s\n"
+        "fadd v22.4s, v22.4s, v14.4s\n"
+        "fadd v23.4s, v23.4s, v15.4s\n"
+        "fadd v24.4s, v24.4s, v14.4s\n"
+        "fadd v25.4s, v25.4s, v15.4s\n"
+        "fadd v26.4s, v26.4s, v14.4s\n"
+        "fadd v27.4s, v27.4s, v15.4s\n"
+        "fadd v28.4s, v28.4s, v14.4s\n"
+        "fadd v29.4s, v29.4s, v15.4s\n"
+        "fadd v30.4s, v30.4s, v14.4s\n"
+        "fadd v31.4s, v31.4s, v15.4s\n"
+        "b 7f\n"
+
+        "6:\n"
+        // Case where channels are columns
+        "dup v8.4s, v14.s[0]\n"
+        "dup v9.4s, v14.s[1]\n"
+        "dup v10.4s, v14.s[2]\n"
+        "dup v11.4s, v14.s[3]\n"
+        "dup v12.4s, v15.s[0]\n"
+        "dup v13.4s, v15.s[1]\n"
+        "dup v14.4s, v15.s[2]\n"
+        "dup v15.4s, v15.s[3]\n"
+        "fadd v16.4s, v16.4s, v8.4s\n"
+        "fadd v17.4s, v17.4s, v8.4s\n"
+        "fadd v18.4s, v18.4s, v9.4s\n"
+        "fadd v19.4s, v19.4s, v9.4s\n"
+        "fadd v20.4s, v20.4s, v10.4s\n"
+        "fadd v21.4s, v21.4s, v10.4s\n"
+        "fadd v22.4s, v22.4s, v11.4s\n"
+        "fadd v23.4s, v23.4s, v11.4s\n"
+        "fadd v24.4s, v24.4s, v12.4s\n"
+        "fadd v25.4s, v25.4s, v12.4s\n"
+        "fadd v26.4s, v26.4s, v13.4s\n"
+        "fadd v27.4s, v27.4s, v13.4s\n"
+        "fadd v28.4s, v28.4s, v14.4s\n"
+        "fadd v29.4s, v29.4s, v14.4s\n"
+        "fadd v30.4s, v30.4s, v15.4s\n"
+        "fadd v31.4s, v31.4s, v15.4s\n"
+        "7:\n"
+
+        // Load the clamp_min, clamp_max bounds
+        "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+        "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+        "dup v14.4s, w2\n"  // clamp_min
+        "dup v15.4s, w3\n"  // clamp_max
+
+        // Apply the clamp_min bound
+        "fmax v16.4s, v16.4s, v14.4s\n"
+        "fmax v17.4s, v17.4s, v14.4s\n"
+        "fmax v18.4s, v18.4s, v14.4s\n"
+        "fmax v19.4s, v19.4s, v14.4s\n"
+        "fmax v20.4s, v20.4s, v14.4s\n"
+        "fmax v21.4s, v21.4s, v14.4s\n"
+        "fmax v22.4s, v22.4s, v14.4s\n"
+        "fmax v23.4s, v23.4s, v14.4s\n"
+        "fmax v24.4s, v24.4s, v14.4s\n"
+        "fmax v25.4s, v25.4s, v14.4s\n"
+        "fmax v26.4s, v26.4s, v14.4s\n"
+        "fmax v27.4s, v27.4s, v14.4s\n"
+        "fmax v28.4s, v28.4s, v14.4s\n"
+        "fmax v29.4s, v29.4s, v14.4s\n"
+        "fmax v30.4s, v30.4s, v14.4s\n"
+        "fmax v31.4s, v31.4s, v14.4s\n"
+
+        // Apply the clamp_max bound
+        "fmin v16.4s, v16.4s, v15.4s\n"
+        "fmin v17.4s, v17.4s, v15.4s\n"
+        "fmin v18.4s, v18.4s, v15.4s\n"
+        "fmin v19.4s, v19.4s, v15.4s\n"
+        "fmin v20.4s, v20.4s, v15.4s\n"
+        "fmin v21.4s, v21.4s, v15.4s\n"
+        "fmin v22.4s, v22.4s, v15.4s\n"
+        "fmin v23.4s, v23.4s, v15.4s\n"
+        "fmin v24.4s, v24.4s, v15.4s\n"
+        "fmin v25.4s, v25.4s, v15.4s\n"
+        "fmin v26.4s, v26.4s, v15.4s\n"
+        "fmin v27.4s, v27.4s, v15.4s\n"
+        "fmin v28.4s, v28.4s, v15.4s\n"
+        "fmin v29.4s, v29.4s, v15.4s\n"
+        "fmin v30.4s, v30.4s, v15.4s\n"
+        "fmin v31.4s, v31.4s, v15.4s\n"
+
+        // Compute how much of the 8x8 block of destination 8bit values that
+        // we have computed, fit in the destination matrix. Typically, all of
+        // it fits, but when the destination matrix shape is not a multiple
+        // of 8x8, there are some 8x8 blocks along the boundaries that do
+        // not fit entirely.
+        "sub w1, %w[dst_rows], %w[row]\n"
+        "sub w2, %w[dst_cols], %w[col]\n"
+        "mov w3, #8\n"
+        "cmp w1, #8\n"
+        // Compute w1 = how many rows of the 8x8 block fit
+        "csel w1, w1, w3, le\n"
+        "cmp w2, #8\n"
+        // Compute w2 = how many cols of the 8x8 block fit
+        "csel w2, w2, w3, le\n"
+
+        // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+        "cmp w1, w3\n"
+        "ccmp w2, w3, 0, eq\n"
+        // Yes, all of the 8x8 block fits, go to fast path.
+        "beq 30f\n"
+        // Not all of the 8x8 block fits.
+        // Set (x3 address, x4 stride) to write to dst_tmp_buf
+        "mov x3, %[dst_tmp_buf]\n"
+        "mov x4, #32\n"
+        "b 31f\n"
+        "30:\n"
+        // Yes, all of the 8x8 block fits.
+        // Set (x3 address, x4 stride) to write directly to destination matrix.
+        "mov x3, %[dst_ptr]\n"
+        "mov x4, x11\n"
+        "31:\n"
+
+        // Write our 8bit values to the destination described by
+        // (x3 address, x4 stride).
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        "str q16, [x3, #0]\n"
+        "str q17, [x3, #16]\n"
+        "add x3, x3, x4\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        RUY_MAKE_ZERO(v16)
+        RUY_MAKE_ZERO(v17)
+        "str q18, [x3, #0]\n"
+        "str q19, [x3, #16]\n"
+        "add x3, x3, x4\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        RUY_MAKE_ZERO(v18)
+        RUY_MAKE_ZERO(v19)
+        "str q20, [x3, #0]\n"
+        "str q21, [x3, #16]\n"
+        "add x3, x3, x4\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        RUY_MAKE_ZERO(v20)
+        RUY_MAKE_ZERO(v21)
+        "str q22, [x3, #0]\n"
+        "str q23, [x3, #16]\n"
+        "add x3, x3, x4\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        RUY_MAKE_ZERO(v22)
+        RUY_MAKE_ZERO(v23)
+        "str q24, [x3, #0]\n"
+        "str q25, [x3, #16]\n"
+        "add x3, x3, x4\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        RUY_MAKE_ZERO(v24)
+        RUY_MAKE_ZERO(v25)
+        "str q26, [x3, #0]\n"
+        "str q27, [x3, #16]\n"
+        "add x3, x3, x4\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        RUY_MAKE_ZERO(v26)
+        RUY_MAKE_ZERO(v27)
+        "str q28, [x3, #0]\n"
+        "str q29, [x3, #16]\n"
+        "add x3, x3, x4\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+        RUY_MAKE_ZERO(v28)
+        RUY_MAKE_ZERO(v29)
+        "str q30, [x3, #0]\n"
+        "str q31, [x3, #16]\n"
+        RUY_MAKE_ZERO(v30)
+        RUY_MAKE_ZERO(v31)
+
+        // If all of the 8x8 block fits, we just finished writing it to the
+        // destination, so we skip the next part.
+        "beq 41f\n"
+        // Not all of the 8x8 block fits in the destination matrix.  We just
+        // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+        // it to copy into the destination matrix the part that fits.
+        "mov x3, %[dst_tmp_buf]\n"
+        "mov x4, %[dst_ptr]\n"
+        "mov w6, #0\n"
+        "50:\n"
+        RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+        "mov w5, #0\n"
+        "51:\n"
+        "ldr w7, [x3, x5, lsl #2]\n"
+        "str w7, [x4, x5, lsl #2]\n"
+        "add w5, w5, #1\n"
+        "cmp w5, w1\n"
+        "blt 51b\n"
+        "add w6, w6, #1\n"
+        "add x3, x3, #32\n"
+        "add x4, x4, x11\n"
+        "cmp w6, w2\n"
+        "blt 50b\n"
+        "41:\n"
+        "add %[dst_ptr], %[dst_ptr], #32\n"
+        // At this point we have completely finished writing values to the
+        // destination matrix for the current block.
+
+        // Reload some params --- we had used x5 -- x7 for a few other things
+        // since the last time we had loaded them.
+        "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+        "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+        "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+        // Move to the next block of the destination matrix, for the next iter
+        // of the main loop.  Notice that lhs_col_ptr, rhs_col_ptr have already
+        // been updated earlier.
+        // Have we reached the end row?
+        "cmp %w[row], w7\n"
+        "beq 20f\n"  // yes, end row.
+        // Not end row. Move to the next row.
+        "add %w[row], %w[row], #8\n"
+        "b 21f\n"
+        "20:\n"
+        // Was already at end row.
+        "mov %w[row], w6\n"  // Move back to first row.
+        "add %w[col], %w[col], #8\n"  // Move to the next column.
+        "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+        "mov %[dst_ptr], %[dst_col_ptr]\n"
+        "21:\n"
+
+        // Main loop exit condition: have we hit the end column?
+        "cmp %w[col], w8\n"
+
+        // w1 is the number of levels of depth that we have already loaded
+        // LHS and RHS data for. Corresponding to the initial ld1 instructions
+        // above, this is currently 1.
+        "mov w1, #1\n"
+
+        "ble 1b\n"
+
+        // clang-format on
+
+        : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+          [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+          [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+        : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+          [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+        : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+          "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+          "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+          "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
 // Variant of KernelFloatNeon tuned for in-order CPUs that do not
 // support dotprod (while dotprod by itself is not relevant to floating-point,
 // this additional bit of information that we have about the target happens to
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
index fddb482..84b9380 100644
--- a/ruy/kernel_avx512.cc
+++ b/ruy/kernel_avx512.cc
@@ -52,45 +52,6 @@
 
 #else  // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
 
-namespace {
-namespace intrin_utils {
-
-__m256i mm256_blendv_epi64(const __m256i& a, const __m256i& b,
-                           const __m256i& mask) {
-  __m256d result =
-      _mm256_blendv_pd(_mm256_castsi256_pd(a), _mm256_castsi256_pd(b),
-                       _mm256_castsi256_pd(mask));
-  return _mm256_castpd_si256(result);
-}
-
-__m512i mm512_blendv_epi64(const __m512i& a, const __m512i& b,
-                           const __m512i& mask) {
-  __m256i a_lo = _mm512_extracti64x4_epi64(a, 0);
-  __m256i a_hi = _mm512_extracti64x4_epi64(a, 1);
-  __m256i b_lo = _mm512_extracti64x4_epi64(b, 0);
-  __m256i b_hi = _mm512_extracti64x4_epi64(b, 1);
-  __m256i mask_lo = _mm512_extracti64x4_epi64(mask, 0);
-  __m256i mask_hi = _mm512_extracti64x4_epi64(mask, 1);
-  __m256i lo = mm256_blendv_epi64(a_lo, b_lo, mask_lo);
-  __m256i hi = mm256_blendv_epi64(a_hi, b_hi, mask_hi);
-  __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0);
-  return _mm512_inserti64x4(result, hi, 1);
-}
-
-__m512i mm512_cmpgt_epi64(const __m512i& a, const __m512i& b) {
-  __m256i a_lo = _mm512_extracti64x4_epi64(a, 0);
-  __m256i a_hi = _mm512_extracti64x4_epi64(a, 1);
-  __m256i b_lo = _mm512_extracti64x4_epi64(b, 0);
-  __m256i b_hi = _mm512_extracti64x4_epi64(b, 1);
-  __m256i lo = _mm256_cmpgt_epi64(a_lo, b_lo);
-  __m256i hi = _mm256_cmpgt_epi64(a_hi, b_hi);
-  __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0);
-  return _mm512_inserti64x4(result, hi, 1);
-}
-
-}  // namespace intrin_utils
-}  // namespace
-
 void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
   profiler::ScopeLabel label("Kernel kAvx512 8-bit");
 
@@ -391,13 +352,13 @@
           // Construct the "nudge" value for each lane if the exponent is
           // greater than 0. Otherwise, the nudge is 0.
           const __m512i zeros = _mm512_setzero_si512();
-          const __m512i mask_rightshift_gtz =
-              intrin_utils::mm512_cmpgt_epi64(exponent, zeros);
+          const auto mask_rightshift_gtz =
+              _mm512_cmpgt_epi64_mask(exponent, zeros);
           const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64(
               _mm512_set1_epi64(1),
               _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
-          __m512i nudge = intrin_utils::mm512_blendv_epi64(
-              zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+          __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz,
+                                                one_shift_exp_minus1);
           // Calculate the shifted sum (results + nudge) >> exp.
           const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
           const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
@@ -406,14 +367,12 @@
           const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
               _mm512_set1_epi64(1),
               _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
-          const __m512i mask_num_plus_nudge_overflow =
-              intrin_utils::mm512_cmpgt_epi64(
-                  results,
-                  _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
+          const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask(
+              results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
           // Fill results with either (results + nudge) >> exponent or
           // 1 << (31 - exp) in the case of overflow.
-          results = intrin_utils::mm512_blendv_epi64(
-              shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+          results = _mm512_mask_mov_epi64(
+              shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp);
         };
 
         if (per_column_multiplier) {
@@ -424,8 +383,8 @@
                 _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift);
             __m512i m_64bit_val = _mm512_permutexvar_epi64(
                 perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high);
-            __m512i offset_vector_val = _mm512_permutexvar_epi64(
-                perm_64bit_vals, offset_vector);
+            __m512i offset_vector_val =
+                _mm512_permutexvar_epi64(perm_64bit_vals, offset_vector);
             __m512i final_right_shift_val = _mm512_permutexvar_epi64(
                 perm_64bit_vals,
                 col < 8 ? final_right_shift_low : final_right_shift_high);
@@ -802,13 +761,13 @@
         // Construct the "nudge" value for each lane if the exponent is
         // greater than 0. Otherwise, the nudge is 0.
         const __m512i zeros = _mm512_setzero_si512();
-        const __m512i mask_rightshift_gtz =
-            intrin_utils::mm512_cmpgt_epi64(exponent, zeros);
+        const auto mask_rightshift_gtz =
+            _mm512_cmpgt_epi64_mask(exponent, zeros);
         const __m512i one_shift_exp_minus1 =
             _mm512_sllv_epi64(_mm512_set1_epi64(1),
                               _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
-        __m512i nudge = intrin_utils::mm512_blendv_epi64(
-            zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+        __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz,
+                                              one_shift_exp_minus1);
         // Calculate the shifted sum (results + nudge) >> exp.
         const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
         const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
@@ -817,14 +776,12 @@
         const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
             _mm512_set1_epi64(1),
             _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
-        const __m512i mask_num_plus_nudge_overflow =
-            intrin_utils::mm512_cmpgt_epi64(
-                results,
-                _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
+        const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask(
+            results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
         // Fill results with either (results + nudge) >> exponent or
         // 1 << (31 - exp) in the case of overflow.
-        results = intrin_utils::mm512_blendv_epi64(
-            shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+        results = _mm512_mask_mov_epi64(
+            shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp);
       };
 
       // Shift and round column 0.
@@ -930,9 +887,8 @@
       float* dst_ptr = dst_col_ptr + row;
 
       // Process block in two halves, split by columns.
-      {
-        constexpr int mmm = 0;
-
+#pragma unroll(1)
+      for (int mmm = 0; mmm < 2; ++mmm) {
         __m512 accum_data_v0;
         __m512 accum_data_v1;
         __m512 accum_data_v2;
@@ -972,81 +928,49 @@
         const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
         for (int d = 0; d < (params.depth - 1); ++d) {
           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
-          // In this version RHS values are loaded individually rather than
-          // first loading together and then extract with broadcasting. This is
-          // because AVX flavours and instrinsics and compilers in combination
-          // do not handle this pattern of extraction very well.
           const float* rhs_data = rhs_ptr;
           lhs_ptr += 16;
           rhs_ptr += 16;
 
-          {
-            // Load 8 float32 values.
-            __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
-            __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0);  // [0 1 2 3] X 4
-            __m512 rhs4_7 =
-                _mm512_shuffle_f32x4(rhs, rhs, 0x55);  // [4 5 6 7] X 4
-
-            const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
-            accum_data_v0 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
-            const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
-            accum_data_v1 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
-            const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
-            accum_data_v2 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
-            const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
-            accum_data_v3 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
-            const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
-            accum_data_v4 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
-            const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
-            accum_data_v5 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
-            const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
-            accum_data_v6 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
-            const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
-            accum_data_v7 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
-          }
+          // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast:
+          // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do
+          // so if given an rvalue.
+          accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
+                                          accum_data_v0);
+          accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
+                                          accum_data_v1);
+          accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
+                                          accum_data_v2);
+          accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
+                                          accum_data_v3);
+          accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
+                                          accum_data_v4);
+          accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
+                                          accum_data_v5);
+          accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
+                                          accum_data_v6);
+          accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
+                                          accum_data_v7);
         }
-        {
+        {  // nested extra blocks lead to measurable speed gains
           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
           const float* rhs_data = rhs_ptr;
-          {
-            // Load 8 float32 values.
-            __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
-            __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0);  // [0 1 2 3] X 4
-            __m512 rhs4_7 =
-                _mm512_shuffle_f32x4(rhs, rhs, 0x55);  // [4 5 6 7] X 4
-            const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
-            accum_data_v0 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
-            const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
-            accum_data_v1 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
-            const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
-            accum_data_v2 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
-            const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
-            accum_data_v3 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
-            const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
-            accum_data_v4 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
-            const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
-            accum_data_v5 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
-            const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
-            accum_data_v6 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
-            const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
-            accum_data_v7 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
-          }
+          accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
+                                          accum_data_v0);
+          accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
+                                          accum_data_v1);
+          accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
+                                          accum_data_v2);
+          accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
+                                          accum_data_v3);
+          accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
+                                          accum_data_v4);
+          accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
+                                          accum_data_v5);
+          accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
+                                          accum_data_v6);
+          accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
+                                          accum_data_v7);
           {
             float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
             accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
@@ -1075,147 +999,7 @@
             _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
           }
         }
-      }  // Inner half-block loop, unrolled, first iteration.
-      {
-        constexpr int mmm = 1;
-
-        __m512 accum_data_v0;
-        __m512 accum_data_v1;
-        __m512 accum_data_v2;
-        __m512 accum_data_v3;
-        __m512 accum_data_v4;
-        __m512 accum_data_v5;
-        __m512 accum_data_v6;
-        __m512 accum_data_v7;
-
-        // Initialize with bias.
-        if (channel_dimension_is_col) {
-          const float* bias_elem_ptr =
-              bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
-          accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
-          accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
-          accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
-          accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
-          accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
-          accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
-          accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
-          accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
-        } else {
-          const __m512 initial_accum_data =
-              _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
-
-          accum_data_v0 = initial_accum_data;
-          accum_data_v1 = initial_accum_data;
-          accum_data_v2 = initial_accum_data;
-          accum_data_v3 = initial_accum_data;
-          accum_data_v4 = initial_accum_data;
-          accum_data_v5 = initial_accum_data;
-          accum_data_v6 = initial_accum_data;
-          accum_data_v7 = initial_accum_data;
-        }
-        const float* lhs_ptr = lhs_col_ptr;
-        const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
-        for (int d = 0; d < (params.depth - 1); ++d) {
-          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
-          const float* rhs_data = rhs_ptr;
-          lhs_ptr += 16;
-          rhs_ptr += 16;
-          {
-            // Load 8 float32 values.
-            __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
-            __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0);  // [0 1 2 3] X 4
-            __m512 rhs4_7 =
-                _mm512_shuffle_f32x4(rhs, rhs, 0x55);  // [4 5 6 7] X 4
-
-            const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
-            accum_data_v0 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
-            const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
-            accum_data_v1 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
-            const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
-            accum_data_v2 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
-            const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
-            accum_data_v3 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
-            const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
-            accum_data_v4 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
-            const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
-            accum_data_v5 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
-            const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
-            accum_data_v6 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
-            const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
-            accum_data_v7 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
-          }
-        }
-        {
-          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
-          const float* rhs_data = rhs_ptr;
-          {
-            // Load 8 float32 values.
-            __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
-            __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0);  // [0 1 2 3] X 4
-            __m512 rhs4_7 =
-                _mm512_shuffle_f32x4(rhs, rhs, 0x55);  // [4 5 6 7] X 4
-            const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
-            accum_data_v0 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
-            const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
-            accum_data_v1 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
-            const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
-            accum_data_v2 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
-            const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
-            accum_data_v3 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
-            const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
-            accum_data_v4 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
-            const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
-            accum_data_v5 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
-            const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
-            accum_data_v6 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
-            const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
-            accum_data_v7 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
-          }
-          {
-            float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
-            accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
-            accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
-            _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
-            accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
-            accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
-            _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
-            accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
-            accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
-            _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
-            accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
-            accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
-            _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
-            accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
-            accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
-            _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
-            accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
-            accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
-            _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
-            accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
-            accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
-            _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
-            accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
-            accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
-            _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
-          }
-        }
-      }  // Inner half-block loop, unrolled, second iteration.
+      }
     }    // End row-block loop.
 
     // The unrolling within this conditional may be somewhat pointless. It
@@ -1273,73 +1057,45 @@
           const float* rhs_data = rhs_ptr;
           lhs_ptr += 16;
           rhs_ptr += 16;
-          {
-            // Load 8 float32 values.
-            __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
-            __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0);  // [0 1 2 3] X 4
-            __m512 rhs4_7 =
-                _mm512_shuffle_f32x4(rhs, rhs, 0x55);  // [4 5 6 7] X 4
-
-            const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
-            accum_data_v0 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
-            const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
-            accum_data_v1 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
-            const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
-            accum_data_v2 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
-            const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
-            accum_data_v3 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
-            const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
-            accum_data_v4 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
-            const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
-            accum_data_v5 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
-            const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
-            accum_data_v6 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
-            const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
-            accum_data_v7 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
-          }
+          // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast:
+          // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do
+          // so if given an rvalue.
+          accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
+                                          accum_data_v0);
+          accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
+                                          accum_data_v1);
+          accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
+                                          accum_data_v2);
+          accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
+                                          accum_data_v3);
+          accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
+                                          accum_data_v4);
+          accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
+                                          accum_data_v5);
+          accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
+                                          accum_data_v6);
+          accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
+                                          accum_data_v7);
         }
         {
           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
           const float* rhs_data = rhs_ptr;
-          {
-            // Load 8 float32 values.
-            __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
-            __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0);  // [0 1 2 3] X 4
-            __m512 rhs4_7 =
-                _mm512_shuffle_f32x4(rhs, rhs, 0x55);  // [4 5 6 7] X 4
-            const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
-            accum_data_v0 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
-            const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
-            accum_data_v1 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
-            const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
-            accum_data_v2 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
-            const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
-            accum_data_v3 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
-            const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
-            accum_data_v4 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
-            const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
-            accum_data_v5 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
-            const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
-            accum_data_v6 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
-            const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
-            accum_data_v7 =
-                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
-          }
+          accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
+                                          accum_data_v0);
+          accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
+                                          accum_data_v1);
+          accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
+                                          accum_data_v2);
+          accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
+                                          accum_data_v3);
+          accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
+                                          accum_data_v4);
+          accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
+                                          accum_data_v5);
+          accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
+                                          accum_data_v6);
+          accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
+                                          accum_data_v7);
           {
             float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
             accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h
index 9509b8f..cff243b 100644
--- a/ruy/kernel_common.h
+++ b/ruy/kernel_common.h
@@ -177,6 +177,8 @@
   params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth;
   params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
   if (mul_params.multiplier_fixedpoint_perchannel()) {
+    // Temporary release-assert to debug some crashes in an application.
+    RUY_CHECK(mul_params.multiplier_exponent_perchannel());
     params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL;
     params->multiplier_fixedpoint =
         mul_params.multiplier_fixedpoint_perchannel();
@@ -200,6 +202,11 @@
   params->dst_type_id = DstTypeId<DstScalar>::kValue;
   params->dst_base_ptr =
       dst->data.get() + start_col * dst->layout.stride + start_row;
+
+  // Temporary release-asserts to debug some crashes in an application.
+  RUY_CHECK(params->multiplier_fixedpoint);
+  RUY_CHECK(params->multiplier_exponent);
+  RUY_CHECK(params->bias);
 }
 
 template <int LhsCols, int RhsCols>
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
index 2f8fe19..b716502 100644
--- a/ruy/kernel_x86.h
+++ b/ruy/kernel_x86.h
@@ -607,14 +607,12 @@
       const float* rhs_ptr = rhs_col_ptr;
       for (int d = 0; d < params.depth; ++d) {
         const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
-        const float* rhs_data = rhs_ptr;
-        // Load 8 RHS values, then use permute instructions to
-        // broadcast each value to a register.
-        __m256 rhs1 = _mm256_loadu_ps(rhs_data);  // Load [0 1 2 3 4 5 6 7]
+        // Load 8 RHS values, then use permute instructions to broadcast each
+        // value to a register. _mm256_permute2f128_ps is slow on AMD.
         __m256 rhs0_3 =
-            _mm256_permute2f128_ps(rhs1, rhs1, 0);  // [0 1 2 3 0 1 2 3]
+            _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
         __m256 rhs4_7 =
-            _mm256_permute2f128_ps(rhs1, rhs1, 17);  // [4 5 6 7 4 5 6 7]
+            _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4));
 
         const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
         accum_data_v[0] = intrin_utils::MulAdd<path>(
@@ -707,13 +705,11 @@
       const float* rhs_ptr = rhs_col_ptr;
       for (int d = 0; d < params.depth; ++d) {
         const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
-        const float* rhs_data = rhs_ptr;
 
-        __m256 rhs1 = _mm256_loadu_ps(rhs_data);  // Load [0 1 2 3 4 5 6 7]
         __m256 rhs0_3 =
-            _mm256_permute2f128_ps(rhs1, rhs1, 0);  // [0 1 2 3 0 1 2 3]
+            _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
         __m256 rhs4_7 =
-            _mm256_permute2f128_ps(rhs1, rhs1, 17);  // [4 5 6 7 4 5 6 7]
+            _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4));
 
         const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
         accum_data_v[0] = intrin_utils::MulAdd<path>(
diff --git a/ruy/mat.h b/ruy/mat.h
index 587b208..c2254f9 100644
--- a/ruy/mat.h
+++ b/ruy/mat.h
@@ -327,7 +327,7 @@
   return layout.order == Order::kColMajor;
 }
 
-inline int FlatSize(const MatLayout& layout) {
+inline std::ptrdiff_t FlatSize(const MatLayout& layout) {
   const int outerdim =
       layout.order == Order::kColMajor ? layout.cols : layout.rows;
   return layout.stride * outerdim;
@@ -349,7 +349,7 @@
   return layout.order == Order::kColMajor;
 }
 
-inline int FlatSize(const PMatLayout& layout) {
+inline std::ptrdiff_t FlatSize(const PMatLayout& layout) {
   const int outerdim =
       layout.order == Order::kColMajor ? layout.cols : layout.rows;
   return layout.stride * outerdim;
@@ -429,11 +429,11 @@
 
 // Helpers for PEMat.
 
-inline int DataBytes(const PEMat& packed) {
+inline std::ptrdiff_t DataBytes(const PEMat& packed) {
   return FlatSize(packed.layout) * packed.data_type.size;
 }
 
-inline int SumsBytes(const PEMat& packed) {
+inline std::ptrdiff_t SumsBytes(const PEMat& packed) {
   // Packed matrices are only relevant for Ruy's TrMul implementations. For
   // TrMul, the number of sums is always equal to the number of columns.
   return packed.layout.cols * packed.sums_type.size;
diff --git a/ruy/mul_params.h b/ruy/mul_params.h
index d5aa27b..42a5700 100644
--- a/ruy/mul_params.h
+++ b/ruy/mul_params.h
@@ -103,14 +103,9 @@
   // The bias vector data, if not null.
   const AccumScalar* bias() const { return storage_.bias; }
   void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; }
-  // Only for non-floating-point cases. The fixed-point part of the multiplier
-  // by which accumulators are multiplied before being casted to the destination
-  // type. This is a fixed-point quantity with 0 integer bits. Since
-  // (as explained in the class comment) AccumScalar must be std::int32_t,
-  // that means that the fixed-point format is Q0.31. For example,
-  // a multiplier_fixedpoint value of 2^30 has the effect of multiplying
-  // by one half (1/2). More generally, the effect is to multiply by
-  // (multiplier_fixedpoint / (2^31)).
+  // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
+  // of the multiplier by which accumulators are multiplied before being casted
+  // to the destination type.
   AccumScalar multiplier_fixedpoint() const {
     return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint;
   }
@@ -132,10 +127,9 @@
   // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel`
   // and `multiplier_exponent_perchannel` are used instead.
   //
-  // This must point to a buffer of as many values as there are rows or columns
-  // in the destination matrix, whichever is the channels dimension. Each
-  // channel of the destination matrix will use the corresponding buffer element
-  // instead of multiplier_fixedpoint.
+  // This must point to a buffer of as many values as there are rows in the
+  // destination matrix. Each row of the destination matrix will use the
+  // corresponding buffer element instead of multiplier_fixedpoint.
   const AccumScalar* multiplier_fixedpoint_perchannel() const {
     return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel
                                : nullptr;
@@ -205,6 +199,16 @@
   detail::MulParamsStorage<AccumScalar, DstScalar> storage_;
 
   void set_perchannel(bool perchannel) {
+    if (storage_.perchannel == perchannel) {
+      return;
+    }
+    if (perchannel) {
+      RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0);
+      RUY_DCHECK_EQ(storage_.multiplier_exponent, 0);
+    } else {
+      RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr);
+      RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr);
+    }
     storage_.perchannel = perchannel;
   }
 };
@@ -240,25 +244,25 @@
 struct MulParamsStorage<std::int32_t, DstScalar> final {
   using AccumScalar = std::int32_t;
   static_assert(std::is_integral<DstScalar>::value, "");
-  static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, "");
+  static_assert(sizeof(DstScalar) < sizeof(AccumScalar), "");
 
   const AccumScalar* bias = nullptr;
-  union {
-    const AccumScalar* multiplier_fixedpoint_perchannel;
-    // Let the default multiplier be effecively a multiplication by 1, so that
-    // the matmul behaves as a (saturating) plain integer matmul. Unfortunately
-    // 1 is not exactly representable in fixedpoint with 0 integer bits, but
-    // using the highest representable value is a sufficiently good
-    // approximation: since this specialization of MulParams is for the case
-    // where DstScalar is at least 2x narrower than MulScalar, the values
-    // for which there would be a difference will get saturated anyway.
-    AccumScalar multiplier_fixedpoint = std::numeric_limits<AccumScalar>::max();
-  };
-  union {
-    const int* multiplier_exponent_perchannel;
-    // See the above comment about the default value of multiplier_fixedpoint.
-    int multiplier_exponent = 0;
-  };
+  // union {  // This used to be a union, temporarily flattened to debug a crash
+  const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
+  // Let the default multiplier be effecively a multiplication by 1, so that
+  // the matmul behaves as a (saturating) plain integer matmul. Unfortunately
+  // 1 is not exactly representable in fixedpoint with 0 integer bits, but
+  // using the highest representable value is a sufficiently good
+  // approximation: since this specialization of MulParams is for the case
+  // where DstScalar is at least 2x narrower than MulScalar, the values
+  // for which there would be a difference will get saturated anyway.
+  AccumScalar multiplier_fixedpoint = 0;
+  //};
+  // union {  // This used to be a union, temporarily flattened to debug a crash
+  const int* multiplier_exponent_perchannel = nullptr;
+  // See the above comment about the default value of multiplier_fixedpoint.
+  int multiplier_exponent = 0;
+  // };
   DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
   DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
   ChannelDimension channel_dimension = ChannelDimension::kRow;
diff --git a/ruy/mul_params_test.cc b/ruy/mul_params_test.cc
index 4bc9f87..feb7dbb 100644
--- a/ruy/mul_params_test.cc
+++ b/ruy/mul_params_test.cc
@@ -31,7 +31,7 @@
 
   MulParamsType mul_params;
   EXPECT_EQ(mul_params.bias(), nullptr);
-  EXPECT_EQ(mul_params.multiplier_fixedpoint(), std::numeric_limits<std::int32_t>::max());
+  EXPECT_EQ(mul_params.multiplier_fixedpoint(), 0);
   EXPECT_EQ(mul_params.multiplier_exponent(), 0);
   EXPECT_EQ(mul_params.multiplier_fixedpoint_perchannel(), nullptr);
   EXPECT_EQ(mul_params.multiplier_exponent_perchannel(), nullptr);
diff --git a/ruy/prepacked_cache.cc b/ruy/prepacked_cache.cc
index ee891cb..5080ca9 100644
--- a/ruy/prepacked_cache.cc
+++ b/ruy/prepacked_cache.cc
@@ -26,10 +26,10 @@
 // Allocates the `data` and `sums` buffers, and sets the corresponding
 // pointer fields, in a PEMat whose other fields, particularly `layout`
 // and the runtime data types, are already populated.
-int AllocateBuffers(PEMat* packed_matrix) {
-  const int data_bytes = DataBytes(*packed_matrix);
+std::ptrdiff_t AllocateBuffers(PEMat* packed_matrix) {
+  const std::ptrdiff_t data_bytes = DataBytes(*packed_matrix);
   packed_matrix->data = detail::SystemAlignedAlloc(data_bytes);
-  int sums_bytes = 0;
+  std::ptrdiff_t sums_bytes = 0;
   if (!packed_matrix->sums_type.is_floating_point) {
     // Integer quantized matrices also need the `sums` buffer.
     sums_bytes = SumsBytes(*packed_matrix);
@@ -93,7 +93,7 @@
   }
 
   // No existing entry found. Allocate new buffers now and insert in the cache.
-  const int new_bytes = AllocateBuffers(packed_matrix);
+  const std::ptrdiff_t new_bytes = AllocateBuffers(packed_matrix);
   EjectUntilRoomFor(new_bytes);
   Entry entry{*packed_matrix, timestamp_++};
   cache_.emplace(key, entry);
@@ -101,7 +101,7 @@
   return Action::kInsertedNewEntry;
 }
 
-void PrepackedCache::EjectUntilRoomFor(int new_bytes) {
+void PrepackedCache::EjectUntilRoomFor(std::ptrdiff_t new_bytes) {
   profiler::ScopeLabel label("PrepackedCacheEjection");
   // While we are above the threshold of ejection, eject the LRU entry.
   while (!cache_.empty() && buffers_bytes_ + new_bytes > max_buffers_bytes_) {
diff --git a/ruy/prepacked_cache.h b/ruy/prepacked_cache.h
index cb3a113..c58593e 100644
--- a/ruy/prepacked_cache.h
+++ b/ruy/prepacked_cache.h
@@ -101,7 +101,7 @@
   ~PrepackedCache();
 
   // Returns the total size in bytes of buffers held in this cache.
-  int BuffersBytes() const { return buffers_bytes_; }
+  std::ptrdiff_t BuffersBytes() const { return buffers_bytes_; }
 
   // Returns the number of packed matrices held in this cache.
   int MatrixCount() const { return cache_.size(); }
@@ -128,11 +128,11 @@
 
  private:
   void EjectOne();
-  void EjectUntilRoomFor(int new_bytes);
+  void EjectUntilRoomFor(std::ptrdiff_t new_bytes);
 
   std::unordered_map<Key, Entry, KeyHash> cache_;
-  const int max_buffers_bytes_;
-  int buffers_bytes_ = 0;
+  const std::ptrdiff_t max_buffers_bytes_;
+  std::ptrdiff_t buffers_bytes_ = 0;
   Timestamp timestamp_ = 0;
 };
 
diff --git a/ruy/ruy.h b/ruy/ruy.h
index ddbe192..3cf7bdd 100644
--- a/ruy/ruy.h
+++ b/ruy/ruy.h
@@ -93,14 +93,6 @@
 // (e.g. the number of CPU cores in typical scenarios). At least ruy forces
 // each invocation to make an explicit decision here, there is no automatic
 // detection of the best number of threads to use in ruy.
-//
-// Constraints on the template parameters:
-// * If DstScalar is floating-point then AccumScalar must also be.
-// * If DstScalar is integral then AccumScalar must be std::int32_t.
-// Please refer to MulParams' class comment for more information. When
-// DstScalar is integral and is narrower than AccumScalar, additional
-// MulParams fields must be set to control the scaling of internal accumulators
-// before the final saturating cast to the DstScalar type.
 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
           typename DstScalar>
 void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
diff --git a/ruy/test.h b/ruy/test.h
index 5aa4c41..0b05399 100644
--- a/ruy/test.h
+++ b/ruy/test.h
@@ -122,6 +122,7 @@
     return #NAME;
   switch (tuning) {
     RUY_SUBPATHNAME_CASE(kA55ish)
+    RUY_SUBPATHNAME_CASE(kX1)
     RUY_SUBPATHNAME_CASE(kGeneric)
     default:
       RUY_CHECK(false);
@@ -1825,7 +1826,7 @@
   }
 #if RUY_PLATFORM_ARM
   if (path == Path::kNeon || path == Path::kNeonDotprod) {
-    return {Tuning::kA55ish, Tuning::kGeneric, Tuning::kAuto};
+    return {Tuning::kA55ish, Tuning::kX1, Tuning::kGeneric, Tuning::kAuto};
   }
 #endif
   (void)path;
diff --git a/ruy/test_overflow_dst_zero_point.cc b/ruy/test_overflow_dst_zero_point.cc
new file mode 100644
index 0000000..db1f08d
--- /dev/null
+++ b/ruy/test_overflow_dst_zero_point.cc
@@ -0,0 +1,133 @@
+/* Copyright 2021 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This test covers destination zero_points that cause internal int16 overflow.
+
+// Kernels tend to perform the addition of the destination zero_point in int16.
+// Although this happens after the rescaling to the destination scale, it is
+// still possible for this int16 addition to overflow. This should be handled
+// by saturating, which ensures correct results as the subsequent cast to
+// the destination 8-bit type is saturating anyway, so this second saturation
+// eats any effect of the previous saturation in the int16 addition of the
+// destination zero_point.
+// When this is not correctly saturating, a typical effect is wrapping around
+// to the opposite end of the range of int16, which causes the latter saturation
+// to the int8/uint8 range to saturate to the opposite end of that, resulting
+// in a large numerical difference in the output values.
+
+#include <limits>
+#include <type_traits>
+#include <vector>
+
+#include "ruy/context.h"
+#include "ruy/gtest_wrapper.h"
+#include "ruy/matrix.h"
+#include "ruy/mul_params.h"
+#include "ruy/path.h"
+#include "ruy/ruy.h"
+#include "ruy/test.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+namespace {
+
+template <typename DstScalar>
+void TestOverflowingAdditionOfDestinationZeroPoint(ruy::Context* context,
+                                                   int cols,
+                                                   DstScalar dst_zero_point) {
+  // Set the bias value so that the int16 addition of the zero_point will
+  // overflow.
+  const int bias_value = dst_zero_point > 0
+                             ? std::numeric_limits<std::int16_t>::max()
+                             : std::numeric_limits<std::int16_t>::min();
+  // This is the end of the DstScalar range that we expect values will be
+  // clamped to.
+  const int expected_dst_value = dst_zero_point > 0
+                                     ? std::numeric_limits<DstScalar>::max()
+                                     : std::numeric_limits<DstScalar>::min();
+
+  const std::vector<const std::int8_t> lhs_data(1, 0);
+  const std::vector<std::int8_t> rhs_data(cols, 0);
+  std::vector<DstScalar> dst_data(cols, 0);
+
+  ruy::MulParams<std::int32_t, DstScalar> mul_params;
+  std::int32_t bias_data[1] = {bias_value};
+  mul_params.set_bias(bias_data);
+  // Set the quantized multiplier to essentially 1 so we get unscaled
+  // accumulators in the output, only clamped.
+  mul_params.set_multiplier_fixedpoint(
+      std::numeric_limits<std::int32_t>::max());
+
+  ruy::Matrix<std::int8_t> lhs;
+  ruy::MakeSimpleLayout(1, 1, ruy::Order::kColMajor, lhs.mutable_layout());
+  lhs.set_data(lhs_data.data());
+
+  ruy::Matrix<std::int8_t> rhs;
+  ruy::MakeSimpleLayout(1, cols, ruy::Order::kColMajor, rhs.mutable_layout());
+  rhs.set_data(rhs_data.data());
+
+  ruy::Matrix<DstScalar> dst;
+  ruy::MakeSimpleLayout(1, cols, ruy::Order::kColMajor, dst.mutable_layout());
+  dst.set_data(dst_data.data());
+  dst.set_zero_point(dst_zero_point);
+
+  ruy::Mul(lhs, rhs, mul_params, context, &dst);
+
+  // Check that the DstScalar overflow was clamped, not wrapped around.
+  for (auto d : dst_data) {
+    EXPECT_EQ(d, expected_dst_value);
+  }
+}
+
+template <typename DstScalar>
+void TestOverflowingAdditionOfDestinationZeroPoint(ruy::Context* context) {
+  // Test both a matrix*vector and a general matrix*matrix (in the sense that
+  // cols>1) as these may exercise different kernels.
+  TestOverflowingAdditionOfDestinationZeroPoint<DstScalar>(context, 1, 1);
+  TestOverflowingAdditionOfDestinationZeroPoint<DstScalar>(context, 8, 1);
+  if (std::is_signed<DstScalar>::value) {
+    TestOverflowingAdditionOfDestinationZeroPoint<DstScalar>(context, 1, -1);
+    TestOverflowingAdditionOfDestinationZeroPoint<DstScalar>(context, 8, -1);
+  }
+}
+
+TEST(RuyTest, OverflowingAdditionOfDestinationZeroPoint) {
+  ruy::Context context;
+  ruy::Path runtime_enabled_paths = context.get_runtime_enabled_paths();
+  for (unsigned bit = 0; bit < 8 * sizeof(ruy::Path); bit++) {
+    ruy::Path path = static_cast<ruy::Path>(1 << bit);
+    if ((path & runtime_enabled_paths) == ruy::Path::kNone) {
+      continue;
+    }
+    context.set_runtime_enabled_paths(path);
+    for (ruy::Tuning tuning :
+         {ruy::Tuning::kGeneric, ruy::Tuning::kA55ish, ruy::Tuning::kX1}) {
+      fprintf(stderr, "Testing path %s, tuning %s\n", PathName(path),
+              TuningName(tuning));
+      context.set_explicit_tuning(tuning);
+      TestOverflowingAdditionOfDestinationZeroPoint<std::int8_t>(&context);
+      TestOverflowingAdditionOfDestinationZeroPoint<std::uint8_t>(&context);
+      TestOverflowingAdditionOfDestinationZeroPoint<std::int16_t>(&context);
+    }
+  }
+}
+
+}  // namespace
+}  // namespace ruy
+
+int main(int argc, char** argv) {
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/ruy/thread_pool.cc b/ruy/thread_pool.cc
index 100cfe3..5f22a13 100644
--- a/ruy/thread_pool.cc
+++ b/ruy/thread_pool.cc
@@ -25,6 +25,7 @@
 #include <thread>  // NOLINT(build/c++11)
 
 #include "ruy/check_macros.h"
+#include "ruy/denormal.h"
 #include "ruy/trace.h"
 #include "ruy/wait.h"
 
@@ -113,6 +114,9 @@
     RUY_TRACE_SCOPE_NAME("Ruy worker thread function");
     ChangeState(State::Ready);
 
+    // Suppress denormals to avoid computation inefficiency.
+    ScopedSuppressDenormals suppress_denormals;
+
     // Thread main loop
     while (true) {
       RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration");
diff --git a/ruy/trmul.cc b/ruy/trmul.cc
index 9345f0c..602660b 100644
--- a/ruy/trmul.cc
+++ b/ruy/trmul.cc
@@ -30,6 +30,7 @@
 #include "ruy/cpu_cache_params.h"
 #include "ruy/cpuinfo.h"
 #include "ruy/ctx.h"
+#include "ruy/denormal.h"
 #include "ruy/mat.h"
 #include "ruy/matrix.h"
 #include "ruy/mul_params.h"
@@ -307,6 +308,12 @@
       GetTentativeThreadCount(ctx, rows, cols, depth);
   const auto& cpu_cache_params = ctx->mutable_cpuinfo()->CacheParams();
 
+  // Suppress denormals to avoid computation inefficiency.
+  // Note this only handles the denormal suppression on the main thread. As for
+  // worker threads, the suppression is handled in each thread's main loop. See
+  // the corresponding code in thread_pool.cc for details.
+  ScopedSuppressDenormals suppress_denormals;
+
   // Case of running this TrMul as a simple loop.
   // This is a good place to start reading this function: all the rest
   // of this function is just an optimized, but functionally equivalent,
diff --git a/ruy/trmul_params.h b/ruy/trmul_params.h
index e68d909..486a6c6 100644
--- a/ruy/trmul_params.h
+++ b/ruy/trmul_params.h
@@ -53,7 +53,9 @@
                                 kMaxMulParamsSizeQuantizedIntegerCase));
 
 // OK to adjust as needed, but we want to avoid unnecessarily inflating that.
-static_assert(kMaxMulParamsSize <= 32, "");
+// Temporarily bumped from 32 to 48 as part of temporarily not using unions
+// in MulParams.
+static_assert(kMaxMulParamsSize <= 48, "");
 
 // Type-erased data needed for implementing TrMul.
 struct TrMulParams {
diff --git a/ruy/tune.cc b/ruy/tune.cc
index 1f615bf..004bd5a 100644
--- a/ruy/tune.cc
+++ b/ruy/tune.cc
@@ -23,7 +23,13 @@
 namespace ruy {
 
 Tuning TuningResolver::ResolveNow(CpuInfo* cpuinfo) {
-  return cpuinfo->CurrentCpuIsA55ish() ? Tuning::kA55ish : Tuning::kGeneric;
+  if (cpuinfo->CurrentCpuIsA55ish()) {
+    return Tuning::kA55ish;
+  }
+  if (cpuinfo->CurrentCpuIsX1()) {
+    return Tuning::kX1;
+  }
+  return Tuning::kGeneric;
 }
 
 TuningResolver::TuningResolver()
diff --git a/ruy/tune.h b/ruy/tune.h
index c9beed9..f50c750 100644
--- a/ruy/tune.h
+++ b/ruy/tune.h
@@ -69,7 +69,13 @@
   // A55r1 supports dotprod unlike A55r0 and A53, they are not using the same
   // kernels in practice anyway, so there was no need to distinguish them with
   // separate Tuning values.
-  kA55ish
+  kA55ish,
+  // Use code tuned for Cortex-X1 CPUs. Currently, the driver to distinguish
+  // this CPU is the get maximum performance on the dotprod kernels, where we
+  // attain high performance simply by avoiding any manual loop unrolling. As a
+  // purely performance oriented microarchitecture, there will likely be
+  // additional reasons to distinguish the X1 from other CPUs.
+  kX1
 };
 
 // Why a TuningResolver class?