Upgrade ruy to c08ec529fc91722bde519628d9449258082eb847

This project was upgraded with external_updater.
Usage: tools/external_updater/updater.sh update external/ruy
For more info, check https://cs.android.com/android/platform/superproject/main/+/main:tools/external_updater/README.md

Test: TreeHugger
Change-Id: I2c5bc994f7db7683f9200cd1b85cdcb511603f7c
diff --git a/BUILD b/BUILD
index 8c2d62e..342aad7 100644
--- a/BUILD
+++ b/BUILD
@@ -1,7 +1,15 @@
 # Ruy is not BLAS
 
+load("//tools/build_defs/license:license.bzl", "license")
+
 package(
+    default_applicable_licenses = ["//third_party/ruy:license"],
     licenses = ["notice"],  # Apache 2.0
 )
 
+license(
+    name = "license",
+    package_name = "ruy",
+)
+
 exports_files(["LICENSE"])
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 98d480d..f4fe893 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -18,7 +18,7 @@
 cmake_minimum_required(VERSION 3.13)  # Copied from IREE
 set(CMAKE_CXX_STANDARD 14)
 
-
+include(GNUInstallDirs)
 
 if (PROJECT_NAME STREQUAL CMAKE_PROJECT_NAME)
   set(RUY_IS_TOPLEVEL TRUE)
@@ -35,41 +35,49 @@
 
 option(RUY_PROFILER "Enable ruy's built-in profiler (harms performance)" OFF)
 
+option(RUY_ENABLE_INSTALL "Enable install rule" ${RUY_IS_TOPLEVEL})
+
 include(cmake/ruy_add_all_subdirs.cmake)
 include(cmake/ruy_cc_library.cmake)
 include(cmake/ruy_cc_binary.cmake)
 include(cmake/ruy_cc_test.cmake)
 
+option(RUY_FIND_CPUINFO "Use find_package to find cpuinfo" OFF)
+
 # Skip cpuinfo if it was already generated, which can happen when ruy is
 # a subdirectory in a wider project that already uses cpuinfo.
-if (NOT TARGET cpuinfo)
-  # Test if the third_party/cpuinfo submodule was checked out before
-  # adding that subdirectory, so we can do more helpful things below in the
-  # else() block when it's not.
-  set(RUY_CPUINFO_CMAKELISTS_FILE "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cpuinfo/CMakeLists.txt")
-  if (EXISTS "${RUY_CPUINFO_CMAKELISTS_FILE}")
-    # Disabling cpuinfo's tests and benchmarks to prevent a copy of its
-    # googletest dependency getting downloaded into a 'deps' directory in the
-    # source tree!
-    set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE)
-    set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "" FORCE)
-    set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "" FORCE)
-    add_subdirectory("third_party/cpuinfo" EXCLUDE_FROM_ALL)
+if (NOT TARGET cpuinfo::cpuinfo)
+  if (RUY_FIND_CPUINFO)
+    find_package(cpuinfo REQUIRED)
   else()
-    # third_party/cpuinfo is not checked out. That could be intentional when
-    # ruy is a subdirectory in a wider project that is already providing
-    # the cpuinfo target. Maybe that wider project's CMakeLists is ordered
-    # in such a way that cpuinfo gets generated after ruy. In that case,
-    # it's helpful that we continue silently. In the worst case if the cpuinfo
-    # target never gets defined, ruy will fail to compile.
-    # On the other hand, if ruy is the top-level project here (not part of a
-    # wider project) then nothing will define the cpuinfo target for us,
-    # so we will definitely fail to compile, so we may as well fail right here.
-    if (RUY_IS_TOPLEVEL)
-      message(FATAL_ERROR "This file does not exist:\n${RUY_CPUINFO_CMAKELISTS_FILE}\n"
-                    "That typically means that the git submodules of the ruy "
-                    "repository haven't been checked out. Try this in the ruy "
-                    "git directory:\n  git submodule update --init")
+    # Test if the third_party/cpuinfo submodule was checked out before
+    # adding that subdirectory, so we can do more helpful things below in the
+    # else() block when it's not.
+    set(RUY_CPUINFO_CMAKELISTS_FILE "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cpuinfo/CMakeLists.txt")
+    if (EXISTS "${RUY_CPUINFO_CMAKELISTS_FILE}")
+      # Disabling cpuinfo's tests and benchmarks to prevent a copy of its
+      # googletest dependency getting downloaded into a 'deps' directory in the
+      # source tree!
+      set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE)
+      set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "" FORCE)
+      set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "" FORCE)
+      add_subdirectory("third_party/cpuinfo" EXCLUDE_FROM_ALL)
+    else()
+      # third_party/cpuinfo is not checked out. That could be intentional when
+      # ruy is a subdirectory in a wider project that is already providing
+      # the cpuinfo target. Maybe that wider project's CMakeLists is ordered
+      # in such a way that cpuinfo gets generated after ruy. In that case,
+      # it's helpful that we continue silently. In the worst case if the cpuinfo
+      # target never gets defined, ruy will fail to compile.
+      # On the other hand, if ruy is the top-level project here (not part of a
+      # wider project) then nothing will define the cpuinfo target for us,
+      # so we will definitely fail to compile, so we may as well fail right here.
+      if (RUY_IS_TOPLEVEL)
+        message(FATAL_ERROR "This file does not exist:\n${RUY_CPUINFO_CMAKELISTS_FILE}\n"
+                      "That typically means that the git submodules of the ruy "
+                      "repository haven't been checked out. Try this in the ruy "
+                      "git directory:\n  git submodule update --init")
+      endif()
     endif()
   endif()
 endif()
@@ -88,3 +96,22 @@
 if (NOT RUY_MINIMAL_BUILD)
   add_subdirectory("example")
 endif()
+
+if (RUY_ENABLE_INSTALL)
+  install(EXPORT ${PROJECT_NAME}Targets
+    NAMESPACE ${PROJECT_NAME}::
+    DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}"
+  )
+
+  include(CMakePackageConfigHelpers)
+
+  configure_package_config_file(
+    "cmake/${PROJECT_NAME}Config.cmake.in"
+    "${PROJECT_BINARY_DIR}/${PROJECT_NAME}Config.cmake"
+    INSTALL_DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}"
+  )
+
+  install(FILES "${PROJECT_BINARY_DIR}/${PROJECT_NAME}Config.cmake"
+    DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}"
+  )
+endif()
diff --git a/METADATA b/METADATA
index e9e9490..d9eb697 100644
--- a/METADATA
+++ b/METADATA
@@ -1,15 +1,19 @@
+# This project was upgraded with external_updater.
+# Usage: tools/external_updater/updater.sh update external/ruy
+# For more info, check https://cs.android.com/android/platform/superproject/main/+/main:tools/external_updater/README.md
+
 name: "ruy"
 description: "ruy is a matrix multiplication library."
 third_party {
-  url {
-    type: GIT
-    value: "https://github.com/google/ruy"
-  }
-  version: "9c56af3fce210a8a103eda19bd6f47c08a9e3d90"
   license_type: NOTICE
   last_upgrade_date {
-    year: 2021
-    month: 8
-    day: 11
+    year: 2024
+    month: 11
+    day: 8
+  }
+  identifier {
+    type: "Git"
+    value: "https://github.com/google/ruy"
+    version: "c08ec529fc91722bde519628d9449258082eb847"
   }
 }
diff --git a/cmake/bazel_to_cmake.py b/cmake/bazel_to_cmake.py
index 8f972ba..caf9cbf 100755
--- a/cmake/bazel_to_cmake.py
+++ b/cmake/bazel_to_cmake.py
@@ -49,7 +49,7 @@
     ['selects.config_setting_group', 'config_setting_group'],
     ['@com_google_googletest//:gtest', 'gtest'],
     ['@com_google_googletest//:gtest_main', 'gtest_main'],
-    ['@cpuinfo', 'cpuinfo'],
+    ['@cpuinfo', 'cpuinfo::cpuinfo'],
 ]
 
 
diff --git a/cmake/ruyConfig.cmake.in b/cmake/ruyConfig.cmake.in
new file mode 100644
index 0000000..0f3a4f1
--- /dev/null
+++ b/cmake/ruyConfig.cmake.in
@@ -0,0 +1,9 @@
+# ruy CMake configuration file.
+
+include(CMakeFindDependencyMacro)
+
+find_dependency(cpuinfo)
+
+@PACKAGE_INIT@
+
+include("${CMAKE_CURRENT_LIST_DIR}/@PROJECT_NAME@Targets.cmake")
diff --git a/cmake/ruy_cc_library.cmake b/cmake/ruy_cc_library.cmake
index 38accc5..3f3a062 100644
--- a/cmake/ruy_cc_library.cmake
+++ b/cmake/ruy_cc_library.cmake
@@ -42,12 +42,16 @@
     set(_RULE_IS_INTERFACE 0)
   endif()
 
+  file(RELATIVE_PATH _SUBDIR ${CMAKE_SOURCE_DIR} ${CMAKE_CURRENT_LIST_DIR})
+
   if(_RULE_IS_INTERFACE)
     # Generating a header-only library.
     add_library(${_NAME} INTERFACE)
+    set_target_properties(${_NAME} PROPERTIES PUBLIC_HEADER "${_RULE_HDRS}")
     target_include_directories(${_NAME}
       INTERFACE
-        "${PROJECT_SOURCE_DIR}"
+        "$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}>"
+        "$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>"
     )
     target_link_libraries(${_NAME}
       INTERFACE
@@ -60,12 +64,8 @@
     )
   else()
     # Generating a static binary library.
-    add_library(${_NAME} STATIC "")
-    target_sources(${_NAME}
-      PRIVATE
-        ${_RULE_SRCS}
-        ${_RULE_HDRS}
-    )
+    add_library(${_NAME} STATIC ${_RULE_SRCS} ${_RULE_HDRS})
+    set_target_properties(${_NAME} PROPERTIES PUBLIC_HEADER "${_RULE_HDRS}")
     ruy_include_directories(${_NAME} "${_RULE_DEPS}")
     target_compile_options(${_NAME}
       PRIVATE
@@ -82,4 +82,15 @@
         ${_RULE_DEFINES}
     )
   endif()
+
+  add_library(${PROJECT_NAME}::${_NAME} ALIAS ${_NAME})
+
+  if(NOT _RULE_TESTONLY)
+    install(
+      TARGETS ${_NAME}
+      EXPORT ruyTargets
+      LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
+      PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${_SUBDIR}
+    )
+  endif()
 endfunction()
diff --git a/cmake/ruy_include_directories.cmake b/cmake/ruy_include_directories.cmake
index e9b50a9..a90ab61 100644
--- a/cmake/ruy_include_directories.cmake
+++ b/cmake/ruy_include_directories.cmake
@@ -14,20 +14,8 @@
 
 function(ruy_include_directories NAME DEPS)
   target_include_directories(${NAME}
-      PUBLIC
-      "${PROJECT_SOURCE_DIR}"
+    PUBLIC
+      "$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}>"
+      "$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>"
   )
-  if (cpuinfo IN_LIST DEPS)
-    target_include_directories(${NAME}
-      PRIVATE
-        "${PROJECT_SOURCE_DIR}/third_party/cpuinfo/include"
-    )
-  endif()
-  if ((gtest IN_LIST DEPS) OR
-      (gtest_main IN_LIST DEPS))
-    target_include_directories(${NAME}
-      PRIVATE
-        "${PROJECT_SOURCE_DIR}/third_party/googletest/googletest"
-    )
-  endif()
-endfunction()
\ No newline at end of file
+endfunction()
diff --git a/doc/depgraph.sh b/doc/depgraph.sh
index d66d44f..d1f72af 100755
--- a/doc/depgraph.sh
+++ b/doc/depgraph.sh
@@ -29,7 +29,7 @@
     ':validate'
     'profiler:instrumentation'
     '\bclog\b'
-    '\bcpuinfo_impl\b'
+    '\bcpuinfo\b'
     ':apply_multiplier'
     '\blabel='
 )
diff --git a/example/BUILD b/example/BUILD
index 738c33e..912fb2d 100644
--- a/example/BUILD
+++ b/example/BUILD
@@ -1,4 +1,5 @@
 package(
+    default_applicable_licenses = ["//third_party/ruy:license"],
     licenses = ["notice"],  # Apache 2.0
 )
 
diff --git a/example/example.cc b/example/example.cc
index 3bb95f4..6d4fff2 100644
--- a/example/example.cc
+++ b/example/example.cc
@@ -126,6 +126,7 @@
   std::cout << "RHS:\n" << rhs;
   std::cout << "Result:\n" << dst << "\n";
 }
+
 void ExampleMulInt8GetRawAccumulators(ruy::Context *context) {
   const std::int8_t lhs_data[] = {1, 2, 3, 4};
   const std::int8_t rhs_data[] = {1, 2, 3, 4};
@@ -151,6 +152,35 @@
   std::cout << "Result:\n" << dst << "\n";
 }
 
+void ExampleMulInt8TimesInt16PerChannelQuantized(ruy::Context *context) {
+  const std::int8_t lhs_data[] = {1, 2, 3, 4};
+  const std::int16_t rhs_data[] = {1000, 2000, 3000, 4000};
+  const std::int32_t multiplier_data[] = {3 << 28, 5 << 28};
+  const int exponent_data[] = {1, -2};
+  std::int16_t dst_data[4];
+
+  ruy::Matrix<std::int8_t> lhs;
+  ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
+  lhs.set_data(lhs_data);
+  ruy::Matrix<std::int16_t> rhs;
+  ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
+  rhs.set_data(rhs_data);
+  ruy::Matrix<std::int16_t> dst;
+  ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
+  dst.set_data(dst_data);
+
+  ruy::MulParams<std::int32_t, std::int16_t> mul_params;
+  mul_params.set_multiplier_fixedpoint_perchannel(multiplier_data);
+  mul_params.set_multiplier_exponent_perchannel(exponent_data);
+  ruy::Mul(lhs, rhs, mul_params, context, &dst);
+
+  std::cout << "Example Mul, int8 times int16 quantized with per-channel "
+               "multipliers\n";
+  std::cout << "LHS:\n" << lhs;
+  std::cout << "RHS:\n" << rhs;
+  std::cout << "Result:\n" << dst << "\n";
+}
+
 int main() {
   ruy::Context context;
   ExampleMulFloat(&context);
@@ -158,4 +188,5 @@
   ExampleMulUint8AsymmetricQuantized(&context);
   ExampleMulInt8PerChannelQuantized(&context);
   ExampleMulInt8GetRawAccumulators(&context);
+  ExampleMulInt8TimesInt16PerChannelQuantized(&context);
 }
diff --git a/example/parametrized_example.cc b/example/parametrized_example.cc
index ef6ad23..253d911 100644
--- a/example/parametrized_example.cc
+++ b/example/parametrized_example.cc
@@ -140,7 +140,8 @@
   }
   Params params;
   const char* allowed_types =
-      "f32xf32->f32, i8xi8->i8, i8xi8->i16, i8xi8->i32, u8xu8->i16, u8xi8->u8";
+      "f32xf32->f32, i8xi8->i8, i8xi8->i16, i8xi8->i32, u8xu8->i16, u8xi8->u8, "
+      "i8xi16->i16, i16xi8->i16";
   const char* allowed_orders = "row-major, column-major";
   read_cmdline_args(help, argc, argv, "--types", "%s", "f32xf32->f32",
                     allowed_types, &params.types);
@@ -172,7 +173,7 @@
                     allowed_orders, &params.lhs_order);
   read_cmdline_args(help, argc, argv, "--rhs_order", "%s", "row-major",
                     allowed_orders, &params.rhs_order);
-  read_cmdline_args(help, argc, argv, "--rhs_order", "%s", "row-major",
+  read_cmdline_args(help, argc, argv, "--dst_order", "%s", "row-major",
                     allowed_orders, &params.dst_order);
 
   if (help) {
@@ -191,6 +192,10 @@
     run<std::uint8_t, std::uint8_t, std::int16_t>(params);
   } else if (!strcmp(params.types, "u8xi8->u8")) {
     run<std::uint8_t, std::int8_t, std::uint8_t>(params);
+  } else if (!strcmp(params.types, "i8xi16->i16")) {
+    run<std::int8_t, std::int16_t, std::int16_t>(params);
+  } else if (!strcmp(params.types, "i16xi8->i16")) {
+    run<std::int16_t, std::int8_t, std::int16_t>(params);
   } else {
     fprintf(stderr, "Unknown types: %s\n", params.types);
     exit(1);
diff --git a/ruy/BUILD b/ruy/BUILD
index d04a45d..81d336a 100644
--- a/ruy/BUILD
+++ b/ruy/BUILD
@@ -8,66 +8,55 @@
 load(":ruy_test.bzl", "ruy_benchmark", "ruy_test")
 
 package(
+    default_applicable_licenses = ["//third_party/ruy:license"],
     licenses = ["notice"],  # Apache 2.0
 )
 
-config_setting(
-    name = "armeabi-v7a",
-    values = {"cpu": "armeabi-v7a"},
-)
-
-config_setting(
-    name = "armv7a",
-    values = {"cpu": "armv7a"},
-)
-
 # Detect ARM 32-bit targets where we are going to just assume NEON support.
-selects.config_setting_group(
+config_setting(
     name = "arm32_assuming_neon",
-    match_any = [
-        ":armeabi-v7a",
-        ":armv7a",
+    constraint_values = [
+        "@platforms//cpu:armv7",
     ],
 )
 
 config_setting(
-    name = "x86_64_k8",
-    values = {"cpu": "k8"},
-)
-
-config_setting(
-    name = "x86_64_haswell",
-    values = {"cpu": "haswell"},
-)
-
-# MSVC toolchains define a different "cpu" value, which helps us as we need
-# to pass different flags on MSVC vs GCC-compatible toolchains to enable
-# x86 SIMD extensions.
-selects.config_setting_group(
     name = "x86_64_and_not_msvc",
-    match_any = [
-        ":x86_64_k8",
-        ":x86_64_haswell",
+    constraint_values = [
+        "@platforms//cpu:x86_64",
+        "@platforms//os:linux",
     ],
 )
 
 config_setting(
+    name = "windows_msvc",
+    constraint_values = [
+        "@platforms//os:windows",
+    ],
+    flag_values = {
+        "//tools/cpp:compiler": "msvc",
+    },
+)
+
+config_setting(
     name = "ppc",
-    values = {
-        "cpu": "ppc",
-    },
+    constraint_values = [
+        "@platforms//cpu:ppc",
+    ],
 )
 
 config_setting(
     name = "s390x",
-    values = {
-        "cpu": "s390x",
-    },
+    constraint_values = [
+        "@platforms//cpu:s390x",
+    ],
 )
 
 config_setting(
     name = "fuchsia",
-    values = {"cpu": "fuchsia"},
+    constraint_values = [
+        "@platforms//os:fuchsia",
+    ],
 )
 
 config_setting(
@@ -87,7 +76,7 @@
 selects.config_setting_group(
     name = "do_not_want_O3",
     match_any = [
-        "@bazel_tools//src/conditions:windows_msvc",
+        ":windows_msvc",
         ":dbg_build",
         ":fastbuild_build",
     ],
@@ -380,7 +369,7 @@
     ],
     copts = ruy_copts() +
             select({
-                "@bazel_tools//src/conditions:windows": [],
+                "@platforms//os:windows": [],
                 "//conditions:default": [
                     # ruy_copts contains -Wundef, but cpuinfo's header warns with that.
                     "-Wno-undef",
@@ -397,9 +386,9 @@
         "//conditions:default": ["-DRUY_HAVE_CPUINFO"],
     }),
     deps = [
-        ":platform",
         ":check_macros",
         ":cpu_cache_params",
+        ":platform",
     ] + select({
         # This select must match the similar select in `copts`
         ":ppc": [],
@@ -436,6 +425,13 @@
 )
 
 cc_library(
+    name = "strategy_controls",
+    hdrs = ["strategy_controls.h"],
+    copts = ruy_copts(),
+    visibility = ["//visibility:public"],
+)
+
+cc_library(
     name = "matrix",
     hdrs = ["matrix.h"],
     copts = ruy_copts(),
@@ -859,6 +855,7 @@
         ":performance_advisory",
         ":platform",
         ":prepacked_cache",
+        ":strategy_controls",
         ":thread_pool",
         ":tune",
     ],
@@ -874,6 +871,7 @@
         ":path",
         ":platform",
         ":prepacked_cache",
+        ":strategy_controls",
         ":tune",
     ],
 )
@@ -907,6 +905,7 @@
         ":performance_advisory",
         ":platform",
         ":prepacked_cache",
+        ":strategy_controls",
         ":thread_pool",
         ":trace",
         ":tune",
@@ -937,6 +936,7 @@
         ":gtest_wrapper",
         ":path",
         ":platform",
+        ":strategy_controls",
     ],
 )
 
@@ -972,6 +972,7 @@
         ":opt_set",
         ":side_pair",
         ":size_util",
+        ":strategy_controls",
         ":thread_pool",
         ":trace",
         ":trmul_params",
@@ -1126,24 +1127,24 @@
     # need defines, not copts, because it's controlling a header, test.h
     defines = ruy_test_ext_defines(),
     linkopts = select({
-        "@bazel_tools//src/conditions:windows": [],
+        "@platforms//os:windows": [],
         "//conditions:default": ["-lm"],
     }),
     deps = [
         ":allocator",
-        ":size_util",
-        ":reference_mul",
-        ":matrix",
-        ":pmu",
-        ":ruy",
-        ":mul_params",
-        ":time",
-        ":gtest_wrapper",
-        ":platform",
         ":context",
-        ":ctx",
         ":context_get_ctx",
+        ":ctx",
+        ":gtest_wrapper",
+        ":matrix",
+        ":mul_params",
         ":pack_common",
+        ":platform",
+        ":pmu",
+        ":reference_mul",
+        ":ruy",
+        ":size_util",
+        ":time",
         "//ruy/profiler",
     ] + ruy_test_ext_deps(),
 )
@@ -1159,6 +1160,8 @@
         ("i8", "i8", "i32", "i8"),
         ("u8", "u8", "i32", "i16"),
         ("i8", "i8", "i32", "i32"),
+        ("i8", "i16", "i32", "i16"),
+        ("i16", "i8", "i32", "i16"),
     ],
     deps = [
         ":test_lib",
@@ -1180,6 +1183,8 @@
         ("u8", "u8", "i32", "i16"),
         ("i8", "i8", "i32", "i32"),
         ("i8", "u8", "i32", "i32"),
+        ("i8", "i16", "i32", "i16"),
+        ("i16", "i8", "i32", "i16"),
     ],
     deps = [
         ":test_lib",
@@ -1197,6 +1202,8 @@
         ("i8", "i8", "i32", "i8"),
         ("u8", "u8", "i32", "i16"),
         ("i8", "i8", "i32", "i32"),
+        ("i8", "i16", "i32", "i16"),
+        ("i16", "i8", "i32", "i16"),
     ],
     tags = ["slow"],
     deps = [
diff --git a/ruy/CMakeLists.txt b/ruy/CMakeLists.txt
index 502ad8a..8e493de 100644
--- a/ruy/CMakeLists.txt
+++ b/ruy/CMakeLists.txt
@@ -3,9 +3,9 @@
 #   cmake/bazel_to_cmake.sh
 
 if(CMAKE_SYSTEM_NAME STREQUAL Windows)
-  set(ruy_0_Wall_Wcxx14_compat_Wextra_Wundef "")
+  set(ruy_0_Wall_Wextra_Wundef "")
 else()
-  set(ruy_0_Wall_Wcxx14_compat_Wextra_Wundef "-Wall;-Wextra;-Wc++14-compat;-Wundef")
+  set(ruy_0_Wall_Wextra_Wundef "-Wall;-Wextra;-Wundef")
 endif()
 
 if(CMAKE_SYSTEM_PROCESSOR STREQUAL arm)
@@ -26,7 +26,7 @@
   HDRS
     trace.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -43,7 +43,7 @@
   HDRS
     platform.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
 )
@@ -64,7 +64,7 @@
   HDRS
     check_macros.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
 )
@@ -75,7 +75,7 @@
   SRCS
     check_macros_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -89,7 +89,7 @@
   HDRS
     opt_set.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
 )
@@ -100,7 +100,7 @@
   HDRS
     time.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
 )
@@ -119,7 +119,7 @@
   HDRS
     wait.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   LINKOPTS
@@ -134,7 +134,7 @@
   SRCS
     wait_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   LINKOPTS
@@ -151,7 +151,7 @@
   HDRS
     size_util.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -164,7 +164,7 @@
   SRCS
     size_util_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -180,7 +180,7 @@
   HDRS
     tune.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -199,7 +199,7 @@
   HDRS
     system_aligned_alloc.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
 )
@@ -212,7 +212,7 @@
   HDRS
     prepacked_cache.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -227,7 +227,7 @@
   SRCS
     tune_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -242,7 +242,7 @@
   SRCS
     prepacked_cache_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -265,7 +265,7 @@
   HDRS
     allocator.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -280,7 +280,7 @@
   SRCS
     allocator_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -294,7 +294,7 @@
   HDRS
     side_pair.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -309,7 +309,7 @@
   HDRS
     block_map.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -328,7 +328,7 @@
   SRCS
     block_map_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -348,7 +348,7 @@
   HDRS
     blocking_counter.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   LINKOPTS
@@ -367,7 +367,7 @@
   HDRS
     thread_pool.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   LINKOPTS
@@ -388,7 +388,7 @@
   HDRS
     cpu_cache_params.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
 )
@@ -410,13 +410,13 @@
 endif()
 
 if(CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le)
-  set(ruy_6_cpuinfo "")
+  set(ruy_6_cpuinfo_cpuinfo "")
 elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR CMAKE_SYSTEM_PROCESSOR STREQUAL s390x)
-  set(ruy_6_cpuinfo "")
+  set(ruy_6_cpuinfo_cpuinfo "")
 elseif(CMAKE_SYSTEM_NAME STREQUAL Fuchsia)
-  set(ruy_6_cpuinfo "")
+  set(ruy_6_cpuinfo_cpuinfo "")
 else()
-  set(ruy_6_cpuinfo "cpuinfo")
+  set(ruy_6_cpuinfo_cpuinfo "cpuinfo::cpuinfo")
 endif()
 
 ruy_cc_library(
@@ -427,7 +427,7 @@
   HDRS
     cpuinfo.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     ${ruy_4_Wno_undef}
@@ -436,7 +436,7 @@
     ruy_platform
     ruy_check_macros
     ruy_cpu_cache_params
-    ${ruy_6_cpuinfo}
+    ${ruy_6_cpuinfo_cpuinfo}
 )
 
 ruy_cc_library(
@@ -445,7 +445,7 @@
   HDRS
     path.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   PUBLIC
@@ -462,7 +462,7 @@
   HDRS
     denormal.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   PUBLIC
@@ -474,7 +474,7 @@
   HDRS
     performance_advisory.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   PUBLIC
@@ -486,7 +486,7 @@
   HDRS
     matrix.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   PUBLIC
@@ -500,7 +500,7 @@
   SRCS
     matrix_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -514,7 +514,7 @@
   HDRS
     mul_params.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   PUBLIC
@@ -529,7 +529,7 @@
   SRCS
     mul_params_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -543,7 +543,7 @@
   HDRS
     mat.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -558,7 +558,7 @@
   HDRS
     asm_helpers.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -573,7 +573,7 @@
   HDRS
     apply_multiplier.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -587,7 +587,7 @@
   SRCS
     apply_multiplier_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -602,7 +602,7 @@
   HDRS
     kernel_common.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -626,7 +626,7 @@
   HDRS
     pack_common.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -649,7 +649,7 @@
   HDRS
     kernel_arm.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -675,7 +675,7 @@
   HDRS
     pack_arm.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -706,7 +706,7 @@
   HDRS
     kernel_x86.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     ${ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512}
@@ -730,7 +730,7 @@
   HDRS
     pack_x86.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     ${ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512}
@@ -753,7 +753,7 @@
   HDRS
     have_built_path_for.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     ${ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512}
@@ -778,7 +778,7 @@
   HDRS
     kernel_x86.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     ${ruy_8_mavx2_mfma_arch_AVX2}
@@ -802,7 +802,7 @@
   HDRS
     pack_x86.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     ${ruy_8_mavx2_mfma_arch_AVX2}
@@ -825,7 +825,7 @@
   HDRS
     have_built_path_for.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     ${ruy_8_mavx2_mfma_arch_AVX2}
@@ -850,7 +850,7 @@
   HDRS
     kernel_x86.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     ${ruy_9_mavx_arch_AVX}
@@ -874,7 +874,7 @@
   HDRS
     pack_x86.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     ${ruy_9_mavx_arch_AVX}
@@ -897,7 +897,7 @@
   HDRS
     have_built_path_for.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     ${ruy_9_mavx_arch_AVX}
@@ -912,7 +912,7 @@
   HDRS
     kernel.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -942,7 +942,7 @@
   HDRS
     pack.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -982,7 +982,7 @@
   HDRS
     context.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   PUBLIC
@@ -1004,7 +1004,7 @@
   SRCS
     context_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -1033,7 +1033,7 @@
     ctx.h
     ctx_impl.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -1058,7 +1058,7 @@
   HDRS
     context_get_ctx.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -1072,7 +1072,7 @@
   SRCS
     ctx_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -1088,7 +1088,7 @@
   HDRS
     trmul_params.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -1107,7 +1107,7 @@
   HDRS
     trmul.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -1139,7 +1139,7 @@
   HDRS
     prepare_packed_matrices.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -1158,7 +1158,7 @@
   HDRS
     create_trmul_params.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -1183,7 +1183,7 @@
   HDRS
     validate.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -1201,7 +1201,7 @@
   HDRS
     frontend.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -1228,7 +1228,7 @@
     path.h
     ruy.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   PUBLIC
@@ -1252,7 +1252,7 @@
   SRCS
     perchannel_buffers_reallocation_test.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -1274,7 +1274,7 @@
   HDRS
     pmu.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
@@ -1287,7 +1287,7 @@
   HDRS
     reference_mul.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   PUBLIC
@@ -1310,7 +1310,7 @@
   HDRS
     test.h
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   LINKOPTS
@@ -1340,7 +1340,7 @@
   SRCS
     benchmark.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=f32
@@ -1359,7 +1359,7 @@
   SRCS
     benchmark.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=u8
@@ -1378,7 +1378,7 @@
   SRCS
     benchmark.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=i8
@@ -1397,7 +1397,7 @@
   SRCS
     benchmark.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=i8
@@ -1416,7 +1416,7 @@
   SRCS
     benchmark.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=u8
@@ -1435,7 +1435,7 @@
   SRCS
     benchmark.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=i8
@@ -1447,13 +1447,51 @@
     ruy_profiler_instrumentation
 )
 
+ruy_cc_binary(
+  NAME
+    ruy_benchmark_i8_i16_i32_i16
+  TESTONLY
+  SRCS
+    benchmark.cc
+  COPTS
+    ${ruy_0_Wall_Wextra_Wundef}
+    ${ruy_1_mfpu_neon}
+    ${ruy_2_O3}
+    -DRUY_TEST_LHSSCALAR=i8
+    -DRUY_TEST_RHSSCALAR=i16
+    -DRUY_TEST_ACCUMSCALAR=i32
+    -DRUY_TEST_DSTSCALAR=i16
+  DEPS
+    ruy_test_lib
+    ruy_profiler_instrumentation
+)
+
+ruy_cc_binary(
+  NAME
+    ruy_benchmark_i16_i8_i32_i16
+  TESTONLY
+  SRCS
+    benchmark.cc
+  COPTS
+    ${ruy_0_Wall_Wextra_Wundef}
+    ${ruy_1_mfpu_neon}
+    ${ruy_2_O3}
+    -DRUY_TEST_LHSSCALAR=i16
+    -DRUY_TEST_RHSSCALAR=i8
+    -DRUY_TEST_ACCUMSCALAR=i32
+    -DRUY_TEST_DSTSCALAR=i16
+  DEPS
+    ruy_test_lib
+    ruy_profiler_instrumentation
+)
+
 ruy_cc_test(
   NAME
     ruy_test_fast_f32_f32_f32_f32
   SRCS
     test_fast.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=f32
@@ -1471,7 +1509,7 @@
   SRCS
     test_fast.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=f64
@@ -1489,7 +1527,7 @@
   SRCS
     test_fast.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=f32
@@ -1507,7 +1545,7 @@
   SRCS
     test_fast.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=u8
@@ -1525,7 +1563,7 @@
   SRCS
     test_fast.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=i8
@@ -1543,7 +1581,7 @@
   SRCS
     test_fast.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=i8
@@ -1561,7 +1599,7 @@
   SRCS
     test_fast.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=u8
@@ -1579,7 +1617,7 @@
   SRCS
     test_fast.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=i8
@@ -1597,7 +1635,7 @@
   SRCS
     test_fast.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=i8
@@ -1611,11 +1649,47 @@
 
 ruy_cc_test(
   NAME
+    ruy_test_fast_i8_i16_i32_i16
+  SRCS
+    test_fast.cc
+  COPTS
+    ${ruy_0_Wall_Wextra_Wundef}
+    ${ruy_1_mfpu_neon}
+    ${ruy_2_O3}
+    -DRUY_TEST_LHSSCALAR=i8
+    -DRUY_TEST_RHSSCALAR=i16
+    -DRUY_TEST_ACCUMSCALAR=i32
+    -DRUY_TEST_DSTSCALAR=i16
+  DEPS
+    ruy_test_lib
+    gtest_main
+)
+
+ruy_cc_test(
+  NAME
+    ruy_test_fast_i16_i8_i32_i16
+  SRCS
+    test_fast.cc
+  COPTS
+    ${ruy_0_Wall_Wextra_Wundef}
+    ${ruy_1_mfpu_neon}
+    ${ruy_2_O3}
+    -DRUY_TEST_LHSSCALAR=i16
+    -DRUY_TEST_RHSSCALAR=i8
+    -DRUY_TEST_ACCUMSCALAR=i32
+    -DRUY_TEST_DSTSCALAR=i16
+  DEPS
+    ruy_test_lib
+    gtest_main
+)
+
+ruy_cc_test(
+  NAME
     ruy_test_slow_f32_f32_f32_f32
   SRCS
     test_slow.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=f32
@@ -1635,7 +1709,7 @@
   SRCS
     test_slow.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=u8
@@ -1655,7 +1729,7 @@
   SRCS
     test_slow.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=i8
@@ -1675,7 +1749,7 @@
   SRCS
     test_slow.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=u8
@@ -1695,7 +1769,7 @@
   SRCS
     test_slow.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
     -DRUY_TEST_LHSSCALAR=i8
@@ -1711,18 +1785,58 @@
 
 ruy_cc_test(
   NAME
+    ruy_test_slow_i8_i16_i32_i16
+  SRCS
+    test_slow.cc
+  COPTS
+    ${ruy_0_Wall_Wextra_Wundef}
+    ${ruy_1_mfpu_neon}
+    ${ruy_2_O3}
+    -DRUY_TEST_LHSSCALAR=i8
+    -DRUY_TEST_RHSSCALAR=i16
+    -DRUY_TEST_ACCUMSCALAR=i32
+    -DRUY_TEST_DSTSCALAR=i16
+  DEPS
+    ruy_test_lib
+    gtest_main
+  TAGS
+    slow
+)
+
+ruy_cc_test(
+  NAME
+    ruy_test_slow_i16_i8_i32_i16
+  SRCS
+    test_slow.cc
+  COPTS
+    ${ruy_0_Wall_Wextra_Wundef}
+    ${ruy_1_mfpu_neon}
+    ${ruy_2_O3}
+    -DRUY_TEST_LHSSCALAR=i16
+    -DRUY_TEST_RHSSCALAR=i8
+    -DRUY_TEST_ACCUMSCALAR=i32
+    -DRUY_TEST_DSTSCALAR=i16
+  DEPS
+    ruy_test_lib
+    gtest_main
+  TAGS
+    slow
+)
+
+ruy_cc_test(
+  NAME
     ruy_test_overflow_dst_zero_point
   SRCS
     test_overflow_dst_zero_point.cc
   COPTS
-    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_0_Wall_Wextra_Wundef}
     ${ruy_1_mfpu_neon}
     ${ruy_2_O3}
   DEPS
     ruy_gtest_wrapper
     ruy_matrix
-    ruy
     ruy_path
+    ruy
     ruy_test_lib
     ruy_tune
 )
diff --git a/ruy/allocator.cc b/ruy/allocator.cc
index 64da664..3b9bcf0 100644
--- a/ruy/allocator.cc
+++ b/ruy/allocator.cc
@@ -103,20 +103,24 @@
     return;
   }
 
-  // No rounding-up of the size means linear instead of logarithmic
+  // Free all memory before reallocating `ptr_`.
+  // This minimizes the memory high-water-mark.
+  detail::SystemAlignedFree(ptr_);
+  for (void* p : fallback_blocks_) {
+    detail::SystemAlignedFree(p);
+  }
+
+  // We reallocate to the exact new size, rather than growing
+  // exponentially like std::vector. This means linear instead of logarithmic
   // bound on the number of allocation in some worst-case calling patterns.
   // This is considered worth it because minimizing memory usage is important
   // and actual calling patterns in applications that we care about still
   // reach the no-further-allocations steady state in a small finite number
   // of iterations.
   std::ptrdiff_t new_size = size_ + fallback_blocks_total_size_;
-  detail::SystemAlignedFree(ptr_);
   ptr_ = detail::SystemAlignedAlloc(new_size);
   size_ = new_size;
 
-  for (void* p : fallback_blocks_) {
-    detail::SystemAlignedFree(p);
-  }
   fallback_blocks_.clear();
   fallback_blocks_total_size_ = 0;
 }
diff --git a/ruy/benchmark.cc b/ruy/benchmark.cc
index 3c63249..d551852 100644
--- a/ruy/benchmark.cc
+++ b/ruy/benchmark.cc
@@ -103,10 +103,20 @@
 }
 
 void Benchmark() {
+  // For now, support for int8*int16 cases is limited to the
+  // symmetric case (zero_point==0) because that appears to be
+  // the case in the initial use cases, and that limits complexity
+  // in thinking about accumulator overflows. This would not be a concern
+  // in the future if the accumulator type was int64, but for now its int32.
+  const bool is_int8_times_int16 =
+      (std::is_same<LhsScalar, std::int8_t>::value &&
+       std::is_same<RhsScalar, std::int16_t>::value) ||
+      (std::is_same<LhsScalar, std::int16_t>::value &&
+       std::is_same<RhsScalar, std::int8_t>::value);
   const bool symm_lhs = std::is_floating_point<LhsScalar>::value ||
-                        GetBoolEnvVarOrFalse("SYMM_LHS");
+                        is_int8_times_int16 || GetBoolEnvVarOrFalse("SYMM_LHS");
   const bool symm_rhs = std::is_floating_point<RhsScalar>::value ||
-                        GetBoolEnvVarOrFalse("SYMM_RHS");
+                        is_int8_times_int16 || GetBoolEnvVarOrFalse("SYMM_RHS");
   const bool benchmark_cubic = GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC") ||
                                GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC_LIST");
   const int explicit_rows = GetIntEnvVarOrZero("ROWS");
diff --git a/ruy/build_defs.bzl b/ruy/build_defs.bzl
index 836f47a..d7c6f21 100644
--- a/ruy/build_defs.bzl
+++ b/ruy/build_defs.bzl
@@ -4,7 +4,7 @@
 # Returns warnings flags to use for all ruy code.
 def ruy_copts_warnings():
     return select({
-        "@bazel_tools//src/conditions:windows": [
+        "//tools/cc_target_os:windows": [
             # We run into trouble on Windows toolchains with warning flags,
             # as mentioned in the comments below on each flag.
             # We could be more aggressive in enabling supported warnings on each
@@ -15,9 +15,6 @@
             "-Wall",
             # Some clang-based Windows toolchains have more warnings in -Wextra.
             "-Wextra",
-            # TensorFlow is C++14 at the moment. This flag ensures that we warn
-            # on any code that isn't C++14, but MSVC does not support it.
-            "-Wc++14-compat",
             # Warn on preprocessor expansion of an undefined token, e.g. catching
             # typos such as `#ifdef __linus__` instead of `#ifdef __linux__`.
             # Not supported by MSVC.
@@ -57,14 +54,14 @@
 def ruy_copts_avx():
     return select({
         "//ruy:x86_64_and_not_msvc": ["-mavx"],
-        "@bazel_tools//src/conditions:windows_msvc": ["/arch:AVX"],
+        "//tools/cc_target_os:windows_msvc": ["/arch:AVX"],
         "//conditions:default": [],
     })
 
 def ruy_copts_avx2_fma():
     return select({
         "//ruy:x86_64_and_not_msvc": ["-mavx2", "-mfma"],
-        "@bazel_tools//src/conditions:windows_msvc": ["/arch:AVX2"],
+        "//tools/cc_target_os:windows_msvc": ["/arch:AVX2"],
         "//conditions:default": [],
     })
 
@@ -74,6 +71,6 @@
     # in optimized builds (-c opt).
     return select({
         "//ruy:x86_64_and_not_msvc": ["$(STACK_FRAME_UNLIMITED)", "-mavx512f", "-mavx512vl", "-mavx512cd", "-mavx512bw", "-mavx512dq"],
-        "@bazel_tools//src/conditions:windows_msvc": ["/arch:AVX512"],
+        "//tools/cc_target_os:windows_msvc": ["/arch:AVX512"],
         "//conditions:default": [],
     })
diff --git a/ruy/build_defs.oss.bzl b/ruy/build_defs.oss.bzl
index e405b41..6d34ba6 100644
--- a/ruy/build_defs.oss.bzl
+++ b/ruy/build_defs.oss.bzl
@@ -10,6 +10,6 @@
     # with Bazel. Instead we do the following, which is copied from
     # https://github.com/abseil/abseil-cpp/blob/1112609635037a32435de7aa70a9188dcb591458/absl/base/BUILD.bazel#L155
     return select({
-        "@bazel_tools//src/conditions:windows": [],
+        "//tools/cc_target_os:windows": [],
         "//conditions:default": ["-pthread"],
     })
diff --git a/ruy/context.cc b/ruy/context.cc
index 342ce52..ec651f9 100644
--- a/ruy/context.cc
+++ b/ruy/context.cc
@@ -17,6 +17,7 @@
 
 #include "ruy/ctx.h"
 #include "ruy/ctx_impl.h"
+#include "ruy/strategy_controls.h"
 #include "ruy/path.h"
 #include "ruy/performance_advisory.h"
 #include "ruy/prepacked_cache.h"
@@ -44,6 +45,12 @@
 void Context::set_max_num_threads(int value) {
   mutable_ctx()->set_max_num_threads(value);
 }
+NumThreadsStrategy Context::num_threads_strategy() const {
+  return ctx().num_threads_strategy();
+}
+void Context::set_num_threads_strategy(NumThreadsStrategy strategy) {
+  mutable_ctx()->set_num_threads_strategy(strategy);
+}
 
 void Context::ClearPrepackedCache() { mutable_ctx()->ClearPrepackedCache(); }
 
diff --git a/ruy/context.h b/ruy/context.h
index f148f0f..16f40e7 100644
--- a/ruy/context.h
+++ b/ruy/context.h
@@ -28,6 +28,7 @@
 enum class Path : std::uint8_t;
 enum class Tuning;
 enum class PerformanceAdvisory;
+enum class NumThreadsStrategy : std::uint8_t;
 
 // A Context holds runtime information used by Ruy. It holds runtime resources
 // such as the workers thread pool and the allocator (which holds buffers for
@@ -71,6 +72,10 @@
   int max_num_threads() const;
   void set_max_num_threads(int value);
 
+  // Controls the logic to determine how many threads to use.
+  NumThreadsStrategy num_threads_strategy() const;
+  void set_num_threads_strategy(NumThreadsStrategy strategy);
+
   // Returns true of the last ruy::Mul using this Context flagged the specified
   // `advisory`. This is reset by each ruy::Mul call.
   bool performance_advisory(PerformanceAdvisory advisory) const;
diff --git a/ruy/context_test.cc b/ruy/context_test.cc
index 4e69e65..6497c77 100644
--- a/ruy/context_test.cc
+++ b/ruy/context_test.cc
@@ -16,6 +16,7 @@
 #include "ruy/context.h"
 
 #include "ruy/gtest_wrapper.h"
+#include "ruy/strategy_controls.h"
 #include "ruy/path.h"
 #include "ruy/prepacked_cache.h"
 #include "ruy/tune.h"
@@ -30,10 +31,14 @@
   EXPECT_EQ(&context.thread_pool(), context.mutable_thread_pool());
   EXPECT_NE(context.mutable_thread_pool(), nullptr);
   EXPECT_EQ(context.max_num_threads(), 1);
+  EXPECT_EQ(context.num_threads_strategy(), NumThreadsStrategy::kDefault);
   context.set_explicit_tuning(Tuning::kGeneric);
   context.set_max_num_threads(2);
+  context.set_num_threads_strategy(NumThreadsStrategy::kForceMaxNumThreads);
   EXPECT_EQ(context.explicit_tuning(), Tuning::kGeneric);
   EXPECT_EQ(context.max_num_threads(), 2);
+  EXPECT_EQ(context.num_threads_strategy(),
+            NumThreadsStrategy::kForceMaxNumThreads);
 }
 
 }  // namespace
diff --git a/ruy/cpuinfo.cc b/ruy/cpuinfo.cc
index a3e75d7..5daee0b 100644
--- a/ruy/cpuinfo.cc
+++ b/ruy/cpuinfo.cc
@@ -39,7 +39,7 @@
 }
 
 namespace {
-void QueryCacheParams(CpuCacheParams* cache_params) {
+bool QueryCacheParams(CpuCacheParams* cache_params) {
   const int processors_count = cpuinfo_get_processors_count();
   RUY_DCHECK_GT(processors_count, 0);
   int overall_local_cache_size = std::numeric_limits<int>::max();
@@ -56,11 +56,19 @@
         continue;  // continue, not break, it is possible to have L1+L3 but no
                    // L2.
       }
-      const bool is_local =
-          cpuinfo_get_processor(cache->processor_start)->core ==
-          cpuinfo_get_processor(cache->processor_start +
-                                cache->processor_count - 1)
-              ->core;
+      if (!cache->processor_count) {
+        // This may happen in a sand-boxed process, e.g.: a browser renderer.
+        continue;
+      }
+      const cpuinfo_processor* processor_start =
+          cpuinfo_get_processor(cache->processor_start);
+      const cpuinfo_processor* processor_end = cpuinfo_get_processor(
+          cache->processor_start + cache->processor_count - 1);
+      if (!processor_start || !processor_end) {
+        // This may happen in a sand-boxed process, e.g.: a browser renderer.
+        continue;
+      }
+      const bool is_local = processor_start->core == processor_end->core;
       if (is_local) {
         local_cache_size = cache->size;
       }
@@ -70,8 +78,9 @@
     if (!local_cache_size) {
       local_cache_size = last_level_cache_size;
     }
-    RUY_DCHECK_GT(local_cache_size, 0);
-    RUY_DCHECK_GT(last_level_cache_size, 0);
+    if (local_cache_size == 0 || last_level_cache_size == 0) {
+      return false;
+    }
     RUY_DCHECK_GE(last_level_cache_size, local_cache_size);
     overall_local_cache_size =
         std::min(overall_local_cache_size, local_cache_size);
@@ -80,6 +89,7 @@
   }
   cache_params->local_cache_size = overall_local_cache_size;
   cache_params->last_level_cache_size = overall_last_level_cache_size;
+  return true;
 }
 }  // end namespace
 
@@ -89,7 +99,10 @@
     MakeDummyCacheParams(&cache_params_);
     return InitStatus::kFailed;
   }
-  QueryCacheParams(&cache_params_);
+  if (!QueryCacheParams(&cache_params_)) {
+    MakeDummyCacheParams(&cache_params_);
+    return InitStatus::kFailed;
+  }
   return InitStatus::kInitialized;
 }
 
@@ -123,7 +136,12 @@
     return false;
   }
 
-  switch (cpuinfo_get_uarch(cpuinfo_get_current_uarch_index())->uarch) {
+  const struct cpuinfo_uarch_info* cpuinfo_uarch =
+      cpuinfo_get_uarch(cpuinfo_get_current_uarch_index());
+  if (!cpuinfo_uarch) {
+    return false;
+  }
+  switch (cpuinfo_uarch->uarch) {
     case cpuinfo_uarch_cortex_a53:
     case cpuinfo_uarch_cortex_a55r0:
     case cpuinfo_uarch_cortex_a55:
@@ -137,8 +155,12 @@
   if (!EnsureInitialized()) {
     return false;
   }
-  if (cpuinfo_get_uarch(cpuinfo_get_current_uarch_index())->uarch ==
-      cpuinfo_uarch_cortex_x1) {
+  const struct cpuinfo_uarch_info* cpuinfo_uarch =
+      cpuinfo_get_uarch(cpuinfo_get_current_uarch_index());
+  if (!cpuinfo_uarch) {
+    return false;
+  }
+  if (cpuinfo_uarch->uarch == cpuinfo_uarch_cortex_x1) {
     return true;
   }
   return false;
diff --git a/ruy/ctx.cc b/ruy/ctx.cc
index 0ef098d..5d6afd4 100644
--- a/ruy/ctx.cc
+++ b/ruy/ctx.cc
@@ -26,6 +26,7 @@
 #include "ruy/path.h"
 #include "ruy/performance_advisory.h"
 #include "ruy/platform.h"
+#include "ruy/strategy_controls.h"
 #include "ruy/prepacked_cache.h"
 #include "ruy/trace.h"
 
@@ -56,6 +57,12 @@
   return (impl().performance_advisory_ & advisory) !=
          PerformanceAdvisory::kNone;
 }
+void Ctx::set_num_threads_strategy(NumThreadsStrategy strategy) {
+  mutable_impl()->num_threads_strategy_ = strategy;
+}
+NumThreadsStrategy Ctx::num_threads_strategy() const {
+  return impl().num_threads_strategy_;
+}
 
 void Ctx::SetRuntimeEnabledPaths(Path paths) {
   if (paths == Path::kNone) {
diff --git a/ruy/ctx.h b/ruy/ctx.h
index df9dee2..f576a90 100644
--- a/ruy/ctx.h
+++ b/ruy/ctx.h
@@ -32,6 +32,7 @@
 enum class Path : std::uint8_t;
 enum class Tuning;
 enum class PerformanceAdvisory;
+enum class NumThreadsStrategy : std::uint8_t;
 
 // Ctx is the internal context class used throughout ruy code. Whereas Context
 // is exposed to users, Ctx is internal to ruy. As many of ruy's internal
@@ -53,6 +54,8 @@
   void clear_performance_advisories();
   void set_performance_advisory(PerformanceAdvisory advisory);
   bool performance_advisory(PerformanceAdvisory advisory) const;
+  void set_num_threads_strategy(NumThreadsStrategy strategy);
+  NumThreadsStrategy num_threads_strategy() const;
 
   // Returns the set of Path's that are available. By default, this is based on
   // runtime detection of CPU features, as well as on which code paths were
diff --git a/ruy/ctx_impl.h b/ruy/ctx_impl.h
index 0a07ef6..be64553 100644
--- a/ruy/ctx_impl.h
+++ b/ruy/ctx_impl.h
@@ -29,6 +29,7 @@
 #include "ruy/path.h"
 #include "ruy/performance_advisory.h"
 #include "ruy/prepacked_cache.h"
+#include "ruy/strategy_controls.h"
 #include "ruy/thread_pool.h"
 #include "ruy/tune.h"
 
@@ -63,6 +64,7 @@
   Tuning explicit_tuning_ = Tuning::kAuto;
   ThreadPool thread_pool_;
   int max_num_threads_ = 1;
+  NumThreadsStrategy num_threads_strategy_ = NumThreadsStrategy::kDefault;
   // Allocator for main thread work before invoking the threadpool.
   // Our simple Allocator does not allow reserving/allocating more blocks
   // while it's already in committed state, so the main thread needs both
diff --git a/ruy/ctx_test.cc b/ruy/ctx_test.cc
index e55dcfc..c40f2d6 100644
--- a/ruy/ctx_test.cc
+++ b/ruy/ctx_test.cc
@@ -15,6 +15,7 @@
 
 #include "ruy/ctx_impl.h"
 #include "ruy/gtest_wrapper.h"
+#include "ruy/strategy_controls.h"
 #include "ruy/path.h"
 #include "ruy/platform.h"
 
@@ -67,6 +68,14 @@
   }
 }
 
+TEST(ContextInternalTest, SetNumThreadsStrategy) {
+  CtxImpl ctx;
+  EXPECT_EQ(ctx.num_threads_strategy(), NumThreadsStrategy::kDefault);
+  ctx.set_num_threads_strategy(NumThreadsStrategy::kForceMaxNumThreads);
+  EXPECT_EQ(ctx.num_threads_strategy(),
+            NumThreadsStrategy::kForceMaxNumThreads);
+}
+
 }  // namespace
 }  // namespace ruy
 
diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc
index 8782dce..be0c267 100644
--- a/ruy/kernel_arm32.cc
+++ b/ruy/kernel_arm32.cc
@@ -282,16 +282,20 @@
         // Let r8 be stack offset of the row or column variable, whichever
         // is the channel index.
         "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
-        "ite eq\n"
-        "moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
-        "movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
+        "bne 1000f\n"
+        "mov r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
+        "b 1001f\n"
+        "1000:\n"
+        "mov r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
+        "1001:\n"
         // Let r8 be the channel index.
         "ldr r8, [sp, r8]\n"
         // Compute the bias pointer, by conditionally using the channel index
         // (r8) as offset into bias buffer (r1).
         "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
-        "it ne\n"
-        "addne r1, r1, r8, lsl #2\n"
+        "beq 1002f\n"
+        "add r1, r1, r8, lsl #2\n"
+        "1002:\n"
 
         // Load 4 bias values. When the channel dimension is rows, we will load
         // another 4 bias values just before performing the bias addition below,
@@ -630,7 +634,8 @@
   CheckOffsetsInKernelParams8bit(params);
 
   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const std::int8_t* rhs_col_ptr =
+      static_cast<const int8_t*>(params.rhs_base_ptr);
   const std::int8_t* lhs_ptr = lhs_col_ptr;
   const std::int8_t* rhs_ptr = rhs_col_ptr;
 
@@ -895,16 +900,21 @@
         // Let r8 be stack offset of the row or column variable, whichever
         // is the channel index.
         "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
-        "ite eq\n"
-        "moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
-        "movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
+        "bne 1000f\n"
+        "mov r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
+        "b 1001f\n"
+        "1000:\n"
+        "mov r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
+        "1001:\n"
+
         // Let r8 be the channel index.
         "ldr r8, [sp, r8]\n"
         // Compute the bias pointer, by conditionally using the channel index
         // (r8) as offset into bias buffer (r1).
         "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
-        "it ne\n"
-        "addne r1, r1, r8, lsl #2\n"
+        "beq 1002f\n"
+        "add r1, r1, r8, lsl #2\n"
+        "1002:\n"
 
         // Load 2 bias values. When the channel dimension is rows, we will load
         // another 2 bias values just before performing the bias addition below,
@@ -1011,10 +1021,10 @@
         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
         // r6 has flags, r8 has channel index
         "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
-        "it ne\n"
-        "addne r1, r1, r8, lsl #2\n"
-        "it ne\n"
-        "addne r2, r2, r8, lsl #2\n"
+        "beq 1003f\n"
+        "add r1, r1, r8, lsl #2\n"
+        "add r2, r2, r8, lsl #2\n"
+        "1003:\n"
 
         // Load the first 2 values of multiplier exponent and fixedpoint data
         // Since this kernel is rectangular 4x2, we will only conditionally load
@@ -1630,7 +1640,8 @@
   CheckOffsetsInKernelParams8bit(params);
 
   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const std::int8_t* rhs_col_ptr =
+      static_cast<const int8_t*>(params.rhs_base_ptr);
   const std::int8_t* lhs_ptr = lhs_col_ptr;
   const std::int8_t* rhs_ptr = rhs_col_ptr;
 
@@ -1868,8 +1879,9 @@
         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
 
         "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
-        "it ne\n"
-        "addne r1, r1, r8, lsl #2\n"
+        "beq 1000f\n"
+        "add r1, r1, r8, lsl #2\n"
+        "1000:\n"
 
         // Load 4 bias values.
         "vld1.32 {d24, d25}, [r1]\n"
@@ -1956,8 +1968,9 @@
         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
         "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
-        "it ne\n"
-        "addne r1, r1, r4, lsl #2\n"
+        "beq 1001f\n"
+        "add r1, r1, r4, lsl #2\n"
+        "1001:\n"
 
         "vld1.32 {q10}, [r1]\n"
 
@@ -1972,8 +1985,9 @@
         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
         // r6 has flags, r4 has row
         "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
-        "it ne\n"
-        "addne r1, r1, r4, lsl #2\n"
+        "beq 1002f\n"
+        "add r1, r1, r4, lsl #2\n"
+        "1002:\n"
         "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint
 
         // Apply the fixed-point part of the multiplier.
diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc
index 5424107..532138d 100644
--- a/ruy/kernel_arm64.cc
+++ b/ruy/kernel_arm64.cc
@@ -101,7 +101,8 @@
   CheckOffsetsInKernelParams8bit(params);
 
   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const std::int8_t* rhs_col_ptr =
+      static_cast<const int8_t*>(params.rhs_base_ptr);
   const std::int8_t* lhs_ptr = lhs_col_ptr;
   const std::int8_t* rhs_ptr = rhs_col_ptr;
   void* dst_col_ptr = params.dst_base_ptr;
@@ -1160,7 +1161,8 @@
   CheckOffsetsInKernelParams8bit(params);
 
   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const std::int8_t* rhs_col_ptr =
+      static_cast<const int8_t*>(params.rhs_base_ptr);
   const std::int8_t* lhs_ptr = lhs_col_ptr;
   const std::int8_t* rhs_ptr = rhs_col_ptr;
   void* dst_col_ptr = params.dst_base_ptr;
@@ -1832,7 +1834,8 @@
   CheckOffsetsInKernelParams8bit(params);
 
   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const std::int8_t* rhs_col_ptr =
+      static_cast<const int8_t*>(params.rhs_base_ptr);
   const std::int8_t* lhs_ptr = lhs_col_ptr;
   const std::int8_t* rhs_ptr = rhs_col_ptr;
   void* dst_col_ptr = params.dst_base_ptr;
@@ -2987,7 +2990,8 @@
   CheckOffsetsInKernelParams8bit(params);
 
   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const std::int8_t* rhs_col_ptr =
+      static_cast<const int8_t*>(params.rhs_base_ptr);
   const std::int8_t* lhs_ptr = lhs_col_ptr;
   const std::int8_t* rhs_ptr = rhs_col_ptr;
   void* dst_col_ptr = params.dst_base_ptr;
@@ -4413,7 +4417,8 @@
   CheckOffsetsInKernelParams8bit(params);
 
   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const std::int8_t* rhs_col_ptr =
+      static_cast<const int8_t*>(params.rhs_base_ptr);
   const std::int8_t* lhs_ptr = lhs_col_ptr;
   const std::int8_t* rhs_ptr = rhs_col_ptr;
   void* dst_col_ptr = params.dst_base_ptr;
@@ -5667,7 +5672,8 @@
   CheckOffsetsInKernelParams8bit(params);
 
   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const std::int8_t* rhs_col_ptr =
+      static_cast<const int8_t*>(params.rhs_base_ptr);
   const std::int8_t* lhs_ptr = lhs_col_ptr;
   const std::int8_t* rhs_ptr = rhs_col_ptr;
   void* dst_col_ptr = params.dst_base_ptr;
@@ -6362,7 +6368,8 @@
   CheckOffsetsInKernelParams8bit(params);
 
   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const std::int8_t* rhs_col_ptr =
+      static_cast<const int8_t*>(params.rhs_base_ptr);
   const std::int8_t* lhs_ptr = lhs_col_ptr;
   const std::int8_t* rhs_ptr = rhs_col_ptr;
   void* dst_col_ptr = params.dst_base_ptr;
diff --git a/ruy/kernel_avx.cc b/ruy/kernel_avx.cc
index 2405735..0f7e2e3 100644
--- a/ruy/kernel_avx.cc
+++ b/ruy/kernel_avx.cc
@@ -462,7 +462,8 @@
     RUY_DCHECK(false);
   }
 
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const std::int8_t* rhs_col_ptr =
+      static_cast<const int8_t*>(params.rhs_base_ptr);
   void* dst_col_ptr = params.dst_base_ptr;
 
   for (int col = params.start_col; col <= params.last_col;
@@ -1184,7 +1185,8 @@
   int bias_ptr_block_increment =
       params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
 
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const std::int8_t* rhs_col_ptr =
+      static_cast<const int8_t*>(params.rhs_base_ptr);
   void* dst_col_ptr = params.dst_base_ptr;
   const std::int32_t* bias_col_ptr = params.bias;
   if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc
index eae333c..e725777 100644
--- a/ruy/kernel_avx2_fma.cc
+++ b/ruy/kernel_avx2_fma.cc
@@ -121,7 +121,7 @@
     RUY_DCHECK(false);
   }
 
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const void* rhs_col_ptr = params.rhs_base_ptr;
   void* dst_col_ptr = params.dst_base_ptr;
 
   for (int col = params.start_col; col <= params.last_col;
@@ -251,7 +251,7 @@
       }
 
       const std::int8_t* lhs_ptr = lhs_col_ptr;
-      const std::int8_t* rhs_ptr = rhs_col_ptr;
+      const void* rhs_ptr = rhs_col_ptr;
       for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
         const __m256i lhs_data =
             _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
@@ -259,21 +259,29 @@
             _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
 
         // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
-        std::int32_t rhs_data[16];
-        const __m128i rhs_data_bottom_lane =
-            _mm256_castsi256_si128(rhs_data_8bit);
-        const __m128i rhs_data_top_lane =
-            _mm256_extracti128_si256(rhs_data_8bit, 1);
-        const __m256i rhs_16_bit_dup_low =
-            _mm256_cvtepi8_epi16(rhs_data_bottom_lane);
-        const __m256i rhs_16_bit_dup_high =
-            _mm256_cvtepi8_epi16(rhs_data_top_lane);
-        // Now that we have cast the RHS data, we store it so that each value
-        // can be separately loaded in the accumulation loop.
-        _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data),
-                            rhs_16_bit_dup_low);
-        _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8),
-                            rhs_16_bit_dup_high);
+        std::int32_t rhs_data_buf[16];
+        const std::int32_t* rhs_data =
+            reinterpret_cast<const std::int32_t*>(rhs_ptr);
+
+        if (params.rhs_scalar_size == 1) {
+          rhs_data = rhs_data_buf;
+          const __m128i rhs_data_bottom_lane =
+              _mm256_castsi256_si128(rhs_data_8bit);
+          const __m128i rhs_data_top_lane =
+              _mm256_extracti128_si256(rhs_data_8bit, 1);
+          const __m256i rhs_16_bit_dup_low =
+              _mm256_cvtepi8_epi16(rhs_data_bottom_lane);
+          const __m256i rhs_16_bit_dup_high =
+              _mm256_cvtepi8_epi16(rhs_data_top_lane);
+          // Now that we have cast the RHS data, we store it so that each value
+          // can be separately loaded in the accumulation loop.
+          _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf),
+                              rhs_16_bit_dup_low);
+          _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf + 8),
+                              rhs_16_bit_dup_high);
+        } else {
+          RUY_DCHECK(params.rhs_scalar_size == 2);
+        }
 
         const __m256i lhs_data_split =
             _mm256_shuffle_epi8(lhs_data, splitter_idx);
@@ -339,7 +347,9 @@
         process_column(tmp2, tmp3, accum_data_v7);
 
         lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
-        rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+        rhs_ptr = static_cast<const void*>(
+            static_cast<const char*>(rhs_ptr) +
+            kAvx8bitBlockSize * kAvx8bitInnerSize * params.rhs_scalar_size);
       }
 
       if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
@@ -717,7 +727,9 @@
 
     dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
                                      kAvx8bitBlockSize * params.dst_stride);
-    rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+    rhs_col_ptr =
+        static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) +
+                                 kAvx8bitBlockSize * params.rhs_stride);
   }  // End col-block loop.
 }  // NOLINT(readability/fn_size)
 
@@ -743,7 +755,7 @@
   int bias_ptr_block_increment =
       params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
 
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const void* rhs_col_ptr = params.rhs_base_ptr;
   void* dst_col_ptr = params.dst_base_ptr;
   const std::int32_t* bias_col_ptr = params.bias;
   if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
@@ -807,20 +819,29 @@
     }
 
     const std::int8_t* lhs_ptr = lhs_col_ptr;
-    const std::int8_t* rhs_ptr = rhs_col_ptr;
+    const void* rhs_ptr = rhs_col_ptr;
     for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
       const __m256i lhs_data =
           _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
-      const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr);
+      const std::int32_t* rhs_data =
+          reinterpret_cast<const std::int32_t*>(rhs_ptr);
 
       // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
       // For simplicity we load 4x the data that we need and process twice the
       // data  that we need  and store only the data we need.
-      std::int32_t rhs_data[2];
-      const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
-      // Now that we have cast the RHS data, we store it so that each value
-      // can be separately loaded in the accumulation loop.
-      _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
+      std::int32_t rhs_data_buf[2];
+      if (params.rhs_scalar_size == 1) {
+        rhs_data = rhs_data_buf;
+        const __m128i rhs_data_8bit =
+            intrin_utils::mm_loadu_si32<path>(rhs_ptr);
+        const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
+        // Now that we have cast the RHS data, we store it so that each value
+        // can be separately loaded in the accumulation loop.
+        _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf),
+                        rhs_16_bit_dup);
+      } else {
+        RUY_DCHECK(params.rhs_scalar_size == 2);
+      }
 
       // NOTE: There may be opportunities for permuting the data in the packing
       // code instead of here.
@@ -851,7 +872,9 @@
           _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
 
       lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
-      rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+      rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
+                                         kAvx8bitBlockSize * kAvx8bitInnerSize *
+                                             params.rhs_scalar_size);
     }
 
     if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
@@ -989,7 +1012,8 @@
 
   dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
                                    kAvx8bitBlockSize * params.dst_stride);
-  rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+  rhs_col_ptr = static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) +
+                                         kAvx8bitBlockSize * params.rhs_stride);
 }  // NOLINT(readability/fn_size)
 
 void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
index 84b9380..654ba27 100644
--- a/ruy/kernel_avx512.cc
+++ b/ruy/kernel_avx512.cc
@@ -67,7 +67,7 @@
     RUY_DCHECK(false);
   }
 
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const void* rhs_col_ptr = params.rhs_base_ptr;
   void* dst_col_ptr = params.dst_base_ptr;
 
   for (int col = params.start_col; col <= params.last_col; col += 16) {
@@ -247,27 +247,34 @@
       }
 
       const std::int8_t* lhs_ptr = lhs_col_ptr;
-      const std::int8_t* rhs_ptr = rhs_col_ptr;
+      const void* rhs_ptr = rhs_col_ptr;
       for (int d = 0; d < params.depth; d += 4) {
         const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
         __m512i rhs_data_8bit = _mm512_loadu_si512(rhs_ptr);
 
         // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
-        std::int32_t rhs_data[32];
-        const __m256i rhs_data_bottom_lane =
-            _mm512_castsi512_si256(rhs_data_8bit);
-        const __m256i rhs_data_top_lane =
-            _mm512_extracti32x8_epi32(rhs_data_8bit, 1);
-        const __m512i rhs_16_bit_dup_low =
-            _mm512_cvtepi8_epi16(rhs_data_bottom_lane);
-        const __m512i rhs_16_bit_dup_high =
-            _mm512_cvtepi8_epi16(rhs_data_top_lane);
-        // Now that we have cast the RHS data, we store it so that each value
-        // can be separately loaded in the accumulation loop.
-        _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data),
-                            rhs_16_bit_dup_low);
-        _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16),
-                            rhs_16_bit_dup_high);
+        std::int32_t rhs_data_buf[32];
+        const std::int32_t* rhs_data =
+            reinterpret_cast<const std::int32_t*>(rhs_ptr);
+        if (params.rhs_scalar_size == 1) {
+          rhs_data = rhs_data_buf;
+          const __m256i rhs_data_bottom_lane =
+              _mm512_castsi512_si256(rhs_data_8bit);
+          const __m256i rhs_data_top_lane =
+              _mm512_extracti32x8_epi32(rhs_data_8bit, 1);
+          const __m512i rhs_16_bit_dup_low =
+              _mm512_cvtepi8_epi16(rhs_data_bottom_lane);
+          const __m512i rhs_16_bit_dup_high =
+              _mm512_cvtepi8_epi16(rhs_data_top_lane);
+          // Now that we have cast the RHS data, we store it so that each value
+          // can be separately loaded in the accumulation loop.
+          _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data_buf),
+                              rhs_16_bit_dup_low);
+          _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data_buf + 16),
+                              rhs_16_bit_dup_high);
+        } else {
+          RUY_DCHECK(params.rhs_scalar_size == 2);
+        }
 
         // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
         const __m512i lhs_16_bit_low =
@@ -305,7 +312,8 @@
         process_column(15, accum_data_vf);
 
         lhs_ptr += 16 * 4;
-        rhs_ptr += 16 * 4;
+        rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
+                                           16 * 4 * params.rhs_scalar_size);
       }
 
       if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
@@ -612,7 +620,8 @@
 
     dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
                                      16 * params.dst_stride);
-    rhs_col_ptr += 16 * params.rhs_stride;
+    rhs_col_ptr = static_cast<const void*>(
+        static_cast<const char*>(rhs_col_ptr) + 16 * params.rhs_stride);
   }  // End col-block loop.
 }  // NOLINT(readability/fn_size)
 
@@ -625,7 +634,7 @@
 
   int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
 
-  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  const void* rhs_col_ptr = params.rhs_base_ptr;
   void* dst_col_ptr = params.dst_base_ptr;
   const std::int32_t* bias_col_ptr = params.bias;
   if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
@@ -684,20 +693,28 @@
     }
 
     const std::int8_t* lhs_ptr = lhs_col_ptr;
-    const std::int8_t* rhs_ptr = rhs_col_ptr;
+    const void* rhs_ptr = rhs_col_ptr;
     for (int d = 0; d < params.depth; d += 4) {
       const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
-      const __m128i rhs_data_8bit =
-          _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr));
+      const std::int32_t* rhs_data =
+          reinterpret_cast<const std::int32_t*>(rhs_ptr);
 
       // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
       // For simplicity we load 4x the data that we need and process twice the
       // data  that we need  and store only the data we need.
-      std::int32_t rhs_data[2];
-      const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
-      // Now that we have cast the RHS data, we store it so that each value
-      // can be separately loaded in the accumulation loop.
-      _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
+      std::int32_t rhs_data_buf[2];
+      if (params.rhs_scalar_size == 1) {
+        rhs_data = rhs_data_buf;
+        const __m128i rhs_data_8bit =
+            _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr));
+        const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
+        // Now that we have cast the RHS data, we store it so that each value
+        // can be separately loaded in the accumulation loop.
+        _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf),
+                        rhs_16_bit_dup);
+      } else {
+        RUY_DCHECK(params.rhs_scalar_size == 2);
+      }
 
       // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
       const __m512i lhs_16_bit_low =
@@ -721,7 +738,8 @@
       accum_data_v0 = accum_v;
 
       lhs_ptr += 16 * 4;
-      rhs_ptr += 16 * 4;
+      rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
+                                         16 * 4 * params.rhs_scalar_size);
     }
 
     if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h
index cff243b..69e819b 100644
--- a/ruy/kernel_common.h
+++ b/ruy/kernel_common.h
@@ -101,7 +101,8 @@
   const std::int8_t* lhs_base_ptr;
   const std::int32_t* multiplier_fixedpoint;
   const std::int32_t* multiplier_exponent;
-  const std::int8_t* rhs_base_ptr;
+  // Make it void* to support 8bit(LHS)x16bit(RHS) case.
+  const void* rhs_base_ptr;
   void* dst_base_ptr;
   std::int32_t lhs_zero_point;
   std::int32_t rhs_zero_point;
@@ -125,11 +126,12 @@
   std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize];
   std::int32_t multiplier_fixedpoint_buf[LhsCols];
   std::int32_t multiplier_exponent_buf[LhsCols];
+  std::size_t rhs_scalar_size;
 };
 
-template <typename DstScalar, int LhsCols, int RhsCols>
+template <typename RhsScalar, typename DstScalar, int LhsCols, int RhsCols>
 void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
-                          const PMat<std::int8_t>& rhs,
+                          const PMat<RhsScalar>& rhs,
                           const MulParams<std::int32_t, DstScalar>& mul_params,
                           int start_row, int start_col, int end_row,
                           int end_col, Mat<DstScalar>* dst,
@@ -145,6 +147,7 @@
   RUY_DCHECK_EQ(end_col % RhsCols, 0);
 
   params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
+  params->rhs_scalar_size = sizeof(RhsScalar);
   params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
   params->flags = 0;
   params->bias = params->zero_data;
@@ -168,7 +171,7 @@
   params->last_row = end_row - LhsCols;
   params->last_col = end_col - RhsCols;
   params->lhs_stride = lhs.layout.stride;
-  params->rhs_stride = rhs.layout.stride;
+  params->rhs_stride = params->rhs_scalar_size * rhs.layout.stride;
   params->dst_stride = sizeof(DstScalar) * dst->layout.stride;
   params->lhs_zero_point = lhs.zero_point;
   params->rhs_zero_point = rhs.zero_point;
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
index b716502..51787b9 100644
--- a/ruy/kernel_x86.h
+++ b/ruy/kernel_x86.h
@@ -31,8 +31,8 @@
 
 #if RUY_PLATFORM_X86
 
-RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx2Fma)
 RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx)
+RUY_INHERIT_KERNEL(Path::kAvx, Path::kAvx2Fma)
 RUY_INHERIT_KERNEL(Path::kAvx2Fma, Path::kAvx512)
 
 void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
@@ -60,6 +60,29 @@
   }
 };
 
+template <typename DstScalar>
+struct Kernel<Path::kAvx512, std::int8_t, std::int16_t, std::int32_t,
+              DstScalar> {
+  static constexpr Path kPath = Path::kAvx512;
+  Tuning tuning = Tuning::kAuto;
+  using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+  using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+  void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs,
+           const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
+           int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
+    KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+    MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
+                         end_col, dst, &params);
+    if (dst->layout.cols == 1 &&
+        mul_params.channel_dimension() == ChannelDimension::kRow) {
+      Kernel8bitAvx512SingleCol(params);
+    } else {
+      Kernel8bitAvx512(params);
+    }
+  }
+};
+
 void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
 void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param);
 
@@ -111,6 +134,29 @@
   }
 };
 
+template <typename DstScalar>
+struct Kernel<Path::kAvx2Fma, std::int8_t, std::int16_t, std::int32_t,
+              DstScalar> {
+  static constexpr Path kPath = Path::kAvx2Fma;
+  Tuning tuning = Tuning::kAuto;
+  using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+  using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+  void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs,
+           const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
+           int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
+    KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+    MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
+                         end_col, dst, &params);
+    if (dst->layout.cols == 1 &&
+        mul_params.channel_dimension() == ChannelDimension::kRow) {
+      Kernel8bitAvx2SingleCol(params);
+    } else {
+      Kernel8bitAvx2(params);
+    }
+  }
+};
+
 void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params);
 void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params);
 
diff --git a/ruy/pack_arm.cc b/ruy/pack_arm.cc
index c337986..91f823b 100644
--- a/ruy/pack_arm.cc
+++ b/ruy/pack_arm.cc
@@ -1592,7 +1592,7 @@
                                     int packed_stride, std::int32_t* sums_ptr,
                                     int input_xor) {
   profiler::ScopeLabel label("Pack (kNeonDotprod, from row-major)");
-  asm(
+  asm volatile(
       // clang-format off
           // Prefetch data. This was tuned on Cortex-A55-rev1 cores.
           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0]]\n")
diff --git a/ruy/pack_avx512.cc b/ruy/pack_avx512.cc
index 5281fa8..29a1850 100644
--- a/ruy/pack_avx512.cc
+++ b/ruy/pack_avx512.cc
@@ -38,6 +38,12 @@
   RUY_DCHECK(false);
 }
 
+void Pack16bitColMajorForAvx512(const std::int16_t*, const std::int16_t*, int,
+                                int, int, std::int16_t*, std::int32_t*) {
+  // CPU-ID-based checks should disable the path that would reach this point.
+  RUY_DCHECK(false);
+}
+
 void PackFloatColMajorForAvx512(const float*, const float*, int, int, int,
                                 float*) {
   // CPU-ID-based checks should disable the path that would reach this point.
@@ -56,20 +62,24 @@
 using PackImpl8bitAvx512 =
     PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
              std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>;
+using PackImpl16bitAvx512 =
+    PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+             std::int16_t, std::int16_t, std::int32_t, Order::kColMajor>;
 
 namespace {
 
-inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point,
-                               std::int8_t* packed_ptr) {
-  using Layout = PackImpl8bitAvx512::Layout;
+template <typename PackImplAvx512, typename Scalar>
+inline void ZeroHalfAvx512(int src_rows, Scalar packed_zero_point,
+                           Scalar* packed_ptr, int chunked_row_mask) {
+  using Layout = typename PackImplAvx512::Layout;
   static constexpr int kHalfLayoutCols =
-      PackImpl8bitAvx512::kHalfLayoutCols;  // Half the number of cols in a
-                                            // block.
+      PackImplAvx512::kHalfLayoutCols;  // Half the number of cols in a
+                                        // block.
   RUY_DCHECK_EQ(kHalfLayoutCols, 8);
   RUY_DCHECK_EQ(Layout::kCols, 16);
   RUY_DCHECK_EQ(Layout::kRows, 4);
 
-  const int non_trailing_blocks = (src_rows & ~31) >> 2;
+  const int non_trailing_blocks = (src_rows & ~chunked_row_mask) >> 2;
   // This routine fills half blocks, and typically fills the second halves.
   // Thus packed_ptr is already offset by 8 * 4.
   for (int k = 0; k < non_trailing_blocks; ++k) {
@@ -79,8 +89,8 @@
   }
 }
 
-inline __m512i LoaduTwo(const std::int8_t* addr_lo,
-                        const std::int8_t* addr_hi) {
+template <typename Scalar>
+inline __m512i LoaduTwo(const Scalar* addr_lo, const Scalar* addr_hi) {
   __m512i lower_filled = _mm512_castsi256_si512(
       _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_lo)));
   return _mm512_inserti32x8(
@@ -98,6 +108,16 @@
       1);
 }
 
+inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
+                            const std::int16_t* addr_lo,
+                            const std::int16_t* addr_hi) {
+  const __m512i lower_filled = _mm512_castsi256_si512(
+      _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_lo));
+  return _mm512_inserti32x8(
+      lower_filled, _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_hi),
+      1);
+}
+
 inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
                                std::int8_t input_xor,
                                const std::int8_t* zerobuf, int src_stride,
@@ -454,6 +474,193 @@
   }
 }
 
+inline void HalfPack16bitAvx512(const std::int16_t* src_ptr,
+                                const std::int16_t* zerobuf, int src_stride,
+                                int remaining_src_cols, int src_rows,
+                                std::int16_t* packed_ptr,
+                                std::int32_t* sums_ptr,
+                                std::int16_t* trailing_buf) {
+  using Layout = PackImpl16bitAvx512::Layout;
+  RUY_DCHECK_EQ(Layout::kCols, 16);
+  RUY_DCHECK_EQ(Layout::kRows, 4);
+  // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+  // We process 4 of these chunks at a time, padding std::int16_t input chunks.
+  constexpr int kNumRowChunks = 4;
+  constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
+
+  const std::int16_t* src_ptr0 = src_ptr;
+  const std::int16_t* src_ptr1 = src_ptr0 + src_stride;
+  const std::int16_t* src_ptr2 = src_ptr1 + src_stride;
+  const std::int16_t* src_ptr3 = src_ptr2 + src_stride;
+  const std::int16_t* src_ptr4 = src_ptr3 + src_stride;
+  const std::int16_t* src_ptr5 = src_ptr4 + src_stride;
+  const std::int16_t* src_ptr6 = src_ptr5 + src_stride;
+  const std::int16_t* src_ptr7 = src_ptr6 + src_stride;
+  std::int64_t src_inc0 = kNumChunkedSrcRows;
+  std::int64_t src_inc1 = kNumChunkedSrcRows;
+  std::int64_t src_inc2 = kNumChunkedSrcRows;
+  std::int64_t src_inc3 = kNumChunkedSrcRows;
+  std::int64_t src_inc4 = kNumChunkedSrcRows;
+  std::int64_t src_inc5 = kNumChunkedSrcRows;
+  std::int64_t src_inc6 = kNumChunkedSrcRows;
+  std::int64_t src_inc7 = kNumChunkedSrcRows;
+  // Handle cases where source does not have kHalfLayoutCols (8) columns.
+  if (remaining_src_cols < 8) {
+    if (remaining_src_cols <= 0) {
+      src_ptr0 = zerobuf;
+      src_inc0 = 0;
+    }
+    if (remaining_src_cols <= 1) {
+      src_ptr1 = zerobuf;
+      src_inc1 = 0;
+    }
+    if (remaining_src_cols <= 2) {
+      src_ptr2 = zerobuf;
+      src_inc2 = 0;
+    }
+    if (remaining_src_cols <= 3) {
+      src_ptr3 = zerobuf;
+      src_inc3 = 0;
+    }
+    if (remaining_src_cols <= 4) {
+      src_ptr4 = zerobuf;
+      src_inc4 = 0;
+    }
+    if (remaining_src_cols <= 5) {
+      src_ptr5 = zerobuf;
+      src_inc5 = 0;
+    }
+    if (remaining_src_cols <= 6) {
+      src_ptr6 = zerobuf;
+      src_inc6 = 0;
+    }
+    src_ptr7 = zerobuf;
+    src_inc7 = 0;
+  }
+
+  const std::int16_t zero_point = zerobuf[0];
+
+  if (sums_ptr) {
+    // i: kHalfLayoutCols.
+    for (int i = 0; i < 8; ++i) {
+      sums_ptr[i] = 0;
+    }
+  }
+  std::int32_t sums_adjustment = 0;
+  const __m512i ones_16bit = _mm512_set1_epi16(1);
+  __m512i sums_8x2_32bit = _mm512_set1_epi32(0);
+
+  // The overall packing effectively pads the source rows to
+  // (src_rows + 31) & ~31. The iteration over k may skip when m=1, and then we
+  // only pack for (src_rows + 15) & ~15. When there is an incomplete
+  // destination block, this is stored into trailing_buf instead of packed_ptr.
+  for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
+    // m: {0, 1} for 2 chunks of rows.
+    for (int m = 0; m < 2; ++m) {
+      const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
+
+      // Available source rows.
+      // If this is less than 0 (for m=1), we skip, having filled trailing
+      // buffer for m=0. Also, if source rows is zero on m=1, then we filled
+      // exactly to the end of the column in the packed buffer.
+      if (available_src_rows > 0) {
+        __m512i t0, t1, t2, t3;
+        __m512i r0, r1, r2, r3;
+        std::int16_t* dst_ptr = packed_ptr;
+
+        if (available_src_rows >= kNumChunkedSrcRows) {
+          t0 = LoaduTwo(src_ptr0, src_ptr4);
+          t1 = LoaduTwo(src_ptr1, src_ptr5);
+          t2 = LoaduTwo(src_ptr2, src_ptr6);
+          t3 = LoaduTwo(src_ptr3, src_ptr7);
+        } else {
+          RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
+          // We do not care what goes into the trailing buffer, but we want
+          // in_data[...] == zero_point for irrelevant values in the summation.
+          //
+          // We compensate for padding-with-zero_point by initializing the
+          // summations with the compensating offset.
+          sums_adjustment +=
+              -(zero_point)*4 * (4 - ((available_src_rows + 3) >> 2));
+
+          const __m256i zero_point_v = _mm256_set1_epi16(zero_point);
+          const __mmask32 row_mask =
+              (static_cast<std::uint64_t>(1) << available_src_rows) - 1;
+
+          t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4);
+          t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5);
+          t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6);
+          t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7);
+          dst_ptr = trailing_buf;
+        }
+
+        r0 = _mm512_unpacklo_epi64(t0, t1);
+        r2 = _mm512_unpackhi_epi64(t0, t1);
+        r1 = _mm512_unpacklo_epi64(t2, t3);
+        r3 = _mm512_unpackhi_epi64(t2, t3);
+
+        r1 = _mm512_permutex_epi64(r1, 0x4e);
+        r3 = _mm512_permutex_epi64(r3, 0x4e);
+
+        t0 = _mm512_mask_blend_epi64(0xcc, r0, r1);
+        t1 = _mm512_mask_blend_epi64(0x33, r0, r1);
+        t2 = _mm512_mask_blend_epi64(0xcc, r2, r3);
+        t3 = _mm512_mask_blend_epi64(0x33, r2, r3);
+
+        t1 = _mm512_permutex_epi64(t1, 0x4e);
+        t3 = _mm512_permutex_epi64(t3, 0x4e);
+
+        _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 0 * 16 * 4),
+                            t0);
+        _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 2 * 16 * 4),
+                            t1);
+        _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 1 * 16 * 4),
+                            t2);
+        _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 3 * 16 * 4),
+                            t3);
+
+        if (sums_ptr) {
+          sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
+                                            _mm512_madd_epi16(t0, ones_16bit));
+          sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
+                                            _mm512_madd_epi16(t1, ones_16bit));
+          sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
+                                            _mm512_madd_epi16(t2, ones_16bit));
+          sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
+                                            _mm512_madd_epi16(t3, ones_16bit));
+        }
+      }
+
+      packed_ptr += 16 * kNumChunkedSrcRows;
+      src_ptr0 += src_inc0;
+      src_ptr1 += src_inc1;
+      src_ptr2 += src_inc2;
+      src_ptr3 += src_inc3;
+      src_ptr4 += src_inc4;
+      src_ptr5 += src_inc5;
+      src_ptr6 += src_inc6;
+      src_ptr7 += src_inc7;
+    }
+  }
+
+  if (sums_ptr) {
+    const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
+
+    __m256i sums =
+        _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
+    const __m512i idx =
+        _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
+
+    const __m512i sums_2x8_32bit =
+        _mm512_permutexvar_epi32(idx, sums_8x2_32bit);
+    sums = _mm256_add_epi32(sums, sums_adjustment_v);
+    sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit));
+    sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1));
+
+    _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
+  }
+}
+
 inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) {
   const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo));
   return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1);
@@ -658,6 +865,7 @@
       kNumRowChunks * Layout::kCols * Layout::kRows;
   std::int8_t trailing_buf[kTrailingBufSize];
   memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
+  constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
 
   std::int32_t* second_sums_ptr =
       sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
@@ -674,8 +882,9 @@
     HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
                        remaining_src_cols, src_rows, packed_ptr, sums_ptr,
                        trailing_buf);
-    ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor,
-                       packed_ptr + kHalfBlockOffset);
+    ZeroHalfAvx512<PackImpl8bitAvx512, std::int8_t>(
+        src_rows, zerobuf[0] ^ input_xor, packed_ptr + kHalfBlockOffset,
+        kChunkedRowMask);
     // The kernel may not need the second half-blocks sums to be set.
     if (second_sums_ptr) {
       for (int i = 0; i < kHalfLayoutCols; ++i) {
@@ -683,7 +892,6 @@
       }
     }
   }
-  constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
   const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
   // If the number of source rows is not a multiple of kChunkedRowMask, there
   // will be data in the trailing buffer,
@@ -697,6 +905,68 @@
   }
 }
 
+void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr,
+                                const std::int16_t* zerobuf, int src_stride,
+                                int remaining_src_cols, int src_rows,
+                                std::int16_t* packed_ptr,
+                                std::int32_t* sums_ptr) {
+  profiler::ScopeLabel label("Pack kAvx512 16bit");
+
+  using Layout = PackImpl16bitAvx512::Layout;
+  constexpr int kHalfBlockOffset = 32;
+  RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols);
+  static constexpr int kHalfLayoutCols =
+      PackImpl16bitAvx512::kHalfLayoutCols;  // Half the number of cols in a
+                                             // block.
+  RUY_DCHECK_EQ(kHalfLayoutCols, 8);
+  RUY_DCHECK_EQ(Layout::kCols, 16);
+  RUY_DCHECK_EQ(Layout::kRows, 4);
+
+  // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+  // We process 8 of these chunks at a time, padding short input chunks.
+  constexpr int kNumRowChunks = 4;
+
+  // Each packed block is 4*16, and there are normally 8. The trailing block is
+  // only slightly shorter.
+  constexpr int kTrailingBufSize =
+      kNumRowChunks * Layout::kCols * Layout::kRows;
+  std::int16_t trailing_buf[kTrailingBufSize] = {0};
+  constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
+
+  std::int32_t* second_sums_ptr =
+      sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
+  if (remaining_src_cols > kHalfLayoutCols) {
+    HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
+                        src_rows, packed_ptr, sums_ptr, trailing_buf);
+    HalfPack16bitAvx512(src_ptr + src_stride * kHalfLayoutCols, zerobuf,
+                        src_stride, remaining_src_cols - kHalfLayoutCols,
+                        src_rows, packed_ptr + kHalfBlockOffset,
+                        second_sums_ptr, trailing_buf + kHalfBlockOffset);
+  } else {
+    HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
+                        src_rows, packed_ptr, sums_ptr, trailing_buf);
+    ZeroHalfAvx512<PackImpl16bitAvx512, std::int16_t>(
+        src_rows, zerobuf[0], packed_ptr + kHalfBlockOffset, kChunkedRowMask);
+    // The kernel may not need the second half-blocks sums to be set.
+    if (second_sums_ptr) {
+      for (int i = 0; i < kHalfLayoutCols; ++i) {
+        second_sums_ptr[i] = (zerobuf[0]) * ((src_rows + 3) & ~3);
+      }
+    }
+  }
+  const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
+  // If the number of source rows is not a multiple of kChunkedRowMask, there
+  // will be data in the trailing buffer,
+  if (trailing_data) {
+    const int non_trailing_rows = src_rows & ~kChunkedRowMask;
+    // Destination "rows" are padded to next highest multiple of Layout::kRows.
+    const int dst_rows = (src_rows + 3) & ~3;
+    const int trailing_rows = dst_rows - non_trailing_rows;
+    memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
+           Layout::kCols * trailing_rows * sizeof(std::int16_t));
+  }
+}
+
 void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
                                 int src_stride, int remaining_src_cols,
                                 int src_rows, float* packed_ptr) {
diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h
index f3ea54e..a28bbc9 100644
--- a/ruy/pack_x86.h
+++ b/ruy/pack_x86.h
@@ -16,6 +16,7 @@
 #ifndef RUY_RUY_PACK_X86_H_
 #define RUY_RUY_PACK_X86_H_
 
+#include <algorithm>
 #include <cstdint>
 #include <cstring>
 #include <type_traits>
@@ -271,6 +272,52 @@
   }
 };
 
+void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr,
+                                const std::int16_t* zerobuf, int src_stride,
+                                int remaining_src_cols, int src_rows,
+                                std::int16_t* packed_ptr,
+                                std::int32_t* sums_ptr);
+
+template <>
+struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+                std::int16_t, std::int16_t, std::int32_t, Order::kColMajor> {
+  using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+  static constexpr int kHalfLayoutCols =
+      8;  // Half the number of cols in a block.
+
+  static void Run(Tuning, const Mat<std::int16_t>& src_matrix,
+                  PMat<std::int16_t>* packed_matrix, int start_col,
+                  int end_col) {
+    profiler::ScopeLabel label("Pack (AVX-512 16-bit)");
+
+    RUY_DCHECK(IsColMajor(src_matrix.layout));
+    RUY_DCHECK(IsColMajor(packed_matrix->layout));
+    RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+    RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+    RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
+    std::int32_t* sums = packed_matrix->sums;
+    std::int16_t zerobuf[kHalfLayoutCols * Layout::kRows];
+    std::fill(zerobuf, zerobuf + kHalfLayoutCols * Layout::kRows,
+              static_cast<int16_t>(packed_matrix->zero_point));
+    for (int block_col = start_col; block_col < end_col;
+         block_col += Layout::kCols) {
+      std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+      int src_stride = src_matrix.layout.stride;
+      const std::int16_t* src_ptr =
+          src_matrix.data.get() + src_stride * block_col;
+      int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+      static constexpr int block_col_mask = ~(Layout::kCols - 1);
+      std::int16_t* packed_ptr =
+          packed_matrix->data +
+          packed_matrix->layout.stride * (block_col & block_col_mask);
+      Pack16bitColMajorForAvx512(src_ptr, zerobuf, src_stride,
+                                 remaining_src_cols, src_matrix.layout.rows,
+                                 packed_ptr, sums_ptr);
+    }
+  }
+};
+
 void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
                                 int src_stride, int remaining_src_cols,
                                 int src_rows, float* packed_ptr);
diff --git a/ruy/platform.h b/ruy/platform.h
index eb51931..9b67416 100644
--- a/ruy/platform.h
+++ b/ruy/platform.h
@@ -28,8 +28,11 @@
 // Detect APPLE.
 #ifdef __APPLE__
 #define RUY_PLATFORM_APPLE 1
+#include <TargetConditionals.h>
+#define RUY_PLATFORM_APPLE_IPHONE_SIMULATOR TARGET_IPHONE_SIMULATOR
 #else
 #define RUY_PLATFORM_APPLE 0
+#define RUY_PLATFORM_APPLE_IPHONE_SIMULATOR 0
 #endif
 
 // Detect APPLE.
@@ -108,11 +111,11 @@
 // Enable on sufficiently recent Android NDK. Earlier versions had broken
 // intrinsics headers.
 #define RUY_PLATFORM_X86_ENHANCEMENTS 1
-#elif defined(__linux__) && defined(__clang__) && (__clang_major__ >= 8)
-// Enable on recent versions of Clang on Linux. Might be possible
+#elif ((RUY_PLATFORM_APPLE && !RUY_PLATFORM_APPLE_IPHONE_SIMULATOR) || \
+       defined(__linux__)) &&                                          \
+    defined(__clang__) && (__clang_major__ >= 8)
+// Enable on recent versions of Clang. Might be possible
 // to relax this version requirement.
-// Not enabling on Apple at the moment because b/138922878, see comment #8, we
-// may only need to disable this on XCode <= 10.2.
 #define RUY_PLATFORM_X86_ENHANCEMENTS 1
 #elif defined(__GNUC__) && (__GNUC__ >= 9)
 // Enable on recent versions of GCC. Might be possible
diff --git a/ruy/profiler/BUILD b/ruy/profiler/BUILD
index 64754bf..7ec8e5f 100644
--- a/ruy/profiler/BUILD
+++ b/ruy/profiler/BUILD
@@ -3,6 +3,7 @@
 load("//ruy:build_defs.oss.bzl", "ruy_linkopts_thread_standard_library")
 
 package(
+    default_applicable_licenses = ["//third_party/ruy:license"],
     licenses = ["notice"],  # Apache 2.0
 )
 
diff --git a/ruy/strategy_controls.h b/ruy/strategy_controls.h
new file mode 100644
index 0000000..629c2b8
--- /dev/null
+++ b/ruy/strategy_controls.h
@@ -0,0 +1,34 @@
+/* Copyright 2022 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_STRATEGY_CONTROLS_H_
+#define RUY_RUY_STRATEGY_CONTROLS_H_
+
+#include <cstdint>
+
+namespace ruy {
+
+enum class NumThreadsStrategy : std::uint8_t {
+  // kDefault means using smart heuristic logic that has been optimized
+  // for cubic ColxRowxDepth matrix multiplication.
+  kDefault,
+  // kForceMaxNumThreads means using ctx->max_num_thread()
+  // for multi-thread computing.
+  kForceMaxNumThreads
+};
+
+}  // namespace ruy
+
+#endif  // RUY_RUY_STRATEGY_CONTROLS_H_
diff --git a/ruy/test.h b/ruy/test.h
index 0b05399..6517519 100644
--- a/ruy/test.h
+++ b/ruy/test.h
@@ -1063,19 +1063,17 @@
                                                         : dst->layout().rows());
   using DimPair =
       typename Eigen::Tensor<Scalar, 1, 0, Eigen::Index>::DimensionPair;
-  Eigen::array<DimPair, 1> contract_dims(
+  Eigen::array<DimPair, 1> contract_dims{
       {DimPair((LhsOrder == Order::kColMajor) ? 1 : 0,
-               (RhsOrder == Order::kColMajor) ? 0 : 1)});
-  Eigen::array<int, 2> shuffle(DstOrder == Order::kColMajor ? 0 : 1,
-                               DstOrder == Order::kColMajor ? 1 : 0);
+               (RhsOrder == Order::kColMajor) ? 0 : 1)}};
   static Eigen::ThreadPool pool(max_num_threads ? max_num_threads : 1);
   static Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
   if (mul_params.bias()) {
     TensorBiasType tensor_bias(mul_params.bias(), dst->layout().rows());
-    Eigen::array<int, 2> bias_2d_shape(tr ? 1 : dst->layout().rows(),
-                                       tr ? dst->layout().rows() : 1);
-    Eigen::array<int, 2> bcast(tr ? dst->layout().cols() : 1,
-                               tr ? 1 : dst->layout().cols());
+    Eigen::array<int, 2> bias_2d_shape{tr ? 1 : dst->layout().rows(),
+                                       tr ? dst->layout().rows() : 1};
+    Eigen::array<int, 2> bcast{tr ? dst->layout().cols() : 1,
+                               tr ? 1 : dst->layout().cols()};
     if (mul_params.clamp_max() == std::numeric_limits<Scalar>::infinity() &&
         mul_params.clamp_min() == -std::numeric_limits<Scalar>::infinity()) {
       tensor_dst.device(device) =
@@ -1715,6 +1713,16 @@
           typename DstScalar>
 void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::MakeZeroPoints() {
   RUY_CHECK_EQ(life_stage, LifeStage::kInitial);
+  if (std::is_same<LhsScalar, std::int16_t>::value ||
+      std::is_same<RhsScalar, std::int16_t>::value) {
+    // For now, support for int16 source types is limited to the
+    // symmetric case (zero_point==0) because that appears to be
+    // the case in the initial use cases, and that limits complexity
+    // in thinking about accumulator overflows.
+    // Setting use_specified_zero_points causes the default values 0 to be
+    // used unless explicitly overridden.
+    use_specified_zero_points = true;
+  }
   if (!benchmark && !use_specified_zero_points) {
     MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &lhs_zero_point);
     MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &rhs_zero_point);
@@ -1847,6 +1855,12 @@
     paths_bitfield = get_ctx(&context)->GetRuntimeEnabledPaths();
   }
 
+  // Disable the internal test-only variants of the StandardCpp path in
+  // benchmarks
+  if (benchmark) {
+    paths_bitfield = paths_bitfield & kAllPaths;
+  }
+
   // Disable the internal test-only variants of the StandardCpp path on large
   // tests.
   // This constant be large enough to exercise some interesting BlockMap logic,
diff --git a/ruy/test_overflow_dst_zero_point.cc b/ruy/test_overflow_dst_zero_point.cc
index db1f08d..96ee38c 100644
--- a/ruy/test_overflow_dst_zero_point.cc
+++ b/ruy/test_overflow_dst_zero_point.cc
@@ -58,7 +58,7 @@
                                      ? std::numeric_limits<DstScalar>::max()
                                      : std::numeric_limits<DstScalar>::min();
 
-  const std::vector<const std::int8_t> lhs_data(1, 0);
+  const std::vector<std::int8_t> lhs_data(1, 0);
   const std::vector<std::int8_t> rhs_data(cols, 0);
   std::vector<DstScalar> dst_data(cols, 0);
 
diff --git a/ruy/thread_pool.cc b/ruy/thread_pool.cc
index 5f22a13..2e2ca2c 100644
--- a/ruy/thread_pool.cc
+++ b/ruy/thread_pool.cc
@@ -34,135 +34,178 @@
 // A worker thread.
 class Thread {
  public:
+  explicit Thread(BlockingCounter* count_busy_threads, Duration spin_duration)
+      : state_(State::Startup),
+        count_busy_threads_(count_busy_threads),
+        spin_duration_(spin_duration) {
+    thread_.reset(new std::thread(ThreadFunc, this));
+  }
+
+  void RequestExitAsSoonAsPossible() {
+    ChangeStateFromOutsideThread(State::ExitAsSoonAsPossible);
+  }
+
+  ~Thread() {
+    RUY_DCHECK_EQ(state_.load(), State::ExitAsSoonAsPossible);
+    thread_->join();
+  }
+
+  // Called by an outside thead to give work to the worker thread.
+  void StartWork(Task* task) {
+    ChangeStateFromOutsideThread(State::HasWork, task);
+  }
+
+ private:
   enum class State {
-    Startup,  // The initial state before the thread main loop runs.
+    Startup,  // The initial state before the thread loop runs.
     Ready,    // Is not working, has not yet received new work to do.
     HasWork,  // Has work to do.
     ExitAsSoonAsPossible  // Should exit at earliest convenience.
   };
 
-  explicit Thread(BlockingCounter* counter_to_decrement_when_ready,
-                  Duration spin_duration)
-      : task_(nullptr),
-        state_(State::Startup),
-        counter_to_decrement_when_ready_(counter_to_decrement_when_ready),
-        spin_duration_(spin_duration) {
-    thread_.reset(new std::thread(ThreadFunc, this));
+  // Implements the state_ change to State::Ready, which is where we consume
+  // task_. Only called on the worker thread.
+  // Reads task_, so assumes ordering past any prior writes to task_.
+  void RevertToReadyState() {
+    RUY_TRACE_SCOPE_NAME("Worker thread task");
+    // See task_ member comment for the ordering of accesses.
+    if (task_) {
+      task_->Run();
+      task_ = nullptr;
+    }
+    // No need to notify state_cond_, since only the worker thread ever waits
+    // on it, and we are that thread.
+    // Relaxed order because ordering is already provided by the
+    // count_busy_threads_->DecrementCount() at the next line, since the next
+    // state_ mutation will be to give new work and that won't happen before
+    // the outside thread has finished the current batch with a
+    // count_busy_threads_->Wait().
+    state_.store(State::Ready, std::memory_order_relaxed);
+    count_busy_threads_->DecrementCount();
   }
 
-  ~Thread() {
-    ChangeState(State::ExitAsSoonAsPossible);
-    thread_->join();
-  }
-
-  // Changes State; may be called from either the worker thread
-  // or the master thread; however, not all state transitions are legal,
-  // which is guarded by assertions.
+  // Changes State, from outside thread.
   //
   // The Task argument is to be used only with new_state==HasWork.
   // It specifies the Task being handed to this Thread.
-  void ChangeState(State new_state, Task* task = nullptr) {
-    state_mutex_.lock();
-    State old_state = state_.load(std::memory_order_relaxed);
+  //
+  // new_task is only used with State::HasWork.
+  void ChangeStateFromOutsideThread(State new_state, Task* new_task = nullptr) {
+    RUY_DCHECK(new_state == State::ExitAsSoonAsPossible ||
+               new_state == State::HasWork);
+    RUY_DCHECK((new_task != nullptr) == (new_state == State::HasWork));
+
+#ifndef NDEBUG
+    // Debug-only sanity checks based on old_state.
+    State old_state = state_.load();
     RUY_DCHECK_NE(old_state, new_state);
-    switch (old_state) {
-      case State::Startup:
-        RUY_DCHECK_EQ(new_state, State::Ready);
-        break;
-      case State::Ready:
-        RUY_DCHECK(new_state == State::HasWork ||
-                   new_state == State::ExitAsSoonAsPossible);
-        break;
+    RUY_DCHECK(old_state == State::Ready || old_state == State::HasWork);
+    RUY_DCHECK_NE(old_state, new_state);
+#endif
+
+    switch (new_state) {
       case State::HasWork:
-        RUY_DCHECK(new_state == State::Ready ||
-                   new_state == State::ExitAsSoonAsPossible);
+        // See task_ member comment for the ordering of accesses.
+        RUY_DCHECK(!task_);
+        task_ = new_task;
+        break;
+      case State::ExitAsSoonAsPossible:
         break;
       default:
         abort();
     }
-    switch (new_state) {
-      case State::Ready:
-        if (task_) {
-          // Doing work is part of reverting to 'ready' state.
-          task_->Run();
-          task_ = nullptr;
-        }
-        break;
-      case State::HasWork:
-        RUY_DCHECK(!task_);
-        task_ = task;
-        break;
-      default:
-        break;
-    }
-    state_.store(new_state, std::memory_order_relaxed);
-    state_cond_.notify_all();
-    state_mutex_.unlock();
-    if (new_state == State::Ready) {
-      counter_to_decrement_when_ready_->DecrementCount();
-    }
+    // Release order because the worker thread will read this with acquire
+    // order.
+    state_.store(new_state, std::memory_order_release);
+    state_cond_mutex_.lock();
+    state_cond_.notify_one();  // Only this one worker thread cares.
+    state_cond_mutex_.unlock();
   }
 
   static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); }
 
-  // Called by the master thead to give this thread work to do.
-  void StartWork(Task* task) { ChangeState(State::HasWork, task); }
+  // Waits for state_ to be different from State::Ready, and returns that
+  // new value.
+  State GetNewStateOtherThanReady() {
+    State new_state;
+    const auto& new_state_not_ready = [this, &new_state]() {
+      new_state = state_.load(std::memory_order_acquire);
+      return new_state != State::Ready;
+    };
+    RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING);
+    Wait(new_state_not_ready, spin_duration_, &state_cond_, &state_cond_mutex_);
+    return new_state;
+  }
 
- private:
   // Thread entry point.
   void ThreadFuncImpl() {
     RUY_TRACE_SCOPE_NAME("Ruy worker thread function");
-    ChangeState(State::Ready);
+    RevertToReadyState();
 
     // Suppress denormals to avoid computation inefficiency.
     ScopedSuppressDenormals suppress_denormals;
 
-    // Thread main loop
-    while (true) {
-      RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration");
-      // In the 'Ready' state, we have nothing to do but to wait until
-      // we switch to another state.
-      const auto& condition = [this]() {
-        return state_.load(std::memory_order_acquire) != State::Ready;
-      };
-      RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING);
-      Wait(condition, spin_duration_, &state_cond_, &state_mutex_);
-
-      // Act on new state.
-      switch (state_.load(std::memory_order_acquire)) {
-        case State::HasWork: {
-          RUY_TRACE_SCOPE_NAME("Worker thread task");
-          // Got work to do! So do it, and then revert to 'Ready' state.
-          ChangeState(State::Ready);
-          break;
-        }
-        case State::ExitAsSoonAsPossible:
-          return;
-        default:
-          abort();
-      }
+    // Thread loop
+    while (GetNewStateOtherThanReady() == State::HasWork) {
+      RevertToReadyState();
     }
+
+    // Thread end. We should only get here if we were told to exit.
+    RUY_DCHECK(state_.load() == State::ExitAsSoonAsPossible);
   }
 
-  // The underlying thread.
+  // The underlying thread. Used to join on destruction.
   std::unique_ptr<std::thread> thread_;
 
   // The task to be worked on.
-  Task* task_;
+  //
+  // The ordering of reads and writes to task_ is as follows.
+  //
+  // 1. The outside thread gives new work by calling
+  //      ChangeStateFromOutsideThread(State::HasWork, new_task);
+  //    That does:
+  //    - a. Write task_ = new_task (non-atomic).
+  //    - b. Store state_ = State::HasWork (memory_order_release).
+  // 2. The worker thread picks up the new state by calling
+  //      GetNewStateOtherThanReady()
+  //    That does:
+  //    - c. Load state (memory_order_acquire).
+  //    The worker thread then reads the new task in RevertToReadyState().
+  //    That does:
+  //    - d. Read task_ (non-atomic).
+  // 3. The worker thread, still in RevertToReadyState(), consumes the task_ and
+  //    does:
+  //    - e. Write task_ = nullptr (non-atomic).
+  //    And then calls Call count_busy_threads_->DecrementCount()
+  //    which does
+  //    - f. Store count_busy_threads_ (memory_order_release).
+  // 4. The outside thread, in ThreadPool::ExecuteImpl, finally waits for worker
+  //    threads by calling count_busy_threads_->Wait(), which does:
+  //    - g. Load count_busy_threads_ (memory_order_acquire).
+  //
+  // Thus the non-atomic write-then-read accesses to task_ (a. -> d.) are
+  // ordered by the release-acquire relationship of accesses to state_ (b. ->
+  // c.), and the non-atomic write accesses to task_ (e. -> a.) are ordered by
+  // the release-acquire relationship of accesses to count_busy_threads_ (f. ->
+  // g.).
+  Task* task_ = nullptr;
 
-  // The condition variable and mutex guarding state changes.
+  // Condition variable used by the outside thread to notify the worker thread
+  // of a state change.
   std::condition_variable state_cond_;
-  std::mutex state_mutex_;
+
+  // Mutex used to guard state_cond_
+  std::mutex state_cond_mutex_;
 
   // The state enum tells if we're currently working, waiting for work, etc.
-  // Its concurrent accesses by the thread and main threads are guarded by
-  // state_mutex_, and can thus use memory_order_relaxed. This still needs
-  // to be a std::atomic because we use WaitForVariableChange.
+  // It is written to from either the outside thread or the worker thread,
+  // in the ChangeState method.
+  // It is only ever read by the worker thread.
   std::atomic<State> state_;
 
   // pointer to the master's thread BlockingCounter object, to notify the
   // master thread of when this thread switches to the 'Ready' state.
-  BlockingCounter* const counter_to_decrement_when_ready_;
+  BlockingCounter* const count_busy_threads_;
 
   // See ThreadPool::spin_duration_.
   const Duration spin_duration_;
@@ -180,7 +223,7 @@
 
   // Task #0 will be run on the current thread.
   CreateThreads(task_count - 1);
-  counter_to_decrement_when_ready_.Reset(task_count - 1);
+  count_busy_threads_.Reset(task_count - 1);
   for (int i = 1; i < task_count; i++) {
     RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK);
     auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
@@ -193,7 +236,7 @@
 
   RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS);
   // Wait for the threads submitted above to finish.
-  counter_to_decrement_when_ready_.Wait(spin_duration_);
+  count_busy_threads_.Wait(spin_duration_);
 }
 
 // Ensures that the pool has at least the given count of threads.
@@ -205,15 +248,18 @@
   if (threads_.size() >= unsigned_threads_count) {
     return;
   }
-  counter_to_decrement_when_ready_.Reset(threads_count - threads_.size());
+  count_busy_threads_.Reset(threads_count - threads_.size());
   while (threads_.size() < unsigned_threads_count) {
-    threads_.push_back(
-        new Thread(&counter_to_decrement_when_ready_, spin_duration_));
+    threads_.push_back(new Thread(&count_busy_threads_, spin_duration_));
   }
-  counter_to_decrement_when_ready_.Wait(spin_duration_);
+  count_busy_threads_.Wait(spin_duration_);
 }
 
 ThreadPool::~ThreadPool() {
+  // Send all exit requests upfront so threads can work on them in parallel.
+  for (auto w : threads_) {
+    w->RequestExitAsSoonAsPossible();
+  }
   for (auto w : threads_) {
     delete w;
   }
diff --git a/ruy/thread_pool.h b/ruy/thread_pool.h
index e3b6803..946be3d 100644
--- a/ruy/thread_pool.h
+++ b/ruy/thread_pool.h
@@ -98,12 +98,12 @@
   // copy construction disallowed
   ThreadPool(const ThreadPool&) = delete;
 
-  // The threads in this pool. They are owned by the pool:
+  // The worker threads in this pool. They are owned by the pool:
   // the pool creates threads and destroys them in its destructor.
   std::vector<Thread*> threads_;
 
   // The BlockingCounter used to wait for the threads.
-  BlockingCounter counter_to_decrement_when_ready_;
+  BlockingCounter count_busy_threads_;
 
   // This value was empirically derived with some microbenchmark, we don't have
   // high confidence in it.
diff --git a/ruy/trmul.cc b/ruy/trmul.cc
index 602660b..2ff519f 100644
--- a/ruy/trmul.cc
+++ b/ruy/trmul.cc
@@ -21,6 +21,7 @@
 #include <atomic>
 #include <cstdint>
 #include <cstring>
+#include <limits>
 #include <memory>
 #include <vector>
 
@@ -34,6 +35,7 @@
 #include "ruy/mat.h"
 #include "ruy/matrix.h"
 #include "ruy/mul_params.h"
+#include "ruy/strategy_controls.h"
 #include "ruy/opt_set.h"
 #include "ruy/profiler/instrumentation.h"
 #include "ruy/side_pair.h"
@@ -256,12 +258,28 @@
   // Empirically determined rule for reasonable number of
   // threads to use. This is proportional to the number of arithmetic ops
   // in this Mul (product of the 3 sizes).
-  static constexpr int kDivisorLog2 = 15;
-  const int guess_log2 = std::max(
-      0, ceil_log2(rows) + ceil_log2(cols) + ceil_log2(depth) - kDivisorLog2);
-  int tentative_thread_count =
-      std::min(1 << guess_log2, ctx->max_num_threads());
-  RUY_TRACE_INFO(GET_TENTATIVE_THREAD_COUNT);
+  // Be defensive here by explicitly promoting operands to int64 to avoid the
+  // pitfall of `int64 result = x * y;` overflowing as x and y are still narrow.
+  if (ctx->num_threads_strategy() == NumThreadsStrategy::kForceMaxNumThreads) {
+    return ctx->max_num_threads();
+  }
+  RUY_CHECK_EQ(ctx->num_threads_strategy(), NumThreadsStrategy::kDefault);
+  const std::int64_t rows_i64 = rows;
+  const std::int64_t cols_i64 = cols;
+  const std::int64_t depth_i64 = depth;
+  const std::int64_t problem_size = rows_i64 * cols_i64 * depth_i64;
+  // Division is cheap when the denominator is constant
+  static constexpr std::int64_t kSizePerAdditionalThread = 32768;
+  std::int64_t tentative_thread_count = problem_size / kSizePerAdditionalThread;
+  // tentative_thread_count is still an int64, still not necessarily in the
+  // range of type int. It probably is as long as kSizePerAdditionalThread is
+  // large, but imagine that that constant might change in the future.
+  tentative_thread_count = std::max<std::int64_t>(tentative_thread_count, 1);
+  tentative_thread_count =
+      std::min<std::int64_t>(tentative_thread_count, ctx->max_num_threads());
+  // now tentative_thread_count must be in the range of type int, because
+  // ctx->max_num_threads() is.
+  RUY_DCHECK_LE(tentative_thread_count, std::numeric_limits<int>::max());
   return tentative_thread_count;
 }
 
@@ -377,20 +395,22 @@
   // reservation granule.
   std::atomic<int>* atomic_block_id;
   main_allocator->Allocate(1, &atomic_block_id);
-
-  // Create task objects.
-  TrMulTask* tasks;
-  main_allocator->Allocate(thread_count, &tasks);
-
   atomic_block_id->store(thread_count);
 
+  // Create task objects. We allocate a single buffer and then use placement-new
+  // to construct N TrMulTask objects within it. To avoid having the Clang CFI
+  // sanitizer complain about a TrMulTask* pointer temporarily pointing to
+  // garbage, we keep the pointer a plain char* until finished constructing.
+  char* tasks_buf =
+      main_allocator->Allocate<char>(thread_count * sizeof(TrMulTask));
   for (int i = 0; i < thread_count; i++) {
     auto* allocator = ctx->GetThreadSpecificAllocator(i);
     auto* tuning_resolver = ctx->GetThreadSpecificTuningResolver(i);
-    new (tasks + i) TrMulTask(params, block_map, atomic_block_id, i,
-                              need_atomics, packing_status, tuning_resolver,
-                              allocator, ctx->mutable_cpuinfo());
+    new (tasks_buf + i * sizeof(TrMulTask)) TrMulTask(
+        params, block_map, atomic_block_id, i, need_atomics, packing_status,
+        tuning_resolver, allocator, ctx->mutable_cpuinfo());
   }
+  TrMulTask* tasks = reinterpret_cast<TrMulTask*>(tasks_buf);
 
   // Do the computation.
   ctx->mutable_thread_pool()->Execute(thread_count, tasks);
diff --git a/ruy/validate.h b/ruy/validate.h
index b164530..c19cf67 100644
--- a/ruy/validate.h
+++ b/ruy/validate.h
@@ -44,6 +44,18 @@
   CheckZeroPoint(rhs_zero_point);
   CheckZeroPoint(dst_zero_point);
 
+  // For now, support for int16 source types is limited to the
+  // symmetric case (zero_point==0) because that appears to be
+  // the case in the initial use cases, and that limits complexity
+  // in thinking about accumulator overflows.
+  const bool has_16bit_input = std::is_same<LhsScalar, std::int16_t>::value ||
+                               std::is_same<RhsScalar, std::int16_t>::value;
+  if (has_16bit_input) {
+    RUY_DCHECK(!lhs_zero_point);
+    RUY_DCHECK(!rhs_zero_point);
+    RUY_DCHECK(!dst_zero_point);
+  }
+
   // Guard against the case when both LHS and RHS zero_point's are equal to
   // the minimum representable value. In that case, padding with zero_point
   // values will generate the bad case for fast int8 kernels on NEON