[MPS] fix float32 error on mps, in linalg.matrix_rank and linalg.pinv (#114771)

Fixes #114285

(However, still have NotImplementedError
```NotImplementedError: The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.```)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114771
Approved by: https://github.com/lezcano
diff --git a/.gitignore b/.gitignore
index 20019ec..458d476 100644
--- a/.gitignore
+++ b/.gitignore
@@ -126,6 +126,7 @@
 .circleci/scripts/COMMIT_MSG
 scripts/release_notes/*.json
 sccache-stats*.json
+lint.json
 
 # These files get copied over on invoking setup.py
 torchgen/packaged/*
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 444fac5..eb95456 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -444,7 +444,12 @@
     const optional<Tensor>& atol_opt,
     const optional<Tensor>& rtol_opt,
     const c10::string_view function_name) {
-  auto options = input.options().dtype(ScalarType::Double);
+  auto options = input.options();
+  if (input.device().type() == kMetal || input.device().type() == kMPS) {
+    options = options.dtype(ScalarType::Float);
+  } else {
+    options = options.dtype(ScalarType::Double);
+  }
   auto atol = atol_opt.has_value() ? atol_opt.value() : at::zeros({}, options);
   checkNotComplexTolerance(atol, function_name, "atol");
   Tensor rtol;
@@ -465,7 +470,7 @@
     const Tensor& input,
     optional<double> atol_opt,
     optional<double> rtol_opt) {
-  double atol = atol_opt.has_value() ? atol_opt.value() : 0.0;
+  auto atol = atol_opt.has_value() ? atol_opt.value() : 0.0;
   c10::SymFloat rtol;
   if (rtol_opt.has_value()) {
     rtol = rtol_opt.value();
@@ -476,7 +481,12 @@
            ? 0.0
            : default_rtol;
   }
-  auto options = input.options().dtype(ScalarType::Double);
+  auto options = input.options();
+  if (input.device().type() == kMetal || input.device().type() == kMPS) {
+    options = options.dtype(ScalarType::Float);
+  } else {
+    options = options.dtype(ScalarType::Double);
+  }
   auto atol_tensor = at::full({}, atol, options);
   auto rtol_tensor = at::full({}, rtol, options);
   return std::make_tuple(atol_tensor, rtol_tensor);
@@ -545,7 +555,12 @@
 Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {
   // For NumPy compatibility the rcond argument is used as relative tolerance
   checkNotComplexTolerance(rcond, "torch.linalg.pinv", "rcond");
-  auto options = input.options().dtype(ScalarType::Double);
+  auto options = input.options();
+  if (input.device().type() == kMetal || input.device().type() == kMPS) {
+    options = options.dtype(ScalarType::Float);
+  } else {
+    options = options.dtype(ScalarType::Double);
+  }
   return at::linalg_pinv(input, at::zeros({}, options), rcond, hermitian);
 }
 
diff --git a/test/test_mps.py b/test/test_mps.py
index 6a8807b..d27681b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -189,6 +189,12 @@
         'msort': [torch.float16],
     }
 
+    ON_MPS_XFAILLIST = {
+        # Failures due to lack of implementation of downstream functions on MPS backend
+        # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
+        'linalg.matrix_rank': None,
+    }
+
     def addDecorator(op, d) -> None:
         op.decorators = list(op.decorators) if op.decorators is not None else []
         op.decorators.append(d)
@@ -205,6 +211,11 @@
                          unittest.skip,
                          dtypes=SKIPLIST_GRAD[key]))
 
+        if key in ON_MPS_XFAILLIST:
+            addDecorator(op, DecorateInfo(
+                         unittest.expectedFailure,
+                         dtypes=ON_MPS_XFAILLIST[key]))
+
         if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()):
             addDecorator(op, DecorateInfo(
                          unittest.expectedFailure,
@@ -722,7 +733,6 @@
         'nn.functional.norm': None,
         'ormqr': None,
         'pca_lowrank': None,
-        'pinverse': None,
         'qr': None,
         'quantile': None,
         'rsub': None,
@@ -792,9 +802,7 @@
         'softmaxwith_dtype': None,
         'float_power': None,
         'full_like': None,
-        'linalg.matrix_rank': None,
         'linalg.matrix_rankhermitian': None,
-        'linalg.pinv': None,
         'linalg.pinvhermitian': None,
         'nonzero_static': None,
 
@@ -918,6 +926,12 @@
         'logit': [torch.float16],
     }
 
+    ON_MPS_XFAILLIST = {
+        # Failures due to lack of implementation of downstream functions on MPS backend
+        # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
+        'linalg.matrix_rank': None,
+    }
+
     EMPTY_OPS_SKIPLIST = {
         # Fill tensors with uninitialized data, causing mismatch with CPU.
         # They occasionally match, thus skipping them.
@@ -954,7 +968,7 @@
                          dtypes=EMPTY_OPS_SKIPLIST[key]))
         if key in SKIPLIST:
             addDecorator(op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key]))
-        for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST]:
+        for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST, ON_MPS_XFAILLIST]:
             if key in xfaillist:
                 addDecorator(op, DecorateInfo(
                              unittest.expectedFailure,
@@ -8729,6 +8743,129 @@
         m2 = torch.randn(25, device=device).to(dtype)
         self._test_addr(torch.addr, M, m1, m2, beta=0)
 
+    def test_matrix_rank(self, device="mps", dtype=torch.float32):
+        matrix_rank = torch.linalg.matrix_rank
+
+        def run_test(shape0, shape1, batch):
+            a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
+            rank_a = matrix_rank(a)
+
+            self.assertEqual(rank_a, matrix_rank(a.mH))
+            aaH = torch.matmul(a, a.mH)
+            rank_aaH = matrix_rank(aaH)
+            rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
+            self.assertEqual(rank_aaH, rank_aaH_hermitian)
+            aHa = torch.matmul(a.mH, a)
+            self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
+
+            # check against NumPy
+            self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
+            self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
+
+            self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
+            self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
+
+            # hermitian flag for NumPy was added in 1.14.0
+            if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
+                self.assertEqual(rank_aaH_hermitian,
+                                 np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
+                self.assertEqual(matrix_rank(aaH, 0.01, True),
+                                 np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
+
+            # check out= variant
+            out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
+            ans = matrix_rank(a, out=out)
+            self.assertEqual(ans, out)
+            self.assertEqual(ans, rank_a)
+
+        shapes = (3, 13)
+        batches = ((), (0, ), (4, ), (3, 5, ))
+        for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
+            # escape only when NotImplementedError of downstream function is raised
+            # TODO: remove this once the required function is implemented
+            try:
+                run_test(shape0, shape1, batch)
+            except NotImplementedError as e:
+                with self.assertRaisesRegex(
+                        NotImplementedError,
+                        "The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."):
+                    raise e
+
+    def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4):
+        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
+
+        def run_test_main(A, hermitian):
+            # Testing against definition for pseudo-inverses
+            A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
+            np_A = A.cpu().numpy()
+            np_A_pinv = A_pinv.cpu().numpy()
+            if A.numel() > 0:
+                self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=precision, rtol=precision)
+                self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=precision, rtol=precision)
+                self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
+                self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
+            else:
+                self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
+
+            # Check out= variant
+            out = torch.empty_like(A_pinv)
+            ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
+            self.assertEqual(ans, out)
+            self.assertEqual(ans, A_pinv)
+
+        def run_test_numpy(A, hermitian):
+            # Check against NumPy output
+            # Test float rcond, and specific value for each matrix
+            rconds = [float(torch.rand(1)), ]
+            # Test different types of rcond tensor
+            for rcond_type in MPS_DTYPES:
+                rconds.append(torch.rand(A.shape[:-2], dtype=torch.float32, device=device).to(rcond_type))
+            # Test broadcasting of rcond
+            if A.ndim > 2:
+                rconds.append(torch.rand(A.shape[-3], device=device))
+            for rcond in rconds:
+                actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
+                torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
+                self.assertEqual(actual, torch_rtol, atol=precision, rtol=precision)
+                numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
+                expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
+                self.assertEqual(actual, expected, atol=precision, rtol=precision)
+
+        for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),  # square matrices
+                      (3, 2), (5, 3, 2), (2, 5, 3, 2),  # fat matrices
+                      (2, 3), (5, 2, 3), (2, 5, 2, 3),  # thin matrices
+                      (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]:  # zero numel matrices
+            A = torch.randn(*sizes, dtype=dtype, device=device)
+            hermitian = False
+            run_test_main(A, hermitian)
+            run_test_numpy(A, hermitian)
+
+        # Check hermitian = True
+        for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),  # square matrices
+                      (0, 0), (3, 0, 0), ]:  # zero numel square matrices
+            A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
+            hermitian = True
+            # escape only when NotImplementedError of downstream function is raised
+            # TODO: remove this once the required function is implemented
+            try:
+                run_test_main(A, hermitian)
+            except NotImplementedError as e:
+                with self.assertRaisesRegex(
+                        NotImplementedError,
+                        "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
+                    raise e
+            try:
+                run_test_numpy(A, hermitian)
+            except NotImplementedError as e:
+                with self.assertRaisesRegex(
+                        NotImplementedError,
+                        "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
+                    raise e
+
+
+
+
+
 class TestGatherScatter(TestCaseMPS):
     def test_slicing_with_step(self):
         # Slicing with step