fixing csr addmm bug (#58768)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58768
Fixes gh-58757
This PR has a fix for CPU version of addmm op. Just for context, before this PR, only CSR @ vector was supported. I found out a minor bug in the addmm_out_sparse_csr_dense_cpu for the non MKL code which is solved in this PR.
Moreover, I discovered a limitation in the current MKL implementation. It only works well (acceptable tolerance for output error) with square matrices. I was looking in deep to this issue and I found out that it could be a limitation of the MKL API.
I used this [gist code](https://gist.github.com/aocsa/0606e833cd16a8bfb7d37a5fbb3a5b14) based on [this](https://github.com/baidu-research/DeepBench/blob/master/code/intel/spmm/spmm_bench.cpp) to test this behavior.
As you can see there is not an acceptable output error (last column) when the matrices are squares and there is a not acceptable error when the matrices are not square. I reported the issue here: https://github.com/pytorch/pytorch/issues/58770
Looking forward to your comments.
Test Plan: Imported from OSS
Reviewed By: zou3519
Differential Revision: D28629563
Pulled By: malfet
fbshipit-source-id: 5ee00ae667336e0d9301e5117057213f472cbc86
diff --git a/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp b/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp
index bbc1380..bf84d58 100644
--- a/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp
+++ b/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp
@@ -100,6 +100,7 @@
retval);
}
+ // res(nrows, dense_ncols) = (sparse(nrows * ncols) @ dense(ncols x dense_ncols))
inline void sparse_mm(
float* res,
float* dense,
@@ -108,19 +109,32 @@
MKL_INT nrows,
MKL_INT ncols,
MKL_INT dense_ncols) {
- int stat = mkl_sparse_s_mm(
+ int stat;
+ if (dense_ncols == 1) {
+ stat = mkl_sparse_s_mv(
+ SPARSE_OPERATION_NON_TRANSPOSE,
+ alpha,
+ A,
+ desc,
+ dense,
+ beta,
+ res);
+ TORCH_CHECK(stat == 0, "mkl_sparse_s_mv failed with error code: ", stat);
+ } else {
+ stat = mkl_sparse_s_mm(
SPARSE_OPERATION_NON_TRANSPOSE,
alpha,
A,
desc,
SPARSE_LAYOUT_ROW_MAJOR,
dense,
- dense_ncols,
- dense_ncols,
+ nrows,
+ ncols,
beta,
res,
dense_ncols);
- TORCH_CHECK(stat == 0, "mkl_sparse_s_mm failed with error code: ", stat);
+ TORCH_CHECK(stat == 0, "mkl_sparse_s_mm failed with error code: ", stat);
+ }
}
inline void sparse_mm(
@@ -131,19 +145,33 @@
MKL_INT nrows,
MKL_INT ncols,
MKL_INT dense_ncols) {
- int stat = mkl_sparse_d_mm(
+ int stat;
+ if (dense_ncols == 1) {
+ stat = mkl_sparse_d_mv(
+ SPARSE_OPERATION_NON_TRANSPOSE,
+ alpha,
+ A,
+ desc,
+ dense,
+ beta,
+ res);
+ TORCH_CHECK(stat == 0, "mkl_sparse_d_mv failed with error code: ", stat);
+ }
+ else {
+ stat = mkl_sparse_d_mm(
SPARSE_OPERATION_NON_TRANSPOSE,
alpha,
A,
desc,
SPARSE_LAYOUT_ROW_MAJOR,
dense,
- dense_ncols,
- dense_ncols,
+ nrows,
+ ncols,
beta,
res,
dense_ncols);
- TORCH_CHECK(stat == 0, "mkl_sparse_d_mm failed with error code: ", stat);
+ TORCH_CHECK(stat == 0, "mkl_sparse_d_mm failed with error code: ", stat);
+ }
}
~SparseCsrMKLInterface() {
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
index 92a1011..750440f 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
@@ -9,6 +9,7 @@
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/CPUBlas.h>
+#include <ATen/native/Resize.h>
#include <ATen/native/mkl/SparseCsrLinearAlgebra.h>
#include <algorithm>
@@ -20,165 +21,155 @@
// certain utiliy functions are usable from sparse COO.
using namespace at::sparse;
-static constexpr bool is_msvc() {
+static constexpr bool is_mkl_supported() {
#ifdef _MSC_VER
- return true;
-#else
return false;
+#elif __APPLE__ || __MACH__
+ return false;
+#else
+ return true;
#endif
}
+// Only accept squares sparse matrices or dense input as a vector
+// TODO: Check what happens with MKL, the output error reported with non square matrices tends to be high
+// See: https://github.com/pytorch/pytorch/issues/58770
+bool is_square_or_vec(int64_t dim_i, int64_t dim_j, int64_t dim_k) {
+ return (dim_i == dim_k && dim_k == dim_j) || (dim_i == dim_j && dim_k == 1);
+}
+
+template <typename scalar_t>
+void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, int64_t dim_k, Tensor& r, Scalar beta, const Tensor& t, Scalar alpha, const Tensor& csr, const Tensor& col_indices, const Tensor& values, const Tensor& dense) {
+
+ scalar_t cast_alpha = alpha.to<scalar_t>();
+ scalar_t cast_beta = beta.to<scalar_t>();
+ if (cast_beta == 0) {
+ r.zero_();
+ } else if (cast_beta == 1) {
+ if (!is_same_tensor(r, t)) {
+ r.copy_(t);
+ }
+ } else {
+ at::mul_out(r, t, scalar_to_tensor(beta));
+ }
+ AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_mm_crow_indices", [&]() {
+ auto csr_accessor = csr.accessor<index_t, 1>();
+ auto col_indices_accessor = col_indices.accessor<index_t, 1>();
+
+ auto values_accessor = values.accessor<scalar_t, 1>();
+ scalar_t* dense_ptr = dense.data<scalar_t>();
+ scalar_t* r_ptr = r.data<scalar_t>();
+
+ int64_t dense_stride0 = dense.stride(0);
+ int64_t dense_stride1 = dense.stride(1);
+ int64_t r_stride0 = r.stride(0);
+ int64_t r_stride1 = r.stride(1);
+
+ at::parallel_for(
+ 0,
+ dim_i,
+ internal::GRAIN_SIZE,
+ [&](int64_t irow_start, int64_t irow_end) {
+ for (index_t h = irow_start; h < irow_end; ++h) {
+ index_t i_start = csr_accessor[h];
+ index_t i_end = csr_accessor[h+1];
+ for (index_t i = i_start; i < i_end; i++) {
+ scalar_t val = values_accessor[i];
+ index_t col = col_indices_accessor[i];
+ at::native::cpublas::axpy<scalar_t>(dim_k,
+ cast_alpha * val,
+ dense_ptr + col * dense_stride0, dense_stride1,
+ r_ptr + h * r_stride0, r_stride1);
+ }
+ }
+ });
+ });
+}
+
// Functions for matrix multiplication.
Tensor& addmm_out_sparse_csr_dense_cpu(
const Tensor& self,
- const SparseCsrTensor& op1,
- const Tensor& op2,
+ const SparseCsrTensor& sparse,
+ const Tensor& dense,
const Scalar& beta,
const Scalar& alpha,
- Tensor& out) {
- AT_ASSERT(op1.is_sparse_csr());
- Tensor expand_self = *expand_size(self, {op1.size(0), op2.size(1)}, "addmm_out_sparse_csr");
+ Tensor& r) {
+ TORCH_INTERNAL_ASSERT(sparse.is_sparse_csr());
+ Tensor t = *expand_size(self, {sparse.size(0), dense.size(1)}, "addmm_out_sparse_csr");
- AT_ASSERT(expand_self.device().type() == kCPU);
+ TORCH_INTERNAL_ASSERT(t.device().type() == kCPU);
TORCH_CHECK(
- out.device().type() == kCPU,
+ r.device().type() == kCPU,
"addmm: expected 'out' to be CPU tensor, but got CUDA tensor");
TORCH_CHECK(
- op1.device().type() == kCPU,
+ sparse.device().type() == kCPU,
"addmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor");
TORCH_CHECK(
- op2.device().type() == kCPU,
+ dense.device().type() == kCPU,
"addmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor");
TORCH_CHECK(
- op1.dim() == 2,
+ sparse.dim() == 2,
"addmm: 2-D matrices expected, got ",
- op1.dim(),
+ sparse.dim(),
"D tensor");
TORCH_CHECK(
- op2.dim() == 2,
+ dense.dim() == 2,
"addmm: 2-D matrices expected, got ",
- op2.dim(),
+ dense.dim(),
"D tensor");
TORCH_CHECK(
- out.is_contiguous(),
+ r.is_contiguous(),
"out argument must be contiguous, but got: ",
- out.suggest_memory_format());
+ r.suggest_memory_format());
- // ixk * kxj = ixj
- int64_t dim_i = op1.size(0);
- int64_t dim_j = op2.size(1);
- int64_t dim_k = op1.size(1);
+ // ixj * jxk = ixk
+ int64_t dim_i = sparse.size(0);
+ int64_t dim_j = sparse.size(1);
+ int64_t dim_k = dense.size(1);
TORCH_CHECK(
- op2.size(0) == dim_k,
+ dense.size(0) == dim_j,
"addmm: Expected dense matrix (op2) size(0)=",
- dim_k,
+ dim_j,
", got ",
- op2.size(0));
+ dense.size(0));
TORCH_CHECK(
- op1.size(1) == dim_k,
+ sparse.size(1) == dim_j,
"addmm: Expected sparse matrix (op1) size(1)=",
- dim_k,
+ dim_j,
", got ",
- op1.size(1));
- out.resize_({dim_i, dim_j});
-
- auto col_indices = op1.col_indices();
- auto crow_indices = op1.crow_indices();
- auto values = op1.values();
-
- AT_DISPATCH_FLOATING_TYPES(
- values.scalar_type(), "addmm_sparse_csr_dense", [&] {
- scalar_t cast_beta = beta.to<scalar_t>();
- if (!is_same_tensor(out, expand_self)) {
- out.copy_(expand_self);
- }
- if (cast_beta == 0) {
- out.zero_();
- } else {
- at::mul_out(out, expand_self, scalar_to_tensor(beta));
- }
- });
+ sparse.size(1));
+ resize_output(r, {dim_i, dim_k});
+ auto col_indices = sparse.col_indices();
+ auto crow_indices = sparse.crow_indices();
+ auto values = sparse.values();
+ int64_t nnz = sparse._nnz();
// Do not use MKL for Windows due to linking issues with sparse MKL routines.
- if (at::hasMKL() && !is_msvc()) {
- _sparse_mm_mkl_(out, op1, op2, expand_self, alpha, beta);
+ if (at::hasMKL() && is_mkl_supported() && is_square_or_vec(dim_i, dim_j, dim_k)) {
+ AT_DISPATCH_FLOATING_TYPES(values.type(), "addmm_sparse_dense", [&] {
+ scalar_t cast_beta = beta.to<scalar_t>();
+ if (cast_beta == 0) {
+ r.zero_();
+ } else if (cast_beta == 1) {
+ if (!is_same_tensor(r, t)) {
+ r.copy_(t);
+ }
+ } else {
+ at::mul_out(r, t, scalar_to_tensor(beta));
+ }
+ // r = r + alpha * sparse * dense
+ _sparse_mm_mkl_(r, sparse, dense, t, alpha, Scalar(static_cast<scalar_t>(1.0)));
+ });
} else {
- int64_t dense_stride0 = op1.stride(0);
- int64_t dense_stride1 = op1.stride(1);
- int64_t out_stride0 = out.stride(0);
- int64_t out_stride1 = out.stride(1);
-
- AT_DISPATCH_FLOATING_TYPES(
- values.scalar_type(),
- "sparse_csr_mm_cpu",
- [&alpha,
- &beta,
- &op1,
- &out,
- &values,
- &crow_indices,
- &col_indices,
- &dense_stride0,
- &dense_stride1,
- &out_stride0,
- &out_stride1,
- &dim_k]() {
- AT_DISPATCH_INDEX_TYPES(
- crow_indices.scalar_type(),
- "csr_mm_crow_indices",
- [&alpha,
- &beta,
- &op1,
- &out,
- &values,
- &crow_indices,
- &col_indices,
- &dense_stride0,
- &dense_stride1,
- &out_stride0,
- &out_stride1,
- &dim_k]() {
- scalar_t cast_alpha = alpha.to<scalar_t>();
- // NOLINTNEXTLINE(clang-diagnostic-unused-variable)
- scalar_t cast_beta = beta.to<scalar_t>();
- scalar_t* dense_ptr = op1.data_ptr<scalar_t>();
- scalar_t* out_ptr = out.data_ptr<scalar_t>();
-
- auto col_indices_accessor = col_indices.accessor<index_t, 1>();
- auto crow_indices_accessor =
- crow_indices.accessor<index_t, 1>();
- auto values_accessor = values.accessor<scalar_t, 1>();
-
- at::parallel_for(
- 0,
- crow_indices.size(0) - 1,
- internal::GRAIN_SIZE,
- [&](int64_t irow_start, int64_t irow_end) {
- for (int irow = irow_start; irow < irow_end; ++irow) {
- int start_index = crow_indices_accessor[irow];
- int end_index = crow_indices_accessor[irow + 1];
-
- for (int i = start_index; i < end_index; ++i) {
- auto val = values_accessor[i];
- auto icol = col_indices_accessor[i];
-
- at::native::cpublas::axpy<scalar_t>(
- dim_k,
- cast_alpha * val,
- dense_ptr + icol * dense_stride0,
- dense_stride1,
- out_ptr + irow * out_stride0,
- out_stride1);
- }
- }
- });
- });
- });
+ // r = beta * t + alpha * sparse * dense
+ AT_DISPATCH_FLOATING_TYPES(values.type(), "addmm_sparse_dense", [&] {
+ s_addmm_out_sparse_dense_worker<scalar_t>(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, crow_indices, col_indices, values, dense);
+ });
}
- return out;
+ return r;
}
Tensor addmm_sparse_csr_dense_cpu(
@@ -229,9 +220,9 @@
const Tensor& dense,
const SparseCsrTensor& src,
const Scalar& alpha) {
- AT_ASSERT(dense.layout() == kStrided);
- AT_ASSERT(src.is_sparse_csr());
- AT_ASSERT(dense.device() == kCPU);
+ TORCH_INTERNAL_ASSERT(dense.layout() == kStrided);
+ TORCH_INTERNAL_ASSERT(src.is_sparse_csr());
+ TORCH_INTERNAL_ASSERT(dense.device() == kCPU);
TORCH_CHECK(
out.is_contiguous(),
@@ -263,11 +254,12 @@
out.scalar_type(),
" in add operation");
- auto src_values = src.values().to(commonDtype);
+ auto src_values = src.values();
auto src_crow_indices = src.crow_indices();
auto src_col_indices = src.col_indices();
- out.resize_as_(dense);
+ resize_output(out, dense.sizes());
+
Tensor resultBuffer = out;
Tensor valuesBuffer = src_values.to(commonDtype);
@@ -280,21 +272,21 @@
AT_DISPATCH_ALL_TYPES(
commonDtype,
"add_out_op2_sparse_csr",
- [&src_values, &out, &alpha, &src_crow_indices, &src_col_indices]() {
+ [&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
AT_DISPATCH_INDEX_TYPES(
src_crow_indices.scalar_type(),
"csr_add_out_crow_indices",
- [&src_values, &out, &alpha, &src_crow_indices, &src_col_indices]() {
- auto values_accessor = src_values.accessor<scalar_t, 1>();
- scalar_t* out_ptr = out.data_ptr<scalar_t>();
+ [&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
+ auto values_accessor = valuesBuffer.accessor<scalar_t, 1>();
+ scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
scalar_t cast_value = alpha.to<scalar_t>();
auto crow_indices_accessor =
src_crow_indices.accessor<index_t, 1>();
auto col_indices_accessor =
src_col_indices.accessor<index_t, 1>();
- auto out_strides0 = out.strides()[0];
- auto out_strides1 = out.strides()[1];
+ auto out_strides0 = resultBuffer.strides()[0];
+ auto out_strides1 = resultBuffer.strides()[1];
for (int32_t irow = 0; irow < src_crow_indices.size(0) - 1;
++irow) {
@@ -303,13 +295,16 @@
for (int i = start_index; i < end_index; ++i) {
auto icol = col_indices_accessor[i];
- auto index = out.storage_offset() + irow * out_strides0 +
+ auto index = resultBuffer.storage_offset() + irow * out_strides0 +
icol * out_strides1;
out_ptr[index] += cast_value * values_accessor[i];
}
}
});
});
+ if (out.scalar_type() != commonDtype) {
+ out.copy_(resultBuffer);
+ }
return out;
}
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 23d3d1d..4c80e92 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -1,6 +1,7 @@
import torch
import warnings
import unittest
+import random
from torch.testing._internal.common_utils import \
(IS_MACOS, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff)
from torch.testing._internal.common_device_type import \
@@ -163,7 +164,6 @@
@coalescedonoff
@onlyCPU
@dtypes(torch.double)
- @unittest.skipIf(IS_MACOS or IS_WINDOWS, "see: https://github.com/pytorch/pytorch/issues/58757")
def test_coo_to_csr_convert(self, device, dtype, coalesced):
size = (5, 5)
sparse_dim = 2
@@ -192,8 +192,8 @@
self.assertEqual(coo.matmul(vec), csr.matmul(vec))
@onlyCPU
+ @unittest.skipIf(IS_MACOS or IS_WINDOWS, "MKL doesn't work on windows or mac")
@dtypes(torch.float, torch.double)
- @unittest.skipIf(IS_MACOS or IS_WINDOWS, "see: https://github.com/pytorch/pytorch/issues/58757")
def test_mkl_matvec_warnings(self, device, dtype):
if torch.has_mkl:
for index_dtype in [torch.int32, torch.int64]:
@@ -219,7 +219,6 @@
@onlyCPU
@dtypes(torch.float, torch.double)
- @unittest.skipIf(IS_MACOS or IS_WINDOWS, "see: https://github.com/pytorch/pytorch/issues/58757")
def test_csr_matvec(self, device, dtype):
side = 100
for index_dtype in [torch.int32, torch.int64]:
@@ -236,6 +235,34 @@
csr.matmul(bad_vec)
@onlyCPU
+ @dtypes(torch.double)
+ def test_mm(self, device, dtype):
+ def test_shape(di, dj, dk, nnz):
+ x = self.genSparseCSRTensor((di, dj), nnz, device=device, dtype=dtype, index_dtype=torch.int32)
+ t = torch.randn(di, dk, dtype=dtype, device=device)
+ y = torch.randn(dj, dk, dtype=dtype, device=device)
+ alpha = random.random()
+ beta = random.random()
+
+ # res = beta * t + alpha * (x @ y)
+ res = torch.addmm(t, x, y, beta=beta, alpha=alpha)
+ expected = torch.addmm(t, x.to_dense(), y, beta=beta, alpha=alpha)
+ self.assertEqual(res, expected)
+
+ res = torch.addmm(t, x, y)
+ expected = torch.addmm(t, x.to_dense(), y)
+ self.assertEqual(res, expected)
+
+ res = torch.mm(x, y)
+ expected = torch.mm(x.to_dense(), y)
+ self.assertEqual(res, expected)
+
+ for i in range(2, 5):
+ for j in range(2, 8):
+ for k in range(2, 8):
+ test_shape(i, j, k, i * j // 2)
+
+ @onlyCPU
@dtypes(*torch.testing.floating_types())
def test_coo_csr_conversion(self, device, dtype):
size = (5, 5)