Create a utility library to suppress floating-point denormals, and apply it to every task execution of every thread.

PiperOrigin-RevId: 366919663
diff --git a/ruy/BUILD b/ruy/BUILD
index 37e89ab..b16f161 100644
--- a/ruy/BUILD
+++ b/ruy/BUILD
@@ -357,6 +357,7 @@
     deps = [
         ":blocking_counter",
         ":check_macros",
+        ":denormal",
         ":time",
         ":trace",
         ":wait",
@@ -420,6 +421,14 @@
 )
 
 cc_library(
+    name = "denormal",
+    srcs = ["denormal.cc"],
+    hdrs = ["denormal.h"],
+    copts = ruy_copts(),
+    visibility = ["//visibility:public"],
+)
+
+cc_library(
     name = "performance_advisory",
     hdrs = ["performance_advisory.h"],
     copts = ruy_copts(),
@@ -956,6 +965,7 @@
         ":cpu_cache_params",
         ":cpuinfo",
         ":ctx",
+        ":denormal",
         ":mat",
         ":matrix",
         ":mul_params",
diff --git a/ruy/CMakeLists.txt b/ruy/CMakeLists.txt
index 4c3e394..b83bc8c 100644
--- a/ruy/CMakeLists.txt
+++ b/ruy/CMakeLists.txt
@@ -376,6 +376,7 @@
   DEPS
     ruy_blocking_counter
     ruy_check_macros
+    ruy_denormal
     ruy_time
     ruy_trace
     ruy_wait
@@ -455,6 +456,20 @@
 
 ruy_cc_library(
   NAME
+    ruy_denormal
+  SRCS
+    denormal.cc
+  HDRS
+    denormal.h
+  COPTS
+    ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+    ${ruy_1_mfpu_neon}
+    ${ruy_2_O3}
+  PUBLIC
+)
+
+ruy_cc_library(
+  NAME
     ruy_performance_advisory
   HDRS
     performance_advisory.h
@@ -1102,6 +1117,7 @@
     ruy_cpu_cache_params
     ruy_cpuinfo
     ruy_ctx
+    ruy_denormal
     ruy_mat
     ruy_matrix
     ruy_mul_params
diff --git a/ruy/denormal.cc b/ruy/denormal.cc
new file mode 100644
index 0000000..b3c0850
--- /dev/null
+++ b/ruy/denormal.cc
@@ -0,0 +1,121 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/denormal.h"
+
+// NOTE: this is simply a copy of pthreadpool/src/threadpool-utils.h that's not
+// exposed by the pthreadpool library
+// (https://github.com/Maratyszcza/pthreadpool), but with an additional C++
+// helper class to suppress floating-point denormal values.
+
+/* SSE-specific headers */
+#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
+    (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
+#include <xmmintrin.h>
+#endif
+
+/* MSVC-specific headers */
+#if defined(_MSC_VER)
+#include <intrin.h>
+#endif
+
+namespace ruy {
+namespace {
+inline struct fpu_state get_fpu_state() {
+  struct fpu_state state = {};
+#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
+    (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
+  state.mxcsr = static_cast<std::uint32_t>(_mm_getcsr());
+#elif defined(_MSC_VER) && defined(_M_ARM)
+  state.fpscr =
+      static_cast<std::uint32_t>(_MoveFromCoprocessor(10, 7, 1, 0, 0));
+#elif defined(_MSC_VER) && defined(_M_ARM64)
+  state.fpcr = static_cast<std::uint64_t>(_ReadStatusReg(0x5A20));
+#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \
+    (__ARM_FP != 0)
+  __asm__ __volatile__("VMRS %[fpscr], fpscr" : [fpscr] "=r"(state.fpscr));
+#elif defined(__GNUC__) && defined(__aarch64__)
+  __asm__ __volatile__("MRS %[fpcr], fpcr" : [fpcr] "=r"(state.fpcr));
+#endif
+  return state;
+}
+
+inline void set_fpu_state(const struct fpu_state state) {
+#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
+    (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
+  _mm_setcsr(static_cast<unsigned int>(state.mxcsr));
+#elif defined(_MSC_VER) && defined(_M_ARM)
+  _MoveToCoprocessor(static_cast<int>(state.fpscr, 10, 7, 1, 0, 0));
+#elif defined(_MSC_VER) && defined(_M_ARM64)
+  _WriteStatusReg(0x5A20, static_cast<__int64>(state.fpcr));
+#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \
+    (__ARM_FP != 0)
+  __asm__ __volatile__("VMSR fpscr, %[fpscr]" : : [fpscr] "r"(state.fpscr));
+#elif defined(__GNUC__) && defined(__aarch64__)
+  __asm__ __volatile__("MSR fpcr, %[fpcr]" : : [fpcr] "r"(state.fpcr));
+#else
+  (void)state;
+#endif
+}
+
+inline void disable_fpu_denormals() {
+#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
+    (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
+  _mm_setcsr(_mm_getcsr() | 0x8040);
+#elif defined(_MSC_VER) && defined(_M_ARM)
+  int fpscr = _MoveFromCoprocessor(10, 7, 1, 0, 0);
+  fpscr |= 0x1000000;
+  _MoveToCoprocessor(fpscr, 10, 7, 1, 0, 0);
+#elif defined(_MSC_VER) && defined(_M_ARM64)
+  __int64 fpcr = _ReadStatusReg(0x5A20);
+  fpcr |= 0x1080000;
+  _WriteStatusReg(0x5A20, fpcr);
+#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \
+    (__ARM_FP != 0)
+  std::uint32_t fpscr;
+#if defined(__thumb__) && !defined(__thumb2__)
+  __asm__ __volatile__(
+      "VMRS %[fpscr], fpscr\n"
+      "ORRS %[fpscr], %[bitmask]\n"
+      "VMSR fpscr, %[fpscr]\n"
+      : [fpscr] "=l"(fpscr)
+      : [bitmask] "l"(0x1000000)
+      : "cc");
+#else
+  __asm__ __volatile__(
+      "VMRS %[fpscr], fpscr\n"
+      "ORR %[fpscr], #0x1000000\n"
+      "VMSR fpscr, %[fpscr]\n"
+      : [fpscr] "=r"(fpscr));
+#endif
+#elif defined(__GNUC__) && defined(__aarch64__)
+  std::uint64_t fpcr;
+  __asm__ __volatile__(
+      "MRS %[fpcr], fpcr\n"
+      "ORR %w[fpcr], %w[fpcr], 0x1000000\n"
+      "ORR %w[fpcr], %w[fpcr], 0x80000\n"
+      "MSR fpcr, %[fpcr]\n"
+      : [fpcr] "=r"(fpcr));
+#endif
+}
+}  // namespace
+
+ScopedSuppressDenormals::ScopedSuppressDenormals() {
+  restore_ = get_fpu_state();
+  disable_fpu_denormals();
+}
+
+ScopedSuppressDenormals::~ScopedSuppressDenormals() { set_fpu_state(restore_); }
+}  // namespace ruy
diff --git a/ruy/denormal.h b/ruy/denormal.h
new file mode 100644
index 0000000..e5b836c
--- /dev/null
+++ b/ruy/denormal.h
@@ -0,0 +1,53 @@
+/* Copyright 2021 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef RUY_RUY_DENORMAL_H_
+#define RUY_RUY_DENORMAL_H_
+
+#include <cstdint>
+
+namespace ruy {
+// NOTE: the following 'fpu_state' struct is copied from
+// pthreadpool/src/threadpool-utils.h that's not exposed by the pthreadpool
+// library (https://github.com/Maratyszcza/pthreadpool).
+struct fpu_state {
+#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
+    (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
+  std::uint32_t mxcsr;
+#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \
+        (__ARM_FP != 0) ||                                          \
+    defined(_MSC_VER) && defined(_M_ARM)
+  std::uint32_t fpscr;
+#elif defined(__GNUC__) && defined(__aarch64__) || \
+    defined(_MSC_VER) && defined(_M_ARM64)
+  std::uint64_t fpcr;
+#endif
+};
+
+// While this class is active, denormal floating point numbers are suppressed.
+// The destructor restores the original flags.
+class ScopedSuppressDenormals {
+ public:
+  ScopedSuppressDenormals();
+  ~ScopedSuppressDenormals();
+
+ private:
+  fpu_state restore_;
+
+  ScopedSuppressDenormals(const ScopedSuppressDenormals&) = delete;
+  void operator=(const ScopedSuppressDenormals&) = delete;
+};
+}  // namespace ruy
+
+#endif  // RUY_RUY_DENORMAL_H_
diff --git a/ruy/thread_pool.cc b/ruy/thread_pool.cc
index 100cfe3..5f22a13 100644
--- a/ruy/thread_pool.cc
+++ b/ruy/thread_pool.cc
@@ -25,6 +25,7 @@
 #include <thread>  // NOLINT(build/c++11)
 
 #include "ruy/check_macros.h"
+#include "ruy/denormal.h"
 #include "ruy/trace.h"
 #include "ruy/wait.h"
 
@@ -113,6 +114,9 @@
     RUY_TRACE_SCOPE_NAME("Ruy worker thread function");
     ChangeState(State::Ready);
 
+    // Suppress denormals to avoid computation inefficiency.
+    ScopedSuppressDenormals suppress_denormals;
+
     // Thread main loop
     while (true) {
       RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration");
diff --git a/ruy/trmul.cc b/ruy/trmul.cc
index 9345f0c..602660b 100644
--- a/ruy/trmul.cc
+++ b/ruy/trmul.cc
@@ -30,6 +30,7 @@
 #include "ruy/cpu_cache_params.h"
 #include "ruy/cpuinfo.h"
 #include "ruy/ctx.h"
+#include "ruy/denormal.h"
 #include "ruy/mat.h"
 #include "ruy/matrix.h"
 #include "ruy/mul_params.h"
@@ -307,6 +308,12 @@
       GetTentativeThreadCount(ctx, rows, cols, depth);
   const auto& cpu_cache_params = ctx->mutable_cpuinfo()->CacheParams();
 
+  // Suppress denormals to avoid computation inefficiency.
+  // Note this only handles the denormal suppression on the main thread. As for
+  // worker threads, the suppression is handled in each thread's main loop. See
+  // the corresponding code in thread_pool.cc for details.
+  ScopedSuppressDenormals suppress_denormals;
+
   // Case of running this TrMul as a simple loop.
   // This is a good place to start reading this function: all the rest
   // of this function is just an optimized, but functionally equivalent,