Optimize sparse.mm reduce in BFloat16 data type in CPU backend (#103239)

### Description

This PR is to optimize sparse.mm reduce of BFloat16 data type in CPU backend, which is one task in https://github.com/pyg-team/pytorch_geometric/issues/7057. Half support (need support addmm Half implementation) will be done once https://github.com/pytorch/pytorch/pull/99498 upstream.

Next step:
- [x] Add benchmarks
- [x] Update UTs
- [x] Check backward behaviors
- [x] Refactor code

### Performance test (Updated)

Test BFloat16 in Intel(R) Xeon(R) Platinum 8380 CPU @ 2.30GHz
With jemalloc and iomp

Single socket (40C)
![image](https://github.com/pytorch/pytorch/assets/61222868/509e8482-9160-4b85-bc39-5b6aad510283)

Single core
![image](https://github.com/pytorch/pytorch/assets/61222868/c953a494-8f8e-4dbd-a8a7-421d8c22e946)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103239
Approved by: https://github.com/mingfeima, https://github.com/albanD
diff --git a/aten/src/ATen/native/cpu/ReduceUtils.h b/aten/src/ATen/native/cpu/ReduceUtils.h
index 96cb1c9..c54dc49 100644
--- a/aten/src/ATen/native/cpu/ReduceUtils.h
+++ b/aten/src/ATen/native/cpu/ReduceUtils.h
@@ -7,6 +7,8 @@
 #include <ATen/native/ReductionType.h>
 #include <c10/util/irange.h>
 #include <ATen/OpMathType.h>
+#include <ATen/native/cpu/utils.h>
+#include <ATen/OpMathType.h>
 
 namespace at::native {
 inline namespace CPU_CAPABILITY {
@@ -104,7 +106,8 @@
 }
 
 template <typename scalar_t>
-inline scalar_t _max(const scalar_t& x, const scalar_t& y) {
+inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
+_max(const scalar_t& x, const scalar_t& y) {
   return at::_isnan(y) ? y : std::max(x, y);
 }
 
@@ -114,8 +117,16 @@
   return vec::maximum(x, y);
 }
 
+template <typename vec_t>
+inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
+_max(const vec_t& x, const vec_t& y) {
+  // vec::maximum propagates NaN
+  return maximum(x, y);
+}
+
 template <typename scalar_t>
-inline scalar_t _min(const scalar_t& x, const scalar_t& y) {
+inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
+_min(const scalar_t& x, const scalar_t& y) {
   return at::_isnan(y) ? y : std::min(x, y);
 }
 
@@ -125,6 +136,13 @@
   return vec::minimum(x, y);
 }
 
+template <typename vec_t>
+inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
+_min(const vec_t& x, const vec_t& y) {
+  // vec::minimum propagates NaN
+  return minimum(x, y);
+}
+
 template <typename scalar_t, typename accumut, typename Op,
           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
 inline void map_acc(
diff --git a/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp b/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp
index 6f0cfe0..d9aa9a3 100644
--- a/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp
+++ b/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp
@@ -9,12 +9,14 @@
 #include <ATen/native/cpu/ReduceUtils.h>
 #include <ATen/native/cpu/utils.h>
 #include <c10/util/irange.h>
+#include <ATen/OpMathType.h>
 
 #ifndef AT_PER_OPERATOR_HEADERS
 #include <ATen/Functions.h>
 #else
 #include <ATen/ops/empty.h>
 #include <ATen/ops/empty_native.h>
+#include <ATen/ops/zeros.h>
 #endif
 
 namespace at { namespace native {
@@ -22,6 +24,46 @@
 namespace {
 
 template <typename scalar_t, typename index_t, ReductionType reduce>
+inline void _update(at::opmath_type<scalar_t>* out_ptr, int64_t e, int64_t c, const scalar_t val, scalar_t* other_data, int64_t K) {
+  using opmath_t = at::opmath_type<scalar_t>;
+  using Vec = vec::Vectorized<scalar_t>;
+  using aVec = VecType<scalar_t>;
+  constexpr int64_t kVecSize = Vec::size();
+  constexpr int64_t kVLEN = kVecSize * 4;
+
+  int64_t k = 0;
+  aVec val_vec = aVec((opmath_t)val);
+  scalar_t* other_ptr = other_data + c * K;
+
+  for (; k < K - (K % kVLEN); k += kVLEN) {
+    aVec out_vec0 = aVec::loadu(out_ptr + k);
+    aVec out_vec1 = aVec::loadu(out_ptr + k + kVecSize);
+    aVec out_vec2 = aVec::loadu(out_ptr + k + kVecSize * 2);
+    aVec out_vec3 = aVec::loadu(out_ptr + k + kVecSize * 3);
+
+    out_vec0 = update<aVec, reduce>(out_vec0, aVec::loadu(other_ptr + k) * val_vec);
+    out_vec1 = update<aVec, reduce>(out_vec1, aVec::loadu(other_ptr + k + kVecSize) * val_vec);
+    out_vec2 = update<aVec, reduce>(out_vec2, aVec::loadu(other_ptr + k + kVecSize * 2) * val_vec);
+    out_vec3 = update<aVec, reduce>(out_vec3, aVec::loadu(other_ptr + k + kVecSize * 3) * val_vec);
+
+    out_vec0.store(out_ptr + k);
+    out_vec1.store(out_ptr + k + kVecSize);
+    out_vec2.store(out_ptr + k + kVecSize * 2);
+    out_vec3.store(out_ptr + k + kVecSize * 3);
+  }
+  for (; k < K - (K % kVecSize); k += kVecSize) {
+    aVec out_vec = aVec::loadu(out_ptr + k);
+    out_vec = update<aVec, reduce>(out_vec, aVec::loadu(other_ptr + k) * val_vec);
+    out_vec.store(out_ptr + k);
+  }
+  for (; k < K; k++) {
+    opmath_t out_val = opmath_t(out_ptr[k]);
+    out_val = update<opmath_t, reduce>(out_val, opmath_t(other_ptr[k]) * opmath_t(val));
+    out_ptr[k] = out_val;
+  }
+}
+
+template <typename scalar_t, typename index_t, ReductionType reduce>
 void spmm_reduce_kernel_impl(
     const Tensor& out,
     const Tensor& crow_indices,
@@ -46,69 +88,54 @@
   int64_t M = crow_indices.numel() - 1;
   int64_t K = other.size(-1);
 
-  using Vec = vec::Vectorized<scalar_t>;
+  int num_threads = at::get_num_threads();
+  using opmath_t = at::opmath_type<scalar_t>;
+  Tensor buffer;
+  opmath_t* buffer_data = nullptr;
+  static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
+  if constexpr (need_acc) {
+    auto acc_type = at::toAccumulateType(out.scalar_type(), /*is_cuda=*/true);
+    buffer = at::zeros({num_threads, K}, out.options().dtype(acc_type));
+    buffer_data = buffer.data_ptr<opmath_t>();
+  }
+
   utils::parallel_sparse_csr(csr_data, M, nnz, [&](int64_t begin, int64_t end) {
-    int64_t row_start, row_end, c;
+    int tid = at::get_thread_num();
+    TORCH_CHECK(tid < num_threads,
+                "expect thread id smaller than ", num_threads, ", got thread id ", tid);
+    opmath_t* buffer_ptr = nullptr;
+
+    int64_t row_start, row_end;
     for (const auto m : c10::irange(begin, end)) {
       row_start = csr_data[m];
       row_end = csr_data[m + 1];
 
       scalar_t* out_ptr = out_data + m * K;
-
-      constexpr int64_t kVecSize = Vec::size();
-      constexpr int64_t kVLEN = kVecSize * 4;
-      constexpr int64_t CHUNK_SIZE = 16;
+      if constexpr (need_acc) {
+        buffer_ptr = buffer_data + tid * K;
+      } else {
+        buffer_ptr = reinterpret_cast<opmath_t*>(out_ptr);
+      }
 
       // step 1: reinit the output row for reduce type 'amax' and 'amin'
       int64_t count = row_end - row_start;
       if (count != 0) {
-        init<scalar_t, reduce>(out_ptr, K, /*include_self*/false);
+        _init<scalar_t, reduce>(out_ptr, buffer_ptr, K, /*include_self*/false);
       }
 
       // step 2: reduce, do blocking on rowwise to reduce write memory bandwidth
+      constexpr int64_t CHUNK_SIZE = 16;
       for (int64_t e0 = row_start; e0 < row_end; e0 += CHUNK_SIZE) {
         int64_t e1 = std::min(e0 + CHUNK_SIZE, row_end);
-
-        int64_t k = 0;
-        for (; k < K - (K % kVLEN); k += kVLEN) {
-          Vec out_vec0 = Vec::loadu(out_ptr + k);
-          Vec out_vec1 = Vec::loadu(out_ptr + k + kVecSize);
-          Vec out_vec2 = Vec::loadu(out_ptr + k + kVecSize * 2);
-          Vec out_vec3 = Vec::loadu(out_ptr + k + kVecSize * 3);
-          for (const auto e : c10::irange(e0, e1)) {
-            c = col_data[e];
-            scalar_t val = val_data[e];
-            scalar_t* other_ptr = other_data + c * K + k;
-
-            out_vec0 = update<Vec, reduce>(out_vec0, Vec::loadu(other_ptr) * Vec(val));
-            out_vec1 = update<Vec, reduce>(out_vec1, Vec::loadu(other_ptr + kVecSize) * Vec(val));
-            out_vec2 = update<Vec, reduce>(out_vec2, Vec::loadu(other_ptr + kVecSize * 2) * Vec(val));
-            out_vec3 = update<Vec, reduce>(out_vec3, Vec::loadu(other_ptr + kVecSize * 3) * Vec(val));
-          }
-          out_vec0.store(out_ptr + k);
-          out_vec1.store(out_ptr + k + kVecSize);
-          out_vec2.store(out_ptr + k + kVecSize * 2);
-          out_vec3.store(out_ptr + k + kVecSize * 3);
+        for (const auto e : c10::irange(e0, e1)) {
+          int64_t c = col_data[e];
+          scalar_t val = val_data[e];
+          _update<scalar_t, index_t, reduce>(buffer_ptr, e, c, val, other_data, K);
         }
-        for (; k < K - (K % kVecSize); k += kVecSize) {
-          Vec out_vec = Vec::loadu(out_ptr + k);
-          for (const auto e : c10::irange(e0, e1)) {
-            c = col_data[e];
-            scalar_t val = val_data[e];
-            scalar_t* other_ptr = other_data + c * K;
-            out_vec = update<Vec, reduce>(out_vec, Vec::loadu(other_ptr + k) * Vec(val));
-          }
-          out_vec.store(out_ptr + k);
-        }
-        for (; k < K; k++) {
-          scalar_t out_val = out_ptr[k];
-          for (const auto e : c10::irange(e0, e1)) {
-            c = col_data[e];
-            scalar_t val = val_data[e];
-            scalar_t* other_ptr = other_data + c * K;
-            out_val = update<scalar_t, reduce>(out_val, other_ptr[k] * val);
-          }
-          out_ptr[k] = out_val;
+      }
+      if constexpr (need_acc) {
+        if (count != 0) {
+          vec::convert(buffer_ptr, out_ptr, K);
         }
       }
 
@@ -159,7 +186,23 @@
   int64_t M = crow_indices.numel() - 1;
   int64_t K = other.size(-1);
 
+  int num_threads = at::get_num_threads();
+  using opmath_t = at::opmath_type<scalar_t>;
+  Tensor buffer;
+  opmath_t* buffer_data = nullptr;
+  static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
+  if constexpr (need_acc) {
+    auto acc_type = at::toAccumulateType(out.scalar_type(), /*is_cuda=*/true);
+    buffer = at::zeros({num_threads, K}, out.options().dtype(acc_type));
+    buffer_data = buffer.data_ptr<opmath_t>();
+  }
+
   at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
+    int tid = at::get_thread_num();
+    TORCH_CHECK(tid < num_threads,
+                "expect thread id smaller than ", num_threads, ", got thread id ", tid);
+    opmath_t* buffer_ptr = nullptr;
+
     int64_t row_start, row_end, c;
     for (const auto m : c10::irange(begin, end)) {
       row_start = csr_data[m];
@@ -167,20 +210,30 @@
 
       scalar_t* out_ptr = out_data + m * K;
       index_t* arg_out_ptr = arg_out_data + m * K;
+      if constexpr (need_acc) {
+        buffer_ptr = buffer_data + tid * K;
+      } else {
+        buffer_ptr = reinterpret_cast<opmath_t*>(out_ptr);
+      }
 
       if (row_end != row_start) {
-        init<scalar_t, reduce>(out_ptr, K, /*include_self*/false);
+        _init<scalar_t, reduce>(out_ptr, buffer_ptr, K, /*include_self*/false);
         for (const auto e : c10::irange(row_start, row_end)) {
           c = col_data[e];
-          scalar_t val = val_data[e];
+          opmath_t val = opmath_t(val_data[e]);
 
           scalar_t* other_ptr = other_data + c * K;
           for (const auto k : c10::irange(K)) {
-            update_with_index<scalar_t, index_t, reduce>(
-                &out_ptr[k], val *  other_ptr[k], &arg_out_ptr[k], index_t(e));
+            update_with_index<opmath_t, index_t, reduce>(
+                &buffer_ptr[k], opmath_t(val *  other_ptr[k]), &arg_out_ptr[k], index_t(e));
           };
         }
       }
+      if constexpr (need_acc) {
+        if (row_end != row_start) {
+          vec::convert(buffer_ptr, out_ptr, K);
+        }
+      }
     }
   });
 }
@@ -381,14 +434,14 @@
     const Tensor& values,
     const Tensor& other,
     ReductionType reduce_op) {
-  AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
-    AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
-      AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
-        spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
-            out, crow_indices, col_indices, values, other);
+    AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
+      AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
+        AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
+          spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
+              out, crow_indices, col_indices, values, other);
+        });
       });
     });
-  });
 }
 
 void spmm_reduce_arg_kernel(
diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h
index e029b27..8c2888f 100644
--- a/aten/src/ATen/native/cpu/utils.h
+++ b/aten/src/ATen/native/cpu/utils.h
@@ -50,13 +50,24 @@
     std::tie(v0, v1) = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
     return {v0, v1};
   }
+  static Vec2 loadu(const float* ptr) {
+    return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
+  }
   void store(BFloat16* ptr) const {
     Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
     val.store(ptr);
   }
+  void store(float* ptr) const {
+    val0.store(ptr);
+    val1.store(ptr + Vectorized<float>::size());
+  }
 };
 inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; }
 inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; }
+inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; }
+inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; }
+inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; }
+inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; }
 
 template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
 template <> struct VectorizedType<BFloat16> { using type = Vec2; };