Eliminate unnecessary copy in CUDA addmm with sparse compressed block operand (#114484)
As in the title.
As a result, `nn.linear(<strided tensor>, <BSR tensor>, bias=<strided tensor>)` performance increases as follows (`float16`, `NVIDIA A100-SXM4-80GB`):
- 256x256 weights, speed up is 14..27 %
- 512x512 weights, speed up is 9..25 %
- 1024x1024 weights, speed up is 5..20 %
- 2048x2048 weights, speed up is 3..16 %
- 4092x4092 weights, speed up is 2..9 %
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114484
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlas.cpp
index 297e4b6..6cac383 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseBlas.cpp
+++ b/aten/src/ATen/native/sparse/cuda/SparseBlas.cpp
@@ -126,13 +126,12 @@
"x",
self_->size(1));
- if (&result != &self) {
+ if (!result.is_same(self)) {
if (result.layout() == kStrided) {
at::native::resize_output(result, self_->sizes());
} else {
result.resize_as_sparse_(*self_);
}
- result.copy_(*self_);
}
if (result.numel() == 0) {
@@ -142,15 +141,21 @@
if (sparse::impl::_is_sparse_and_zero(mat1) || sparse::impl::_is_sparse_and_zero(mat2)) {
// According to docs, when beta==0 values in self should be ignored.
// nans and infs should not propagate
- if (beta.toComplexDouble() == 0.) {
+ const auto beta_val = beta.toComplexDouble();
+ if (beta_val == 0.) {
result.zero_();
} else {
- result.mul_(beta);
+ if (!result.is_same(self)) {
+ result.copy_(*self_);
+ }
+ if (beta_val != 1.) {
+ result.mul_(beta);
+ }
}
return result;
}
- sparse::impl::cuda::addmm_out_sparse_csr(mat1, mat2, beta, alpha, result);
+ sparse::impl::cuda::addmm_out_sparse_csr(*self_, mat1, mat2, beta, alpha, result);
return result;
}
@@ -167,9 +172,8 @@
TORCH_CHECK(mat2.layout() == kStrided, "torch.baddbmm: Expect mat2 to be strided, but got ", mat2.layout());
TORCH_CHECK(result.layout() == kStrided, "torch.baddbmm: Expect result to be strided, but got ", result.layout());
- if (&result != &self) {
+ if (!result.is_same(self)) {
at::native::resize_output(result, self.sizes());
- result.copy_(self);
}
if (mat1._nnz() == 0) {
@@ -178,12 +182,17 @@
if (beta.toComplexDouble() == 0.) {
result.zero_();
} else {
- result.mul_(beta);
+ if (!result.is_same(self)) {
+ result.copy_(self);
+ }
+ if (beta.toComplexDouble() != 1.) {
+ result.mul_(beta);
+ }
}
return result;
}
- sparse::impl::cuda::addmm_out_sparse_csr(mat1, mat2, beta, alpha, result);
+ sparse::impl::cuda::addmm_out_sparse_csr(self, mat1, mat2, beta, alpha, result);
return result;
}
diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
index 2309a40..408d25b 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
+++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
@@ -464,6 +464,7 @@
}
void block_sparse_mm(
+ const Tensor& input,
const at::sparse_csr::SparseCsrTensor& mat1,
const Tensor& mat2,
const Scalar& beta,
@@ -486,7 +487,7 @@
// especially for not very sparse inputs.
if (mat1.scalar_type() == ScalarType::Half || mat1.scalar_type() == ScalarType::BFloat16) {
at::native::sparse::impl::_compressed_row_strided_addmm_out(
- result,
+ input,
mat1,
mat2,
/*beta=*/beta,
@@ -497,6 +498,10 @@
return;
}
+ if (beta.toComplexDouble() != 0. && !result.is_same(input)) {
+ result.copy_(input);
+ }
+
const cusparseDirection_t block_layout = mat1.values().is_contiguous()
? CUSPARSE_DIRECTION_ROW
: CUSPARSE_DIRECTION_COLUMN;
@@ -838,6 +843,7 @@
} // anonymous namespace
void addmm_out_sparse_csr(
+ const Tensor& input,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
@@ -853,6 +859,39 @@
// Valid combinations terminate in a return
// Invalid combinations are omitted and will fall though to the TORCH check
// generating an informative error message
+
+ // mm functions that copy input to result when needed (e.g. mm
+ // triton kernels do not require result being initialized with
+ // input):
+ if (mat1.layout() == kSparseBsr) {
+ if (mat2.layout() == kStrided) {
+ if (result.layout() == kStrided)
+ return block_sparse_mm(input, mat1, mat2, beta, alpha, result);
+ }
+ }
+
+ if (mat1.layout() == kStrided) {
+ if (mat2.layout() == kSparseBsc) {
+ if (result.layout() == kStrided) {
+ auto result_t = result.transpose(-2, -1);
+ auto input_t = (result.is_same(input) ? result_t : input.transpose(-2, -1));
+ return block_sparse_mm(
+ input_t,
+ mat2.transpose(-2, -1),
+ mat1.transpose(-2, -1),
+ beta,
+ alpha,
+ result_t);
+ }
+ }
+ }
+
+ // copy input to result:
+ if (beta.toComplexDouble() != 0. && !result.is_same(input)) {
+ result.copy_(input);
+ }
+
+ // mm functions that assume that result contains input:
if (mat1.layout() == kStrided) {
if (mat2.layout() == kSparseCsr) {
if (result.layout() == kStrided) {
@@ -875,16 +914,6 @@
result.transpose(-2, -1));
}
}
- if (mat2.layout() == kSparseBsc) {
- if (result.layout() == kStrided) {
- return block_sparse_mm(
- mat2.transpose(-2, -1),
- mat1.transpose(-2, -1),
- beta,
- alpha,
- result.transpose(-2, -1));
- }
- }
}
if (mat1.layout() == kSparseCsr) {
if (mat2.layout() == kStrided) {
@@ -933,12 +962,6 @@
}
}
}
- if (mat1.layout() == kSparseBsr) {
- if (mat2.layout() == kStrided) {
- if (result.layout() == kStrided)
- return block_sparse_mm(mat1, mat2, beta, alpha, result);
- }
- }
TORCH_CHECK(
false,
"addmm: computation on CUDA is not implemented for ",
diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h
index 4bd7281..b2bae73 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h
+++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h
@@ -11,6 +11,7 @@
namespace cuda {
void addmm_out_sparse_csr(
+ const Tensor& input,
const at::sparse_csr::SparseCsrTensor& mat1,
const Tensor& mat2,
const Scalar& beta,