[MPS] fix float32 error on mps, in linalg.matrix_rank and linalg.pinv (#114771)
Fixes #114285
(However, still have NotImplementedError
```NotImplementedError: The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.```)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114771
Approved by: https://github.com/lezcano
diff --git a/.gitignore b/.gitignore
index 20019ec..458d476 100644
--- a/.gitignore
+++ b/.gitignore
@@ -126,6 +126,7 @@
.circleci/scripts/COMMIT_MSG
scripts/release_notes/*.json
sccache-stats*.json
+lint.json
# These files get copied over on invoking setup.py
torchgen/packaged/*
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 444fac5..eb95456 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -444,7 +444,12 @@
const optional<Tensor>& atol_opt,
const optional<Tensor>& rtol_opt,
const c10::string_view function_name) {
- auto options = input.options().dtype(ScalarType::Double);
+ auto options = input.options();
+ if (input.device().type() == kMetal || input.device().type() == kMPS) {
+ options = options.dtype(ScalarType::Float);
+ } else {
+ options = options.dtype(ScalarType::Double);
+ }
auto atol = atol_opt.has_value() ? atol_opt.value() : at::zeros({}, options);
checkNotComplexTolerance(atol, function_name, "atol");
Tensor rtol;
@@ -465,7 +470,7 @@
const Tensor& input,
optional<double> atol_opt,
optional<double> rtol_opt) {
- double atol = atol_opt.has_value() ? atol_opt.value() : 0.0;
+ auto atol = atol_opt.has_value() ? atol_opt.value() : 0.0;
c10::SymFloat rtol;
if (rtol_opt.has_value()) {
rtol = rtol_opt.value();
@@ -476,7 +481,12 @@
? 0.0
: default_rtol;
}
- auto options = input.options().dtype(ScalarType::Double);
+ auto options = input.options();
+ if (input.device().type() == kMetal || input.device().type() == kMPS) {
+ options = options.dtype(ScalarType::Float);
+ } else {
+ options = options.dtype(ScalarType::Double);
+ }
auto atol_tensor = at::full({}, atol, options);
auto rtol_tensor = at::full({}, rtol, options);
return std::make_tuple(atol_tensor, rtol_tensor);
@@ -545,7 +555,12 @@
Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {
// For NumPy compatibility the rcond argument is used as relative tolerance
checkNotComplexTolerance(rcond, "torch.linalg.pinv", "rcond");
- auto options = input.options().dtype(ScalarType::Double);
+ auto options = input.options();
+ if (input.device().type() == kMetal || input.device().type() == kMPS) {
+ options = options.dtype(ScalarType::Float);
+ } else {
+ options = options.dtype(ScalarType::Double);
+ }
return at::linalg_pinv(input, at::zeros({}, options), rcond, hermitian);
}
diff --git a/test/test_mps.py b/test/test_mps.py
index 6a8807b..d27681b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -189,6 +189,12 @@
'msort': [torch.float16],
}
+ ON_MPS_XFAILLIST = {
+ # Failures due to lack of implementation of downstream functions on MPS backend
+ # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
+ 'linalg.matrix_rank': None,
+ }
+
def addDecorator(op, d) -> None:
op.decorators = list(op.decorators) if op.decorators is not None else []
op.decorators.append(d)
@@ -205,6 +211,11 @@
unittest.skip,
dtypes=SKIPLIST_GRAD[key]))
+ if key in ON_MPS_XFAILLIST:
+ addDecorator(op, DecorateInfo(
+ unittest.expectedFailure,
+ dtypes=ON_MPS_XFAILLIST[key]))
+
if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()):
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
@@ -722,7 +733,6 @@
'nn.functional.norm': None,
'ormqr': None,
'pca_lowrank': None,
- 'pinverse': None,
'qr': None,
'quantile': None,
'rsub': None,
@@ -792,9 +802,7 @@
'softmaxwith_dtype': None,
'float_power': None,
'full_like': None,
- 'linalg.matrix_rank': None,
'linalg.matrix_rankhermitian': None,
- 'linalg.pinv': None,
'linalg.pinvhermitian': None,
'nonzero_static': None,
@@ -918,6 +926,12 @@
'logit': [torch.float16],
}
+ ON_MPS_XFAILLIST = {
+ # Failures due to lack of implementation of downstream functions on MPS backend
+ # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
+ 'linalg.matrix_rank': None,
+ }
+
EMPTY_OPS_SKIPLIST = {
# Fill tensors with uninitialized data, causing mismatch with CPU.
# They occasionally match, thus skipping them.
@@ -954,7 +968,7 @@
dtypes=EMPTY_OPS_SKIPLIST[key]))
if key in SKIPLIST:
addDecorator(op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key]))
- for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST]:
+ for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST, ON_MPS_XFAILLIST]:
if key in xfaillist:
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
@@ -8729,6 +8743,129 @@
m2 = torch.randn(25, device=device).to(dtype)
self._test_addr(torch.addr, M, m1, m2, beta=0)
+ def test_matrix_rank(self, device="mps", dtype=torch.float32):
+ matrix_rank = torch.linalg.matrix_rank
+
+ def run_test(shape0, shape1, batch):
+ a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
+ rank_a = matrix_rank(a)
+
+ self.assertEqual(rank_a, matrix_rank(a.mH))
+ aaH = torch.matmul(a, a.mH)
+ rank_aaH = matrix_rank(aaH)
+ rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
+ self.assertEqual(rank_aaH, rank_aaH_hermitian)
+ aHa = torch.matmul(a.mH, a)
+ self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
+
+ # check against NumPy
+ self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
+ self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
+
+ self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
+ self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
+
+ # hermitian flag for NumPy was added in 1.14.0
+ if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
+ self.assertEqual(rank_aaH_hermitian,
+ np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
+ self.assertEqual(matrix_rank(aaH, 0.01, True),
+ np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
+
+ # check out= variant
+ out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
+ ans = matrix_rank(a, out=out)
+ self.assertEqual(ans, out)
+ self.assertEqual(ans, rank_a)
+
+ shapes = (3, 13)
+ batches = ((), (0, ), (4, ), (3, 5, ))
+ for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
+ # escape only when NotImplementedError of downstream function is raised
+ # TODO: remove this once the required function is implemented
+ try:
+ run_test(shape0, shape1, batch)
+ except NotImplementedError as e:
+ with self.assertRaisesRegex(
+ NotImplementedError,
+ "The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."):
+ raise e
+
+ def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4):
+ from torch.testing._internal.common_utils import random_hermitian_pd_matrix
+
+ def run_test_main(A, hermitian):
+ # Testing against definition for pseudo-inverses
+ A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
+ np_A = A.cpu().numpy()
+ np_A_pinv = A_pinv.cpu().numpy()
+ if A.numel() > 0:
+ self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=precision, rtol=precision)
+ self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=precision, rtol=precision)
+ self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
+ self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
+ else:
+ self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
+
+ # Check out= variant
+ out = torch.empty_like(A_pinv)
+ ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
+ self.assertEqual(ans, out)
+ self.assertEqual(ans, A_pinv)
+
+ def run_test_numpy(A, hermitian):
+ # Check against NumPy output
+ # Test float rcond, and specific value for each matrix
+ rconds = [float(torch.rand(1)), ]
+ # Test different types of rcond tensor
+ for rcond_type in MPS_DTYPES:
+ rconds.append(torch.rand(A.shape[:-2], dtype=torch.float32, device=device).to(rcond_type))
+ # Test broadcasting of rcond
+ if A.ndim > 2:
+ rconds.append(torch.rand(A.shape[-3], device=device))
+ for rcond in rconds:
+ actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
+ torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
+ self.assertEqual(actual, torch_rtol, atol=precision, rtol=precision)
+ numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
+ expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
+ self.assertEqual(actual, expected, atol=precision, rtol=precision)
+
+ for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
+ (3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices
+ (2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices
+ (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices
+ A = torch.randn(*sizes, dtype=dtype, device=device)
+ hermitian = False
+ run_test_main(A, hermitian)
+ run_test_numpy(A, hermitian)
+
+ # Check hermitian = True
+ for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
+ (0, 0), (3, 0, 0), ]: # zero numel square matrices
+ A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
+ hermitian = True
+ # escape only when NotImplementedError of downstream function is raised
+ # TODO: remove this once the required function is implemented
+ try:
+ run_test_main(A, hermitian)
+ except NotImplementedError as e:
+ with self.assertRaisesRegex(
+ NotImplementedError,
+ "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
+ raise e
+ try:
+ run_test_numpy(A, hermitian)
+ except NotImplementedError as e:
+ with self.assertRaisesRegex(
+ NotImplementedError,
+ "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
+ raise e
+
+
+
+
+
class TestGatherScatter(TestCaseMPS):
def test_slicing_with_step(self):
# Slicing with step