Sparse Compressed mm avoid creating temp sparse (#104062)

When mm forwards to addmm it creates a zeroed out self this tensor
should take options from the result not one of the sparse arguments.

The bug was leading to an error when calling linear with an `out` kwarg.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104062
Approved by: https://github.com/nikitaved, https://github.com/pearu
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
index e2b2da5..a7c50a4 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
@@ -31,7 +31,6 @@
 #include <ATen/ops/_sparse_csr_sum_native.h>
 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
 #include <ATen/ops/_sparse_mm_reduce_impl_backward_native.h>
-#include <ATen/ops/_sparse_mm_reduce_impl_backward_native.h>
 #include <ATen/ops/_sparse_mm_reduce_impl_native.h>
 #include <ATen/ops/_unique.h>
 #include <ATen/ops/abs.h>
@@ -121,6 +120,7 @@
 #include <ATen/ops/trunc_native.h>
 #include <ATen/ops/zero_native.h>
 #include <ATen/ops/zeros.h>
+#include <ATen/ops/zeros_like.h>
 #endif
 
 #include <algorithm>
@@ -735,7 +735,7 @@
     const Tensor& mat1,
     const Tensor& mat2,
     Tensor& result) {
-  auto zero = at::zeros({mat1.size(0), mat2.size(1)}, mat2.options());
+  auto zero = at::zeros_like(result);
   return at::addmm_out(result, zero, mat1, mat2, 0.0, 1.0);
 }
 
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 3eb3c50..89e5506 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -3487,8 +3487,9 @@
             if bsr.dim() == 2 and dtype != torch.float:
                 # Test against linear to check dispatch
                 # which takes place for torch.half and torch.bfloat16.
-                res_tri = torch.nn.functional.linear(dense, bsr)
                 res_dense = torch.nn.functional.linear(dense, bsr.to_dense())
+                res_tri_out = torch.empty_like(res_dense)
+                res_tri = torch.nn.functional.linear(dense, bsr, out=res_tri_out)
 
                 # Check dispatch worked with non-trivial outputs
                 if m > 0 and n > 0 and k > 0:
@@ -3500,7 +3501,8 @@
                 res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
                 res_tri_out = torch.empty_like(res_dense)
                 res_tri = kernel(bsr, dense.transpose(-2, -1), out=res_tri_out)
-                self.assertTrue(res_tri is res_tri_out)
+
+            self.assertTrue(res_tri is res_tri_out)
             self.assertEqual(res_tri, res_dense)
 
             res_dense = bsr.to_dense() @ dense.transpose(-2, -1)