parameterized test_graph_optims and test_graph_scaling_fused_optimizers (#133749)

Fixes #123451

This is a rework of a reverted pull request, https://github.com/pytorch/pytorch/pull/125127.
The test failure is fixed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133749
Approved by: https://github.com/janeyx99
diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py
index a1a38a9..0fcba0d 100644
--- a/test/inductor/test_compiled_optimizers.py
+++ b/test/inductor/test_compiled_optimizers.py
@@ -146,6 +146,8 @@
     "test_sgd_weight_decay_maximize_cuda": 4,
     "test_sgd_weight_decay_maximize_xpu": 4,
     "test_sgd_weight_decay_maximize_cpu": 4,
+    "test_sgd_weight_decay_cpu": 4,
+    "test_sgd_weight_decay_cuda": 4,
     "test_sgd_momentum_weight_decay_foreach_cuda": 2,
     "test_sgd_momentum_weight_decay_foreach_xpu": 2,
     "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2,
diff --git a/test/test_cuda.py b/test/test_cuda.py
index e5e25a6..bb52e8c 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -40,7 +40,12 @@
     onlyCUDA,
     onlyNativeDeviceTypes,
 )
-from torch.testing._internal.common_optimizers import optim_db, optims, TensorTracker
+from torch.testing._internal.common_optimizers import (
+    _get_optim_inputs_including_global_cliquey_kwargs,
+    optim_db,
+    optims,
+    TensorTracker,
+)
 from torch.testing._internal.common_utils import (
     EXPANDABLE_SEGMENTS,
     freeze_rng_state,
@@ -3457,93 +3462,6 @@
     @unittest.skipIf(
         not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
     )
-    def test_graph_optims(self):
-        # Needs generalization if we want to extend this test to non-Adam-like optimizers.
-        cases = (
-            [
-                (
-                    optimizer_ctor,
-                    {
-                        "lr": 0.1,
-                        "betas": (0.8, 0.7),
-                        "foreach": foreach,
-                        "decoupled_weight_decay": decoupled_weight_decay,
-                        "weight_decay": weight_decay,
-                    },
-                )
-                for optimizer_ctor, foreach, decoupled_weight_decay, weight_decay in product(
-                    (torch.optim.NAdam, torch.optim.RAdam),
-                    (False, True),
-                    (False, True),
-                    (0.0, 0.1),
-                )
-            ]
-            + [
-                (
-                    torch.optim.Rprop,
-                    {"lr": 0.1, "foreach": foreach, "maximize": maximize},
-                )
-                for foreach, maximize in product(
-                    (False, True),
-                    (False, True),
-                )
-            ]
-            + [
-                (
-                    optimizer_ctor,
-                    {
-                        "lr": 0.1,
-                        "betas": (0.8, 0.7),
-                        "foreach": foreach,
-                        "amsgrad": amsgrad,
-                    },
-                )
-                for optimizer_ctor, foreach, amsgrad in product(
-                    (torch.optim.Adam, torch.optim.AdamW),
-                    (False, True),
-                    (False, True),
-                )
-            ]
-            + [
-                (
-                    optimizer_ctor,
-                    {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad},
-                )
-                for optimizer_ctor, amsgrad in product(
-                    (torch.optim.Adam, torch.optim.AdamW), (False, True)
-                )
-            ]
-            + [
-                (
-                    optimizer_ctor,
-                    {
-                        "lr": 0.1,
-                        "foreach": foreach,
-                        "maximize": maximize,
-                        "weight_decay": weight_decay,
-                    },
-                )
-                for optimizer_ctor, foreach, maximize, weight_decay in product(
-                    (
-                        torch.optim.Adamax,
-                        torch.optim.ASGD,
-                        torch.optim.Adadelta,
-                        torch.optim.RMSprop,
-                    ),
-                    (False, True),
-                    (False, True),
-                    (0, 0.1),
-                )
-            ]
-        )
-
-        for optimizer_ctor, kwargs in cases:
-            with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs):
-                self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs)
-
-    @unittest.skipIf(
-        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
-    )
     def test_graph_optims_with_explicitly_capturable_param_groups(self):
         # mimicking `_test_graphed_optimizer` maladroitly to pass two param_groups to optimizer.__init__
         n_warmup, n_replay = 3, 2
@@ -3615,125 +3533,6 @@
     @unittest.skipIf(
         not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
     )
-    def test_graph_scaling_fused_optimizers(self):
-        cases = [
-            (
-                optimizer_ctor,
-                {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad},
-            )
-            for optimizer_ctor, amsgrad in product(
-                (torch.optim.Adam, torch.optim.AdamW), (False, True)
-            )
-        ] + list(
-            product(
-                (torch.optim.SGD,),
-                [
-                    {
-                        "lr": 0.1,
-                        "momentum": 0.0,
-                        "dampening": d,
-                        "weight_decay": w,
-                        "nesterov": n,
-                        "fused": True,
-                    }
-                    for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,))
-                ]
-                + [
-                    {
-                        "lr": 0.1,
-                        "momentum": 0.5,
-                        "dampening": d,
-                        "weight_decay": w,
-                        "nesterov": n,
-                        "fused": True,
-                    }
-                    for d, w, n in product((0.0,), (0.0, 0.5), (True, False))
-                ],
-            )
-        )
-
-        steps_warmup = 3
-        steps_train = 2
-
-        for OptClass, kwargs in cases:
-            has_capturable_arg = OptClass in (torch.optim.Adam, torch.optim.AdamW)
-            for actually_do_graphs in (True, False) if has_capturable_arg else (True,):
-                params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)]
-                params_control = [p.clone().requires_grad_() for p in params]
-                params_graphed = [p.clone().requires_grad_() for p in params]
-
-                # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients.
-                grads = [
-                    [torch.randn_like(p) for p in params]
-                    for _ in range(steps_warmup + steps_train)
-                ]
-                with torch.no_grad():
-                    grads_control = [[g.clone() for g in gs] for gs in grads]
-                    grads_graphed = [[g.clone() for g in gs] for gs in grads]
-
-                # Gradient Scaler
-                scaler_for_control = torch.amp.GradScaler(
-                    device="cuda", init_scale=128.0
-                )
-                with torch.no_grad():
-                    scaler_for_control._lazy_init_scale_growth_tracker(
-                        torch.device("cuda")
-                    )
-
-                scaler_for_graphed = torch.amp.GradScaler(device="cuda")
-                scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
-                with torch.no_grad():
-                    scaler_for_graphed._lazy_init_scale_growth_tracker(
-                        torch.device("cuda")
-                    )
-
-                # Control (capturable=False)
-                if has_capturable_arg:
-                    kwargs["capturable"] = False
-                opt = OptClass(params_control, **kwargs)
-
-                for i in range(steps_warmup + steps_train):
-                    for j, p in enumerate(params_control):
-                        p.grad = grads_control[i][j]
-                    scaler_for_control.step(opt)
-                    scaler_for_control.update()
-
-                # capturable=True
-                if has_capturable_arg:
-                    kwargs["capturable"] = True
-                opt = OptClass(params_graphed, **kwargs)
-
-                for i in range(steps_warmup):
-                    for j, p in enumerate(params_graphed):
-                        p.grad = grads_graphed[i][j]
-                    scaler_for_graphed.step(opt)
-                    scaler_for_graphed.update()
-
-                if actually_do_graphs:
-                    g = torch.cuda.CUDAGraph()
-                    with torch.cuda.graph(g):
-                        scaler_for_graphed.step(opt)
-                        scaler_for_graphed.update()
-
-                for i in range(steps_train):
-                    if actually_do_graphs:
-                        for j, p in enumerate(params_graphed):
-                            p.grad.copy_(grads_graphed[i + steps_warmup][j])
-                        g.replay()
-                    else:
-                        # Passing capturable=True to the constructor and running without graphs should still be
-                        # numerically correct, even if it's not ideal for performance.
-                        for j, p in enumerate(params_graphed):
-                            p.grad = grads_graphed[i + steps_warmup][j]
-                        scaler_for_graphed.step(opt)
-                        scaler_for_graphed.update()
-
-                for p_control, p_graphed in zip(params_control, params_graphed):
-                    self.assertEqual(p_control, p_graphed)
-
-    @unittest.skipIf(
-        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
-    )
     def test_cuda_graph_error_options(self):
         def fn():
             x = torch.zeros([2000], device="cuda")
@@ -5082,10 +4881,179 @@
         self.assertEqual(len(set(active_pool_ids)), 4)
 
 
+@torch.testing._internal.common_utils.markDynamoStrictTest
 class TestCudaOptims(TestCase):
     # These tests will be instantiate with instantiate_device_type_tests
     # to apply the new OptimizerInfo structure.
 
+    @onlyCUDA
+    @unittest.skipIf(
+        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs"
+    )
+    @optims(
+        [optim for optim in optim_db if optim.has_capturable_arg],
+        dtypes=[torch.float32],
+    )
+    def test_graph_optims(self, device, dtype, optim_info):
+        optim_cls = optim_info.optim_cls
+        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
+            device, dtype, optim_info, skip=("differentiable",)
+        )
+
+        steps_warmup = 3
+        steps_train = 2
+
+        for optim_input in all_optim_inputs:
+            kwargs = optim_input.kwargs
+
+            # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam
+            # and torch.optim.adamw
+            kwargs["lr"] = 0.1
+
+            for actually_do_graphs in (True, False):
+                params = [
+                    torch.randn((i + 5, i + 5), device=device) for i in range(2)
+                ] + [torch.randn((), device=device)]
+                params_control = [p.clone().requires_grad_() for p in params]
+                params_graphed = [p.clone().requires_grad_() for p in params]
+
+                grads = [
+                    [torch.randn_like(p) for p in params]
+                    for _ in range(steps_warmup + steps_train)
+                ]
+
+                # Control (capturable=False)
+                kwargs["capturable"] = False
+
+                opt = optim_cls(params_control, **kwargs)
+                for i in range(steps_warmup + steps_train):
+                    for j, p in enumerate(params_control):
+                        p.grad = grads[i][j]
+                    opt.step()
+
+                # capturable=True
+                kwargs["capturable"] = True
+                opt = optim_cls(params_graphed, **kwargs)
+
+                for i in range(steps_warmup):
+                    for j, p in enumerate(params_graphed):
+                        p.grad = grads[i][j]
+                    opt.step()
+
+                if actually_do_graphs:
+                    g = torch.cuda.CUDAGraph()
+                    with torch.cuda.graph(g):
+                        opt.step()
+
+                for i in range(steps_train):
+                    if actually_do_graphs:
+                        for j, p in enumerate(params_graphed):
+                            p.grad.copy_(grads[i + steps_warmup][j])
+                        g.replay()
+                    else:
+                        # Passing capturable=True to the constructor and running without graphs should still be
+                        # numerically correct, even if it's not ideal for performance.
+                        for j, p in enumerate(params_graphed):
+                            p.grad = grads[i + steps_warmup][j]
+                        opt.step()
+
+                for p_control, p_graphed in zip(params_control, params_graphed):
+                    self.assertEqual(p_control, p_graphed)
+
+    @onlyCUDA
+    @unittest.skipIf(
+        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
+    )
+    @optims(
+        [
+            optim
+            for optim in optim_db
+            if "fused" in optim.supported_impls and "cuda" in optim.supports_fused_on
+        ],
+        dtypes=[torch.float32],
+    )
+    def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
+        optim_cls = optim_info.optim_cls
+
+        steps_warmup = 3
+        steps_train = 2
+
+        optim_inputs = optim_info.optim_inputs_func(device=device)
+
+        for optim_input in optim_inputs:
+            kwargs = optim_input.kwargs
+            kwargs["fused"] = True
+
+            for actually_do_graphs in (
+                (True, False) if optim_info.has_capturable_arg else (True,)
+            ):
+                params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)]
+                params_control = [p.clone().requires_grad_() for p in params]
+                params_graphed = [p.clone().requires_grad_() for p in params]
+
+                # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients.
+                grads = [
+                    [torch.randn_like(p) for p in params]
+                    for _ in range(steps_warmup + steps_train)
+                ]
+                with torch.no_grad():
+                    grads_control = [[g.clone() for g in gs] for gs in grads]
+                    grads_graphed = [[g.clone() for g in gs] for gs in grads]
+
+                # Gradient Scaler
+                scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0)
+                with torch.no_grad():
+                    scaler_for_control._lazy_init_scale_growth_tracker(device)
+
+                scaler_for_graphed = torch.cuda.amp.GradScaler()
+                scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
+                with torch.no_grad():
+                    scaler_for_graphed._lazy_init_scale_growth_tracker(device)
+
+                # Control (capturable=False)
+                if optim_info.has_capturable_arg:
+                    kwargs["capturable"] = False
+                opt = optim_cls(params_control, **kwargs)
+
+                for i in range(steps_warmup + steps_train):
+                    for j, p in enumerate(params_control):
+                        p.grad = grads_control[i][j]
+                    scaler_for_control.step(opt)
+                    scaler_for_control.update()
+
+                # capturable=True
+                if optim_info.has_capturable_arg:
+                    kwargs["capturable"] = True
+                opt = optim_cls(params_graphed, **kwargs)
+
+                for i in range(steps_warmup):
+                    for j, p in enumerate(params_graphed):
+                        p.grad = grads_graphed[i][j]
+                    scaler_for_graphed.step(opt)
+                    scaler_for_graphed.update()
+
+                if actually_do_graphs:
+                    g = torch.cuda.CUDAGraph()
+                    with torch.cuda.graph(g):
+                        scaler_for_graphed.step(opt)
+                        scaler_for_graphed.update()
+
+                for i in range(steps_train):
+                    if actually_do_graphs:
+                        for j, p in enumerate(params_graphed):
+                            p.grad.copy_(grads_graphed[i + steps_warmup][j])
+                        g.replay()
+                    else:
+                        # Passing capturable=True to the constructor and running without graphs should still be
+                        # numerically correct, even if it's not ideal for performance.
+                        for j, p in enumerate(params_graphed):
+                            p.grad = grads_graphed[i + steps_warmup][j]
+                        scaler_for_graphed.step(opt)
+                        scaler_for_graphed.update()
+
+                for p_control, p_graphed in zip(params_control, params_graphed):
+                    self.assertEqual(p_control, p_graphed)
+
     @onlyNativeDeviceTypes
     @optims(
         [optim for optim in optim_db if "fused" in optim.supported_impls],
diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py
index 7d4d225..bd8f1f2 100644
--- a/torch/testing/_internal/common_optimizers.py
+++ b/torch/testing/_internal/common_optimizers.py
@@ -132,6 +132,8 @@
         ),
         # the optim supports passing in sparse gradients as well as dense grads
         supports_sparse: bool = False,
+        # the optimizer constructor supports passing in capturable as a kwarg
+        has_capturable_arg: bool = False,
         # the optim only supports one config: sparse grads w/ dense params, see SparseAdam
         only_supports_sparse_grads: bool = False,
         # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
@@ -157,6 +159,7 @@
         self.supported_impls = supported_impls
         self.not_og_supported_flags = not_og_supported_flags
         self.supports_sparse = supports_sparse
+        self.has_capturable_arg = has_capturable_arg
         self.metadata_for_sparse = metadata_for_sparse
         self.only_supports_sparse_grads = only_supports_sparse_grads
         self.supports_complex = supports_complex
@@ -330,10 +333,11 @@
         OptimizerInput(
             params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
         ),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
         OptimizerInput(
             params=None,
             kwargs={"weight_decay": 0.1, "maximize": True},
-            desc="maximize",
+            desc="maximize, weight_decay",
         ),
         OptimizerInput(
             params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
@@ -631,9 +635,14 @@
         ),
         OptimizerInput(
             params=None,
-            kwargs={"weight_decay": 0.1, "maximize": True},
+            kwargs={"maximize": True},
             desc="maximize",
         ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize, weight_decay",
+        ),
     ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
 
 
@@ -788,11 +797,18 @@
         ),
         OptimizerInput(
             params=None,
-            kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
+            kwargs={
+                "weight_decay": 0.1,
+            },
             desc="weight_decay",
         ),
         OptimizerInput(
             params=None,
+            kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
+            desc="weight_decay, momentum_decay",
+        ),
+        OptimizerInput(
+            params=None,
             kwargs={
                 "weight_decay": 0.1,
                 "momentum_decay": 6e-3,
@@ -935,11 +951,26 @@
         ),
         OptimizerInput(
             params=None,
+            kwargs={
+                "maximize": True,
+            },
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
             kwargs={"weight_decay": 0.1, "centered": True},
             desc="centered",
         ),
         OptimizerInput(
             params=None,
+            kwargs={
+                "maximize": True,
+                "weight_decay": 0.1,
+            },
+            desc="maximize, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
             kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
             desc="momentum",
         ),
@@ -951,7 +982,7 @@
                 "momentum": 0.1,
                 "maximize": True,
             },
-            desc="maximize",
+            desc="maximize, centered, weight_decay, w/ momentum",
         ),
     ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
 
@@ -1022,27 +1053,30 @@
         OptimizerInput(
             params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
         ),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
+        ),
         OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
         OptimizerInput(
             params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
             kwargs={"momentum": 0.9, "dampening": 0.5},
             desc="dampening",
         ),
         OptimizerInput(
             params=None,
             kwargs={"momentum": 0.9, "weight_decay": 0.1},
-            desc="non-zero weight_decay",
+            desc="weight_decay w/ momentum",
         ),
         OptimizerInput(
             params=None,
             kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
             desc="nesterov",
         ),
-        OptimizerInput(
-            params=None,
-            kwargs={"weight_decay": 0.1, "maximize": True},
-            desc="maximize",
-        ),
     ]
 
 
@@ -1208,6 +1242,7 @@
         optim_inputs_func=optim_inputs_func_adadelta,
         optim_error_inputs_func=optim_error_inputs_func_adadelta,
         supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
         skips=(
             DecorateInfo(
                 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
@@ -1493,6 +1528,7 @@
         ),
         optim_error_inputs_func=optim_error_inputs_func_adam,
         supported_impls=("foreach", "differentiable", "fused"),
+        has_capturable_arg=True,
         not_og_supported_flags=(
             "foreach",
             "differentiable",
@@ -1578,6 +1614,7 @@
         optim_inputs_func=optim_inputs_func_adamax,
         optim_error_inputs_func=optim_error_inputs_func_adamax,
         supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
         skips=(
             DecorateInfo(
                 skipIfMps,  # addcdiv doesn't work for non-contiguous, see #118115
@@ -1630,6 +1667,7 @@
             "capturable",
         ),
         supports_fused_on=("cpu", "cuda", "mps"),
+        has_capturable_arg=True,
         decorators=(
             # Expected error between compiled forloop and fused optimizers
             DecorateInfo(
@@ -1710,6 +1748,7 @@
         optim_inputs_func=optim_inputs_func_asgd,
         optim_error_inputs_func=optim_error_inputs_func_asgd,
         supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
         skips=(
             DecorateInfo(
                 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
@@ -1822,6 +1861,7 @@
         optim_inputs_func=optim_inputs_func_nadam,
         optim_error_inputs_func=optim_error_inputs_func_nadam,
         supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
         skips=(
             DecorateInfo(
                 skipIfMps,  # addcdiv doesn't work for non-contiguous, see #118115
@@ -1870,6 +1910,7 @@
         optim_inputs_func=optim_inputs_func_radam,
         optim_error_inputs_func=optim_error_inputs_func_radam,
         supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
         skips=(
             DecorateInfo(
                 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
@@ -1915,6 +1956,7 @@
         optim_inputs_func=optim_inputs_func_rmsprop,
         optim_error_inputs_func=optim_error_inputs_func_rmsprop,
         supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
         skips=(
             DecorateInfo(
                 skipIfMps,  # addcdiv doesn't work for non-contiguous, see #118115
@@ -1964,6 +2006,7 @@
         optim_inputs_func=optim_inputs_func_rprop,
         optim_error_inputs_func=optim_error_inputs_func_rprop,
         supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
         skips=(
             DecorateInfo(
                 skipIfMps,  # Rprop doesn't update for non-contiguous, see #118117