fixing csr addmm bug (#58768)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58768

Fixes gh-58757

This PR has a fix for CPU version of addmm op. Just for context, before this PR, only CSR @ vector was supported. I found out a minor bug in the addmm_out_sparse_csr_dense_cpu for the non MKL code which is solved in this PR.

Moreover, I discovered a limitation in the current MKL implementation. It only works well (acceptable tolerance for output error) with square matrices. I was looking in deep to this issue and I found out that it could be a limitation of the MKL API.

I used this [gist code](https://gist.github.com/aocsa/0606e833cd16a8bfb7d37a5fbb3a5b14) based on [this](https://github.com/baidu-research/DeepBench/blob/master/code/intel/spmm/spmm_bench.cpp) to test this behavior.

As you can see there is not an acceptable output error (last column) when the matrices are squares and there is a not acceptable error when the matrices are not square. I reported the issue here: https://github.com/pytorch/pytorch/issues/58770

Looking forward to your comments.

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D28629563

Pulled By: malfet

fbshipit-source-id: 5ee00ae667336e0d9301e5117057213f472cbc86
diff --git a/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp b/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp
index bbc1380..bf84d58 100644
--- a/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp
+++ b/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp
@@ -100,6 +100,7 @@
         retval);
   }
 
+ // res(nrows, dense_ncols) = (sparse(nrows * ncols) @ dense(ncols x dense_ncols))
   inline void sparse_mm(
       float* res,
       float* dense,
@@ -108,19 +109,32 @@
       MKL_INT nrows,
       MKL_INT ncols,
       MKL_INT dense_ncols) {
-    int stat = mkl_sparse_s_mm(
+    int stat;
+    if (dense_ncols == 1) {
+      stat = mkl_sparse_s_mv(
+        SPARSE_OPERATION_NON_TRANSPOSE,
+        alpha,
+        A,
+        desc,
+        dense,
+        beta,
+        res);
+      TORCH_CHECK(stat == 0, "mkl_sparse_s_mv failed with error code: ", stat);
+    } else {
+      stat = mkl_sparse_s_mm(
         SPARSE_OPERATION_NON_TRANSPOSE,
         alpha,
         A,
         desc,
         SPARSE_LAYOUT_ROW_MAJOR,
         dense,
-        dense_ncols,
-        dense_ncols,
+        nrows,
+        ncols,
         beta,
         res,
         dense_ncols);
-    TORCH_CHECK(stat == 0, "mkl_sparse_s_mm failed with error code: ", stat);
+      TORCH_CHECK(stat == 0, "mkl_sparse_s_mm failed with error code: ", stat);
+    }
   }
 
   inline void sparse_mm(
@@ -131,19 +145,33 @@
       MKL_INT nrows,
       MKL_INT ncols,
       MKL_INT dense_ncols) {
-    int stat = mkl_sparse_d_mm(
+    int stat;
+    if (dense_ncols == 1) {
+      stat = mkl_sparse_d_mv(
+        SPARSE_OPERATION_NON_TRANSPOSE,
+        alpha,
+        A,
+        desc,
+        dense,
+        beta,
+        res);
+      TORCH_CHECK(stat == 0, "mkl_sparse_d_mv failed with error code: ", stat);
+    }
+    else {
+      stat = mkl_sparse_d_mm(
         SPARSE_OPERATION_NON_TRANSPOSE,
         alpha,
         A,
         desc,
         SPARSE_LAYOUT_ROW_MAJOR,
         dense,
-        dense_ncols,
-        dense_ncols,
+        nrows,
+        ncols,
         beta,
         res,
         dense_ncols);
-    TORCH_CHECK(stat == 0, "mkl_sparse_d_mm failed with error code: ", stat);
+      TORCH_CHECK(stat == 0, "mkl_sparse_d_mm failed with error code: ", stat);
+    }
   }
 
   ~SparseCsrMKLInterface() {
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
index 92a1011..750440f 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
@@ -9,6 +9,7 @@
 #include <ATen/WrapDimUtilsMulti.h>
 #include <ATen/native/BinaryOps.h>
 #include <ATen/native/CPUBlas.h>
+#include <ATen/native/Resize.h>
 #include <ATen/native/mkl/SparseCsrLinearAlgebra.h>
 
 #include <algorithm>
@@ -20,165 +21,155 @@
 // certain utiliy functions are usable from sparse COO.
 using namespace at::sparse;
 
-static constexpr bool is_msvc() {
+static constexpr bool is_mkl_supported() {
 #ifdef _MSC_VER
-  return true;
-#else
   return false;
+#elif  __APPLE__ || __MACH__
+  return false;
+#else
+  return true;
 #endif
 }
 
+// Only accept squares sparse matrices or dense input as a vector
+// TODO: Check what happens with MKL, the output error reported with non square matrices tends to be high
+// See: https://github.com/pytorch/pytorch/issues/58770
+bool is_square_or_vec(int64_t dim_i, int64_t dim_j, int64_t dim_k) {
+  return (dim_i == dim_k  && dim_k == dim_j) || (dim_i == dim_j && dim_k == 1);
+}
+
+template <typename scalar_t>
+void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, int64_t dim_k, Tensor& r, Scalar beta, const Tensor& t, Scalar alpha, const Tensor& csr, const Tensor& col_indices, const Tensor& values, const Tensor& dense) {
+
+  scalar_t cast_alpha = alpha.to<scalar_t>();
+  scalar_t cast_beta = beta.to<scalar_t>();
+  if (cast_beta == 0) {
+    r.zero_();
+  } else if (cast_beta == 1) {
+    if (!is_same_tensor(r, t)) {
+      r.copy_(t);
+    }
+  } else {
+    at::mul_out(r, t, scalar_to_tensor(beta));
+  }
+  AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_mm_crow_indices", [&]() {
+    auto csr_accessor = csr.accessor<index_t, 1>();
+    auto col_indices_accessor = col_indices.accessor<index_t, 1>();
+
+    auto values_accessor = values.accessor<scalar_t, 1>();
+    scalar_t* dense_ptr = dense.data<scalar_t>();
+    scalar_t* r_ptr = r.data<scalar_t>();
+
+    int64_t dense_stride0 = dense.stride(0);
+    int64_t dense_stride1 = dense.stride(1);
+    int64_t r_stride0 = r.stride(0);
+    int64_t r_stride1 = r.stride(1);
+
+    at::parallel_for(
+        0,
+        dim_i,
+        internal::GRAIN_SIZE,
+        [&](int64_t irow_start, int64_t irow_end) {
+            for (index_t h = irow_start; h < irow_end; ++h) {
+              index_t i_start = csr_accessor[h];
+              index_t i_end = csr_accessor[h+1];
+              for (index_t i = i_start; i < i_end; i++) {
+                scalar_t val = values_accessor[i];
+                index_t col = col_indices_accessor[i];
+                at::native::cpublas::axpy<scalar_t>(dim_k,
+                    cast_alpha * val,
+                    dense_ptr + col * dense_stride0, dense_stride1,
+                    r_ptr + h * r_stride0, r_stride1);
+              }
+            }
+    });
+  });
+}
+
 // Functions for matrix multiplication.
 Tensor& addmm_out_sparse_csr_dense_cpu(
     const Tensor& self,
-    const SparseCsrTensor& op1,
-    const Tensor& op2,
+    const SparseCsrTensor& sparse,
+    const Tensor& dense,
     const Scalar& beta,
     const Scalar& alpha,
-    Tensor& out) {
-  AT_ASSERT(op1.is_sparse_csr());
-  Tensor expand_self = *expand_size(self, {op1.size(0), op2.size(1)}, "addmm_out_sparse_csr");
+    Tensor& r) {
+  TORCH_INTERNAL_ASSERT(sparse.is_sparse_csr());
+  Tensor t = *expand_size(self, {sparse.size(0), dense.size(1)}, "addmm_out_sparse_csr");
 
-  AT_ASSERT(expand_self.device().type() == kCPU);
+  TORCH_INTERNAL_ASSERT(t.device().type() == kCPU);
   TORCH_CHECK(
-      out.device().type() == kCPU,
+      r.device().type() == kCPU,
       "addmm: expected 'out' to be CPU tensor, but got CUDA tensor");
   TORCH_CHECK(
-      op1.device().type() == kCPU,
+      sparse.device().type() == kCPU,
       "addmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor");
   TORCH_CHECK(
-      op2.device().type() == kCPU,
+      dense.device().type() == kCPU,
       "addmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor");
 
   TORCH_CHECK(
-      op1.dim() == 2,
+      sparse.dim() == 2,
       "addmm: 2-D matrices expected, got ",
-      op1.dim(),
+      sparse.dim(),
       "D tensor");
   TORCH_CHECK(
-      op2.dim() == 2,
+      dense.dim() == 2,
       "addmm: 2-D matrices expected, got ",
-      op2.dim(),
+      dense.dim(),
       "D tensor");
 
   TORCH_CHECK(
-      out.is_contiguous(),
+      r.is_contiguous(),
       "out argument must be contiguous, but got: ",
-      out.suggest_memory_format());
+      r.suggest_memory_format());
 
-  // ixk * kxj = ixj
-  int64_t dim_i = op1.size(0);
-  int64_t dim_j = op2.size(1);
-  int64_t dim_k = op1.size(1);
+  // ixj * jxk = ixk
+  int64_t dim_i = sparse.size(0);
+  int64_t dim_j = sparse.size(1);
+  int64_t dim_k = dense.size(1);
 
   TORCH_CHECK(
-      op2.size(0) == dim_k,
+      dense.size(0) == dim_j,
       "addmm: Expected dense matrix (op2) size(0)=",
-      dim_k,
+      dim_j,
       ", got ",
-      op2.size(0));
+      dense.size(0));
   TORCH_CHECK(
-      op1.size(1) == dim_k,
+      sparse.size(1) == dim_j,
       "addmm: Expected sparse matrix (op1) size(1)=",
-      dim_k,
+      dim_j,
       ", got ",
-      op1.size(1));
-  out.resize_({dim_i, dim_j});
-
-  auto col_indices = op1.col_indices();
-  auto crow_indices = op1.crow_indices();
-  auto values = op1.values();
-
-  AT_DISPATCH_FLOATING_TYPES(
-      values.scalar_type(), "addmm_sparse_csr_dense", [&] {
-        scalar_t cast_beta = beta.to<scalar_t>();
-        if (!is_same_tensor(out, expand_self)) {
-          out.copy_(expand_self);
-        }
-        if (cast_beta == 0) {
-          out.zero_();
-        } else {
-          at::mul_out(out, expand_self, scalar_to_tensor(beta));
-        }
-      });
+      sparse.size(1));
+  resize_output(r, {dim_i, dim_k});
+  auto col_indices = sparse.col_indices();
+  auto crow_indices = sparse.crow_indices();
+  auto values = sparse.values();
+  int64_t nnz        = sparse._nnz();
 
   // Do not use MKL for Windows due to linking issues with sparse MKL routines.
-  if (at::hasMKL() && !is_msvc()) {
-    _sparse_mm_mkl_(out, op1, op2, expand_self, alpha, beta);
+  if (at::hasMKL() && is_mkl_supported() && is_square_or_vec(dim_i, dim_j, dim_k)) {
+    AT_DISPATCH_FLOATING_TYPES(values.type(), "addmm_sparse_dense", [&] {
+        scalar_t cast_beta = beta.to<scalar_t>();
+        if (cast_beta == 0) {
+          r.zero_();
+        } else if (cast_beta == 1) {
+          if (!is_same_tensor(r, t)) {
+            r.copy_(t);
+          }
+        } else {
+          at::mul_out(r, t, scalar_to_tensor(beta));
+        }
+        // r = r + alpha * sparse * dense
+        _sparse_mm_mkl_(r, sparse, dense, t, alpha, Scalar(static_cast<scalar_t>(1.0)));
+    });
   } else {
-    int64_t dense_stride0 = op1.stride(0);
-    int64_t dense_stride1 = op1.stride(1);
-    int64_t out_stride0 = out.stride(0);
-    int64_t out_stride1 = out.stride(1);
-
-    AT_DISPATCH_FLOATING_TYPES(
-        values.scalar_type(),
-        "sparse_csr_mm_cpu",
-        [&alpha,
-         &beta,
-         &op1,
-         &out,
-         &values,
-         &crow_indices,
-         &col_indices,
-         &dense_stride0,
-         &dense_stride1,
-         &out_stride0,
-         &out_stride1,
-         &dim_k]() {
-          AT_DISPATCH_INDEX_TYPES(
-              crow_indices.scalar_type(),
-              "csr_mm_crow_indices",
-              [&alpha,
-               &beta,
-               &op1,
-               &out,
-               &values,
-               &crow_indices,
-               &col_indices,
-               &dense_stride0,
-               &dense_stride1,
-               &out_stride0,
-               &out_stride1,
-               &dim_k]() {
-                scalar_t cast_alpha = alpha.to<scalar_t>();
-                // NOLINTNEXTLINE(clang-diagnostic-unused-variable)
-                scalar_t cast_beta = beta.to<scalar_t>();
-                scalar_t* dense_ptr = op1.data_ptr<scalar_t>();
-                scalar_t* out_ptr = out.data_ptr<scalar_t>();
-
-                auto col_indices_accessor = col_indices.accessor<index_t, 1>();
-                auto crow_indices_accessor =
-                    crow_indices.accessor<index_t, 1>();
-                auto values_accessor = values.accessor<scalar_t, 1>();
-
-                at::parallel_for(
-                    0,
-                    crow_indices.size(0) - 1,
-                    internal::GRAIN_SIZE,
-                    [&](int64_t irow_start, int64_t irow_end) {
-                      for (int irow = irow_start; irow < irow_end; ++irow) {
-                        int start_index = crow_indices_accessor[irow];
-                        int end_index = crow_indices_accessor[irow + 1];
-
-                        for (int i = start_index; i < end_index; ++i) {
-                          auto val = values_accessor[i];
-                          auto icol = col_indices_accessor[i];
-
-                          at::native::cpublas::axpy<scalar_t>(
-                              dim_k,
-                              cast_alpha * val,
-                              dense_ptr + icol * dense_stride0,
-                              dense_stride1,
-                              out_ptr + irow * out_stride0,
-                              out_stride1);
-                        }
-                      }
-                    });
-              });
-        });
+    // r = beta * t + alpha * sparse * dense
+    AT_DISPATCH_FLOATING_TYPES(values.type(), "addmm_sparse_dense", [&] {
+        s_addmm_out_sparse_dense_worker<scalar_t>(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, crow_indices, col_indices, values, dense);
+    });
   }
-  return out;
+  return r;
 }
 
 Tensor addmm_sparse_csr_dense_cpu(
@@ -229,9 +220,9 @@
     const Tensor& dense,
     const SparseCsrTensor& src,
     const Scalar& alpha) {
-  AT_ASSERT(dense.layout() == kStrided);
-  AT_ASSERT(src.is_sparse_csr());
-  AT_ASSERT(dense.device() == kCPU);
+  TORCH_INTERNAL_ASSERT(dense.layout() == kStrided);
+  TORCH_INTERNAL_ASSERT(src.is_sparse_csr());
+  TORCH_INTERNAL_ASSERT(dense.device() == kCPU);
 
   TORCH_CHECK(
       out.is_contiguous(),
@@ -263,11 +254,12 @@
       out.scalar_type(),
       " in add operation");
 
-  auto src_values = src.values().to(commonDtype);
+  auto src_values = src.values();
   auto src_crow_indices = src.crow_indices();
   auto src_col_indices = src.col_indices();
 
-  out.resize_as_(dense);
+  resize_output(out, dense.sizes());
+
   Tensor resultBuffer = out;
   Tensor valuesBuffer = src_values.to(commonDtype);
 
@@ -280,21 +272,21 @@
   AT_DISPATCH_ALL_TYPES(
       commonDtype,
       "add_out_op2_sparse_csr",
-      [&src_values, &out, &alpha, &src_crow_indices, &src_col_indices]() {
+      [&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
         AT_DISPATCH_INDEX_TYPES(
             src_crow_indices.scalar_type(),
             "csr_add_out_crow_indices",
-            [&src_values, &out, &alpha, &src_crow_indices, &src_col_indices]() {
-              auto values_accessor = src_values.accessor<scalar_t, 1>();
-              scalar_t* out_ptr = out.data_ptr<scalar_t>();
+            [&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
+              auto values_accessor = valuesBuffer.accessor<scalar_t, 1>();
+              scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
               scalar_t cast_value = alpha.to<scalar_t>();
 
               auto crow_indices_accessor =
                   src_crow_indices.accessor<index_t, 1>();
               auto col_indices_accessor =
                   src_col_indices.accessor<index_t, 1>();
-              auto out_strides0 = out.strides()[0];
-              auto out_strides1 = out.strides()[1];
+              auto out_strides0 = resultBuffer.strides()[0];
+              auto out_strides1 = resultBuffer.strides()[1];
 
               for (int32_t irow = 0; irow < src_crow_indices.size(0) - 1;
                    ++irow) {
@@ -303,13 +295,16 @@
 
                 for (int i = start_index; i < end_index; ++i) {
                   auto icol = col_indices_accessor[i];
-                  auto index = out.storage_offset() + irow * out_strides0 +
+                  auto index = resultBuffer.storage_offset() + irow * out_strides0 +
                       icol * out_strides1;
                   out_ptr[index] += cast_value * values_accessor[i];
                 }
               }
             });
       });
+  if (out.scalar_type() != commonDtype) {
+    out.copy_(resultBuffer);
+  }
   return out;
 }
 
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 23d3d1d..4c80e92 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -1,6 +1,7 @@
 import torch
 import warnings
 import unittest
+import random
 from torch.testing._internal.common_utils import \
     (IS_MACOS, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff)
 from torch.testing._internal.common_device_type import \
@@ -163,7 +164,6 @@
     @coalescedonoff
     @onlyCPU
     @dtypes(torch.double)
-    @unittest.skipIf(IS_MACOS or IS_WINDOWS, "see: https://github.com/pytorch/pytorch/issues/58757")
     def test_coo_to_csr_convert(self, device, dtype, coalesced):
         size = (5, 5)
         sparse_dim = 2
@@ -192,8 +192,8 @@
         self.assertEqual(coo.matmul(vec), csr.matmul(vec))
 
     @onlyCPU
+    @unittest.skipIf(IS_MACOS or IS_WINDOWS, "MKL doesn't work on windows or mac")
     @dtypes(torch.float, torch.double)
-    @unittest.skipIf(IS_MACOS or IS_WINDOWS, "see: https://github.com/pytorch/pytorch/issues/58757")
     def test_mkl_matvec_warnings(self, device, dtype):
         if torch.has_mkl:
             for index_dtype in [torch.int32, torch.int64]:
@@ -219,7 +219,6 @@
 
     @onlyCPU
     @dtypes(torch.float, torch.double)
-    @unittest.skipIf(IS_MACOS or IS_WINDOWS, "see: https://github.com/pytorch/pytorch/issues/58757")
     def test_csr_matvec(self, device, dtype):
         side = 100
         for index_dtype in [torch.int32, torch.int64]:
@@ -236,6 +235,34 @@
                 csr.matmul(bad_vec)
 
     @onlyCPU
+    @dtypes(torch.double)
+    def test_mm(self, device, dtype):
+        def test_shape(di, dj, dk, nnz):
+            x = self.genSparseCSRTensor((di, dj), nnz, device=device, dtype=dtype, index_dtype=torch.int32)
+            t = torch.randn(di, dk, dtype=dtype, device=device)
+            y = torch.randn(dj, dk, dtype=dtype, device=device)
+            alpha = random.random()
+            beta = random.random()
+
+            # res = beta * t  + alpha * (x @ y)
+            res = torch.addmm(t, x, y, beta=beta, alpha=alpha)
+            expected = torch.addmm(t, x.to_dense(), y, beta=beta, alpha=alpha)
+            self.assertEqual(res, expected)
+
+            res = torch.addmm(t, x, y)
+            expected = torch.addmm(t, x.to_dense(), y)
+            self.assertEqual(res, expected)
+
+            res = torch.mm(x, y)
+            expected = torch.mm(x.to_dense(), y)
+            self.assertEqual(res, expected)
+
+        for i in range(2, 5):
+            for j in range(2, 8):
+                for k in range(2, 8):
+                    test_shape(i, j, k, i * j // 2)
+
+    @onlyCPU
     @dtypes(*torch.testing.floating_types())
     def test_coo_csr_conversion(self, device, dtype):
         size = (5, 5)