SumKernel (BFloat16): use float as accumulation type (#55217)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55217

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D28836794

Pulled By: VitalyFedyunin

fbshipit-source-id: 46ed3a862c2bb4c6325c78ecfc5d01761f7a113a
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
index 21de715..e56da95 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
@@ -742,4 +742,26 @@
 
 #endif
 
+struct Vec2f {
+  Vectorized<float> val0, val1;
+  Vec2f() {}
+  Vec2f(float v) : val0(v), val1(v) {}
+  Vec2f(Vectorized<float> v0, Vectorized<float> v1) : val0(v0), val1(v1) {}
+  operator Vectorized<BFloat16>() const {
+    return convert_float_bfloat16(val0, val1);
+  }
+};
+inline Vec2f& operator+= (Vec2f& a, const Vec2f& b) {
+  a.val0 += b.val0;
+  a.val1 += b.val1;
+  return a;
+}
+inline Vec2f& operator+= (Vec2f& a, const Vectorized<BFloat16>& b) {
+  Vectorized<float> b0, b1;
+  std::tie(b0, b1) = convert_bfloat16_float(b);
+  a.val0 += b0;
+  a.val1 += b1;
+  return a;
+}
+
 }}}
diff --git a/aten/src/ATen/native/cpu/SumKernel.cpp b/aten/src/ATen/native/cpu/SumKernel.cpp
index 56a5afa..9ac1581 100644
--- a/aten/src/ATen/native/cpu/SumKernel.cpp
+++ b/aten/src/ATen/native/cpu/SumKernel.cpp
@@ -11,6 +11,16 @@
 namespace native {
 namespace {
 
+// use float as accumulation type for BFloat16
+template <typename scalar_t> struct AccType { using type = scalar_t; };
+template <> struct AccType<BFloat16> { using type = float; };
+
+template <typename scalar_t> struct AccType<Vectorized<scalar_t>> { using type = Vectorized<scalar_t>; };
+template <> struct AccType<Vectorized<BFloat16>> { using type = Vec2f; };
+
+template <typename scalar_t>
+using acc_type = typename AccType<scalar_t>::type;
+
 template <typename scalar_t>
 struct LoadPolicy {
   static scalar_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
@@ -207,8 +217,9 @@
   const int64_t level_mask = level_step - 1;
 
   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
-  scalar_t acc[num_levels][nrows];
-  std::fill_n(&acc[0][0], num_levels * nrows, scalar_t(0));
+  using accscalar_t = acc_type<scalar_t>;
+  accscalar_t acc[num_levels][nrows];
+  std::fill_n(&acc[0][0], num_levels * nrows, accscalar_t(0));
 
   int64_t i = 0;
   for (; i + level_step <= size;) {
@@ -228,7 +239,7 @@
       #endif
       for (int64_t k = 0; k < nrows; ++k) {
         acc[j][k] += acc[j-1][k];
-        acc[j-1][k] = scalar_t(0);
+        acc[j-1][k] = accscalar_t(0);
       }
 
       const auto mask = (level_mask << (j * level_power));
@@ -260,7 +271,7 @@
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
   std::array<scalar_t, nrows> ret;
   for (int64_t k = 0; k < nrows; ++k) {
-    ret[k] = acc[0][k];
+    ret[k] = scalar_t(acc[0][k]);
   }
   return ret;
 }