Eliminate unnecessary copy in CUDA addmm with sparse compressed block operand (#114484)

As in the title.

As a result, `nn.linear(<strided tensor>, <BSR tensor>, bias=<strided tensor>)` performance increases as follows (`float16`, `NVIDIA A100-SXM4-80GB`):
- 256x256 weights, speed up is 14..27 %
- 512x512 weights, speed up is 9..25 %
- 1024x1024 weights, speed up is 5..20 %
- 2048x2048 weights, speed up is 3..16 %
- 4092x4092 weights, speed up is 2..9 %

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114484
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlas.cpp
index 297e4b6..6cac383 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseBlas.cpp
+++ b/aten/src/ATen/native/sparse/cuda/SparseBlas.cpp
@@ -126,13 +126,12 @@
               "x",
               self_->size(1));
 
-  if (&result != &self) {
+  if (!result.is_same(self)) {
     if (result.layout() == kStrided) {
       at::native::resize_output(result, self_->sizes());
     } else {
       result.resize_as_sparse_(*self_);
     }
-    result.copy_(*self_);
   }
 
   if (result.numel() == 0) {
@@ -142,15 +141,21 @@
   if (sparse::impl::_is_sparse_and_zero(mat1) || sparse::impl::_is_sparse_and_zero(mat2)) {
     // According to docs, when beta==0 values in self should be ignored.
     // nans and infs should not propagate
-    if (beta.toComplexDouble() == 0.) {
+    const auto beta_val = beta.toComplexDouble();
+    if (beta_val == 0.) {
       result.zero_();
     } else {
-      result.mul_(beta);
+      if (!result.is_same(self)) {
+        result.copy_(*self_);
+      }
+      if (beta_val != 1.) {
+        result.mul_(beta);
+      }
     }
     return result;
   }
 
-  sparse::impl::cuda::addmm_out_sparse_csr(mat1, mat2, beta, alpha, result);
+  sparse::impl::cuda::addmm_out_sparse_csr(*self_, mat1, mat2, beta, alpha, result);
   return result;
 }
 
@@ -167,9 +172,8 @@
   TORCH_CHECK(mat2.layout() == kStrided, "torch.baddbmm: Expect mat2 to be strided, but got ", mat2.layout());
   TORCH_CHECK(result.layout() == kStrided, "torch.baddbmm: Expect result to be strided, but got ", result.layout());
 
-  if (&result != &self) {
+  if (!result.is_same(self)) {
     at::native::resize_output(result, self.sizes());
-    result.copy_(self);
   }
 
   if (mat1._nnz() == 0) {
@@ -178,12 +182,17 @@
     if (beta.toComplexDouble() == 0.) {
       result.zero_();
     } else {
-      result.mul_(beta);
+      if (!result.is_same(self)) {
+        result.copy_(self);
+      }
+      if (beta.toComplexDouble() != 1.) {
+        result.mul_(beta);
+      }
     }
     return result;
   }
 
-  sparse::impl::cuda::addmm_out_sparse_csr(mat1, mat2, beta, alpha, result);
+  sparse::impl::cuda::addmm_out_sparse_csr(self, mat1, mat2, beta, alpha, result);
   return result;
 }
 
diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
index 2309a40..408d25b 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
+++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
@@ -464,6 +464,7 @@
 }
 
 void block_sparse_mm(
+    const Tensor& input,
     const at::sparse_csr::SparseCsrTensor& mat1,
     const Tensor& mat2,
     const Scalar& beta,
@@ -486,7 +487,7 @@
   // especially for not very sparse inputs.
   if (mat1.scalar_type() == ScalarType::Half || mat1.scalar_type() == ScalarType::BFloat16) {
     at::native::sparse::impl::_compressed_row_strided_addmm_out(
-        result,
+        input,
         mat1,
         mat2,
         /*beta=*/beta,
@@ -497,6 +498,10 @@
     return;
   }
 
+  if (beta.toComplexDouble() != 0. && !result.is_same(input)) {
+    result.copy_(input);
+  }
+
   const cusparseDirection_t block_layout = mat1.values().is_contiguous()
       ? CUSPARSE_DIRECTION_ROW
       : CUSPARSE_DIRECTION_COLUMN;
@@ -838,6 +843,7 @@
 } // anonymous namespace
 
 void addmm_out_sparse_csr(
+    const Tensor& input,
     const Tensor& mat1,
     const Tensor& mat2,
     const Scalar& beta,
@@ -853,6 +859,39 @@
   // Valid combinations terminate in a return
   // Invalid combinations are omitted and will fall though to the TORCH check
   // generating an informative error message
+
+  // mm functions that copy input to result when needed (e.g. mm
+  // triton kernels do not require result being initialized with
+  // input):
+  if (mat1.layout() == kSparseBsr) {
+    if (mat2.layout() == kStrided) {
+      if (result.layout() == kStrided)
+        return block_sparse_mm(input, mat1, mat2, beta, alpha, result);
+    }
+  }
+
+  if (mat1.layout() == kStrided) {
+    if (mat2.layout() == kSparseBsc) {
+      if (result.layout() == kStrided) {
+        auto result_t = result.transpose(-2, -1);
+        auto input_t = (result.is_same(input) ? result_t : input.transpose(-2, -1));
+        return block_sparse_mm(
+            input_t,
+            mat2.transpose(-2, -1),
+            mat1.transpose(-2, -1),
+            beta,
+            alpha,
+            result_t);
+      }
+    }
+  }
+
+  // copy input to result:
+  if (beta.toComplexDouble() != 0. && !result.is_same(input)) {
+    result.copy_(input);
+  }
+
+  // mm functions that assume that result contains input:
   if (mat1.layout() == kStrided) {
     if (mat2.layout() == kSparseCsr) {
       if (result.layout() == kStrided) {
@@ -875,16 +914,6 @@
             result.transpose(-2, -1));
       }
     }
-    if (mat2.layout() == kSparseBsc) {
-      if (result.layout() == kStrided) {
-        return block_sparse_mm(
-            mat2.transpose(-2, -1),
-            mat1.transpose(-2, -1),
-            beta,
-            alpha,
-            result.transpose(-2, -1));
-      }
-    }
   }
   if (mat1.layout() == kSparseCsr) {
     if (mat2.layout() == kStrided) {
@@ -933,12 +962,6 @@
       }
     }
   }
-  if (mat1.layout() == kSparseBsr) {
-    if (mat2.layout() == kStrided) {
-      if (result.layout() == kStrided)
-        return block_sparse_mm(mat1, mat2, beta, alpha, result);
-    }
-  }
   TORCH_CHECK(
       false,
       "addmm: computation on CUDA is not implemented for ",
diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h
index 4bd7281..b2bae73 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h
+++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h
@@ -11,6 +11,7 @@
 namespace cuda {
 
 void addmm_out_sparse_csr(
+    const Tensor& input,
     const at::sparse_csr::SparseCsrTensor& mat1,
     const Tensor& mat2,
     const Scalar& beta,