[BE] _get_torch_cuda_version should return tuple (#52409)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52409
Reviewed By: jbschlosser, glaringlee
Differential Revision: D26513924
Pulled By: walterddr
fbshipit-source-id: ee18ef357c326c5ad344d80c59821cc2b8814734
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 0577732..8b416c1 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -979,7 +979,7 @@
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
)
@unittest.skipIf(
- TEST_CUDA and _get_torch_cuda_version() < [10, 1],
+ TEST_CUDA and _get_torch_cuda_version() < (10, 1),
"bmm sparse-dense requires CUDA 10.1 or greater"
)
def test_bmm(self):
@@ -1043,7 +1043,7 @@
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
)
@unittest.skipIf(
- _get_torch_cuda_version() < [10, 1],
+ _get_torch_cuda_version() < (10, 1),
"bmm sparse-dense requires CUDA 10.1 or greater"
)
def test_bmm_deterministic(self):
@@ -1078,7 +1078,7 @@
@cuda_only
@unittest.skipIf(
- not IS_WINDOWS or _get_torch_cuda_version() >= [11, 0],
+ not IS_WINDOWS or _get_torch_cuda_version() >= (11, 0),
"this test ensures bmm sparse-dense CUDA gives an error when run on Windows with CUDA < 11.0"
)
def test_bmm_windows_error(self):
@@ -1092,7 +1092,7 @@
@cuda_only
@skipIfRocm
@unittest.skipIf(
- _get_torch_cuda_version() >= [10, 1],
+ _get_torch_cuda_version() >= (10, 1),
"this test ensures bmm gives error if CUDA version is less than 10.1"
)
def test_bmm_cuda_version_error(self):
diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py
index af6b27f..d3af775 100644
--- a/torch/testing/_internal/common_cuda.py
+++ b/torch/testing/_internal/common_cuda.py
@@ -149,6 +149,6 @@
def _get_torch_cuda_version():
if torch.version.cuda is None:
- return [0, 0]
+ return (0, 0)
cuda_version = str(torch.version.cuda)
- return [int(x) for x in cuda_version.split(".")]
+ return tuple(int(x) for x in cuda_version.split("."))
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
index f52f62d..25c1f69 100644
--- a/torch/testing/_internal/common_device_type.py
+++ b/torch/testing/_internal/common_device_type.py
@@ -863,7 +863,7 @@
def skipCUDAIfNoMagmaAndNoCusolver(fn):
version = _get_torch_cuda_version()
- if version >= [10, 2]:
+ if version >= (10, 2):
return fn
else:
# cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA