Sparse Compressed mm avoid creating temp sparse (#104062)
When mm forwards to addmm it creates a zeroed out self this tensor
should take options from the result not one of the sparse arguments.
The bug was leading to an error when calling linear with an `out` kwarg.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104062
Approved by: https://github.com/nikitaved, https://github.com/pearu
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
index e2b2da5..a7c50a4 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
@@ -31,7 +31,6 @@
#include <ATen/ops/_sparse_csr_sum_native.h>
#include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_mm_reduce_impl_backward_native.h>
-#include <ATen/ops/_sparse_mm_reduce_impl_backward_native.h>
#include <ATen/ops/_sparse_mm_reduce_impl_native.h>
#include <ATen/ops/_unique.h>
#include <ATen/ops/abs.h>
@@ -121,6 +120,7 @@
#include <ATen/ops/trunc_native.h>
#include <ATen/ops/zero_native.h>
#include <ATen/ops/zeros.h>
+#include <ATen/ops/zeros_like.h>
#endif
#include <algorithm>
@@ -735,7 +735,7 @@
const Tensor& mat1,
const Tensor& mat2,
Tensor& result) {
- auto zero = at::zeros({mat1.size(0), mat2.size(1)}, mat2.options());
+ auto zero = at::zeros_like(result);
return at::addmm_out(result, zero, mat1, mat2, 0.0, 1.0);
}
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 3eb3c50..89e5506 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -3487,8 +3487,9 @@
if bsr.dim() == 2 and dtype != torch.float:
# Test against linear to check dispatch
# which takes place for torch.half and torch.bfloat16.
- res_tri = torch.nn.functional.linear(dense, bsr)
res_dense = torch.nn.functional.linear(dense, bsr.to_dense())
+ res_tri_out = torch.empty_like(res_dense)
+ res_tri = torch.nn.functional.linear(dense, bsr, out=res_tri_out)
# Check dispatch worked with non-trivial outputs
if m > 0 and n > 0 and k > 0:
@@ -3500,7 +3501,8 @@
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
res_tri_out = torch.empty_like(res_dense)
res_tri = kernel(bsr, dense.transpose(-2, -1), out=res_tri_out)
- self.assertTrue(res_tri is res_tri_out)
+
+ self.assertTrue(res_tri is res_tri_out)
self.assertEqual(res_tri, res_dense)
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)