Sparse CSR CPU: cuSolverSP backend for `linalg.solve` (#71399)

Summary:
This PR introduces the `cuSolverSP` backend for `linalg.solve` with sparse CSR input matrices. The motivation comes from the issue: https://github.com/pytorch/pytorch/issues/69538.

`cuSolver` provides [`cusolverSp<t>csrlsvluHost`](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvlu) API, a few things to note:

1. As mentioned in the documentation: `only CPU (Host) path is provided.` From the profiling, there doesn't seem to be any GPU kernel launch for optimization, please see the profiling below.
2. Since only `host` path is provided, the CPU path uses `csrlsvluHost` (but requires PyTorch to be installed/built with CUDA support).
3. The documentation mentions reordering helps optimize stuff, but it isn't clear how it affects the performance. There are options for reordering, so we stick to `reorder = 0` as the default choice.

`cuSolver` has [`csrlsvqr`](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvqr) function which provides a `device` path to solve the linear system. This function is used for the CUDA path in this PR.

**Gist:**

For CPU Path: we call [`csrlsvluHost` function of cuSolver](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvlu).
For CUDA Path: we call [`csrlsvqr` function of cuSolver](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvqr).

**Profiling:** (On sparse input tensor of size 1000 x 1000, with a vector of shape length 1000), for `csrlsvlu` function (to show no GPU optimization)

```cpp
==3999651== Profiling result:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:  100.00%  2.1440us         1  2.1440us  2.1440us  2.1440us  [CUDA memcpy HtoD]
      API calls:   99.72%  1.07199s         9  119.11ms     500ns  1.07164s  cudaFree
                    0.11%  1.2182ms       398  3.0600us     140ns  137.94us  cuDeviceGetAttribute
                    0.06%  674.45us         4  168.61us  165.50us  173.64us  cuDeviceTotalMem
                    0.03%  357.07us         4  89.268us  2.7800us  201.89us  cudaMalloc
                    0.03%  309.29us         1  309.29us  309.29us  309.29us  cudaGetDeviceProperties
                    0.01%  160.47us       332     483ns     350ns  3.3300us  cudaFuncSetAttribute
                    0.01%  115.12us         4  28.780us  26.290us  33.410us  cuDeviceGetName
                    0.00%  28.591us         5  5.7180us     440ns  16.921us  cudaGetDevice
                    0.00%  22.061us         4  5.5150us     871ns  18.690us  cudaDeviceSynchronize
                    0.00%  20.370us        18  1.1310us     410ns  6.9900us  cudaEventDestroy
                    0.00%  16.390us         1  16.390us  16.390us  16.390us  cudaMemcpy
                    0.00%  11.540us         2  5.7700us  1.4900us  10.050us  cuDeviceGetPCIBusId
                    0.00%  10.510us        18     583ns     430ns  1.6200us  cudaEventCreateWithFlags
                    0.00%  7.9100us        21     376ns     290ns     700ns  cudaDeviceGetAttribute
                    0.00%  1.4300us         6     238ns     150ns     590ns  cuDeviceGet
                    0.00%  1.2200us         4     305ns     190ns     500ns  cuDeviceGetCount
                    0.00%     900ns         1     900ns     900ns     900ns  cuInit
                    0.00%     860ns         4     215ns     180ns     260ns  cuDeviceGetUuid
                    0.00%     240ns         1     240ns     240ns     240ns  cuDriverGetVersion
                    0.00%     230ns         1     230ns     230ns     230ns  cudaGetDeviceCount
```

Script:

```python
import torch

def solve(x, other, out):
    torch.linalg.solve(x, other, out=out)

if __name__ == "__main__":
    dense_inp = torch.randn((1000, 1000), dtype=torch.float64)
    # Set 50% of the values to 0 randomly
    dense_inp = torch.nn.functional.dropout(dense_inp, p=0.5)
    sparse_inp = dense_inp.to_sparse_csr()

    other = torch.randint(100, (1000,), dtype=torch.float64)
    out = torch.randint(1, (1000,), dtype=torch.float64)

    solve(sparse_inp, other, out)
```

The following error is raised when the function is used on a CPU device with PyTorch built/installed without CUDA support:
* When built without CUDA support:

```python
/home/krshrimali/pytorch/torch/autograd/profiler.py:151: UserWarning: CUDA is not available, disabling CUDA profiling
  warn("CUDA is not available, disabling CUDA profiling")
Traceback (most recent call last):
  File "/home/krshrimali/pytorch/test_sp.py", line 17, in <module>
    solve(x, other, out)
  File "/home/krshrimali/pytorch/test_sp.py", line 5, in solve
    torch.linalg.solve(x, other, out=out)
RuntimeError: PyTorch was not built with CUDA support. Please use PyTorch built CUDA support
```

**Performance Comparison** (vs SciPy's [`scipy.sparse.linalg.spsolve`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.spsolve.html):

Time taken by `scipy.sparse.linalg.spsolve` : 0.595 seconds

On CPU: Time taken by `torch.linalg.solve` : 4.565 seconds
On CUDA: Time taken by `torch.linalg.solve`: 1.838 seconds

The inputs are of dimensions: (17281, 17281) and (17281, 1), and were taken from https://math.nist.gov/MatrixMarket/extreme.html.

Thanks to IvanYashchuk for helping me with the PR, and guiding me through it.

cc: IvanYashchuk pearu nikitaved cpuhrsch

cc nikitaved pearu cpuhrsch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71399

Reviewed By: VitalyFedyunin

Differential Revision: D33767740

Pulled By: cpuhrsch

fbshipit-source-id: a945f065210cd719096eb8d7cdbf8e8937c2fce9
(cherry picked from commit f4f35c17da414e1ca6c6d91402933521857aa1ea)
diff --git a/aten/src/ATen/cuda/CUDAContext.h b/aten/src/ATen/cuda/CUDAContext.h
index 0167cd5..4862778 100644
--- a/aten/src/ATen/cuda/CUDAContext.h
+++ b/aten/src/ATen/cuda/CUDAContext.h
@@ -8,6 +8,7 @@
 
 #ifdef CUDART_VERSION
 #include <cusolverDn.h>
+#include <cusolverSp.h>
 #endif
 
 #include <ATen/core/ATenGeneral.h>
@@ -74,6 +75,7 @@
 
 #ifdef CUDART_VERSION
 TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
+TORCH_CUDA_CPP_API cusolverSpHandle_t getCurrentCUDASolverSpHandle();
 #endif
 
 } // namespace cuda
diff --git a/aten/src/ATen/cuda/CusolverSpHandlePool.cpp b/aten/src/ATen/cuda/CusolverSpHandlePool.cpp
new file mode 100644
index 0000000..f6ef188
--- /dev/null
+++ b/aten/src/ATen/cuda/CusolverSpHandlePool.cpp
@@ -0,0 +1,51 @@
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/detail/DeviceThreadHandles.h>
+
+#ifdef CUDART_VERSION
+
+namespace at { namespace cuda {
+namespace {
+
+void createCusolverSpHandle(cusolverSpHandle_t *handle) {
+  TORCH_CUSOLVER_CHECK(cusolverSpCreate(handle));
+}
+
+void destroyCusolverSpHandle(cusolverSpHandle_t handle) {
+// this is because of something dumb in the ordering of
+// destruction. Sometimes atexit, the cuda context (or something)
+// would already be destroyed by the time this gets destroyed. It
+// happens in fbcode setting. @colesbury and @soumith decided to not destroy
+// the handle as a workaround.
+//   - Comments of @soumith copied from cuDNN handle pool implementation
+#ifdef NO_CUDNN_DESTROY_HANDLE
+#else
+    cusolverSpDestroy(handle);
+#endif
+}
+
+using CuSolverSpPoolType = DeviceThreadHandlePool<cusolverSpHandle_t, createCusolverSpHandle, destroyCusolverSpHandle>;
+
+} // namespace
+
+cusolverSpHandle_t getCurrentCUDASolverSpHandle() {
+  int device;
+  AT_CUDA_CHECK(cudaGetDevice(&device));
+
+  // Thread local PoolWindows are lazily-initialized
+  // to avoid initialization issues that caused hangs on Windows.
+  // See: https://github.com/pytorch/pytorch/pull/22405
+  // This thread local unique_ptrs will be destroyed when the thread terminates,
+  // releasing its reserved handles back to the pool.
+  static auto pool = std::make_shared<CuSolverSpPoolType>();
+  thread_local std::unique_ptr<CuSolverSpPoolType::PoolWindow> myPoolWindow(
+      pool->newPoolWindow());
+
+  auto handle = myPoolWindow->reserve(device);
+  auto stream = c10::cuda::getCurrentCUDAStream();
+  TORCH_CUSOLVER_CHECK(cusolverSpSetStream(handle, stream));
+  return handle;
+}
+
+}} // namespace at::cuda
+
+#endif // CUDART_VERSION
diff --git a/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp b/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp
index 6cba66a..cf83ffd 100644
--- a/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp
+++ b/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp
@@ -148,6 +148,149 @@
       info));
 }
 
+template <>
+void lsvqr<float, float>(
+  CUSOLVER_LINEAR_SOLVER_ARGTYPES(float, float)) {
+    TORCH_CUSOLVER_CHECK(cusolverSpScsrlsvqr(
+        handle,
+        n,
+        nnzA,
+        descrA,
+        csrValA,
+        csrRowPtrA,
+        csrColIndA,
+        b,
+        tol,
+        reorder,
+        x,
+        singularity));
+  }
+
+template <>
+void lsvqr<double, double>(
+  CUSOLVER_LINEAR_SOLVER_ARGTYPES(double, double)) {
+    TORCH_CUSOLVER_CHECK(cusolverSpDcsrlsvqr(
+        handle,
+        n,
+        nnzA,
+        descrA,
+        csrValA,
+        csrRowPtrA,
+        csrColIndA,
+        b,
+        tol,
+        reorder,
+        x,
+        singularity));
+  }
+
+template <>
+void lsvqr<c10::complex<float>, float>(
+  CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<float>, float)) {
+    TORCH_CUSOLVER_CHECK(cusolverSpCcsrlsvqr(
+        handle,
+        n,
+        nnzA,
+        descrA,
+        reinterpret_cast<const cuComplex*>(csrValA),
+        csrRowPtrA,
+        csrColIndA,
+        reinterpret_cast<const cuComplex*>(b),
+        tol,
+        reorder,
+        reinterpret_cast<cuComplex*>(x),
+        singularity));
+  }
+
+template <>
+void lsvqr<c10::complex<double>, double>(
+  CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<double>, double)) {
+    TORCH_CUSOLVER_CHECK(cusolverSpZcsrlsvqr(
+        handle,
+        n,
+        nnzA,
+        descrA,
+        reinterpret_cast<const cuDoubleComplex*>(csrValA),
+        csrRowPtrA,
+        csrColIndA,
+        reinterpret_cast<const cuDoubleComplex*>(b),
+        tol,
+        reorder,
+        reinterpret_cast<cuDoubleComplex*>(x),
+        singularity));
+  }
+
+template <>
+void lsvlu<float, float>(
+  CUSOLVER_LINEAR_SOLVER_ARGTYPES(float, float)) {
+    TORCH_CUSOLVER_CHECK(cusolverSpScsrlsvluHost(
+        handle,
+        n,
+        nnzA,
+        descrA,
+        csrValA,
+        csrRowPtrA,
+        csrColIndA,
+        b,
+        tol,
+        reorder,
+        x,
+        singularity));
+  }
+
+template <>
+void lsvlu<double, double>(
+  CUSOLVER_LINEAR_SOLVER_ARGTYPES(double, double)) {
+    TORCH_CUSOLVER_CHECK(cusolverSpDcsrlsvluHost(
+        handle,
+        n,
+        nnzA,
+        descrA,
+        csrValA,
+        csrRowPtrA,
+        csrColIndA,
+        b,
+        tol,
+        reorder,
+        x,
+        singularity));
+  }
+
+template <>
+void lsvlu<c10::complex<float>, float>(
+  CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<float>, float)) {
+    TORCH_CUSOLVER_CHECK(cusolverSpCcsrlsvluHost(
+        handle,
+        n,
+        nnzA,
+        descrA,
+        reinterpret_cast<const cuComplex*>(csrValA),
+        csrRowPtrA,
+        csrColIndA,
+        reinterpret_cast<const cuComplex*>(b),
+        tol,
+        reorder,
+        reinterpret_cast<cuComplex*>(x),
+        singularity));
+  }
+
+template <>
+void lsvlu<c10::complex<double>, double>(
+  CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<double>, double)) {
+    TORCH_CUSOLVER_CHECK(cusolverSpZcsrlsvluHost(
+        handle,
+        n,
+        nnzA,
+        descrA,
+        reinterpret_cast<const cuDoubleComplex*>(csrValA),
+        csrRowPtrA,
+        csrColIndA,
+        reinterpret_cast<const cuDoubleComplex*>(b),
+        tol,
+        reorder,
+        reinterpret_cast<cuDoubleComplex*>(x),
+        singularity));
+  }
 
 template<>
 void gesvd_buffersize<float>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) {
diff --git a/aten/src/ATen/native/cuda/linalg/CUDASolver.h b/aten/src/ATen/native/cuda/linalg/CUDASolver.h
index bd8c5cc..084fcb7 100644
--- a/aten/src/ATen/native/cuda/linalg/CUDASolver.h
+++ b/aten/src/ATen/native/cuda/linalg/CUDASolver.h
@@ -13,6 +13,45 @@
 namespace cuda {
 namespace solver {
 
+#define CUSOLVER_LINEAR_SOLVER_ARGTYPES(scalar_t, value_t)                          \
+    cusolverSpHandle_t handle, int n, int nnzA, const cusparseMatDescr_t descrA,    \
+    const scalar_t *csrValA, const int *csrRowPtrA, const int *csrColIndA,          \
+    const scalar_t *b, value_t tol, int reorder, scalar_t *x, int *singularity
+
+template <typename scalar_t, typename value_t>
+inline void lsvlu(CUSOLVER_LINEAR_SOLVER_ARGTYPES(scalar_t, value_t)) {
+    TORCH_INTERNAL_ASSERT(
+        false,
+        "at::cuda::sparse::lsvlu: not implemented for ",
+        typeid(scalar_t).name());
+}
+
+template <>
+void lsvlu<float, float>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(float, float));
+template <>
+void lsvlu<double, double>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(double, double));
+template <>
+void lsvlu<c10::complex<float>, float>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<float>, float));
+template <>
+void lsvlu<c10::complex<double>, double>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<double>, double));
+
+template <typename scalar_t, typename value_t>
+inline void lsvqr(CUSOLVER_LINEAR_SOLVER_ARGTYPES(scalar_t, value_t)) {
+    TORCH_INTERNAL_ASSERT(
+        false,
+        "at::cuda::sparse::lsvqr: not implemented for ",
+        typeid(scalar_t).name());
+}
+
+template <>
+void lsvqr<float, float>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(float, float));
+template <>
+void lsvqr<double, double>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(double, double));
+template <>
+void lsvqr<c10::complex<float>, float>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<float>, float));
+template <>
+void lsvqr<c10::complex<double>, double>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<double>, double));
+
 #define CUDASOLVER_GETRF_ARGTYPES(Dtype)  \
     cusolverDnHandle_t handle, int m, int n, Dtype* dA, int ldda, int* ipiv, int* info
 
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2fce3eb..ae28689 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -11094,11 +11094,13 @@
   variants: function
   dispatch:
     CPU, CUDA: linalg_solve
+    SparseCsrCPU, SparseCsrCUDA: linalg_solve_sparse_csr
 
 - func: linalg_solve.out(Tensor input, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
   python_module: linalg
   dispatch:
     CPU, CUDA: linalg_solve_out
+    SparseCsrCPU, SparseCsrCUDA: linalg_solve_sparse_csr_out
 
 - func: linalg_tensorinv(Tensor self, int ind=2) -> Tensor
   python_module: linalg
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
index d5d9ead..1bdb555 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
@@ -9,6 +9,8 @@
 #include <ATen/native/BinaryOps.h>
 #include <ATen/native/CPUBlas.h>
 #include <ATen/native/Resize.h>
+#include <ATen/native/LinearAlgebra.h>
+#include <ATen/native/sparse/SparseCsrTensorMath.h>
 #include <ATen/native/mkl/SparseBlasImpl.h>
 #include <ATen/native/sparse/SparseBlasImpl.h>
 #include <c10/util/irange.h>
@@ -60,6 +62,7 @@
 #include <ATen/ops/isneginf_native.h>
 #include <ATen/ops/isposinf.h>
 #include <ATen/ops/isposinf_native.h>
+#include <ATen/ops/linalg_solve_native.h>
 #include <ATen/ops/log1p.h>
 #include <ATen/ops/log1p_native.h>
 #include <ATen/ops/mm_native.h>
@@ -92,6 +95,7 @@
 #include <ATen/ops/trunc.h>
 #include <ATen/ops/trunc_native.h>
 #include <ATen/ops/zeros.h>
+#include <ATen/ops/zeros_like.h>
 #include <ATen/ops/zero_native.h>
 #endif
 
@@ -206,6 +210,17 @@
 
 }
 
+// This function is needed for completeness, though we always call the CUDA dispatch
+//  defined in aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu file when given CPU inputs
+void linalg_solve_sparse_csr_kernel_error(
+  const Tensor& input,
+  const Tensor& other,
+  const Tensor& result,
+  int& singularity
+) {
+  TORCH_CHECK(false, "Not implemented for the given backend: ", input.device().type());
+}
+
 } // end anonymous namespace
 
 namespace native {
@@ -685,5 +700,71 @@
   }
 }
 
+Tensor& linalg_solve_sparse_csr_out(const Tensor& input, const Tensor& other, Tensor& result) {
+  if (at::globalContext().hasCUDA()) {
+    TORCH_INTERNAL_ASSERT(input.is_sparse_csr());
+
+    c10::MaybeOwned<Tensor> other_ = other.expect_contiguous();
+    c10::MaybeOwned<Tensor> result_ = result.expect_contiguous();
+
+    if (other.ndimension() > 1) {
+      TORCH_CHECK(other.size(-1) == 1, "NotImplementedError: multiple vector case stored in 'other' tensor is not implemented yet.");
+    }
+
+    // the "other" Tensor needs to be a vector
+    TORCH_CHECK(
+      other.ndimension() == 1 || (other.ndimension() == 2 && other.size(1) == 1),
+      "other tensor must be a vector, but got tensor with dimension: ",
+      other.ndimension());
+
+    // The API expects a square matrix
+    TORCH_CHECK(
+      input.size(0) == input.size(1),
+      "Expected a sparse matrix of dimension N x N (square matrix), but got: ",
+      input.sizes()
+    );
+
+    // Ensure that the vector shape is either (n, 1) or (n,) given input sparse matrix of
+    // shape (n, n)
+    TORCH_CHECK(
+      other.size(0) == input.size(0),
+      "Dimension mismatch for the vector, got shape: ", other.sizes(), " should have been (",
+      input.size(0), ", 1) or (", input.size(0), ",)"
+    );
+
+    TORCH_CHECK(
+      other.scalar_type() == result.scalar_type(),
+      "other (got: ", result.scalar_type(), ") and out (got: ", result.scalar_type(), ") tensors must have same dtype.");
+
+    at::native::resize_output(result, other.sizes());
+    // Return for an empty other tensor
+    if (other.numel() == 0) return result;
+
+    int singularity = -1;
+
+    linalg_solve_sparse_csr_stub(kCUDA, input, *other_, *result_, singularity);
+
+    result.copy_(*result_);
+
+    TORCH_CHECK_LINALG(singularity == -1, "torch.linalg.solve",
+        ": The diagonal element ", singularity , " is zero, the solve could not be completed because the input matrix is singular.");
+    return result;
+  } else {
+    TORCH_CHECK(false, "PyTorch was not built with CUDA support. Please use PyTorch built CUDA support");
+  }
+}
+
+Tensor linalg_solve_sparse_csr(const Tensor& input, const Tensor& other) {
+  // Result tensor will be a vector, dimension same as that of other tensor
+  // The checks for the input tensors (input, other) are done in the call to out variant
+  auto result = at::zeros_like(other);
+  linalg_solve_sparse_csr_out(input, other, result);
+  return result;
+}
+
+DEFINE_DISPATCH(linalg_solve_sparse_csr_stub);
+
+REGISTER_ALL_CPU_DISPATCH(linalg_solve_sparse_csr_stub, &linalg_solve_sparse_csr_kernel_error);;
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.h b/aten/src/ATen/native/sparse/SparseCsrTensorMath.h
new file mode 100644
index 0000000..3f990a3
--- /dev/null
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include <ATen/Tensor.h>
+#include <ATen/native/DispatchStub.h>
+
+namespace at { namespace native {
+
+using linalg_solve_fn = void(*)(
+  const at::Tensor&, const at::Tensor&, const at::Tensor&, int&
+);
+DECLARE_DISPATCH(linalg_solve_fn, linalg_solve_sparse_csr_stub);
+
+}} // namespace at::native
diff --git a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
index c13984f..636c0c5 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
+++ b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
@@ -9,6 +9,10 @@
 #include <ATen/WrapDimUtilsMulti.h>
 #include <ATen/native/BinaryOps.h>
 #include <ATen/native/Resize.h>
+#include <ATen/native/LinearAlgebra.h>
+#include <ATen/native/sparse/SparseCsrTensorMath.h>
+#include <ATen/cuda/CUDASparseDescriptors.h>
+#include <ATen/native/cuda/linalg/CUDASolver.h>
 #include <algorithm>
 
 #ifndef AT_PER_OPERATOR_HEADERS
@@ -32,6 +36,7 @@
 #include <ATen/native/sparse/cuda/SparseBlasImpl.h>
 #include <ATen/native/sparse/cuda/SparseCUDABlas.h>
 #include <ATen/native/sparse/cuda/SparseCUDATensorMath.cuh>
+#include <ATen/native/cuda/MiscUtils.h>
 
 #include <thrust/device_ptr.h>
 #include <thrust/execution_policy.h>
@@ -112,6 +117,46 @@
   C10_CUDA_KERNEL_LAUNCH_CHECK();
 }
 
+template <typename scalar_t>
+void _apply_sparse_csr_linear_solve(
+  const Tensor& input,
+  const Tensor& other,
+  const Tensor& result,
+  int &_singularity) {
+#ifdef USE_ROCM
+  TORCH_CHECK(
+      false,
+      "Calling torch.linalg.solve with sparse tensors requires compiling ",
+      "PyTorch with CUDA and not supported in ROCm build.");
+#else
+  using value_t = typename c10::scalar_value_type<scalar_t>::type;
+  auto values = input.values();
+  const scalar_t *values_data_ptr = values.data_ptr<scalar_t>();
+  auto crow_indices = input.crow_indices().to(kInt);
+  const int *crow_indices_data_ptr = crow_indices.data_ptr<int>();
+  auto col_indices = input.col_indices().to(kInt);
+  const int *col_indices_data_ptr = col_indices.data_ptr<int>();
+  auto handle = at::cuda::getCurrentCUDASolverSpHandle();
+  auto descrA = at::cuda::sparse::CuSparseMatDescriptor();
+
+  const scalar_t *b = other.data_ptr<scalar_t>();
+  int n = cuda_int_cast(input.size(-1), "n");
+  int nnzA = input._nnz();
+  value_t tol = 0.0;
+  // default reordering of symrcm
+  // Should reorder be an argument provided for users to choose between the following?
+  // symrcm, symamd, csrmetisnd (1, 2, 3)
+  int reorder = 0;
+  scalar_t *x = result.data_ptr<scalar_t>();
+
+  // cuSolver API: lsvqr provides device path, while lsvlu is only available on Host
+  auto cusolver_func = input.is_cuda() ? at::cuda::solver::lsvqr<scalar_t, value_t> : at::cuda::solver::lsvlu<scalar_t, value_t>;
+
+  cusolver_func(handle, n, nnzA, descrA.descriptor(), values_data_ptr,
+    crow_indices_data_ptr, col_indices_data_ptr, b, tol, reorder, x, &_singularity);
+#endif
+}
+
 } // namespace
 
 using namespace at::sparse_csr;
@@ -275,5 +320,17 @@
   }
 }
 
+void linalg_solve_sparse_csr_kernel(
+  const Tensor& input,
+  const Tensor& other,
+  const Tensor& result,
+  int& singularity) {
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "sparse_csr_solve", [&] {
+    _apply_sparse_csr_linear_solve<scalar_t>(input, other, result, singularity);
+  });
+}
+
+REGISTER_CUDA_DISPATCH(linalg_solve_sparse_csr_stub, &linalg_solve_sparse_csr_kernel);
+
 } // namespace native
 } // namespace at
diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst
index 178e4cb..6ca229d 100644
--- a/docs/source/sparse.rst
+++ b/docs/source/sparse.rst
@@ -475,6 +475,7 @@
    :func:`torch.lobpcg`; no; ``GENEIG(M[sparse_coo]) -> M[strided], M[strided]``
    :func:`torch.pca_lowrank`; yes; ``PCA(M[sparse_coo]) -> M[strided], M[strided], M[strided]``
    :func:`torch.svd_lowrank`; yes; ``SVD(M[sparse_coo]) -> M[strided], M[strided], M[strided]``
+   :func:`torch.linalg.solve`; no; ``M[sparse_csr] @ V[strided] -> V[strided]``
 
 where "Sparse grad?" column indicates if the PyTorch operation supports
 backward with respect to sparse matrix argument. All PyTorch operations,
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index add4147..5b111b8 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -106,7 +106,6 @@
     test_case.assertEqual(res1, res2)
     test_case.assertEqual(res1, res3)
 
-
 class TestSparseCSRSampler(TestCase):
 
     def test_make_crow_indices(self):
@@ -1321,9 +1320,9 @@
             if sample.input.ndim != 2:
                 continue
 
-            expected = op(sample.input)
+            expected = op(sample.input, *sample.args, **sample.kwargs)
             assert torch.is_tensor(expected)
-            output = op(sample.input.to_sparse_csr())
+            output = op(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs)
             assert torch.is_tensor(output)
 
             self.assertEqual(output.to_dense(), expected)
@@ -1543,6 +1542,41 @@
 
             self.assertEqual(coo_sparse.to_sparse_csr().to_sparse_coo(), coo_sparse)
 
+    @unittest.skipIf(TEST_WITH_ROCM, "The test doesn't support ROCM")
+    @dtypes(*floating_and_complex_types())
+    def test_linalg_solve_sparse_csr_cusolver(self, device, dtype):
+        from torch.testing._internal.common_methods_invocations import sample_inputs_linalg_solve
+
+        if (device == 'meta') or (device == 'cpu' and not torch.cuda.is_available()):
+            self.skipTest("Skipped!")
+
+        samples = sample_inputs_linalg_solve(None, device, dtype)
+
+        for sample in samples:
+            if sample.input.ndim != 2:
+                continue
+
+            out = torch.zeros(sample.args[0].size(), dtype=dtype, device=device)
+            if not torch.cuda.is_available():
+                with self.assertRaisesRegex(RuntimeError, "PyTorch was not built with CUDA support"):
+                    torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs, out=out)
+                break
+
+            if sample.args[0].ndim != 1 and sample.args[0].size(-1) != 1:
+                with self.assertRaisesRegex(RuntimeError, "not implemented yet"):
+                    torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs, out=out)
+                break
+            if not sample.args[0].numel():
+                with self.assertRaisesRegex(RuntimeError,
+                                            "Expected non-empty other tensor, but found empty tensor"):
+                    torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs, out=out)
+                break
+
+            expect = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs)
+            sample.input = sample.input.to_sparse_csr()
+            torch.linalg.solve(sample.input, *sample.args, **sample.kwargs, out=out)
+            self.assertEqual(expect, out)
+
     @skipMeta
     @dtypes(*get_all_dtypes())
     def test_transpose(self, device, dtype):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 7a6a082..1f0e2a2c 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -15623,6 +15623,7 @@
 binary_ufuncs = [op for op in op_db if isinstance(op, BinaryUfuncInfo)]
 spectral_funcs = [op for op in op_db if isinstance(op, SpectralFuncInfo)]
 sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse]
+sparse_csr_funcs = [op for op in op_db if op.supports_sparse_csr]
 sparse_csr_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse_csr]
 shape_funcs = [op for op in op_db if isinstance(op, ShapeFuncInfo)]
 reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo)]