Sparse CSR CPU: cuSolverSP backend for `linalg.solve` (#71399)
Summary:
This PR introduces the `cuSolverSP` backend for `linalg.solve` with sparse CSR input matrices. The motivation comes from the issue: https://github.com/pytorch/pytorch/issues/69538.
`cuSolver` provides [`cusolverSp<t>csrlsvluHost`](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvlu) API, a few things to note:
1. As mentioned in the documentation: `only CPU (Host) path is provided.` From the profiling, there doesn't seem to be any GPU kernel launch for optimization, please see the profiling below.
2. Since only `host` path is provided, the CPU path uses `csrlsvluHost` (but requires PyTorch to be installed/built with CUDA support).
3. The documentation mentions reordering helps optimize stuff, but it isn't clear how it affects the performance. There are options for reordering, so we stick to `reorder = 0` as the default choice.
`cuSolver` has [`csrlsvqr`](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvqr) function which provides a `device` path to solve the linear system. This function is used for the CUDA path in this PR.
**Gist:**
For CPU Path: we call [`csrlsvluHost` function of cuSolver](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvlu).
For CUDA Path: we call [`csrlsvqr` function of cuSolver](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvqr).
**Profiling:** (On sparse input tensor of size 1000 x 1000, with a vector of shape length 1000), for `csrlsvlu` function (to show no GPU optimization)
```cpp
==3999651== Profiling result:
Type Time(%) Time Calls Avg Min Max Name
GPU activities: 100.00% 2.1440us 1 2.1440us 2.1440us 2.1440us [CUDA memcpy HtoD]
API calls: 99.72% 1.07199s 9 119.11ms 500ns 1.07164s cudaFree
0.11% 1.2182ms 398 3.0600us 140ns 137.94us cuDeviceGetAttribute
0.06% 674.45us 4 168.61us 165.50us 173.64us cuDeviceTotalMem
0.03% 357.07us 4 89.268us 2.7800us 201.89us cudaMalloc
0.03% 309.29us 1 309.29us 309.29us 309.29us cudaGetDeviceProperties
0.01% 160.47us 332 483ns 350ns 3.3300us cudaFuncSetAttribute
0.01% 115.12us 4 28.780us 26.290us 33.410us cuDeviceGetName
0.00% 28.591us 5 5.7180us 440ns 16.921us cudaGetDevice
0.00% 22.061us 4 5.5150us 871ns 18.690us cudaDeviceSynchronize
0.00% 20.370us 18 1.1310us 410ns 6.9900us cudaEventDestroy
0.00% 16.390us 1 16.390us 16.390us 16.390us cudaMemcpy
0.00% 11.540us 2 5.7700us 1.4900us 10.050us cuDeviceGetPCIBusId
0.00% 10.510us 18 583ns 430ns 1.6200us cudaEventCreateWithFlags
0.00% 7.9100us 21 376ns 290ns 700ns cudaDeviceGetAttribute
0.00% 1.4300us 6 238ns 150ns 590ns cuDeviceGet
0.00% 1.2200us 4 305ns 190ns 500ns cuDeviceGetCount
0.00% 900ns 1 900ns 900ns 900ns cuInit
0.00% 860ns 4 215ns 180ns 260ns cuDeviceGetUuid
0.00% 240ns 1 240ns 240ns 240ns cuDriverGetVersion
0.00% 230ns 1 230ns 230ns 230ns cudaGetDeviceCount
```
Script:
```python
import torch
def solve(x, other, out):
torch.linalg.solve(x, other, out=out)
if __name__ == "__main__":
dense_inp = torch.randn((1000, 1000), dtype=torch.float64)
# Set 50% of the values to 0 randomly
dense_inp = torch.nn.functional.dropout(dense_inp, p=0.5)
sparse_inp = dense_inp.to_sparse_csr()
other = torch.randint(100, (1000,), dtype=torch.float64)
out = torch.randint(1, (1000,), dtype=torch.float64)
solve(sparse_inp, other, out)
```
The following error is raised when the function is used on a CPU device with PyTorch built/installed without CUDA support:
* When built without CUDA support:
```python
/home/krshrimali/pytorch/torch/autograd/profiler.py:151: UserWarning: CUDA is not available, disabling CUDA profiling
warn("CUDA is not available, disabling CUDA profiling")
Traceback (most recent call last):
File "/home/krshrimali/pytorch/test_sp.py", line 17, in <module>
solve(x, other, out)
File "/home/krshrimali/pytorch/test_sp.py", line 5, in solve
torch.linalg.solve(x, other, out=out)
RuntimeError: PyTorch was not built with CUDA support. Please use PyTorch built CUDA support
```
**Performance Comparison** (vs SciPy's [`scipy.sparse.linalg.spsolve`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.spsolve.html):
Time taken by `scipy.sparse.linalg.spsolve` : 0.595 seconds
On CPU: Time taken by `torch.linalg.solve` : 4.565 seconds
On CUDA: Time taken by `torch.linalg.solve`: 1.838 seconds
The inputs are of dimensions: (17281, 17281) and (17281, 1), and were taken from https://math.nist.gov/MatrixMarket/extreme.html.
Thanks to IvanYashchuk for helping me with the PR, and guiding me through it.
cc: IvanYashchuk pearu nikitaved cpuhrsch
cc nikitaved pearu cpuhrsch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71399
Reviewed By: VitalyFedyunin
Differential Revision: D33767740
Pulled By: cpuhrsch
fbshipit-source-id: a945f065210cd719096eb8d7cdbf8e8937c2fce9
(cherry picked from commit f4f35c17da414e1ca6c6d91402933521857aa1ea)
diff --git a/aten/src/ATen/cuda/CUDAContext.h b/aten/src/ATen/cuda/CUDAContext.h
index 0167cd5..4862778 100644
--- a/aten/src/ATen/cuda/CUDAContext.h
+++ b/aten/src/ATen/cuda/CUDAContext.h
@@ -8,6 +8,7 @@
#ifdef CUDART_VERSION
#include <cusolverDn.h>
+#include <cusolverSp.h>
#endif
#include <ATen/core/ATenGeneral.h>
@@ -74,6 +75,7 @@
#ifdef CUDART_VERSION
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
+TORCH_CUDA_CPP_API cusolverSpHandle_t getCurrentCUDASolverSpHandle();
#endif
} // namespace cuda
diff --git a/aten/src/ATen/cuda/CusolverSpHandlePool.cpp b/aten/src/ATen/cuda/CusolverSpHandlePool.cpp
new file mode 100644
index 0000000..f6ef188
--- /dev/null
+++ b/aten/src/ATen/cuda/CusolverSpHandlePool.cpp
@@ -0,0 +1,51 @@
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/detail/DeviceThreadHandles.h>
+
+#ifdef CUDART_VERSION
+
+namespace at { namespace cuda {
+namespace {
+
+void createCusolverSpHandle(cusolverSpHandle_t *handle) {
+ TORCH_CUSOLVER_CHECK(cusolverSpCreate(handle));
+}
+
+void destroyCusolverSpHandle(cusolverSpHandle_t handle) {
+// this is because of something dumb in the ordering of
+// destruction. Sometimes atexit, the cuda context (or something)
+// would already be destroyed by the time this gets destroyed. It
+// happens in fbcode setting. @colesbury and @soumith decided to not destroy
+// the handle as a workaround.
+// - Comments of @soumith copied from cuDNN handle pool implementation
+#ifdef NO_CUDNN_DESTROY_HANDLE
+#else
+ cusolverSpDestroy(handle);
+#endif
+}
+
+using CuSolverSpPoolType = DeviceThreadHandlePool<cusolverSpHandle_t, createCusolverSpHandle, destroyCusolverSpHandle>;
+
+} // namespace
+
+cusolverSpHandle_t getCurrentCUDASolverSpHandle() {
+ int device;
+ AT_CUDA_CHECK(cudaGetDevice(&device));
+
+ // Thread local PoolWindows are lazily-initialized
+ // to avoid initialization issues that caused hangs on Windows.
+ // See: https://github.com/pytorch/pytorch/pull/22405
+ // This thread local unique_ptrs will be destroyed when the thread terminates,
+ // releasing its reserved handles back to the pool.
+ static auto pool = std::make_shared<CuSolverSpPoolType>();
+ thread_local std::unique_ptr<CuSolverSpPoolType::PoolWindow> myPoolWindow(
+ pool->newPoolWindow());
+
+ auto handle = myPoolWindow->reserve(device);
+ auto stream = c10::cuda::getCurrentCUDAStream();
+ TORCH_CUSOLVER_CHECK(cusolverSpSetStream(handle, stream));
+ return handle;
+}
+
+}} // namespace at::cuda
+
+#endif // CUDART_VERSION
diff --git a/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp b/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp
index 6cba66a..cf83ffd 100644
--- a/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp
+++ b/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp
@@ -148,6 +148,149 @@
info));
}
+template <>
+void lsvqr<float, float>(
+ CUSOLVER_LINEAR_SOLVER_ARGTYPES(float, float)) {
+ TORCH_CUSOLVER_CHECK(cusolverSpScsrlsvqr(
+ handle,
+ n,
+ nnzA,
+ descrA,
+ csrValA,
+ csrRowPtrA,
+ csrColIndA,
+ b,
+ tol,
+ reorder,
+ x,
+ singularity));
+ }
+
+template <>
+void lsvqr<double, double>(
+ CUSOLVER_LINEAR_SOLVER_ARGTYPES(double, double)) {
+ TORCH_CUSOLVER_CHECK(cusolverSpDcsrlsvqr(
+ handle,
+ n,
+ nnzA,
+ descrA,
+ csrValA,
+ csrRowPtrA,
+ csrColIndA,
+ b,
+ tol,
+ reorder,
+ x,
+ singularity));
+ }
+
+template <>
+void lsvqr<c10::complex<float>, float>(
+ CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<float>, float)) {
+ TORCH_CUSOLVER_CHECK(cusolverSpCcsrlsvqr(
+ handle,
+ n,
+ nnzA,
+ descrA,
+ reinterpret_cast<const cuComplex*>(csrValA),
+ csrRowPtrA,
+ csrColIndA,
+ reinterpret_cast<const cuComplex*>(b),
+ tol,
+ reorder,
+ reinterpret_cast<cuComplex*>(x),
+ singularity));
+ }
+
+template <>
+void lsvqr<c10::complex<double>, double>(
+ CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<double>, double)) {
+ TORCH_CUSOLVER_CHECK(cusolverSpZcsrlsvqr(
+ handle,
+ n,
+ nnzA,
+ descrA,
+ reinterpret_cast<const cuDoubleComplex*>(csrValA),
+ csrRowPtrA,
+ csrColIndA,
+ reinterpret_cast<const cuDoubleComplex*>(b),
+ tol,
+ reorder,
+ reinterpret_cast<cuDoubleComplex*>(x),
+ singularity));
+ }
+
+template <>
+void lsvlu<float, float>(
+ CUSOLVER_LINEAR_SOLVER_ARGTYPES(float, float)) {
+ TORCH_CUSOLVER_CHECK(cusolverSpScsrlsvluHost(
+ handle,
+ n,
+ nnzA,
+ descrA,
+ csrValA,
+ csrRowPtrA,
+ csrColIndA,
+ b,
+ tol,
+ reorder,
+ x,
+ singularity));
+ }
+
+template <>
+void lsvlu<double, double>(
+ CUSOLVER_LINEAR_SOLVER_ARGTYPES(double, double)) {
+ TORCH_CUSOLVER_CHECK(cusolverSpDcsrlsvluHost(
+ handle,
+ n,
+ nnzA,
+ descrA,
+ csrValA,
+ csrRowPtrA,
+ csrColIndA,
+ b,
+ tol,
+ reorder,
+ x,
+ singularity));
+ }
+
+template <>
+void lsvlu<c10::complex<float>, float>(
+ CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<float>, float)) {
+ TORCH_CUSOLVER_CHECK(cusolverSpCcsrlsvluHost(
+ handle,
+ n,
+ nnzA,
+ descrA,
+ reinterpret_cast<const cuComplex*>(csrValA),
+ csrRowPtrA,
+ csrColIndA,
+ reinterpret_cast<const cuComplex*>(b),
+ tol,
+ reorder,
+ reinterpret_cast<cuComplex*>(x),
+ singularity));
+ }
+
+template <>
+void lsvlu<c10::complex<double>, double>(
+ CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<double>, double)) {
+ TORCH_CUSOLVER_CHECK(cusolverSpZcsrlsvluHost(
+ handle,
+ n,
+ nnzA,
+ descrA,
+ reinterpret_cast<const cuDoubleComplex*>(csrValA),
+ csrRowPtrA,
+ csrColIndA,
+ reinterpret_cast<const cuDoubleComplex*>(b),
+ tol,
+ reorder,
+ reinterpret_cast<cuDoubleComplex*>(x),
+ singularity));
+ }
template<>
void gesvd_buffersize<float>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) {
diff --git a/aten/src/ATen/native/cuda/linalg/CUDASolver.h b/aten/src/ATen/native/cuda/linalg/CUDASolver.h
index bd8c5cc..084fcb7 100644
--- a/aten/src/ATen/native/cuda/linalg/CUDASolver.h
+++ b/aten/src/ATen/native/cuda/linalg/CUDASolver.h
@@ -13,6 +13,45 @@
namespace cuda {
namespace solver {
+#define CUSOLVER_LINEAR_SOLVER_ARGTYPES(scalar_t, value_t) \
+ cusolverSpHandle_t handle, int n, int nnzA, const cusparseMatDescr_t descrA, \
+ const scalar_t *csrValA, const int *csrRowPtrA, const int *csrColIndA, \
+ const scalar_t *b, value_t tol, int reorder, scalar_t *x, int *singularity
+
+template <typename scalar_t, typename value_t>
+inline void lsvlu(CUSOLVER_LINEAR_SOLVER_ARGTYPES(scalar_t, value_t)) {
+ TORCH_INTERNAL_ASSERT(
+ false,
+ "at::cuda::sparse::lsvlu: not implemented for ",
+ typeid(scalar_t).name());
+}
+
+template <>
+void lsvlu<float, float>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(float, float));
+template <>
+void lsvlu<double, double>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(double, double));
+template <>
+void lsvlu<c10::complex<float>, float>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<float>, float));
+template <>
+void lsvlu<c10::complex<double>, double>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<double>, double));
+
+template <typename scalar_t, typename value_t>
+inline void lsvqr(CUSOLVER_LINEAR_SOLVER_ARGTYPES(scalar_t, value_t)) {
+ TORCH_INTERNAL_ASSERT(
+ false,
+ "at::cuda::sparse::lsvqr: not implemented for ",
+ typeid(scalar_t).name());
+}
+
+template <>
+void lsvqr<float, float>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(float, float));
+template <>
+void lsvqr<double, double>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(double, double));
+template <>
+void lsvqr<c10::complex<float>, float>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<float>, float));
+template <>
+void lsvqr<c10::complex<double>, double>(CUSOLVER_LINEAR_SOLVER_ARGTYPES(c10::complex<double>, double));
+
#define CUDASOLVER_GETRF_ARGTYPES(Dtype) \
cusolverDnHandle_t handle, int m, int n, Dtype* dA, int ldda, int* ipiv, int* info
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2fce3eb..ae28689 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -11094,11 +11094,13 @@
variants: function
dispatch:
CPU, CUDA: linalg_solve
+ SparseCsrCPU, SparseCsrCUDA: linalg_solve_sparse_csr
- func: linalg_solve.out(Tensor input, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
dispatch:
CPU, CUDA: linalg_solve_out
+ SparseCsrCPU, SparseCsrCUDA: linalg_solve_sparse_csr_out
- func: linalg_tensorinv(Tensor self, int ind=2) -> Tensor
python_module: linalg
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
index d5d9ead..1bdb555 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
@@ -9,6 +9,8 @@
#include <ATen/native/BinaryOps.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/Resize.h>
+#include <ATen/native/LinearAlgebra.h>
+#include <ATen/native/sparse/SparseCsrTensorMath.h>
#include <ATen/native/mkl/SparseBlasImpl.h>
#include <ATen/native/sparse/SparseBlasImpl.h>
#include <c10/util/irange.h>
@@ -60,6 +62,7 @@
#include <ATen/ops/isneginf_native.h>
#include <ATen/ops/isposinf.h>
#include <ATen/ops/isposinf_native.h>
+#include <ATen/ops/linalg_solve_native.h>
#include <ATen/ops/log1p.h>
#include <ATen/ops/log1p_native.h>
#include <ATen/ops/mm_native.h>
@@ -92,6 +95,7 @@
#include <ATen/ops/trunc.h>
#include <ATen/ops/trunc_native.h>
#include <ATen/ops/zeros.h>
+#include <ATen/ops/zeros_like.h>
#include <ATen/ops/zero_native.h>
#endif
@@ -206,6 +210,17 @@
}
+// This function is needed for completeness, though we always call the CUDA dispatch
+// defined in aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu file when given CPU inputs
+void linalg_solve_sparse_csr_kernel_error(
+ const Tensor& input,
+ const Tensor& other,
+ const Tensor& result,
+ int& singularity
+) {
+ TORCH_CHECK(false, "Not implemented for the given backend: ", input.device().type());
+}
+
} // end anonymous namespace
namespace native {
@@ -685,5 +700,71 @@
}
}
+Tensor& linalg_solve_sparse_csr_out(const Tensor& input, const Tensor& other, Tensor& result) {
+ if (at::globalContext().hasCUDA()) {
+ TORCH_INTERNAL_ASSERT(input.is_sparse_csr());
+
+ c10::MaybeOwned<Tensor> other_ = other.expect_contiguous();
+ c10::MaybeOwned<Tensor> result_ = result.expect_contiguous();
+
+ if (other.ndimension() > 1) {
+ TORCH_CHECK(other.size(-1) == 1, "NotImplementedError: multiple vector case stored in 'other' tensor is not implemented yet.");
+ }
+
+ // the "other" Tensor needs to be a vector
+ TORCH_CHECK(
+ other.ndimension() == 1 || (other.ndimension() == 2 && other.size(1) == 1),
+ "other tensor must be a vector, but got tensor with dimension: ",
+ other.ndimension());
+
+ // The API expects a square matrix
+ TORCH_CHECK(
+ input.size(0) == input.size(1),
+ "Expected a sparse matrix of dimension N x N (square matrix), but got: ",
+ input.sizes()
+ );
+
+ // Ensure that the vector shape is either (n, 1) or (n,) given input sparse matrix of
+ // shape (n, n)
+ TORCH_CHECK(
+ other.size(0) == input.size(0),
+ "Dimension mismatch for the vector, got shape: ", other.sizes(), " should have been (",
+ input.size(0), ", 1) or (", input.size(0), ",)"
+ );
+
+ TORCH_CHECK(
+ other.scalar_type() == result.scalar_type(),
+ "other (got: ", result.scalar_type(), ") and out (got: ", result.scalar_type(), ") tensors must have same dtype.");
+
+ at::native::resize_output(result, other.sizes());
+ // Return for an empty other tensor
+ if (other.numel() == 0) return result;
+
+ int singularity = -1;
+
+ linalg_solve_sparse_csr_stub(kCUDA, input, *other_, *result_, singularity);
+
+ result.copy_(*result_);
+
+ TORCH_CHECK_LINALG(singularity == -1, "torch.linalg.solve",
+ ": The diagonal element ", singularity , " is zero, the solve could not be completed because the input matrix is singular.");
+ return result;
+ } else {
+ TORCH_CHECK(false, "PyTorch was not built with CUDA support. Please use PyTorch built CUDA support");
+ }
+}
+
+Tensor linalg_solve_sparse_csr(const Tensor& input, const Tensor& other) {
+ // Result tensor will be a vector, dimension same as that of other tensor
+ // The checks for the input tensors (input, other) are done in the call to out variant
+ auto result = at::zeros_like(other);
+ linalg_solve_sparse_csr_out(input, other, result);
+ return result;
+}
+
+DEFINE_DISPATCH(linalg_solve_sparse_csr_stub);
+
+REGISTER_ALL_CPU_DISPATCH(linalg_solve_sparse_csr_stub, &linalg_solve_sparse_csr_kernel_error);;
+
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.h b/aten/src/ATen/native/sparse/SparseCsrTensorMath.h
new file mode 100644
index 0000000..3f990a3
--- /dev/null
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include <ATen/Tensor.h>
+#include <ATen/native/DispatchStub.h>
+
+namespace at { namespace native {
+
+using linalg_solve_fn = void(*)(
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, int&
+);
+DECLARE_DISPATCH(linalg_solve_fn, linalg_solve_sparse_csr_stub);
+
+}} // namespace at::native
diff --git a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
index c13984f..636c0c5 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
+++ b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
@@ -9,6 +9,10 @@
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/Resize.h>
+#include <ATen/native/LinearAlgebra.h>
+#include <ATen/native/sparse/SparseCsrTensorMath.h>
+#include <ATen/cuda/CUDASparseDescriptors.h>
+#include <ATen/native/cuda/linalg/CUDASolver.h>
#include <algorithm>
#ifndef AT_PER_OPERATOR_HEADERS
@@ -32,6 +36,7 @@
#include <ATen/native/sparse/cuda/SparseBlasImpl.h>
#include <ATen/native/sparse/cuda/SparseCUDABlas.h>
#include <ATen/native/sparse/cuda/SparseCUDATensorMath.cuh>
+#include <ATen/native/cuda/MiscUtils.h>
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
@@ -112,6 +117,46 @@
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
+template <typename scalar_t>
+void _apply_sparse_csr_linear_solve(
+ const Tensor& input,
+ const Tensor& other,
+ const Tensor& result,
+ int &_singularity) {
+#ifdef USE_ROCM
+ TORCH_CHECK(
+ false,
+ "Calling torch.linalg.solve with sparse tensors requires compiling ",
+ "PyTorch with CUDA and not supported in ROCm build.");
+#else
+ using value_t = typename c10::scalar_value_type<scalar_t>::type;
+ auto values = input.values();
+ const scalar_t *values_data_ptr = values.data_ptr<scalar_t>();
+ auto crow_indices = input.crow_indices().to(kInt);
+ const int *crow_indices_data_ptr = crow_indices.data_ptr<int>();
+ auto col_indices = input.col_indices().to(kInt);
+ const int *col_indices_data_ptr = col_indices.data_ptr<int>();
+ auto handle = at::cuda::getCurrentCUDASolverSpHandle();
+ auto descrA = at::cuda::sparse::CuSparseMatDescriptor();
+
+ const scalar_t *b = other.data_ptr<scalar_t>();
+ int n = cuda_int_cast(input.size(-1), "n");
+ int nnzA = input._nnz();
+ value_t tol = 0.0;
+ // default reordering of symrcm
+ // Should reorder be an argument provided for users to choose between the following?
+ // symrcm, symamd, csrmetisnd (1, 2, 3)
+ int reorder = 0;
+ scalar_t *x = result.data_ptr<scalar_t>();
+
+ // cuSolver API: lsvqr provides device path, while lsvlu is only available on Host
+ auto cusolver_func = input.is_cuda() ? at::cuda::solver::lsvqr<scalar_t, value_t> : at::cuda::solver::lsvlu<scalar_t, value_t>;
+
+ cusolver_func(handle, n, nnzA, descrA.descriptor(), values_data_ptr,
+ crow_indices_data_ptr, col_indices_data_ptr, b, tol, reorder, x, &_singularity);
+#endif
+}
+
} // namespace
using namespace at::sparse_csr;
@@ -275,5 +320,17 @@
}
}
+void linalg_solve_sparse_csr_kernel(
+ const Tensor& input,
+ const Tensor& other,
+ const Tensor& result,
+ int& singularity) {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "sparse_csr_solve", [&] {
+ _apply_sparse_csr_linear_solve<scalar_t>(input, other, result, singularity);
+ });
+}
+
+REGISTER_CUDA_DISPATCH(linalg_solve_sparse_csr_stub, &linalg_solve_sparse_csr_kernel);
+
} // namespace native
} // namespace at
diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst
index 178e4cb..6ca229d 100644
--- a/docs/source/sparse.rst
+++ b/docs/source/sparse.rst
@@ -475,6 +475,7 @@
:func:`torch.lobpcg`; no; ``GENEIG(M[sparse_coo]) -> M[strided], M[strided]``
:func:`torch.pca_lowrank`; yes; ``PCA(M[sparse_coo]) -> M[strided], M[strided], M[strided]``
:func:`torch.svd_lowrank`; yes; ``SVD(M[sparse_coo]) -> M[strided], M[strided], M[strided]``
+ :func:`torch.linalg.solve`; no; ``M[sparse_csr] @ V[strided] -> V[strided]``
where "Sparse grad?" column indicates if the PyTorch operation supports
backward with respect to sparse matrix argument. All PyTorch operations,
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index add4147..5b111b8 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -106,7 +106,6 @@
test_case.assertEqual(res1, res2)
test_case.assertEqual(res1, res3)
-
class TestSparseCSRSampler(TestCase):
def test_make_crow_indices(self):
@@ -1321,9 +1320,9 @@
if sample.input.ndim != 2:
continue
- expected = op(sample.input)
+ expected = op(sample.input, *sample.args, **sample.kwargs)
assert torch.is_tensor(expected)
- output = op(sample.input.to_sparse_csr())
+ output = op(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs)
assert torch.is_tensor(output)
self.assertEqual(output.to_dense(), expected)
@@ -1543,6 +1542,41 @@
self.assertEqual(coo_sparse.to_sparse_csr().to_sparse_coo(), coo_sparse)
+ @unittest.skipIf(TEST_WITH_ROCM, "The test doesn't support ROCM")
+ @dtypes(*floating_and_complex_types())
+ def test_linalg_solve_sparse_csr_cusolver(self, device, dtype):
+ from torch.testing._internal.common_methods_invocations import sample_inputs_linalg_solve
+
+ if (device == 'meta') or (device == 'cpu' and not torch.cuda.is_available()):
+ self.skipTest("Skipped!")
+
+ samples = sample_inputs_linalg_solve(None, device, dtype)
+
+ for sample in samples:
+ if sample.input.ndim != 2:
+ continue
+
+ out = torch.zeros(sample.args[0].size(), dtype=dtype, device=device)
+ if not torch.cuda.is_available():
+ with self.assertRaisesRegex(RuntimeError, "PyTorch was not built with CUDA support"):
+ torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs, out=out)
+ break
+
+ if sample.args[0].ndim != 1 and sample.args[0].size(-1) != 1:
+ with self.assertRaisesRegex(RuntimeError, "not implemented yet"):
+ torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs, out=out)
+ break
+ if not sample.args[0].numel():
+ with self.assertRaisesRegex(RuntimeError,
+ "Expected non-empty other tensor, but found empty tensor"):
+ torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs, out=out)
+ break
+
+ expect = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs)
+ sample.input = sample.input.to_sparse_csr()
+ torch.linalg.solve(sample.input, *sample.args, **sample.kwargs, out=out)
+ self.assertEqual(expect, out)
+
@skipMeta
@dtypes(*get_all_dtypes())
def test_transpose(self, device, dtype):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 7a6a082..1f0e2a2c 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -15623,6 +15623,7 @@
binary_ufuncs = [op for op in op_db if isinstance(op, BinaryUfuncInfo)]
spectral_funcs = [op for op in op_db if isinstance(op, SpectralFuncInfo)]
sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse]
+sparse_csr_funcs = [op for op in op_db if op.supports_sparse_csr]
sparse_csr_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse_csr]
shape_funcs = [op for op in op_db if isinstance(op, ShapeFuncInfo)]
reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo)]