Moving _run_autocast_outofplace to basic class named TestAutocast to reduce redundance (#134460)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134460
Approved by: https://github.com/EikanWang, https://github.com/ezyang
diff --git a/test/test_autocast.py b/test/test_autocast.py
index 4c3f031..8e702f9 100644
--- a/test/test_autocast.py
+++ b/test/test_autocast.py
@@ -1,10 +1,12 @@
 # Owner(s): ["module: unknown"]
 
-import collections
 import unittest
 
 import torch
-from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
+from torch.testing._internal.autocast_test_lists import (
+    AutocastCPUTestLists,
+    TestAutocast,
+)
 from torch.testing._internal.common_utils import (
     IS_WINDOWS,
     run_tests,
@@ -14,7 +16,7 @@
 from torch.utils._python_dispatch import TorchDispatchMode
 
 
-class TestAutocastCPU(TestCase):
+class TestAutocastCPU(TestAutocast):
     def setUp(self):
         super().setUp()
         self.autocast_lists = AutocastCPUTestLists(torch.device("cpu"))
@@ -23,100 +25,6 @@
         del self.autocast_lists
         super().tearDown()
 
-    def _run_autocast_outofplace(
-        self,
-        op,
-        args,
-        run_as_type,
-        out_type=None,
-        module=torch,
-        add_kwargs=None,
-        amp_dtype=torch.bfloat16,
-    ):
-        # helper to cast args
-        def cast(val, to_type):
-            if isinstance(val, torch.Tensor):
-                return val.to(to_type) if val.is_floating_point() else val
-            elif isinstance(val, collections.abc.Iterable):
-                return type(val)(cast(v, to_type) for v in val)
-            else:
-                return val
-
-        if add_kwargs is None:
-            add_kwargs = {}
-
-        self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
-        with torch.amp.autocast(device_type="cpu", dtype=amp_dtype):
-            self.assertTrue(torch.is_autocast_enabled(device_type="cpu"))
-            out_type = out_type if out_type is not None else run_as_type
-            output = output_method = None
-
-            # Try module.* variant, if requested:
-            if module is not None and hasattr(module, op):
-                output = getattr(module, op)(*args, **add_kwargs)
-                if isinstance(output, torch.Tensor):
-                    self.assertTrue(
-                        out_type == output.dtype,
-                        f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
-                    )
-            # Try Tensor.* variant:
-            if hasattr(torch.Tensor, op):
-                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
-                if isinstance(output_method, torch.Tensor):
-                    self.assertTrue(
-                        out_type == output_method.dtype,
-                        f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
-                    )
-
-            self.assertTrue(
-                (output is not None) or (output_method is not None),
-                f"{op} not found as an attribute on either Tensor or the requested module {module}",
-            )
-
-            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
-            # For example, lstm_cell returns a tuple and equal returns bool.
-            def compare(first, second):
-                if isinstance(first, torch.Tensor):
-                    return torch.equal(first, second)
-                elif isinstance(first, collections.abc.Iterable):
-                    return all(compare(f, s) for f, s in zip(first, second))
-                else:
-                    return first == second
-
-            # If both torch.* and Tensor.* variants were found, check outputs are identical
-            if (output is not None) and (output_method is not None):
-                self.assertTrue(type(output) == type(output_method))
-                comparison = compare(output, output_method)
-                self.assertTrue(
-                    comparison, f"torch.{op} result did not match Tensor.{op} result"
-                )
-
-            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
-            # as the C++-side autocasting, and should be bitwise accurate.
-            output_to_compare = output if output is not None else output_method
-            with torch.amp.autocast(device_type="cpu", enabled=False):
-                self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
-
-                if module is not None and hasattr(module, op):
-                    control = getattr(module, op)(
-                        *cast(args, run_as_type), **add_kwargs
-                    )
-                else:
-                    control = getattr(args[0].to(run_as_type), op)(
-                        *cast(args[1:], run_as_type), **add_kwargs
-                    )
-                self.assertTrue(type(output_to_compare) == type(control))
-                comparison = compare(output_to_compare, control)
-                self.assertTrue(comparison, f"torch.{op} result did not match control")
-            self.assertTrue(torch.is_autocast_enabled(device_type="cpu"))
-        self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
-
-    def args_maybe_kwargs(self, op_with_args):
-        if len(op_with_args) == 2:
-            return op_with_args[0], op_with_args[1], {}
-        else:
-            return op_with_args[0], op_with_args[1], op_with_args[2]
-
     @skipIfTorchDynamo()
     def test_autocast_torch_expect_builtin_promote(self):
         for (
@@ -125,9 +33,16 @@
             args2,
             out_type,
         ) in self.autocast_lists.torch_expect_builtin_promote:
-            self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
             self._run_autocast_outofplace(
-                op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16
+                op, args1, torch.float32, device="cpu", out_type=out_type
+            )
+            self._run_autocast_outofplace(
+                op,
+                args2,
+                torch.float32,
+                device="cpu",
+                out_type=out_type,
+                amp_dtype=torch.float16,
             )
 
     @skipIfTorchDynamo()
@@ -139,12 +54,13 @@
             out_type,
         ) in self.autocast_lists.methods_expect_builtin_promote:
             self._run_autocast_outofplace(
-                op, args1, torch.float32, module=None, out_type=out_type
+                op, args1, torch.float32, device="cpu", module=None, out_type=out_type
             )
             self._run_autocast_outofplace(
                 op,
                 args2,
                 torch.float32,
+                device="cpu",
                 module=None,
                 out_type=out_type,
                 amp_dtype=torch.float16,
@@ -155,12 +71,13 @@
         for op_with_args in self.autocast_lists.torch_16:
             op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
             self._run_autocast_outofplace(
-                op, args, torch.bfloat16, add_kwargs=maybe_kwargs
+                op, args, torch.bfloat16, device="cpu", add_kwargs=maybe_kwargs
             )
             self._run_autocast_outofplace(
                 op,
                 args,
                 torch.float16,
+                device="cpu",
                 add_kwargs=maybe_kwargs,
                 amp_dtype=torch.float16,
             )
@@ -170,12 +87,18 @@
         for op_with_args in self.autocast_lists.nn_16:
             op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
             self._run_autocast_outofplace(
-                op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
+                op,
+                args,
+                torch.bfloat16,
+                device="cpu",
+                module=torch._C._nn,
+                add_kwargs=maybe_kwargs,
             )
             self._run_autocast_outofplace(
                 op,
                 args,
                 torch.float16,
+                device="cpu",
                 module=torch._C._nn,
                 add_kwargs=maybe_kwargs,
                 amp_dtype=torch.float16,
@@ -186,12 +109,13 @@
         for op_with_args in self.autocast_lists.torch_fp32:
             op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
             self._run_autocast_outofplace(
-                op, args, torch.float32, add_kwargs=maybe_kwargs
+                op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs
             )
             self._run_autocast_outofplace(
                 op,
                 args,
                 torch.float32,
+                device="cpu",
                 add_kwargs=maybe_kwargs,
                 amp_dtype=torch.float16,
             )
@@ -201,12 +125,18 @@
         for op_with_args in self.autocast_lists.nn_fp32:
             op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
             self._run_autocast_outofplace(
-                op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
+                op,
+                args,
+                torch.float32,
+                device="cpu",
+                module=torch._C._nn,
+                add_kwargs=maybe_kwargs,
             )
             self._run_autocast_outofplace(
                 op,
                 args,
                 torch.float32,
+                device="cpu",
                 module=torch._C._nn,
                 add_kwargs=maybe_kwargs,
                 amp_dtype=torch.float16,
@@ -215,9 +145,9 @@
     @skipIfTorchDynamo()
     def test_autocast_torch_need_autocast_promote(self):
         for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
-            self._run_autocast_outofplace(op, args1, torch.float32)
+            self._run_autocast_outofplace(op, args1, torch.float32, device="cpu")
             self._run_autocast_outofplace(
-                op, args2, torch.float32, amp_dtype=torch.float16
+                op, args2, torch.float32, device="cpu", amp_dtype=torch.float16
             )
 
     @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
diff --git a/test/test_cuda.py b/test/test_cuda.py
index c6a61e1..2195e71 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1,6 +1,5 @@
 # Owner(s): ["module: cuda"]
 
-import collections
 import contextlib
 import ctypes
 import gc
@@ -29,7 +28,7 @@
     segment_plot,
     trace_plot,
 )
-from torch.testing._internal.autocast_test_lists import AutocastTestLists
+from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
 from torch.testing._internal.common_cuda import (
     _create_scaling_case,
     _get_torch_cuda_version,
@@ -61,7 +60,6 @@
     IS_WINDOWS,
     load_tests,
     NO_MULTIPROCESSING_SPAWN,
-    NoTest,
     parametrize,
     run_tests,
     serialTest,
@@ -85,10 +83,6 @@
 # sharding on sandcastle. This line silences flake warnings
 load_tests = load_tests
 
-if not TEST_CUDA:
-    print("CUDA not available, skipping tests", file=sys.stderr)
-    TestCase = NoTest  # noqa: F811
-
 try:
     import torchvision.models  # noqa: F401
     from torchvision.models import resnet18  # noqa: F401
@@ -113,6 +107,7 @@
 _cycles_per_ms = None
 
 
+@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
 @torch.testing._internal.common_utils.markDynamoStrictTest
 class TestCuda(TestCase):
     _do_cuda_memory_leak_check = True
@@ -121,10 +116,8 @@
 
     def setUp(self):
         super().setUp()
-        self.autocast_lists = AutocastTestLists(torch.device("cuda:0"))
 
     def tearDown(self):
-        del self.autocast_lists
         super().tearDown()
 
     @property
@@ -1558,545 +1551,6 @@
             for t in range(num_threads):
                 self.assertEqual(results[t].sum().item(), size * size)
 
-    def _run_autocast_outofplace(
-        self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None
-    ):
-        # helper to cast args
-        def cast(val, to_type):
-            if isinstance(val, torch.Tensor):
-                return val.to(to_type) if val.is_floating_point() else val
-            elif isinstance(val, collections.abc.Iterable):
-                return type(val)(cast(v, to_type) for v in val)
-            else:
-                return val
-
-        if add_kwargs is None:
-            add_kwargs = {}
-        fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
-        self.assertFalse(torch.is_autocast_enabled())
-        with torch.autocast("cuda", dtype=fast_dtype):
-            self.assertTrue(torch.is_autocast_enabled())
-
-            out_type = out_type if out_type is not None else run_as_type
-            output = output_method = None
-
-            # Try module.* variant, if requested:
-            if module is not None and hasattr(module, op):
-                output = getattr(module, op)(*args, **add_kwargs)
-                if isinstance(output, torch.Tensor):
-                    self.assertTrue(
-                        out_type == output.dtype,
-                        f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
-                    )
-
-            # Try Tensor.* variant:
-            if hasattr(torch.Tensor, op):
-                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
-                if isinstance(output_method, torch.Tensor):
-                    self.assertTrue(
-                        out_type == output_method.dtype,
-                        f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
-                    )
-
-            self.assertTrue(
-                (output is not None) or (output_method is not None),
-                f"{op} not found as an attribute on either Tensor or the requested module {module}",
-            )
-
-            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
-            # For example, lstm_cell returns a tuple and equal returns bool.
-            def compare(first, second):
-                if isinstance(first, torch.Tensor):
-                    return torch.equal(first, second)
-                elif isinstance(first, collections.abc.Iterable):
-                    return all(compare(f, s) for f, s in zip(first, second))
-                else:
-                    return first == second
-
-            # If both torch.* and Tensor.* variants were found, check outputs are identical
-            if (output is not None) and (output_method is not None):
-                self.assertTrue(type(output) == type(output_method))
-                comparison = compare(output, output_method)
-                self.assertTrue(
-                    comparison, f"torch.{op} result did not match Tensor.{op} result"
-                )
-
-            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
-            # as the C++-side autocasting, and should be bitwise accurate.
-            output_to_compare = output if output is not None else output_method
-            with torch.autocast("cuda", enabled=False):
-                self.assertFalse(torch.is_autocast_enabled())
-
-                if module is not None and hasattr(module, op):
-                    control = getattr(module, op)(
-                        *cast(args, run_as_type), **add_kwargs
-                    )
-                else:
-                    control = getattr(args[0].to(run_as_type), op)(
-                        *cast(args[1:], run_as_type), **add_kwargs
-                    )
-                self.assertTrue(type(output_to_compare) == type(control))
-                comparison = compare(output_to_compare, control)
-                self.assertTrue(comparison, f"torch.{op} result did not match control")
-            self.assertTrue(torch.is_autocast_enabled())
-        self.assertFalse(torch.is_autocast_enabled())
-
-    def args_maybe_kwargs(self, op_with_args):
-        if len(op_with_args) == 2:
-            return op_with_args[0], op_with_args[1], {}
-        else:
-            return op_with_args[0], op_with_args[1], op_with_args[2]
-
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_torch_fp16(self):
-        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
-            for op_with_args in self.autocast_lists.torch_fp16:
-                skip_test = False
-                op, args = op_with_args[0], op_with_args[1]
-                if len(op_with_args) == 3:
-                    skip_test = op_with_args[2]  # TEST_WITH_ROCM
-                if not skip_test:
-                    self._run_autocast_outofplace(op, args, torch.float16)
-
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_torch_bf16(self):
-        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
-            for op_with_args in self.autocast_lists.torch_fp16:
-                skip_test = False
-                op, args = op_with_args[0], op_with_args[1]
-                if len(op_with_args) == 3:
-                    skip_test = op_with_args[2]  # TEST_WITH_ROCM
-                should_error_from_cudnn = "cudnn" in op and (
-                    "TORCH_CUDNN_V8_API_DISABLED" in os.environ
-                    and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"])
-                    or torch.cuda.get_device_capability() < (8, 0)
-                )
-                should_error_from_not_implemented = should_error_from_cudnn
-                if not skip_test:
-                    if should_error_from_not_implemented:
-                        with self.assertRaises(
-                            RuntimeError,
-                            msg=str(op) + " should not be supported for bfloat16!",
-                        ):
-                            self._run_autocast_outofplace(op, args, torch.bfloat16)
-                    else:
-                        if torch.cuda.is_bf16_supported():
-                            self._run_autocast_outofplace(op, args, torch.bfloat16)
-                        else:
-                            with self.assertRaisesRegex(
-                                RuntimeError, "Device does not support bfloat16"
-                            ):
-                                self._run_autocast_outofplace(op, args, torch.bfloat16)
-
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_torch_fp32(self):
-        for op_with_args in self.autocast_lists.torch_fp32:
-            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
-            self._run_autocast_outofplace(
-                op, args, torch.float32, add_kwargs=maybe_kwargs
-            )
-
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_torch_need_autocast_promote(self):
-        for op, args in self.autocast_lists.torch_need_autocast_promote:
-            self._run_autocast_outofplace(op, args, torch.float32)
-
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_torch_expect_builtin_promote(self):
-        for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
-            self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
-
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_nn_fp16(self):
-        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
-            for op, args in self.autocast_lists.nn_fp16:
-                self._run_autocast_outofplace(
-                    op, args, torch.float16, module=torch._C._nn
-                )
-
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_nn_bf16(self):
-        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
-            for op, args in self.autocast_lists.nn_fp16:
-                if torch.cuda.is_bf16_supported():
-                    self._run_autocast_outofplace(
-                        op, args, torch.bfloat16, module=torch._C._nn
-                    )
-                else:
-                    with self.assertRaisesRegex(
-                        RuntimeError, "Device does not support bfloat16"
-                    ):
-                        self._run_autocast_outofplace(
-                            op, args, torch.bfloat16, module=torch._C._nn
-                        )
-
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_nn_fp32(self):
-        for op, args in self.autocast_lists.nn_fp32:
-            self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn)
-
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_linalg_fp16(self):
-        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
-            for op, args in self.autocast_lists.linalg_fp16:
-                self._run_autocast_outofplace(
-                    op, args, torch.float16, module=torch._C._linalg
-                )
-
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_methods_fp16(self):
-        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
-            for op, args in self.autocast_lists.methods_fp16:
-                self._run_autocast_outofplace(op, args, torch.float16, module=None)
-
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_methods_fp32(self):
-        for op, args in self.autocast_lists.methods_fp32:
-            self._run_autocast_outofplace(op, args, torch.float32, module=None)
-
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_methods_expect_builtin_promote(self):
-        for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
-            self._run_autocast_outofplace(
-                op, args, torch.float32, module=None, out_type=out_type
-            )
-
-    def test_autocast_banned(self):
-        with torch.autocast("cuda"):
-            for op, args, module in self.autocast_lists.banned:
-                with self.assertRaises(RuntimeError):
-                    getattr(module, op)(*args)
-
-    def test_autocast_ignored_types(self):
-        with torch.autocast("cuda"):
-            for ignore_type in (torch.double, torch.int32):
-                a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
-                b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
-                c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0")
-
-                # Tests if CastPolicy::fp16 ops ignore double and int
-                # Currently, no ops belonging to this policy support integer inputs.
-                if ignore_type is torch.double:
-                    with self.assertRaises(RuntimeError):
-                        torch.mm(a_ignore, c_16)
-                    with torch.autocast("cuda", enabled=False):
-                        type_no_autocast = torch.mm(a_ignore, b_ignore).dtype
-                    self.assertTrue(
-                        torch.mm(a_ignore, b_ignore).dtype is type_no_autocast
-                    )
-
-                # Tests if CastPolicy::fp32 ops ignore double and int
-                with torch.autocast("cuda", enabled=False):
-                    type_no_autocast = torch.pow(a_ignore, 2.0).dtype
-                self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast)
-
-                # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int
-                with torch.autocast("cuda", enabled=False):
-                    type_no_autocast = torch.sum(a_ignore).dtype
-                self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast)
-
-                # Tests if CastPolicy::fp32_append_dtype ops ignore double and int
-                # Currently, no ops belonging to this policy support integer inputs.
-                if ignore_type is torch.double:
-                    with torch.autocast("cuda", enabled=False):
-                        type_no_autocast = torch.norm(a_ignore).dtype
-                    self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast)
-
-    def test_autocast_custom_enabled(self):
-        class MyMM(torch.autograd.Function):
-            @staticmethod
-            @torch.amp.custom_fwd(device_type="cuda")
-            def forward(ctx, a, b):
-                self.assertTrue(a.dtype is torch.float32)
-                self.assertTrue(b.dtype is torch.float32)
-                self.assertTrue(torch.is_autocast_enabled())
-                ctx.save_for_backward(a, b)
-                return a.mm(b)
-
-            @staticmethod
-            @torch.amp.custom_bwd(device_type="cuda")
-            def backward(ctx, grad):
-                self.assertTrue(torch.is_autocast_enabled())
-                a, b = ctx.saved_tensors
-                a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad)
-                self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype)
-                return a_grad, b_grad
-
-        mymm = MyMM.apply
-
-        x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
-        y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
-
-        dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,)
-        for dtype in dtypes:
-            with torch.cuda.amp.autocast(dtype=dtype):
-                output = mymm(x, y)
-                self.assertTrue(output.dtype is dtype)
-                loss = output.sum()
-            loss.backward()
-
-    def test_autocast_custom_cast_inputs(self):
-        class MyMM(torch.autograd.Function):
-            @staticmethod
-            @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
-            def forward(ctx, a, container, expect_type):
-                b = container[1][0]
-                self.assertTrue(a.dtype is expect_type)
-                self.assertTrue(b.dtype is expect_type)
-                self.assertFalse(torch.is_autocast_enabled())
-                ctx.save_for_backward(a, b)
-                return a.mm(b)
-
-            @staticmethod
-            @torch.amp.custom_bwd(device_type="cuda")
-            def backward(ctx, grad):
-                self.assertFalse(torch.is_autocast_enabled())
-                a, b = ctx.saved_tensors
-                return grad.mm(b.t()), None, None
-
-        mymm = MyMM.apply
-
-        x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True)
-        # Puts one input tensor in a nested container.  y's contained Tensor won't receive a gradient,
-        # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments.
-        # Sets requires_grad=False explicitly so we don't lie about expecting a gradient.
-        y = (
-            0,
-            {
-                0: torch.randn(
-                    (8, 8), device="cuda", dtype=torch.float16, requires_grad=False
-                )
-            },
-        )
-
-        with torch.autocast("cuda"):
-            output = mymm(x, y, torch.float32)
-            self.assertTrue(output.dtype is torch.float32)
-            loss = output.sum()
-        loss.backward()
-
-        # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region.
-        output = mymm(x, y, torch.float16)
-        self.assertTrue(output.dtype is torch.float16)
-        loss = output.sum()
-        loss.backward()
-
-    def test_autocast_custom_deprecated_warning(self):
-        with warnings.catch_warnings(record=True) as w:
-
-            class MyMM(torch.autograd.Function):
-                @staticmethod
-                @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
-                def forward(ctx, x, y):
-                    ctx.save_for_backward(x, y)
-                    self.assertFalse(torch.is_autocast_enabled())
-                    return x + y
-
-                @staticmethod
-                @torch.cuda.amp.custom_bwd
-                def backward(ctx, grad):
-                    _, _ = ctx.saved_tensors
-                    self.assertFalse(torch.is_autocast_enabled())
-                    return grad, grad
-
-        self.assertRegex(
-            str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated."
-        )
-        self.assertRegex(
-            str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated."
-        )
-
-        mymm = MyMM.apply
-        x = torch.randn(3, 3, requires_grad=True)
-        y = torch.randn(3, 3, requires_grad=True)
-        with torch.amp.autocast("cuda"):
-            output = mymm(x, y)
-            loss = output.sum()
-        loss.backward()
-
-    def test_autocast_cat_jit(self):
-        # Reported at https://github.com/pytorch/pytorch/issues/38958
-
-        class Model(torch.nn.Module):
-            def forward(self):
-                a = torch.randn(1)
-                b = torch.randn(1)
-                c = torch.cat((a, b), 0)
-                d = torch.stack([c, c], 0)
-                return d
-
-        # The JIT here doesn't really matter, we just need to call
-        # cat via the boxed API
-        model = Model()
-        model_jit_script = torch.jit.script(model)
-
-        with torch.autocast("cuda", enabled=True):
-            model()
-            model_jit_script()
-
-    # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened)
-    # so they get a dedicated test.
-    # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100).
-    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
-    def test_autocast_rnn(self):
-        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
-            # seq, batch, features, hidden size
-            clses = ("RNN", "GRU", "LSTM")
-            T, B, F, H = 3, 4, 5, 6
-            dtypes = (torch.float16, torch.float32)
-            input_layouts = ("seq_first", "batch_first", "packed")
-
-            for (
-                cls,
-                num_layers,
-                bias,
-                input_layout,
-                bidirectional,
-                try_nonpreflattened_weights,
-                input_dtype,
-                hidden_dtype,
-                weight_dtype,
-            ) in product(
-                clses,
-                (1, 2),
-                (True, False),
-                input_layouts,
-                (True, False),
-                (True, False),
-                dtypes,
-                dtypes,
-                dtypes,
-            ):
-                if input_layout == "seq_first":
-                    batch_first = False
-                    x = torch.randn((T, B, F), device="cuda", dtype=input_dtype)
-                elif input_layout == "batch_first":
-                    batch_first = True
-                    x = torch.randn((B, T, F), device="cuda", dtype=input_dtype)
-                elif input_layout == "packed":
-                    batch_first = False
-                    x = torch.nn.utils.rnn.pack_padded_sequence(
-                        torch.randn((T, B, F), device="cuda", dtype=input_dtype),
-                        lengths=(3, 2, 1, 3),
-                        enforce_sorted=False,
-                    )
-
-                rnn = (
-                    getattr(torch.nn, cls)(
-                        F,
-                        H,
-                        num_layers=num_layers,
-                        bidirectional=bidirectional,
-                        bias=bias,
-                        batch_first=batch_first,
-                    )
-                    .cuda()
-                    .to(dtype=weight_dtype)
-                )
-
-                if try_nonpreflattened_weights:
-                    for p in rnn.parameters():
-                        with torch.no_grad():
-                            p.set_(p.clone())
-
-                h = torch.randn(
-                    (num_layers * (2 if bidirectional else 1), B, H),
-                    device="cuda",
-                    dtype=hidden_dtype,
-                )
-                if cls == "LSTM":
-                    c = torch.randn(
-                        (num_layers * (2 if bidirectional else 1), B, H),
-                        device="cuda",
-                        dtype=hidden_dtype,
-                    )
-                    h = (h, c)
-
-                with torch.autocast("cuda"):
-                    out, h_out = rnn(x, h)
-                out = out.data if input_layout == "packed" else out
-                self.assertEqual(out.dtype, torch.float16)
-                # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed.  This check can't guarantee
-                # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has
-                # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed.
-                self.assertEqual(
-                    out.grad_fn.name(),
-                    "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0",
-                )
-                out.sum().backward()
-                grads = [p.grad.clone() for p in rnn.parameters()]
-
-                rnn.zero_grad()
-
-                if cls == "LSTM":
-                    out_control, h_out_control = rnn.to(dtype=torch.float16)(
-                        x.half(), (h[0].half(), h[1].half())
-                    )
-                else:
-                    out_control, h_out_control = rnn.to(dtype=torch.float16)(
-                        x.half(), h.half()
-                    )
-                out_control = (
-                    out_control.data if input_layout == "packed" else out_control
-                )
-                out_control.sum().backward()
-                grads_control = [p.grad.clone() for p in rnn.parameters()]
-
-                # Compares with default tolerances, even for FP16 execution.  Barring nondeterminism,
-                # autocast and control results should be bitwise identical.
-                self.assertEqual(out, out_control)
-
-                if cls == "LSTM":
-                    self.assertTrue(
-                        h_out[0].dtype is torch.float16
-                        and h_out[1].dtype is torch.float16
-                    )
-                    self.assertEqual(h_out[0], h_out_control[0])
-                    self.assertEqual(h_out[1], h_out_control[1])
-                else:
-                    self.assertEqual(h_out.dtype, torch.float16)
-                    self.assertEqual(h_out, h_out_control)
-                for grad, grad_control in zip(grads, grads_control):
-                    self.assertEqual(grad.half(), grad_control)
-
-    def test_autocast_cache_leak(self):
-        # Reported at https://github.com/pytorch/pytorch/issues/48049
-        # Test is used to check, if autocast recaches the same parameters
-        # when executed in a `torch.no_grad()` block.
-
-        linear = torch.nn.Linear(10, 10).to("cuda")
-        data = torch.randn(1, 10, device="cuda")
-
-        with torch.autocast("cuda"):
-            with torch.no_grad():
-                out = linear(data)
-                first_iter_mem = torch.cuda.memory_allocated()
-                for _ in range(3):
-                    out = linear(data)
-                self.assertTrue(first_iter_mem == torch.cuda.memory_allocated())
-
-    def test_autocast_checkpointing(self):
-        model = torch.nn.Sequential(
-            torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)
-        ).cuda()
-        input = torch.rand(
-            (8, 8), device="cuda", dtype=torch.float16, requires_grad=True
-        )
-        for reentrant in (True, False):
-            with torch.autocast("cuda"):
-                output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant)
-            self.assertTrue(output.requires_grad)
-            self.assertTrue(output.dtype is torch.float16)
-            output.sum().backward()
-
-    def test_cuda_autocast_deprecated_warning(self):
-        with self.assertWarnsRegex(
-            FutureWarning,
-            r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.",
-        ):
-            with torch.cuda.amp.autocast():
-                _ = torch.ones(10)
-
     @slowTest
     @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
     @serialTest()
@@ -3242,7 +2696,9 @@
         relu_control = torch.nn.functional.relu
 
         # This is a good stress test. It graphs four callables: two Modules and two python functions.
-        with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
+        with torch.amp.autocast(
+            device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
+        ):
             (
                 model_graphed[0],
                 model_graphed[1],
@@ -3282,7 +2738,9 @@
             torch.cuda.manual_seed(5)
             for data, target in zip(real_inputs, real_targets):
                 opt.zero_grad(set_to_none=True)
-                with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
+                with torch.amp.autocast(
+                    device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
+                ):
                     y_pred = m({"x": data, "unused_input": unused_input})["output"]
                     y_pred = relu(y_pred)
                     loss = loss_fn(y_pred, target)
@@ -3351,7 +2809,9 @@
         y = torch.randn(N, D_in, device="cuda")
 
         # This is a good stress test. It graphs four callables: two Modules and two python functions.
-        with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
+        with torch.amp.autocast(
+            device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
+        ):
             model_graphed[0] = torch.cuda.make_graphed_callables(
                 model_graphed[0],
                 ({"x": x, "unused_input": unused_input},),
@@ -3366,8 +2826,10 @@
             # so dropout math should be bitwise identical for both.
             torch.manual_seed(5)
             torch.cuda.manual_seed(5)
-            for data, target in zip(real_inputs, real_targets):
-                with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
+            for data, _ in zip(real_inputs, real_targets):
+                with torch.amp.autocast(
+                    device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
+                ):
                     out = m({"x": data, "unused_input": unused_input})["output"]
 
         # We graphed the models in training mode. Eval should still run ungraphed.
@@ -3826,6 +3288,7 @@
                 file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
 
 
+@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
 @torch.testing._internal.common_utils.markDynamoStrictTest
 class TestCudaMallocAsync(TestCase):
     @unittest.skipIf(
@@ -4492,7 +3955,7 @@
     return t
 
 
-@unittest.skipIf(TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI")
+@unittest.skipIf(not TEST_CUDA or TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI")
 @torch.testing._internal.common_utils.markDynamoStrictTest
 class TestBlockStateAbsorption(TestCase):
     @property
@@ -4855,6 +4318,7 @@
         self.assertEqual(rc, "False", "Triton was imported when importing torch!")
 
 
+@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
 class TestMemPool(TestCase):
     def test_mempool_id(self):
         pool1 = torch.cuda.graph_pool_handle()
@@ -4980,6 +4444,7 @@
         self.assertEqual(len(set(active_pool_ids)), 4)
 
 
+@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
 @torch.testing._internal.common_utils.markDynamoStrictTest
 class TestCudaOptims(TestCase):
     # These tests will be instantiate with instantiate_device_type_tests
@@ -5322,6 +4787,7 @@
             self.assertEqual(scaler._growth_tracker, growth_tracker)
 
 
+@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
 class TestGDS(TestCase):
     def _get_tmp_dir_fs_type(self):
         my_path = os.path.realpath("/tmp")
@@ -5356,6 +4822,526 @@
         torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage())
 
 
+@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
+class TestCudaAutocast(TestAutocast):
+    def setUp(self):
+        super().setUp()
+        self.autocast_lists = AutocastTestLists(torch.device("cuda:0"))
+
+    def tearDown(self):
+        del self.autocast_lists
+        super().tearDown()
+
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_torch_fp16(self):
+        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+            for op_with_args in self.autocast_lists.torch_fp16:
+                skip_test = False
+                op, args = op_with_args[0], op_with_args[1]
+                if len(op_with_args) == 3:
+                    skip_test = op_with_args[2]  # TEST_WITH_ROCM
+                if not skip_test:
+                    self._run_autocast_outofplace(
+                        op, args, torch.float16, device="cuda", amp_dtype=torch.float16
+                    )
+
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_torch_bf16(self):
+        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+            for op_with_args in self.autocast_lists.torch_fp16:
+                skip_test = False
+                op, args = op_with_args[0], op_with_args[1]
+                if len(op_with_args) == 3:
+                    skip_test = op_with_args[2]  # TEST_WITH_ROCM
+                should_error_from_cudnn = "cudnn" in op and (
+                    "TORCH_CUDNN_V8_API_DISABLED" in os.environ
+                    and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"])
+                    or torch.cuda.get_device_capability() < (8, 0)
+                )
+                should_error_from_not_implemented = should_error_from_cudnn
+                if not skip_test:
+                    if should_error_from_not_implemented:
+                        with self.assertRaises(
+                            RuntimeError,
+                            msg=str(op) + " should not be supported for bfloat16!",
+                        ):
+                            self._run_autocast_outofplace(
+                                op, args, torch.bfloat16, device="cuda"
+                            )
+                    else:
+                        if torch.cuda.is_bf16_supported():
+                            self._run_autocast_outofplace(
+                                op, args, torch.bfloat16, device="cuda"
+                            )
+                        else:
+                            with self.assertRaisesRegex(
+                                RuntimeError, "Device does not support bfloat16"
+                            ):
+                                self._run_autocast_outofplace(
+                                    op, args, torch.bfloat16, device="cuda"
+                                )
+
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_torch_fp32(self):
+        for op_with_args in self.autocast_lists.torch_fp32:
+            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
+            self._run_autocast_outofplace(
+                op,
+                args,
+                torch.float32,
+                device="cuda",
+                add_kwargs=maybe_kwargs,
+                amp_dtype=torch.float16,
+            )
+
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_torch_need_autocast_promote(self):
+        for op, args in self.autocast_lists.torch_need_autocast_promote:
+            self._run_autocast_outofplace(
+                op, args, torch.float32, device="cuda", amp_dtype=torch.float16
+            )
+
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_torch_expect_builtin_promote(self):
+        for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
+            self._run_autocast_outofplace(
+                op,
+                args,
+                torch.float32,
+                device="cuda",
+                out_type=out_type,
+                amp_dtype=torch.float16,
+            )
+
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_nn_fp16(self):
+        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+            for op, args in self.autocast_lists.nn_fp16:
+                self._run_autocast_outofplace(
+                    op,
+                    args,
+                    torch.float16,
+                    device="cuda",
+                    module=torch._C._nn,
+                    amp_dtype=torch.float16,
+                )
+
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_nn_bf16(self):
+        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+            for op, args in self.autocast_lists.nn_fp16:
+                if torch.cuda.is_bf16_supported():
+                    self._run_autocast_outofplace(
+                        op, args, torch.bfloat16, device="cuda", module=torch._C._nn
+                    )
+                else:
+                    with self.assertRaisesRegex(
+                        RuntimeError, "Device does not support bfloat16"
+                    ):
+                        self._run_autocast_outofplace(
+                            op, args, torch.bfloat16, device="cuda", module=torch._C._nn
+                        )
+
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_nn_fp32(self):
+        for op, args in self.autocast_lists.nn_fp32:
+            self._run_autocast_outofplace(
+                op,
+                args,
+                torch.float32,
+                device="cuda",
+                module=torch._C._nn,
+                amp_dtype=torch.float16,
+            )
+
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_linalg_fp16(self):
+        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+            for op, args in self.autocast_lists.linalg_fp16:
+                self._run_autocast_outofplace(
+                    op,
+                    args,
+                    torch.float16,
+                    device="cuda",
+                    module=torch._C._linalg,
+                    amp_dtype=torch.float16,
+                )
+
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_methods_fp16(self):
+        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+            for op, args in self.autocast_lists.methods_fp16:
+                self._run_autocast_outofplace(
+                    op,
+                    args,
+                    torch.float16,
+                    device="cuda",
+                    module=None,
+                    amp_dtype=torch.float16,
+                )
+
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_methods_fp32(self):
+        for op, args in self.autocast_lists.methods_fp32:
+            self._run_autocast_outofplace(
+                op,
+                args,
+                torch.float32,
+                device="cuda",
+                module=None,
+                amp_dtype=torch.float16,
+            )
+
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_methods_expect_builtin_promote(self):
+        for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
+            self._run_autocast_outofplace(
+                op,
+                args,
+                torch.float32,
+                device="cuda",
+                module=None,
+                out_type=out_type,
+                amp_dtype=torch.float16,
+            )
+
+    def test_autocast_banned(self):
+        with torch.autocast("cuda"):
+            for op, args, module in self.autocast_lists.banned:
+                with self.assertRaises(RuntimeError):
+                    getattr(module, op)(*args)
+
+    def test_autocast_ignored_types(self):
+        with torch.autocast("cuda"):
+            for ignore_type in (torch.double, torch.int32):
+                a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
+                b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
+                c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0")
+
+                # Tests if CastPolicy::fp16 ops ignore double and int
+                # Currently, no ops belonging to this policy support integer inputs.
+                if ignore_type is torch.double:
+                    with self.assertRaises(RuntimeError):
+                        torch.mm(a_ignore, c_16)
+                    with torch.autocast("cuda", enabled=False):
+                        type_no_autocast = torch.mm(a_ignore, b_ignore).dtype
+                    self.assertTrue(
+                        torch.mm(a_ignore, b_ignore).dtype is type_no_autocast
+                    )
+
+                # Tests if CastPolicy::fp32 ops ignore double and int
+                with torch.autocast("cuda", enabled=False):
+                    type_no_autocast = torch.pow(a_ignore, 2.0).dtype
+                self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast)
+
+                # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int
+                with torch.autocast("cuda", enabled=False):
+                    type_no_autocast = torch.sum(a_ignore).dtype
+                self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast)
+
+                # Tests if CastPolicy::fp32_append_dtype ops ignore double and int
+                # Currently, no ops belonging to this policy support integer inputs.
+                if ignore_type is torch.double:
+                    with torch.autocast("cuda", enabled=False):
+                        type_no_autocast = torch.norm(a_ignore).dtype
+                    self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast)
+
+    def test_autocast_custom_enabled(self):
+        class MyMM(torch.autograd.Function):
+            @staticmethod
+            @torch.amp.custom_fwd(device_type="cuda")
+            def forward(ctx, a, b):
+                self.assertTrue(a.dtype is torch.float32)
+                self.assertTrue(b.dtype is torch.float32)
+                self.assertTrue(torch.is_autocast_enabled())
+                ctx.save_for_backward(a, b)
+                return a.mm(b)
+
+            @staticmethod
+            @torch.amp.custom_bwd(device_type="cuda")
+            def backward(ctx, grad):
+                self.assertTrue(torch.is_autocast_enabled())
+                a, b = ctx.saved_tensors
+                a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad)
+                self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype)
+                return a_grad, b_grad
+
+        mymm = MyMM.apply
+
+        x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
+        y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
+
+        dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,)
+        for dtype in dtypes:
+            with torch.cuda.amp.autocast(dtype=dtype):
+                output = mymm(x, y)
+                self.assertTrue(output.dtype is dtype)
+                loss = output.sum()
+            loss.backward()
+
+    def test_autocast_custom_cast_inputs(self):
+        class MyMM(torch.autograd.Function):
+            @staticmethod
+            @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
+            def forward(ctx, a, container, expect_type):
+                b = container[1][0]
+                self.assertTrue(a.dtype is expect_type)
+                self.assertTrue(b.dtype is expect_type)
+                self.assertFalse(torch.is_autocast_enabled())
+                ctx.save_for_backward(a, b)
+                return a.mm(b)
+
+            @staticmethod
+            @torch.amp.custom_bwd(device_type="cuda")
+            def backward(ctx, grad):
+                self.assertFalse(torch.is_autocast_enabled())
+                a, b = ctx.saved_tensors
+                return grad.mm(b.t()), None, None
+
+        mymm = MyMM.apply
+
+        x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True)
+        # Puts one input tensor in a nested container.  y's contained Tensor won't receive a gradient,
+        # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments.
+        # Sets requires_grad=False explicitly so we don't lie about expecting a gradient.
+        y = (
+            0,
+            {
+                0: torch.randn(
+                    (8, 8), device="cuda", dtype=torch.float16, requires_grad=False
+                )
+            },
+        )
+
+        with torch.autocast("cuda"):
+            output = mymm(x, y, torch.float32)
+            self.assertTrue(output.dtype is torch.float32)
+            loss = output.sum()
+        loss.backward()
+
+        # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region.
+        output = mymm(x, y, torch.float16)
+        self.assertTrue(output.dtype is torch.float16)
+        loss = output.sum()
+        loss.backward()
+
+    def test_autocast_custom_deprecated_warning(self):
+        with warnings.catch_warnings(record=True) as w:
+
+            class MyMM(torch.autograd.Function):
+                @staticmethod
+                @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+                def forward(ctx, x, y):
+                    ctx.save_for_backward(x, y)
+                    self.assertFalse(torch.is_autocast_enabled())
+                    return x + y
+
+                @staticmethod
+                @torch.cuda.amp.custom_bwd
+                def backward(ctx, grad):
+                    _, _ = ctx.saved_tensors
+                    self.assertFalse(torch.is_autocast_enabled())
+                    return grad, grad
+
+        self.assertRegex(
+            str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated."
+        )
+        self.assertRegex(
+            str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated."
+        )
+
+        mymm = MyMM.apply
+        x = torch.randn(3, 3, requires_grad=True)
+        y = torch.randn(3, 3, requires_grad=True)
+        with torch.amp.autocast("cuda"):
+            output = mymm(x, y)
+            loss = output.sum()
+        loss.backward()
+
+    def test_autocast_cat_jit(self):
+        # Reported at https://github.com/pytorch/pytorch/issues/38958
+
+        class Model(torch.nn.Module):
+            def forward(self):
+                a = torch.randn(1)
+                b = torch.randn(1)
+                c = torch.cat((a, b), 0)
+                d = torch.stack([c, c], 0)
+                return d
+
+        # The JIT here doesn't really matter, we just need to call
+        # cat via the boxed API
+        model = Model()
+        model_jit_script = torch.jit.script(model)
+
+        with torch.autocast("cuda", enabled=True):
+            model()
+            model_jit_script()
+
+    # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened)
+    # so they get a dedicated test.
+    # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100).
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_autocast_rnn(self):
+        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+            # seq, batch, features, hidden size
+            clses = ("RNN", "GRU", "LSTM")
+            T, B, F, H = 3, 4, 5, 6
+            dtypes = (torch.float16, torch.float32)
+            input_layouts = ("seq_first", "batch_first", "packed")
+
+            for (
+                cls,
+                num_layers,
+                bias,
+                input_layout,
+                bidirectional,
+                try_nonpreflattened_weights,
+                input_dtype,
+                hidden_dtype,
+                weight_dtype,
+            ) in product(
+                clses,
+                (1, 2),
+                (True, False),
+                input_layouts,
+                (True, False),
+                (True, False),
+                dtypes,
+                dtypes,
+                dtypes,
+            ):
+                if input_layout == "seq_first":
+                    batch_first = False
+                    x = torch.randn((T, B, F), device="cuda", dtype=input_dtype)
+                elif input_layout == "batch_first":
+                    batch_first = True
+                    x = torch.randn((B, T, F), device="cuda", dtype=input_dtype)
+                elif input_layout == "packed":
+                    batch_first = False
+                    x = torch.nn.utils.rnn.pack_padded_sequence(
+                        torch.randn((T, B, F), device="cuda", dtype=input_dtype),
+                        lengths=(3, 2, 1, 3),
+                        enforce_sorted=False,
+                    )
+
+                rnn = (
+                    getattr(torch.nn, cls)(
+                        F,
+                        H,
+                        num_layers=num_layers,
+                        bidirectional=bidirectional,
+                        bias=bias,
+                        batch_first=batch_first,
+                    )
+                    .cuda()
+                    .to(dtype=weight_dtype)
+                )
+
+                if try_nonpreflattened_weights:
+                    for p in rnn.parameters():
+                        with torch.no_grad():
+                            p.set_(p.clone())
+
+                h = torch.randn(
+                    (num_layers * (2 if bidirectional else 1), B, H),
+                    device="cuda",
+                    dtype=hidden_dtype,
+                )
+                if cls == "LSTM":
+                    c = torch.randn(
+                        (num_layers * (2 if bidirectional else 1), B, H),
+                        device="cuda",
+                        dtype=hidden_dtype,
+                    )
+                    h = (h, c)
+
+                with torch.autocast("cuda"):
+                    out, h_out = rnn(x, h)
+                out = out.data if input_layout == "packed" else out
+                self.assertEqual(out.dtype, torch.float16)
+                # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed.  This check can't guarantee
+                # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has
+                # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed.
+                self.assertEqual(
+                    out.grad_fn.name(),
+                    "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0",
+                )
+                out.sum().backward()
+                grads = [p.grad.clone() for p in rnn.parameters()]
+
+                rnn.zero_grad()
+
+                if cls == "LSTM":
+                    out_control, h_out_control = rnn.to(dtype=torch.float16)(
+                        x.half(), (h[0].half(), h[1].half())
+                    )
+                else:
+                    out_control, h_out_control = rnn.to(dtype=torch.float16)(
+                        x.half(), h.half()
+                    )
+                out_control = (
+                    out_control.data if input_layout == "packed" else out_control
+                )
+                out_control.sum().backward()
+                grads_control = [p.grad.clone() for p in rnn.parameters()]
+
+                # Compares with default tolerances, even for FP16 execution.  Barring nondeterminism,
+                # autocast and control results should be bitwise identical.
+                self.assertEqual(out, out_control)
+
+                if cls == "LSTM":
+                    self.assertTrue(
+                        h_out[0].dtype is torch.float16
+                        and h_out[1].dtype is torch.float16
+                    )
+                    self.assertEqual(h_out[0], h_out_control[0])
+                    self.assertEqual(h_out[1], h_out_control[1])
+                else:
+                    self.assertEqual(h_out.dtype, torch.float16)
+                    self.assertEqual(h_out, h_out_control)
+                for grad, grad_control in zip(grads, grads_control):
+                    self.assertEqual(grad.half(), grad_control)
+
+    def test_autocast_cache_leak(self):
+        # Reported at https://github.com/pytorch/pytorch/issues/48049
+        # Test is used to check, if autocast recaches the same parameters
+        # when executed in a `torch.no_grad()` block.
+
+        linear = torch.nn.Linear(10, 10).to("cuda")
+        data = torch.randn(1, 10, device="cuda")
+
+        with torch.autocast("cuda"):
+            with torch.no_grad():
+                out = linear(data)
+                first_iter_mem = torch.cuda.memory_allocated()
+                for _ in range(3):
+                    out = linear(data)
+                self.assertTrue(first_iter_mem == torch.cuda.memory_allocated())
+
+    def test_autocast_checkpointing(self):
+        model = torch.nn.Sequential(
+            torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)
+        ).cuda()
+        input = torch.rand(
+            (8, 8), device="cuda", dtype=torch.float16, requires_grad=True
+        )
+        for reentrant in (True, False):
+            with torch.autocast("cuda"):
+                output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant)
+            self.assertTrue(output.requires_grad)
+            self.assertTrue(output.dtype is torch.float16)
+            output.sum().backward()
+
+    def test_cuda_autocast_deprecated_warning(self):
+        with self.assertWarnsRegex(
+            FutureWarning,
+            r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.",
+        ):
+            with torch.cuda.amp.autocast():
+                _ = torch.ones(10)
+
+
 instantiate_parametrized_tests(TestCuda)
 instantiate_parametrized_tests(TestCudaMallocAsync)
 instantiate_device_type_tests(TestCudaOptims, globals())
diff --git a/test/test_xpu.py b/test/test_xpu.py
index 9dde7d8..e77a1e7 100644
--- a/test/test_xpu.py
+++ b/test/test_xpu.py
@@ -1,6 +1,5 @@
 # Owner(s): ["module: intel"]
 
-import collections
 import subprocess
 import sys
 import tempfile
@@ -8,7 +7,7 @@
 
 import torch
 import torch.xpu._gpu_trace as gpu_trace
-from torch.testing._internal.autocast_test_lists import AutocastTestLists
+from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests,
     onlyXPU,
@@ -371,7 +370,7 @@
 instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
 
 
-class TestXpuAutocast(TestCase):
+class TestXpuAutocast(TestAutocast):
     # These operators are not implemented on XPU backend and we can NOT fall back
     # them to CPU. So we have to skip them at this moment.
     # TODO: remove these operators from skip list when they are implemented on XPU backend.
@@ -385,89 +384,6 @@
         del self.autocast_lists
         super().tearDown()
 
-    def _run_autocast_outofplace(
-        self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None
-    ):
-        # helper to cast args
-        def cast(val, to_type):
-            if isinstance(val, torch.Tensor):
-                return val.to(to_type) if val.is_floating_point() else val
-            elif isinstance(val, collections.abc.Iterable):
-                return type(val)(cast(v, to_type) for v in val)
-            else:
-                return val
-
-        if add_kwargs is None:
-            add_kwargs = {}
-        fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
-        self.assertFalse(torch.is_autocast_enabled("xpu"))
-        with torch.amp.autocast("xpu", dtype=fast_dtype):
-            self.assertTrue(torch.is_autocast_enabled("xpu"))
-
-            out_type = out_type if out_type is not None else run_as_type
-            output = output_method = None
-
-            # Try module.* variant, if requested:
-            if module is not None and hasattr(module, op):
-                output = getattr(module, op)(*args, **add_kwargs)
-                if isinstance(output, torch.Tensor):
-                    self.assertTrue(
-                        out_type == output.dtype,
-                        f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
-                    )
-
-            # Try Tensor.* variant:
-            if hasattr(torch.Tensor, op):
-                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
-                if isinstance(output_method, torch.Tensor):
-                    self.assertTrue(
-                        out_type == output_method.dtype,
-                        f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
-                    )
-
-            self.assertTrue(
-                (output is not None) or (output_method is not None),
-                f"{op} not found as an attribute on either Tensor or the requested module {module}",
-            )
-
-            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
-            # For example, lstm_cell returns a tuple and equal returns bool.
-            def compare(first, second):
-                if isinstance(first, torch.Tensor):
-                    return torch.equal(first, second)
-                elif isinstance(first, collections.abc.Iterable):
-                    return all(compare(f, s) for f, s in zip(first, second))
-                else:
-                    return first == second
-
-            # If both torch.* and Tensor.* variants were found, check outputs are identical
-            if (output is not None) and (output_method is not None):
-                self.assertTrue(type(output) == type(output_method))
-                comparison = compare(output, output_method)
-                self.assertTrue(
-                    comparison, f"torch.{op} result did not match Tensor.{op} result"
-                )
-
-            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
-            # as the C++-side autocasting, and should be bitwise accurate.
-            output_to_compare = output if output is not None else output_method
-            with torch.amp.autocast("xpu", enabled=False):
-                self.assertFalse(torch.is_autocast_enabled("xpu"))
-
-                if module is not None and hasattr(module, op):
-                    control = getattr(module, op)(
-                        *cast(args, run_as_type), **add_kwargs
-                    )
-                else:
-                    control = getattr(args[0].to(run_as_type), op)(
-                        *cast(args[1:], run_as_type), **add_kwargs
-                    )
-                self.assertTrue(type(output_to_compare) == type(control))
-                comparison = compare(output_to_compare, control)
-                self.assertTrue(comparison, f"torch.{op} result did not match control")
-            self.assertTrue(torch.is_autocast_enabled("xpu"))
-        self.assertFalse(torch.is_autocast_enabled("xpu"))
-
     def test_autocast_torch_fp16(self):
         for op_with_args in self.autocast_lists.torch_fp16:
             skip_test = False
@@ -477,7 +393,9 @@
             if len(op_with_args) == 3:
                 skip_test = True  # skip cudnn op
             if not skip_test:
-                self._run_autocast_outofplace(op, args, torch.float16)
+                self._run_autocast_outofplace(
+                    op, args, torch.float16, device="xpu", amp_dtype=torch.float16
+                )
 
     def test_autocast_torch_bf16(self):
         for op_with_args in self.autocast_lists.torch_fp16:
@@ -488,15 +406,24 @@
             if len(op_with_args) == 3:
                 skip_test = True  # skip cudnn op
             if not skip_test:
-                self._run_autocast_outofplace(op, args, torch.bfloat16)
+                self._run_autocast_outofplace(op, args, torch.bfloat16, device="xpu")
 
     def test_autocast_torch_need_autocast_promote(self):
         for op, args in self.autocast_lists.torch_need_autocast_promote:
-            self._run_autocast_outofplace(op, args, torch.float32)
+            self._run_autocast_outofplace(
+                op, args, torch.float32, device="xpu", amp_dtype=torch.float16
+            )
 
     def test_autocast_torch_expect_builtin_promote(self):
         for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
-            self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
+            self._run_autocast_outofplace(
+                op,
+                args,
+                torch.float32,
+                device="xpu",
+                out_type=out_type,
+                amp_dtype=torch.float16,
+            )
 
     def test_autocast_checkpointing(self):
         model = torch.nn.Sequential(
diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py
index 8527084..c9789f1 100644
--- a/torch/testing/_internal/autocast_test_lists.py
+++ b/torch/testing/_internal/autocast_test_lists.py
@@ -1,7 +1,10 @@
 # mypy: ignore-errors
 
+import collections
+
 import torch
 from torch.testing._internal.common_utils import TEST_WITH_ROCM
+from torch.testing._internal.common_utils import TestCase
 
 
 class AutocastTestLists:
@@ -234,6 +237,7 @@
                                       torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn),
         ]
 
+
 class AutocastCPUTestLists:
     # Supplies ops and arguments for test_autocast_* in test/test_cpu.py
     def __init__(self, dev):
@@ -368,3 +372,103 @@
             ("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
             ("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
         ]
+
+
+class TestAutocast(TestCase):
+    def args_maybe_kwargs(self, op_with_args):
+        if len(op_with_args) == 2:
+            return op_with_args[0], op_with_args[1], {}
+        else:
+            return op_with_args[0], op_with_args[1], op_with_args[2]
+
+    def _run_autocast_outofplace(
+        self,
+        op,
+        args,
+        run_as_type,
+        device,
+        out_type=None,
+        module=torch,
+        add_kwargs=None,
+        amp_dtype=torch.bfloat16,
+    ):
+        # helper to cast args
+        def cast(val, to_type):
+            if isinstance(val, torch.Tensor):
+                return val.to(to_type) if val.is_floating_point() else val
+            elif isinstance(val, collections.abc.Iterable):
+                return type(val)(cast(v, to_type) for v in val)
+            else:
+                return val
+
+        if add_kwargs is None:
+            add_kwargs = {}
+
+        self.assertFalse(torch.is_autocast_enabled(device_type=device))
+        with torch.amp.autocast(device_type=device, dtype=amp_dtype):
+            self.assertTrue(torch.is_autocast_enabled(device_type=device))
+
+            out_type = out_type if out_type is not None else run_as_type
+            output = output_method = None
+
+            # Try module.* variant, if requested:
+            if module is not None and hasattr(module, op):
+                output = getattr(module, op)(*args, **add_kwargs)
+                if isinstance(output, torch.Tensor):
+                    self.assertTrue(
+                        out_type == output.dtype,
+                        f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
+                    )
+            # Try Tensor.* variant:
+            if hasattr(torch.Tensor, op):
+                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
+                if isinstance(output_method, torch.Tensor):
+                    self.assertTrue(
+                        out_type == output_method.dtype,
+                        f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
+                    )
+
+            self.assertTrue(
+                (output is not None) or (output_method is not None),
+                f"{op} not found as an attribute on either Tensor or the requested module {module}",
+            )
+
+            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
+            # For example, lstm_cell returns a tuple and equal returns bool.
+            def compare(first, second):
+                if isinstance(first, torch.Tensor):
+                    return torch.equal(first, second)
+                elif isinstance(first, collections.abc.Iterable):
+                    return all(compare(f, s) for f, s in zip(first, second))
+                else:
+                    return first == second
+
+            # If both torch.* and Tensor.* variants were found, check outputs are identical
+            if (output is not None) and (output_method is not None):
+                self.assertTrue(type(output) == type(output_method))
+                comparison = compare(output, output_method)
+                self.assertTrue(
+                    comparison, f"torch.{op} result did not match Tensor.{op} result"
+                )
+
+            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
+            # as the C++-side autocasting, and should be bitwise accurate.
+            output_to_compare = output if output is not None else output_method
+            with torch.amp.autocast(device_type=device, enabled=False):
+                self.assertFalse(
+                    torch.is_autocast_enabled(device_type=device)
+                )
+
+                if module is not None and hasattr(module, op):
+                    control = getattr(module, op)(
+                        *cast(args, run_as_type), **add_kwargs
+                    )
+                else:
+                    control = getattr(args[0].to(run_as_type), op)(
+                        *cast(args[1:], run_as_type), **add_kwargs
+                    )
+                self.assertTrue(type(output_to_compare) == type(control))
+                comparison = compare(output_to_compare, control)
+                self.assertTrue(comparison, f"torch.{op} result did not match control")
+            self.assertTrue(torch.is_autocast_enabled(device_type=device))
+        self.assertFalse(torch.is_autocast_enabled(device_type=device))