Optimize sparse.mm reduce in BFloat16 data type in CPU backend (#103239)
### Description
This PR is to optimize sparse.mm reduce of BFloat16 data type in CPU backend, which is one task in https://github.com/pyg-team/pytorch_geometric/issues/7057. Half support (need support addmm Half implementation) will be done once https://github.com/pytorch/pytorch/pull/99498 upstream.
Next step:
- [x] Add benchmarks
- [x] Update UTs
- [x] Check backward behaviors
- [x] Refactor code
### Performance test (Updated)
Test BFloat16 in Intel(R) Xeon(R) Platinum 8380 CPU @ 2.30GHz
With jemalloc and iomp
Single socket (40C)

Single core

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103239
Approved by: https://github.com/mingfeima, https://github.com/albanD
diff --git a/aten/src/ATen/native/cpu/ReduceUtils.h b/aten/src/ATen/native/cpu/ReduceUtils.h
index 96cb1c9..c54dc49 100644
--- a/aten/src/ATen/native/cpu/ReduceUtils.h
+++ b/aten/src/ATen/native/cpu/ReduceUtils.h
@@ -7,6 +7,8 @@
#include <ATen/native/ReductionType.h>
#include <c10/util/irange.h>
#include <ATen/OpMathType.h>
+#include <ATen/native/cpu/utils.h>
+#include <ATen/OpMathType.h>
namespace at::native {
inline namespace CPU_CAPABILITY {
@@ -104,7 +106,8 @@
}
template <typename scalar_t>
-inline scalar_t _max(const scalar_t& x, const scalar_t& y) {
+inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
+_max(const scalar_t& x, const scalar_t& y) {
return at::_isnan(y) ? y : std::max(x, y);
}
@@ -114,8 +117,16 @@
return vec::maximum(x, y);
}
+template <typename vec_t>
+inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
+_max(const vec_t& x, const vec_t& y) {
+ // vec::maximum propagates NaN
+ return maximum(x, y);
+}
+
template <typename scalar_t>
-inline scalar_t _min(const scalar_t& x, const scalar_t& y) {
+inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
+_min(const scalar_t& x, const scalar_t& y) {
return at::_isnan(y) ? y : std::min(x, y);
}
@@ -125,6 +136,13 @@
return vec::minimum(x, y);
}
+template <typename vec_t>
+inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
+_min(const vec_t& x, const vec_t& y) {
+ // vec::minimum propagates NaN
+ return minimum(x, y);
+}
+
template <typename scalar_t, typename accumut, typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map_acc(
diff --git a/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp b/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp
index 6f0cfe0..d9aa9a3 100644
--- a/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp
+++ b/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp
@@ -9,12 +9,14 @@
#include <ATen/native/cpu/ReduceUtils.h>
#include <ATen/native/cpu/utils.h>
#include <c10/util/irange.h>
+#include <ATen/OpMathType.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_native.h>
+#include <ATen/ops/zeros.h>
#endif
namespace at { namespace native {
@@ -22,6 +24,46 @@
namespace {
template <typename scalar_t, typename index_t, ReductionType reduce>
+inline void _update(at::opmath_type<scalar_t>* out_ptr, int64_t e, int64_t c, const scalar_t val, scalar_t* other_data, int64_t K) {
+ using opmath_t = at::opmath_type<scalar_t>;
+ using Vec = vec::Vectorized<scalar_t>;
+ using aVec = VecType<scalar_t>;
+ constexpr int64_t kVecSize = Vec::size();
+ constexpr int64_t kVLEN = kVecSize * 4;
+
+ int64_t k = 0;
+ aVec val_vec = aVec((opmath_t)val);
+ scalar_t* other_ptr = other_data + c * K;
+
+ for (; k < K - (K % kVLEN); k += kVLEN) {
+ aVec out_vec0 = aVec::loadu(out_ptr + k);
+ aVec out_vec1 = aVec::loadu(out_ptr + k + kVecSize);
+ aVec out_vec2 = aVec::loadu(out_ptr + k + kVecSize * 2);
+ aVec out_vec3 = aVec::loadu(out_ptr + k + kVecSize * 3);
+
+ out_vec0 = update<aVec, reduce>(out_vec0, aVec::loadu(other_ptr + k) * val_vec);
+ out_vec1 = update<aVec, reduce>(out_vec1, aVec::loadu(other_ptr + k + kVecSize) * val_vec);
+ out_vec2 = update<aVec, reduce>(out_vec2, aVec::loadu(other_ptr + k + kVecSize * 2) * val_vec);
+ out_vec3 = update<aVec, reduce>(out_vec3, aVec::loadu(other_ptr + k + kVecSize * 3) * val_vec);
+
+ out_vec0.store(out_ptr + k);
+ out_vec1.store(out_ptr + k + kVecSize);
+ out_vec2.store(out_ptr + k + kVecSize * 2);
+ out_vec3.store(out_ptr + k + kVecSize * 3);
+ }
+ for (; k < K - (K % kVecSize); k += kVecSize) {
+ aVec out_vec = aVec::loadu(out_ptr + k);
+ out_vec = update<aVec, reduce>(out_vec, aVec::loadu(other_ptr + k) * val_vec);
+ out_vec.store(out_ptr + k);
+ }
+ for (; k < K; k++) {
+ opmath_t out_val = opmath_t(out_ptr[k]);
+ out_val = update<opmath_t, reduce>(out_val, opmath_t(other_ptr[k]) * opmath_t(val));
+ out_ptr[k] = out_val;
+ }
+}
+
+template <typename scalar_t, typename index_t, ReductionType reduce>
void spmm_reduce_kernel_impl(
const Tensor& out,
const Tensor& crow_indices,
@@ -46,69 +88,54 @@
int64_t M = crow_indices.numel() - 1;
int64_t K = other.size(-1);
- using Vec = vec::Vectorized<scalar_t>;
+ int num_threads = at::get_num_threads();
+ using opmath_t = at::opmath_type<scalar_t>;
+ Tensor buffer;
+ opmath_t* buffer_data = nullptr;
+ static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
+ if constexpr (need_acc) {
+ auto acc_type = at::toAccumulateType(out.scalar_type(), /*is_cuda=*/true);
+ buffer = at::zeros({num_threads, K}, out.options().dtype(acc_type));
+ buffer_data = buffer.data_ptr<opmath_t>();
+ }
+
utils::parallel_sparse_csr(csr_data, M, nnz, [&](int64_t begin, int64_t end) {
- int64_t row_start, row_end, c;
+ int tid = at::get_thread_num();
+ TORCH_CHECK(tid < num_threads,
+ "expect thread id smaller than ", num_threads, ", got thread id ", tid);
+ opmath_t* buffer_ptr = nullptr;
+
+ int64_t row_start, row_end;
for (const auto m : c10::irange(begin, end)) {
row_start = csr_data[m];
row_end = csr_data[m + 1];
scalar_t* out_ptr = out_data + m * K;
-
- constexpr int64_t kVecSize = Vec::size();
- constexpr int64_t kVLEN = kVecSize * 4;
- constexpr int64_t CHUNK_SIZE = 16;
+ if constexpr (need_acc) {
+ buffer_ptr = buffer_data + tid * K;
+ } else {
+ buffer_ptr = reinterpret_cast<opmath_t*>(out_ptr);
+ }
// step 1: reinit the output row for reduce type 'amax' and 'amin'
int64_t count = row_end - row_start;
if (count != 0) {
- init<scalar_t, reduce>(out_ptr, K, /*include_self*/false);
+ _init<scalar_t, reduce>(out_ptr, buffer_ptr, K, /*include_self*/false);
}
// step 2: reduce, do blocking on rowwise to reduce write memory bandwidth
+ constexpr int64_t CHUNK_SIZE = 16;
for (int64_t e0 = row_start; e0 < row_end; e0 += CHUNK_SIZE) {
int64_t e1 = std::min(e0 + CHUNK_SIZE, row_end);
-
- int64_t k = 0;
- for (; k < K - (K % kVLEN); k += kVLEN) {
- Vec out_vec0 = Vec::loadu(out_ptr + k);
- Vec out_vec1 = Vec::loadu(out_ptr + k + kVecSize);
- Vec out_vec2 = Vec::loadu(out_ptr + k + kVecSize * 2);
- Vec out_vec3 = Vec::loadu(out_ptr + k + kVecSize * 3);
- for (const auto e : c10::irange(e0, e1)) {
- c = col_data[e];
- scalar_t val = val_data[e];
- scalar_t* other_ptr = other_data + c * K + k;
-
- out_vec0 = update<Vec, reduce>(out_vec0, Vec::loadu(other_ptr) * Vec(val));
- out_vec1 = update<Vec, reduce>(out_vec1, Vec::loadu(other_ptr + kVecSize) * Vec(val));
- out_vec2 = update<Vec, reduce>(out_vec2, Vec::loadu(other_ptr + kVecSize * 2) * Vec(val));
- out_vec3 = update<Vec, reduce>(out_vec3, Vec::loadu(other_ptr + kVecSize * 3) * Vec(val));
- }
- out_vec0.store(out_ptr + k);
- out_vec1.store(out_ptr + k + kVecSize);
- out_vec2.store(out_ptr + k + kVecSize * 2);
- out_vec3.store(out_ptr + k + kVecSize * 3);
+ for (const auto e : c10::irange(e0, e1)) {
+ int64_t c = col_data[e];
+ scalar_t val = val_data[e];
+ _update<scalar_t, index_t, reduce>(buffer_ptr, e, c, val, other_data, K);
}
- for (; k < K - (K % kVecSize); k += kVecSize) {
- Vec out_vec = Vec::loadu(out_ptr + k);
- for (const auto e : c10::irange(e0, e1)) {
- c = col_data[e];
- scalar_t val = val_data[e];
- scalar_t* other_ptr = other_data + c * K;
- out_vec = update<Vec, reduce>(out_vec, Vec::loadu(other_ptr + k) * Vec(val));
- }
- out_vec.store(out_ptr + k);
- }
- for (; k < K; k++) {
- scalar_t out_val = out_ptr[k];
- for (const auto e : c10::irange(e0, e1)) {
- c = col_data[e];
- scalar_t val = val_data[e];
- scalar_t* other_ptr = other_data + c * K;
- out_val = update<scalar_t, reduce>(out_val, other_ptr[k] * val);
- }
- out_ptr[k] = out_val;
+ }
+ if constexpr (need_acc) {
+ if (count != 0) {
+ vec::convert(buffer_ptr, out_ptr, K);
}
}
@@ -159,7 +186,23 @@
int64_t M = crow_indices.numel() - 1;
int64_t K = other.size(-1);
+ int num_threads = at::get_num_threads();
+ using opmath_t = at::opmath_type<scalar_t>;
+ Tensor buffer;
+ opmath_t* buffer_data = nullptr;
+ static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
+ if constexpr (need_acc) {
+ auto acc_type = at::toAccumulateType(out.scalar_type(), /*is_cuda=*/true);
+ buffer = at::zeros({num_threads, K}, out.options().dtype(acc_type));
+ buffer_data = buffer.data_ptr<opmath_t>();
+ }
+
at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
+ int tid = at::get_thread_num();
+ TORCH_CHECK(tid < num_threads,
+ "expect thread id smaller than ", num_threads, ", got thread id ", tid);
+ opmath_t* buffer_ptr = nullptr;
+
int64_t row_start, row_end, c;
for (const auto m : c10::irange(begin, end)) {
row_start = csr_data[m];
@@ -167,20 +210,30 @@
scalar_t* out_ptr = out_data + m * K;
index_t* arg_out_ptr = arg_out_data + m * K;
+ if constexpr (need_acc) {
+ buffer_ptr = buffer_data + tid * K;
+ } else {
+ buffer_ptr = reinterpret_cast<opmath_t*>(out_ptr);
+ }
if (row_end != row_start) {
- init<scalar_t, reduce>(out_ptr, K, /*include_self*/false);
+ _init<scalar_t, reduce>(out_ptr, buffer_ptr, K, /*include_self*/false);
for (const auto e : c10::irange(row_start, row_end)) {
c = col_data[e];
- scalar_t val = val_data[e];
+ opmath_t val = opmath_t(val_data[e]);
scalar_t* other_ptr = other_data + c * K;
for (const auto k : c10::irange(K)) {
- update_with_index<scalar_t, index_t, reduce>(
- &out_ptr[k], val * other_ptr[k], &arg_out_ptr[k], index_t(e));
+ update_with_index<opmath_t, index_t, reduce>(
+ &buffer_ptr[k], opmath_t(val * other_ptr[k]), &arg_out_ptr[k], index_t(e));
};
}
}
+ if constexpr (need_acc) {
+ if (row_end != row_start) {
+ vec::convert(buffer_ptr, out_ptr, K);
+ }
+ }
}
});
}
@@ -381,14 +434,14 @@
const Tensor& values,
const Tensor& other,
ReductionType reduce_op) {
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
- AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
- AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
- spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
- out, crow_indices, col_indices, values, other);
+ AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
+ AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
+ AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
+ spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
+ out, crow_indices, col_indices, values, other);
+ });
});
});
- });
}
void spmm_reduce_arg_kernel(
diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h
index e029b27..8c2888f 100644
--- a/aten/src/ATen/native/cpu/utils.h
+++ b/aten/src/ATen/native/cpu/utils.h
@@ -50,13 +50,24 @@
std::tie(v0, v1) = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
return {v0, v1};
}
+ static Vec2 loadu(const float* ptr) {
+ return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
+ }
void store(BFloat16* ptr) const {
Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
val.store(ptr);
}
+ void store(float* ptr) const {
+ val0.store(ptr);
+ val1.store(ptr + Vectorized<float>::size());
+ }
};
inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; }
inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; }
+inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; }
+inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; }
+inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; }
+inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; }
template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
template <> struct VectorizedType<BFloat16> { using type = Vec2; };