Add Half for sparse.mm reduce (#133672)
This PR is to add Half support for sparse.mm reduce in CPU backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133672
Approved by: https://github.com/Skylion007
diff --git a/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp b/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp
index cf9749a..b620985 100644
--- a/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp
+++ b/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp
@@ -434,7 +434,7 @@
const Tensor& values,
const Tensor& other,
ReductionType reduce_op) {
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
+ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
@@ -452,7 +452,7 @@
const Tensor& values,
const Tensor& other,
ReductionType reduce_op) {
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
+ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
spmm_reduce_arg_kernel_impl<scalar_t, index_t, reduce>(
@@ -471,7 +471,7 @@
const Tensor& row_indices,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() {
+ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_indices", [&]() {
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
spmm_reduce_backward_input_kernel_impl<scalar_t, index_t, reduce>(
@@ -489,7 +489,7 @@
const Tensor& arg_out,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() {
+ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_arg_indices", [&]() {
spmm_reduce_backward_input_arg_kernel_impl<scalar_t, index_t>(
grad_self, grad_out, col_indices, other, arg_out);
@@ -502,7 +502,7 @@
const Tensor& values,
const Tensor& crow_indices,
const Tensor& row_indices) {
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() {
+ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "spmm_reduce_normalize_values_indices", [&]() {
spmm_reduce_normalize_values_kernel_impl<scalar_t, index_t>(
normalized_values, values, crow_indices, row_indices);
@@ -545,7 +545,7 @@
const Tensor& arg_out,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() {
+ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_other_arg_indices", [&]() {
spmm_reduce_backward_other_arg_kernel_impl<scalar_t, index_t>(
grad_other, grad_out, col_indices, values, arg_out);
diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h
index 641ac0c..bf6af9a 100644
--- a/aten/src/ATen/native/cpu/utils.h
+++ b/aten/src/ATen/native/cpu/utils.h
@@ -53,7 +53,7 @@
return false;
}
-// Helper struct for bfloat16 vectorization
+// Helper struct for bfloat16/float16 vectorization
// Useful when you need float as immediate dtype or accumulate dtype
using namespace vec;
struct Vec2 {
@@ -64,6 +64,10 @@
auto [v0, v1] = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
return {v0, v1};
}
+ static Vec2 loadu(const Half* ptr) {
+ auto [v0, v1] = convert_half_float(Vectorized<Half>::loadu(ptr));
+ return {v0, v1};
+ }
static Vec2 loadu(const float* ptr) {
return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
}
@@ -71,6 +75,10 @@
Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
val.store(ptr);
}
+ void store(Half* ptr) const {
+ Vectorized<Half> val = convert_float_half(val0, val1);
+ val.store(ptr);
+ }
void store(float* ptr) const {
val0.store(ptr);
val1.store(ptr + Vectorized<float>::size());
@@ -85,6 +93,7 @@
template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
template <> struct VectorizedType<BFloat16> { using type = Vec2; };
+template <> struct VectorizedType<Half> { using type = Vec2; };
template <typename scalar_t> using VecType = typename VectorizedType<scalar_t>::type;
// Helper for mixed data type parameter Vec::load
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 25cfadb..fdac0cc 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -226,7 +226,7 @@
("normal", "in_place"): {f16, f32, f64},
("normal", "number_mean"): {f16, f32, f64},
"normal": {f16, f32, f64},
- ("sparse.mm", "reduce"): {f32, f64},
+ ("sparse.mm", "reduce"): {f32, f64, f16},
"sparse.sampled_addmm": {f32, f64},
"to_sparse": {
f32,
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 14e8a85..8ab03c1 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -2575,7 +2575,7 @@
torch.sparse.sampled_addmm(a_sparse, a, a_sparse)
@onlyCPU
- @dtypes(torch.float32, torch.float64, torch.bfloat16)
+ @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
@precisionOverride({torch.bfloat16: 0.01})
def test_sparse_mm_reduce_sum(self, device, dtype):
def run_test(m, n, k, nnz, train):
@@ -2613,8 +2613,8 @@
@skipIfTorchDynamo()
@onlyCPU
- @dtypes(torch.float32, torch.float64, torch.bfloat16)
- @precisionOverride({torch.bfloat16: 0.01})
+ @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
+ @precisionOverride({torch.bfloat16: 0.01, torch.float16: 0.01})
def test_sparse_mm_reduce(self, device, dtype):
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
@@ -2649,7 +2649,7 @@
out = torch.sparse.mm(csr, mat, reduce_type)
self.assertEqual(out, ref_out)
- if train and dtype is not torch.bfloat16:
+ if train and dtype not in (torch.bfloat16, torch.float16):
ref_out.sum().backward()
out.sum().backward()
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index e2e5469..78f99a9 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -13576,7 +13576,7 @@
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
)),
OpInfo('sparse.mm',
- dtypes=floating_types_and(torch.bfloat16),
+ dtypes=floating_types_and(torch.bfloat16, torch.float16),
variant_test_name='reduce',
supports_autograd=True,
supports_out=False,