MAINT/TST: pytorch-ify test_linalg, vendored from NumPy (#109775)

1. Inherit from TestCase
2. Use pytorch parametrization
2. Use unittest.expectedFailure to mark xfails, also unittest skips

All this to make pytest-less invocation work:

$ python test/torch_np/test_basic.py

cross-ref https://github.com/pytorch/pytorch/pull/109593, https://github.com/pytorch/pytorch/pull/109718

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109775
Approved by: https://github.com/ezyang
diff --git a/test/torch_np/numpy_tests/linalg/test_linalg.py b/test/torch_np/numpy_tests/linalg/test_linalg.py
index 03401f5..aa7ad55 100644
--- a/test/torch_np/numpy_tests/linalg/test_linalg.py
+++ b/test/torch_np/numpy_tests/linalg/test_linalg.py
@@ -3,6 +3,7 @@
 """ Test functions for linalg module
 
 """
+import functools
 import itertools
 import os
 import subprocess
@@ -10,6 +11,8 @@
 import textwrap
 import traceback
 
+from unittest import expectedFailure as xfail, skipIf as skipif, SkipTest
+
 import pytest
 
 import torch._numpy as np
@@ -40,6 +43,17 @@
     suppress_warnings,
     #  assert_raises_regex, HAS_LAPACK64, IS_WASM
 )
+from torch.testing._internal.common_utils import (
+    instantiate_parametrized_tests,
+    parametrize,
+    run_tests,
+    TestCase,
+)
+
+skip = functools.partial(skipif, True)
+
+# FIXME: slow tests have never run (= are broken)
+slow = skip
 
 IS_WASM = False
 HAS_LAPACK64 = False
@@ -374,32 +388,32 @@
 
 
 class LinalgGeneralizedSquareTestCase(LinalgTestCase):
-    @pytest.mark.slow
+    @slow
     def test_generalized_sq_cases(self):
         self.check_cases(require={"generalized", "square"}, exclude={"size-0"})
 
-    @pytest.mark.slow
+    @slow
     def test_generalized_empty_sq_cases(self):
         self.check_cases(require={"generalized", "square", "size-0"})
 
 
 class LinalgGeneralizedNonsquareTestCase(LinalgTestCase):
-    @pytest.mark.slow
+    @slow
     def test_generalized_nonsq_cases(self):
         self.check_cases(require={"generalized", "nonsquare"}, exclude={"size-0"})
 
-    @pytest.mark.slow
+    @slow
     def test_generalized_empty_nonsq_cases(self):
         self.check_cases(require={"generalized", "nonsquare", "size-0"})
 
 
 class HermitianGeneralizedTestCase(LinalgTestCase):
-    @pytest.mark.xfail(reason="sort complex")
-    @pytest.mark.slow
+    @xfail  # (reason="sort complex")
+    @slow
     def test_generalized_herm_cases(self):
         self.check_cases(require={"generalized", "hermitian"}, exclude={"size-0"})
 
-    @pytest.mark.slow
+    @slow
     def test_generalized_empty_herm_cases(self):
         self.check_cases(
             require={"generalized", "hermitian", "size-0"}, exclude={"none"}
@@ -443,13 +457,14 @@
         assert_(consistent_subclass(x, b))
 
 
-class TestSolve(SolveCases):
-    @pytest.mark.parametrize("dtype", [single, double, csingle, cdouble])
+@instantiate_parametrized_tests
+class TestSolve(SolveCases, TestCase):
+    @parametrize("dtype", [single, double, csingle, cdouble])
     def test_types(self, dtype):
         x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
         assert_equal(linalg.solve(x, x).dtype, dtype)
 
-    @pytest.mark.skip(reason="subclass")
+    @skip(reason="subclass")
     def test_0_size(self):
         class ArraySubclass(np.ndarray):
             pass
@@ -484,7 +499,7 @@
         assert_raises(ValueError, linalg.solve, a[0:0], b[0:0])
         assert_raises(ValueError, linalg.solve, a[:, 0:0, 0:0], b)
 
-    @pytest.mark.skip(reason="subclass")
+    @skip(reason="subclass")
     def test_0_size_k(self):
         # test zero multiple equation (K=0) case.
         class ArraySubclass(np.ndarray):
@@ -512,13 +527,14 @@
         assert_(consistent_subclass(a_inv, a))
 
 
-class TestInv(InvCases):
-    @pytest.mark.parametrize("dtype", [single, double, csingle, cdouble])
+@instantiate_parametrized_tests
+class TestInv(InvCases, TestCase):
+    @parametrize("dtype", [single, double, csingle, cdouble])
     def test_types(self, dtype):
         x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
         assert_equal(linalg.inv(x).dtype, dtype)
 
-    @pytest.mark.skip(reason="subclass")
+    @skip(reason="subclass")
     def test_0_size(self):
         # Check that all kinds of 0-sized arrays work
         class ArraySubclass(np.ndarray):
@@ -544,15 +560,16 @@
         assert_almost_equal(ev, evalues)
 
 
-class TestEigvals(EigvalsCases):
-    @pytest.mark.parametrize("dtype", [single, double, csingle, cdouble])
+@instantiate_parametrized_tests
+class TestEigvals(EigvalsCases, TestCase):
+    @parametrize("dtype", [single, double, csingle, cdouble])
     def test_types(self, dtype):
         x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
         assert_equal(linalg.eigvals(x).dtype, dtype)
         x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
         assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype))
 
-    @pytest.mark.skip(reason="subclass")
+    @skip(reason="subclass")
     def test_0_size(self):
         # Check that all kinds of 0-sized arrays work
         class ArraySubclass(np.ndarray):
@@ -584,8 +601,9 @@
         assert_(consistent_subclass(evectors, a))
 
 
-class TestEig(EigCases):
-    @pytest.mark.parametrize("dtype", [single, double, csingle, cdouble])
+@instantiate_parametrized_tests
+class TestEig(EigCases, TestCase):
+    @parametrize("dtype", [single, double, csingle, cdouble])
     def test_types(self, dtype):
         x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
         w, v = np.linalg.eig(x)
@@ -597,7 +615,7 @@
         assert_equal(w.dtype, get_complex_dtype(dtype))
         assert_equal(v.dtype, get_complex_dtype(dtype))
 
-    @pytest.mark.skip(reason="subclass")
+    @skip(reason="subclass")
     def test_0_size(self):
         # Check that all kinds of 0-sized arrays work
         class ArraySubclass(np.ndarray):
@@ -622,10 +640,11 @@
         assert_(isinstance(a, np.ndarray))
 
 
+@instantiate_parametrized_tests
 class SVDBaseTests:
     hermitian = False
 
-    @pytest.mark.parametrize("dtype", [single, double, csingle, cdouble])
+    @parametrize("dtype", [single, double, csingle, cdouble])
     def test_types(self, dtype):
         x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
         u, s, vh = linalg.svd(x)
@@ -650,7 +669,7 @@
         assert_(consistent_subclass(vt, a))
 
 
-class TestSVD(SVDCases, SVDBaseTests):
+class TestSVD(SVDCases, SVDBaseTests, TestCase):
     def test_empty_identity(self):
         """Empty input should put an identity matrix in u or vh"""
         x = np.empty((4, 0))
@@ -694,7 +713,7 @@
         assert_(consistent_subclass(vt, a))
 
 
-class TestSVDHermitian(SVDHermitianCases, SVDBaseTests):
+class TestSVDHermitian(SVDHermitianCases, SVDBaseTests, TestCase):
     hermitian = True
 
 
@@ -759,7 +778,7 @@
         )
 
 
-class TestCond(CondCases):
+class TestCond(CondCases, TestCase):
     def test_basic_nonsvd(self):
         # Smoketest the non-svd norms
         A = array([[1.0, 0, 1], [0, -2.0, 0], [0, 0, 3.0]])
@@ -783,9 +802,9 @@
         for A, p in itertools.product(As, p_neg):
             linalg.cond(A, p)
 
-    @pytest.mark.xfail(
-        True, run=False, reason="Platform/LAPACK-dependent failure, see gh-18914"
-    )
+    @xfail  # (
+    #    True, run=False, reason="Platform/LAPACK-dependent failure, see gh-18914"
+    # )
     def test_nan(self):
         # nans should be passed through, not converted to infs
         ps = [None, 1, -1, 2, -2, "fro"]
@@ -842,7 +861,7 @@
         assert_(consistent_subclass(a_ginv, a))
 
 
-class TestPinv(PinvCases):
+class TestPinv(PinvCases, TestCase):
     pass
 
 
@@ -857,7 +876,7 @@
         assert_(consistent_subclass(a_ginv, a))
 
 
-class TestPinvHermitian(PinvHermitianCases):
+class TestPinvHermitian(PinvHermitianCases, TestCase):
     pass
 
 
@@ -880,7 +899,8 @@
         assert_equal(ld[~m], -inf)
 
 
-class TestDet(DetCases):
+@instantiate_parametrized_tests
+class TestDet(DetCases, TestCase):
     def test_zero(self):
         # NB: comment out tests of type(det) == double : we return zero-dim arrays
         assert_equal(linalg.det([[0.0]]), 0.0)
@@ -896,7 +916,7 @@
     #    assert_equal(type(linalg.slogdet([[0.0j]])[0]), cdouble)
     #    assert_equal(type(linalg.slogdet([[0.0j]])[1]), double)
 
-    @pytest.mark.parametrize("dtype", [single, double, csingle, cdouble])
+    @parametrize("dtype", [single, double, csingle, cdouble])
     def test_types(self, dtype):
         x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
         assert_equal(np.linalg.det(x).dtype, dtype)
@@ -954,8 +974,9 @@
         assert_(consistent_subclass(residuals, b))
 
 
-class TestLstsq(LstsqCases):
-    @pytest.mark.xfail(reason="Lstsq: we use the future default =None")
+@instantiate_parametrized_tests
+class TestLstsq(LstsqCases, TestCase):
+    @xfail  # (reason="Lstsq: we use the future default =None")
     def test_future_rcond(self):
         a = np.array(
             [
@@ -978,8 +999,8 @@
             # Warning should be raised exactly once (first command)
             assert_(len(w) == 1)
 
-    @pytest.mark.parametrize(
-        ["m", "n", "n_rhs"],
+    @parametrize(
+        "m, n, n_rhs",
         [
             (4, 2, 2),
             (0, 4, 1),
@@ -1015,20 +1036,22 @@
             linalg.lstsq(A, y, rcond=None)
 
 
-@pytest.mark.parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
-@pytest.mark.xfail(reason="no block()")
-class TestMatrixPower:
-    def setup_method(self):
+# @xfail  #(reason="no block()")
+@skip  # FIXME: otherwise fails in setUp calling np.block
+@instantiate_parametrized_tests
+class TestMatrixPower(TestCase):
+    def setUp(self):
         self.rshft_0 = np.eye(4)
-        self.rshft_1 = rshft_0[[3, 0, 1, 2]]
-        self.rshft_2 = rshft_0[[2, 3, 0, 1]]
-        self.rshft_3 = rshft_0[[1, 2, 3, 0]]
-        self.rshft_all = [rshft_0, rshft_1, rshft_2, rshft_3]
+        self.rshft_1 = self.rshft_0[[3, 0, 1, 2]]
+        self.rshft_2 = self.rshft_0[[2, 3, 0, 1]]
+        self.rshft_3 = self.rshft_0[[1, 2, 3, 0]]
+        self.rshft_all = [self.rshft_0, self.rshft_1, self.rshft_2, self.rshft_3]
         self.noninv = array([[1, 0], [0, 0]])
-        self.stacked = np.block([[[rshft_0]]] * 2)
+        self.stacked = np.block([[[self.rshft_0]]] * 2)
         # FIXME the 'e' dtype might work in future
         self.dtnoinv = [object, np.dtype("e"), np.dtype("g"), np.dtype("G")]
 
+    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
     def test_large_power(self, dt):
         rshft = self.rshft_1.astype(dt)
         assert_equal(matrix_power(rshft, 2**100 + 2**10 + 2**5 + 0), self.rshft_0)
@@ -1036,6 +1059,7 @@
         assert_equal(matrix_power(rshft, 2**100 + 2**10 + 2**5 + 2), self.rshft_2)
         assert_equal(matrix_power(rshft, 2**100 + 2**10 + 2**5 + 3), self.rshft_3)
 
+    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
     def test_power_is_zero(self, dt):
         def tz(M):
             mz = matrix_power(M, 0)
@@ -1047,6 +1071,7 @@
             if dt != object:
                 tz(self.stacked.astype(dt))
 
+    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
     def test_power_is_one(self, dt):
         def tz(mat):
             mz = matrix_power(mat, 1)
@@ -1058,6 +1083,7 @@
             if dt != object:
                 tz(self.stacked.astype(dt))
 
+    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
     def test_power_is_two(self, dt):
         def tz(mat):
             mz = matrix_power(mat, 2)
@@ -1070,6 +1096,7 @@
             if dt != object:
                 tz(self.stacked.astype(dt))
 
+    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
     def test_power_is_minus_one(self, dt):
         def tz(mat):
             invmat = matrix_power(mat, -1)
@@ -1080,17 +1107,20 @@
             if dt not in self.dtnoinv:
                 tz(mat.astype(dt))
 
+    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
     def test_exceptions_bad_power(self, dt):
         mat = self.rshft_0.astype(dt)
         assert_raises(TypeError, matrix_power, mat, 1.5)
         assert_raises(TypeError, matrix_power, mat, [1])
 
+    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
     def test_exceptions_non_square(self, dt):
         assert_raises(LinAlgError, matrix_power, np.array([1], dt), 1)
         assert_raises(LinAlgError, matrix_power, np.array([[1], [2]], dt), 1)
         assert_raises(LinAlgError, matrix_power, np.ones((4, 3, 2), dt), 1)
 
-    @pytest.mark.skipif(IS_WASM, reason="fp errors don't work in wasm")
+    @skipif(IS_WASM, reason="fp errors don't work in wasm")
+    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
     def test_exceptions_not_invertible(self, dt):
         if dt in self.dtnoinv:
             return
@@ -1112,8 +1142,9 @@
         assert_allclose(ev2, evalues, rtol=get_rtol(ev.dtype))
 
 
-class TestEigvalsh:
-    @pytest.mark.parametrize("dtype", [single, double, csingle, cdouble])
+@instantiate_parametrized_tests
+class TestEigvalsh(TestCase):
+    @parametrize("dtype", [single, double, csingle, cdouble])
     def test_types(self, dtype):
         x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
         w = np.linalg.eigvalsh(x)
@@ -1193,8 +1224,9 @@
         )
 
 
-class TestEigh:
-    @pytest.mark.parametrize("dtype", [single, double, csingle, cdouble])
+@instantiate_parametrized_tests
+class TestEigh(TestCase):
+    @parametrize("dtype", [single, double, csingle, cdouble])
     def test_types(self, dtype):
         x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
         w, v = np.linalg.eigh(x)
@@ -1283,7 +1315,9 @@
             at = a.astype(each_type)
 
             if each_type == np.dtype("float16"):
-                pytest.xfail("float16**float64 => float64 (?)")
+                # FIXME: move looping to parametrize, add decorators=[xfail]
+                # pytest.xfail("float16**float64 => float64 (?)")
+                raise SkipTest("float16**float64 => float64 (?)")
 
             an = norm(at, -np.inf)
             self.check_dtype(at, an)
@@ -1563,7 +1597,7 @@
     pass
 
 
-class TestNorm_NonSystematic:
+class TestNorm_NonSystematic(TestCase):
     def test_intmin(self):
         # Non-regression test: p-norm of signed integer would previously do
         # float cast and abs in the wrong order.
@@ -1572,26 +1606,26 @@
 
 
 # Separate definitions so we can use them for matrix tests.
-class _TestNormDoubleBase(_TestNormBase):
+class _TestNormDoubleBase(_TestNormBase, TestCase):
     dt = np.double
     dec = 12
 
 
-class _TestNormSingleBase(_TestNormBase):
+class _TestNormSingleBase(_TestNormBase, TestCase):
     dt = np.float32
     dec = 6
 
 
-class _TestNormInt64Base(_TestNormBase):
+class _TestNormInt64Base(_TestNormBase, TestCase):
     dt = np.int64
     dec = 12
 
 
-class TestNormDouble(_TestNorm, _TestNormDoubleBase):
+class TestNormDouble(_TestNorm, _TestNormDoubleBase, TestCase):
     pass
 
 
-class TestNormSingle(_TestNorm, _TestNormSingleBase):
+class TestNormSingle(_TestNorm, _TestNormSingleBase, TestCase):
     pass
 
 
@@ -1599,7 +1633,7 @@
     pass
 
 
-class TestMatrixRank:
+class TestMatrixRank(TestCase):
     def test_matrix_rank(self):
         # Full rank matrix
         assert_equal(4, matrix_rank(np.eye(4)))
@@ -1633,22 +1667,22 @@
         assert_equal(4, matrix_rank(I, hermitian=True, tol=0.99e-8))
         assert_equal(3, matrix_rank(I, hermitian=True, tol=1.01e-8))
 
-
-def test_reduced_rank():
-    # Test matrices with reduced rank
-    #  rng = np.random.RandomState(20120714)
-    np.random.seed(20120714)
-    for i in range(100):
-        # Make a rank deficient matrix
-        X = np.random.normal(size=(40, 10))
-        X[:, 0] = X[:, 1] + X[:, 2]
-        # Assert that matrix_rank detected deficiency
-        assert_equal(matrix_rank(X), 9)
-        X[:, 3] = X[:, 4] + X[:, 5]
-        assert_equal(matrix_rank(X), 8)
+    def test_reduced_rank(self):
+        # Test matrices with reduced rank
+        #  rng = np.random.RandomState(20120714)
+        np.random.seed(20120714)
+        for i in range(100):
+            # Make a rank deficient matrix
+            X = np.random.normal(size=(40, 10))
+            X[:, 0] = X[:, 1] + X[:, 2]
+            # Assert that matrix_rank detected deficiency
+            assert_equal(matrix_rank(X), 9)
+            X[:, 3] = X[:, 4] + X[:, 5]
+            assert_equal(matrix_rank(X), 8)
 
 
-class TestQR:
+@instantiate_parametrized_tests
+class TestQR(TestCase):
     def check_qr(self, a):
         # This test expects the argument `a` to be an ndarray or
         # a subclass of an ndarray of inexact type.
@@ -1687,8 +1721,8 @@
         assert_(isinstance(r2, a_type))
         assert_almost_equal(r2, r1)
 
-    @pytest.mark.xfail(reason="torch does not allow qr(..., mode='raw'")
-    @pytest.mark.parametrize(["m", "n"], [(3, 0), (0, 3), (0, 0)])
+    @xfail  # (reason="torch does not allow qr(..., mode='raw'")
+    @parametrize("m, n", [(3, 0), (0, 3), (0, 0)])
     def test_qr_empty(self, m, n):
         k = min(m, n)
         a = np.empty((m, n))
@@ -1701,7 +1735,7 @@
         assert_equal(h.shape, (n, m))
         assert_equal(tau.shape, (k,))
 
-    @pytest.mark.xfail(reason="torch does not allow qr(..., mode='raw'")
+    @xfail  # (reason="torch does not allow qr(..., mode='raw'")
     def test_mode_raw(self):
         # The factorization is not unique and varies between libraries,
         # so it is not possible to check against known values. Functional
@@ -1783,9 +1817,9 @@
         assert_(isinstance(r2, a_type))
         assert_almost_equal(r2, r1)
 
-    @pytest.mark.parametrize("size", [(3, 4), (4, 3), (4, 4), (3, 0), (0, 3)])
-    @pytest.mark.parametrize("outer_size", [(2, 2), (2,), (2, 3, 4)])
-    @pytest.mark.parametrize("dt", [np.single, np.double, np.csingle, np.cdouble])
+    @parametrize("size", [(3, 4), (4, 3), (4, 4), (3, 0), (0, 3)])
+    @parametrize("outer_size", [(2, 2), (2,), (2, 3, 4)])
+    @parametrize("dt", [np.single, np.double, np.csingle, np.cdouble])
     def test_stacked_inputs(self, outer_size, size, dt):
         A = np.random.normal(size=outer_size + size).astype(dt)
         B = np.random.normal(size=outer_size + size).astype(dt)
@@ -1793,13 +1827,12 @@
         self.check_qr_stacked(A + 1.0j * B)
 
 
-class TestCholesky:
+@instantiate_parametrized_tests
+class TestCholesky(TestCase):
     # TODO: are there no other tests for cholesky?
 
-    @pytest.mark.parametrize("shape", [(1, 1), (2, 2), (3, 3), (50, 50), (3, 10, 10)])
-    @pytest.mark.parametrize(
-        "dtype", (np.float32, np.float64, np.complex64, np.complex128)
-    )
+    @parametrize("shape", [(1, 1), (2, 2), (3, 3), (50, 50), (3, 10, 10)])
+    @parametrize("dtype", (np.float32, np.float64, np.complex64, np.complex128))
     def test_basic_property(self, shape, dtype):
         # Check A = L L^H
         np.random.seed(1)
@@ -1836,127 +1869,131 @@
         assert_(isinstance(res, np.ndarray))
 
 
-@pytest.mark.xfail(reason="endianness")
-def test_byteorder_check():
-    # Byte order check should pass for native order
-    if sys.byteorder == "little":
-        native = "<"
-    else:
-        native = ">"
+class TestMisc(TestCase):
+    @xfail  # (reason="endianness")
+    def test_byteorder_check(self):
+        # Byte order check should pass for native order
+        if sys.byteorder == "little":
+            native = "<"
+        else:
+            native = ">"
 
-    for dtt in (np.float32, np.float64):
-        arr = np.eye(4, dtype=dtt)
-        n_arr = arr.newbyteorder(native)
-        sw_arr = arr.newbyteorder("S").byteswap()
-        assert_equal(arr.dtype.byteorder, "=")
-        for routine in (linalg.inv, linalg.det, linalg.pinv):
-            # Normal call
-            res = routine(arr)
-            # Native but not '='
-            assert_array_equal(res, routine(n_arr))
-            # Swapped
-            assert_array_equal(res, routine(sw_arr))
+        for dtt in (np.float32, np.float64):
+            arr = np.eye(4, dtype=dtt)
+            n_arr = arr.newbyteorder(native)
+            sw_arr = arr.newbyteorder("S").byteswap()
+            assert_equal(arr.dtype.byteorder, "=")
+            for routine in (linalg.inv, linalg.det, linalg.pinv):
+                # Normal call
+                res = routine(arr)
+                # Native but not '='
+                assert_array_equal(res, routine(n_arr))
+                # Swapped
+                assert_array_equal(res, routine(sw_arr))
 
+    @pytest.mark.skipif(IS_WASM, reason="fp errors don't work in wasm")
+    def test_generalized_raise_multiloop(self):
+        # It should raise an error even if the error doesn't occur in the
+        # last iteration of the ufunc inner loop
 
-@pytest.mark.skipif(IS_WASM, reason="fp errors don't work in wasm")
-def test_generalized_raise_multiloop():
-    # It should raise an error even if the error doesn't occur in the
-    # last iteration of the ufunc inner loop
+        invertible = np.array([[1, 2], [3, 4]])
+        non_invertible = np.array([[1, 1], [1, 1]])
 
-    invertible = np.array([[1, 2], [3, 4]])
-    non_invertible = np.array([[1, 1], [1, 1]])
+        x = np.zeros([4, 4, 2, 2])[1::2]
+        x[...] = invertible
+        x[0, 0] = non_invertible
 
-    x = np.zeros([4, 4, 2, 2])[1::2]
-    x[...] = invertible
-    x[0, 0] = non_invertible
+        assert_raises(np.linalg.LinAlgError, np.linalg.inv, x)
 
-    assert_raises(np.linalg.LinAlgError, np.linalg.inv, x)
+    def test_xerbla_override(self):
+        # Check that our xerbla has been successfully linked in. If it is not,
+        # the default xerbla routine is called, which prints a message to stdout
+        # and may, or may not, abort the process depending on the LAPACK package.
 
+        XERBLA_OK = 255
 
-def test_xerbla_override():
-    # Check that our xerbla has been successfully linked in. If it is not,
-    # the default xerbla routine is called, which prints a message to stdout
-    # and may, or may not, abort the process depending on the LAPACK package.
-
-    XERBLA_OK = 255
-
-    try:
-        pid = os.fork()
-    except (OSError, AttributeError):
-        # fork failed, or not running on POSIX
-        pytest.skip("Not POSIX or fork failed.")
-
-    if pid == 0:
-        # child; close i/o file handles
-        os.close(1)
-        os.close(0)
-        # Avoid producing core files.
-        import resource
-
-        resource.setrlimit(resource.RLIMIT_CORE, (0, 0))
-        # These calls may abort.
         try:
-            np.linalg.lapack_lite.xerbla()
-        except ValueError:
-            pass
-        except Exception:
+            pid = os.fork()
+        except (OSError, AttributeError):
+            # fork failed, or not running on POSIX
+            raise SkipTest("Not POSIX or fork failed.")
+
+        if pid == 0:
+            # child; close i/o file handles
+            os.close(1)
+            os.close(0)
+            # Avoid producing core files.
+            import resource
+
+            resource.setrlimit(resource.RLIMIT_CORE, (0, 0))
+            # These calls may abort.
+            try:
+                np.linalg.lapack_lite.xerbla()
+            except ValueError:
+                pass
+            except Exception:
+                os._exit(os.EX_CONFIG)
+
+            try:
+                a = np.array([[1.0]])
+                np.linalg.lapack_lite.dorgqr(
+                    1, 1, 1, a, 0, a, a, 0, 0
+                )  # <- invalid value
+            except ValueError as e:
+                if "DORGQR parameter number 5" in str(e):
+                    # success, reuse error code to mark success as
+                    # FORTRAN STOP returns as success.
+                    os._exit(XERBLA_OK)
+
+            # Did not abort, but our xerbla was not linked in.
             os._exit(os.EX_CONFIG)
+        else:
+            # parent
+            pid, status = os.wait()
+            if os.WEXITSTATUS(status) != XERBLA_OK:
+                raise SkipTest("Numpy xerbla not linked in.")
 
+    @pytest.mark.skipif(IS_WASM, reason="Cannot start subprocess")
+    @slow
+    def test_sdot_bug_8577(self):
+        # Regression test that loading certain other libraries does not
+        # result to wrong results in float32 linear algebra.
+        #
+        # There's a bug gh-8577 on OSX that can trigger this, and perhaps
+        # there are also other situations in which it occurs.
+        #
+        # Do the check in a separate process.
+
+        bad_libs = ["PyQt5.QtWidgets", "IPython"]
+
+        template = textwrap.dedent(
+            """
+        import sys
+        {before}
         try:
-            a = np.array([[1.0]])
-            np.linalg.lapack_lite.dorgqr(1, 1, 1, a, 0, a, a, 0, 0)  # <- invalid value
-        except ValueError as e:
-            if "DORGQR parameter number 5" in str(e):
-                # success, reuse error code to mark success as
-                # FORTRAN STOP returns as success.
-                os._exit(XERBLA_OK)
-
-        # Did not abort, but our xerbla was not linked in.
-        os._exit(os.EX_CONFIG)
-    else:
-        # parent
-        pid, status = os.wait()
-        if os.WEXITSTATUS(status) != XERBLA_OK:
-            pytest.skip("Numpy xerbla not linked in.")
-
-
-@pytest.mark.skipif(IS_WASM, reason="Cannot start subprocess")
-@pytest.mark.slow
-def test_sdot_bug_8577():
-    # Regression test that loading certain other libraries does not
-    # result to wrong results in float32 linear algebra.
-    #
-    # There's a bug gh-8577 on OSX that can trigger this, and perhaps
-    # there are also other situations in which it occurs.
-    #
-    # Do the check in a separate process.
-
-    bad_libs = ["PyQt5.QtWidgets", "IPython"]
-
-    template = textwrap.dedent(
+            import {bad_lib}
+        except ImportError:
+            sys.exit(0)
+        {after}
+        x = np.ones(2, dtype=np.float32)
+        sys.exit(0 if np.allclose(x.dot(x), 2.0) else 1)
         """
-    import sys
-    {before}
-    try:
-        import {bad_lib}
-    except ImportError:
-        sys.exit(0)
-    {after}
-    x = np.ones(2, dtype=np.float32)
-    sys.exit(0 if np.allclose(x.dot(x), 2.0) else 1)
-    """
-    )
+        )
 
-    for bad_lib in bad_libs:
-        code = template.format(before="import numpy as np", after="", bad_lib=bad_lib)
-        subprocess.check_call([sys.executable, "-c", code])
+        for bad_lib in bad_libs:
+            code = template.format(
+                before="import numpy as np", after="", bad_lib=bad_lib
+            )
+            subprocess.check_call([sys.executable, "-c", code])
 
-        # Swapped import order
-        code = template.format(after="import numpy as np", before="", bad_lib=bad_lib)
-        subprocess.check_call([sys.executable, "-c", code])
+            # Swapped import order
+            code = template.format(
+                after="import numpy as np", before="", bad_lib=bad_lib
+            )
+            subprocess.check_call([sys.executable, "-c", code])
 
 
-class TestMultiDot:
+class TestMultiDot(TestCase):
     def test_basic_function_with_three_arguments(self):
         # multi_dot with three arguments uses a fast hand coded algorithm to
         # determine the optimal order. Therefore test it separately.
@@ -2094,8 +2131,9 @@
         assert_raises((RuntimeError, ValueError), multi_dot, [np.random.random((3, 3))])
 
 
-class TestTensorinv:
-    @pytest.mark.parametrize(
+@instantiate_parametrized_tests
+class TestTensorinv(TestCase):
+    @parametrize(
         "arr, ind",
         [
             (np.ones((4, 6, 8, 2)), 2),
@@ -2106,7 +2144,7 @@
         with assert_raises((LinAlgError, RuntimeError)):
             linalg.tensorinv(arr, ind=ind)
 
-    @pytest.mark.parametrize(
+    @parametrize(
         "shape, ind",
         [
             # examples from docstring
@@ -2121,7 +2159,7 @@
         actual = ainv.shape
         assert_equal(actual, expected)
 
-    @pytest.mark.parametrize(
+    @parametrize(
         "ind",
         [
             0,
@@ -2141,8 +2179,9 @@
         assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))
 
 
-class TestTensorsolve:
-    @pytest.mark.parametrize(
+@instantiate_parametrized_tests
+class TestTensorsolve(TestCase):
+    @parametrize(
         "a, axes",
         [
             (np.ones((4, 6, 8, 2)), None),
@@ -2154,7 +2193,7 @@
             b = np.ones(a.shape[:2])
             linalg.tensorsolve(a, b, axes=axes)
 
-    @pytest.mark.parametrize(
+    @parametrize(
         "shape",
         [(2, 3, 6), (3, 4, 4, 3), (0, 3, 3, 0)],
     )
@@ -2165,61 +2204,58 @@
         assert_allclose(np.tensordot(a, x, axes=len(x.shape)), b)
 
 
-@pytest.mark.xfail(reason="TODO")
-def test_unsupported_commontype():
-    # linalg gracefully handles unsupported type
-    arr = np.array([[1, -2], [2, 5]], dtype="float16")
-    # with assert_raises_regex(TypeError, "unsupported in linalg"):
-    with assert_raises(TypeError):
-        linalg.cholesky(arr)
+class TestMisc2(TestCase):
+    @xfail  # (reason="TODO")
+    def test_unsupported_commontype(self):
+        # linalg gracefully handles unsupported type
+        arr = np.array([[1, -2], [2, 5]], dtype="float16")
+        # with assert_raises_regex(TypeError, "unsupported in linalg"):
+        with assert_raises(TypeError):
+            linalg.cholesky(arr)
 
+    @xfail  # (reason="TODO")
+    # @slow
+    # @pytest.mark.xfail(not HAS_LAPACK64, run=False,
+    #                   reason="Numpy not compiled with 64-bit BLAS/LAPACK")
+    # @requires_memory(free_bytes=16e9)
+    @skip(reason="Bad memory reports lead to OOM in ci testing")
+    def test_blas64_dot(self):
+        n = 2**32
+        a = np.zeros([1, n], dtype=np.float32)
+        b = np.ones([1, 1], dtype=np.float32)
+        a[0, -1] = 1
+        c = np.dot(b, a)
+        assert_equal(c[0, -1], 1)
 
-@pytest.mark.xfail(reason="TODO")
-# @pytest.mark.slow
-# @pytest.mark.xfail(not HAS_LAPACK64, run=False,
-#                   reason="Numpy not compiled with 64-bit BLAS/LAPACK")
-# @requires_memory(free_bytes=16e9)
-@pytest.mark.skip(reason="Bad memory reports lead to OOM in ci testing")
-def test_blas64_dot():
-    n = 2**32
-    a = np.zeros([1, n], dtype=np.float32)
-    b = np.ones([1, 1], dtype=np.float32)
-    a[0, -1] = 1
-    c = np.dot(b, a)
-    assert_equal(c[0, -1], 1)
+    @skip(reason="lapack-lite specific")
+    @xfail  # (
+    #    not HAS_LAPACK64, reason="Numpy not compiled with 64-bit BLAS/LAPACK"
+    # )
+    def test_blas64_geqrf_lwork_smoketest(self):
+        # Smoke test LAPACK geqrf lwork call with 64-bit integers
+        dtype = np.float64
+        lapack_routine = np.linalg.lapack_lite.dgeqrf
 
+        m = 2**32 + 1
+        n = 2**32 + 1
+        lda = m
 
-@pytest.mark.skip(reason="lapack-lite specific")
-@pytest.mark.xfail(
-    not HAS_LAPACK64, reason="Numpy not compiled with 64-bit BLAS/LAPACK"
-)
-def test_blas64_geqrf_lwork_smoketest():
-    # Smoke test LAPACK geqrf lwork call with 64-bit integers
-    dtype = np.float64
-    lapack_routine = np.linalg.lapack_lite.dgeqrf
+        # Dummy arrays, not referenced by the lapack routine, so don't
+        # need to be of the right size
+        a = np.zeros([1, 1], dtype=dtype)
+        work = np.zeros([1], dtype=dtype)
+        tau = np.zeros([1], dtype=dtype)
 
-    m = 2**32 + 1
-    n = 2**32 + 1
-    lda = m
+        # Size query
+        results = lapack_routine(m, n, a, lda, tau, work, -1, 0)
+        assert_equal(results["info"], 0)
+        assert_equal(results["m"], m)
+        assert_equal(results["n"], m)
 
-    # Dummy arrays, not referenced by the lapack routine, so don't
-    # need to be of the right size
-    a = np.zeros([1, 1], dtype=dtype)
-    work = np.zeros([1], dtype=dtype)
-    tau = np.zeros([1], dtype=dtype)
-
-    # Size query
-    results = lapack_routine(m, n, a, lda, tau, work, -1, 0)
-    assert_equal(results["info"], 0)
-    assert_equal(results["m"], m)
-    assert_equal(results["n"], m)
-
-    # Should result to an integer of a reasonable size
-    lwork = int(work.item())
-    assert_(2**32 < lwork < 2**42)
+        # Should result to an integer of a reasonable size
+        lwork = int(work.item())
+        assert_(2**32 < lwork < 2**42)
 
 
 if __name__ == "__main__":
-    from torch._dynamo.test_case import run_tests
-
     run_tests()