SumKernel (BFloat16): use float as accumulation type (#55217)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55217
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D28836794
Pulled By: VitalyFedyunin
fbshipit-source-id: 46ed3a862c2bb4c6325c78ecfc5d01761f7a113a
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
index 21de715..e56da95 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
@@ -742,4 +742,26 @@
#endif
+struct Vec2f {
+ Vectorized<float> val0, val1;
+ Vec2f() {}
+ Vec2f(float v) : val0(v), val1(v) {}
+ Vec2f(Vectorized<float> v0, Vectorized<float> v1) : val0(v0), val1(v1) {}
+ operator Vectorized<BFloat16>() const {
+ return convert_float_bfloat16(val0, val1);
+ }
+};
+inline Vec2f& operator+= (Vec2f& a, const Vec2f& b) {
+ a.val0 += b.val0;
+ a.val1 += b.val1;
+ return a;
+}
+inline Vec2f& operator+= (Vec2f& a, const Vectorized<BFloat16>& b) {
+ Vectorized<float> b0, b1;
+ std::tie(b0, b1) = convert_bfloat16_float(b);
+ a.val0 += b0;
+ a.val1 += b1;
+ return a;
+}
+
}}}
diff --git a/aten/src/ATen/native/cpu/SumKernel.cpp b/aten/src/ATen/native/cpu/SumKernel.cpp
index 56a5afa..9ac1581 100644
--- a/aten/src/ATen/native/cpu/SumKernel.cpp
+++ b/aten/src/ATen/native/cpu/SumKernel.cpp
@@ -11,6 +11,16 @@
namespace native {
namespace {
+// use float as accumulation type for BFloat16
+template <typename scalar_t> struct AccType { using type = scalar_t; };
+template <> struct AccType<BFloat16> { using type = float; };
+
+template <typename scalar_t> struct AccType<Vectorized<scalar_t>> { using type = Vectorized<scalar_t>; };
+template <> struct AccType<Vectorized<BFloat16>> { using type = Vec2f; };
+
+template <typename scalar_t>
+using acc_type = typename AccType<scalar_t>::type;
+
template <typename scalar_t>
struct LoadPolicy {
static scalar_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
@@ -207,8 +217,9 @@
const int64_t level_mask = level_step - 1;
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
- scalar_t acc[num_levels][nrows];
- std::fill_n(&acc[0][0], num_levels * nrows, scalar_t(0));
+ using accscalar_t = acc_type<scalar_t>;
+ accscalar_t acc[num_levels][nrows];
+ std::fill_n(&acc[0][0], num_levels * nrows, accscalar_t(0));
int64_t i = 0;
for (; i + level_step <= size;) {
@@ -228,7 +239,7 @@
#endif
for (int64_t k = 0; k < nrows; ++k) {
acc[j][k] += acc[j-1][k];
- acc[j-1][k] = scalar_t(0);
+ acc[j-1][k] = accscalar_t(0);
}
const auto mask = (level_mask << (j * level_power));
@@ -260,7 +271,7 @@
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
std::array<scalar_t, nrows> ret;
for (int64_t k = 0; k < nrows; ++k) {
- ret[k] = acc[0][k];
+ ret[k] = scalar_t(acc[0][k]);
}
return ret;
}