[optim] Rectify capturable testing and fix bugs! (#118326)

This PR fixes several bugs, listed in priority:
1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed.
2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks
3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented  that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos
4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place.
5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected.

The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device.

Details for posterity:
4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct.
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={}, desc=default
params=None, kwargs={'lr': 0.01}, desc=non-default lr
params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad
params=None, kwargs={'capturable': True}, desc=capturable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad
.
----------------------------------------------------------------------
Ran 1 test in 19.229s

OK
```
5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct.
```
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={'differentiable': False}, desc=default
params=None, kwargs={'differentiable': True}, desc=default & differentiable
params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable
params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable
.params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default
params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach
params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable
params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable
params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused
.
----------------------------------------------------------------------
Ran 2 tests in 11.112s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326
Approved by: https://github.com/mlazos
diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py
index d3d7d3b..74b964f 100644
--- a/test/dynamo/test_optimizers.py
+++ b/test/dynamo/test_optimizers.py
@@ -44,6 +44,10 @@
 
 
 def make_test(optim_cls, closure=None, **kwargs):
+    # Remove this conditional when #118230 is fixed
+    if optim_cls.__name__ == "Adamax":
+        kwargs["foreach"] = True
+
     opt = optim_cls(model.parameters(), **kwargs)
 
     def test_fn(self):
diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py
index 95f19f7..2fc16cc 100644
--- a/test/inductor/test_compiled_optimizers.py
+++ b/test/inductor/test_compiled_optimizers.py
@@ -74,10 +74,10 @@
     SGD: KernelCounts(multitensor=2, singletensor=8),
     RAdam: KernelCounts(
         multitensor=2, singletensor=None
-    ),  # Single tensor eager needs to be refactored to enable tracing
+    ),  # Single tensor eager needs to be refactored to enable tracing (#118230)
     Adamax: KernelCounts(
         multitensor=2, singletensor=None
-    ),  # Single tensor eager needs to be refactored to enable tracing
+    ),  # Single tensor eager needs to be refactored to enable tracing (#117836)
 }
 
 
@@ -87,12 +87,9 @@
         if optim_info.optim_cls not in KERNEL_COUNTS:
             continue
 
-        for optim_inputs in optim_info.optim_inputs_func():
-            for device in ["cpu", "cuda"]:
+        for device in ["cpu", "cuda"]:
+            for optim_inputs in optim_info.optim_inputs_func(device):
                 for foreach in [True, False]:
-                    if device == "cpu" and "capturable" in optim_inputs.kwargs:
-                        continue
-
                     kwargs = dict(optim_inputs.kwargs)
                     name = (
                         f"test_{optim_info.optim_cls.__name__.lower()}"
@@ -107,7 +104,21 @@
                     name += f"_{device}"
 
                     # Eager for-loop impl doesn't support capturable ASGD
-                    if name == "test_asgd_capturable_cuda":
+                    if name in [
+                        "test_asgd_capturable_cuda",
+                        "test_asgd_maximize_capturable_cuda",
+                        "test_asgd_weight_decay_capturable_cuda",
+                        "test_asgd_weight_decay_maximize_capturable_cuda",
+                    ]:
+                        continue
+
+                    # Adam(W) capturable cudagraphs manager is unexpectedly None, #119026
+                    if name in [
+                        "test_adam_amsgrad_capturable_cuda",
+                        "test_adam_foreach_amsgrad_capturable_cuda",
+                        "test_adamw_amsgrad_capturable_cuda",
+                        "test_adamw_foreach_amsgrad_capturable_cuda",
+                    ]:
                         continue
 
                     kwargs["foreach"] = foreach
diff --git a/test/nn/test_lazy_modules.py b/test/nn/test_lazy_modules.py
index bac1684..8dd0125 100644
--- a/test/nn/test_lazy_modules.py
+++ b/test/nn/test_lazy_modules.py
@@ -583,9 +583,9 @@
 
     @suppress_warnings
     def test_optimizer_pass(self):
+        # Add Adamax and RAdam when #118230 and #117836 are complete
         optimizers = [torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam,
-                      torch.optim.AdamW, torch.optim.Adamax,
-                      torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop,
+                      torch.optim.AdamW, torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop,
                       torch.optim.RMSprop, torch.optim.LBFGS]
 
         def run_step(module, optim):
diff --git a/test/test_optim.py b/test/test_optim.py
index 0eec142..20b3799 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -19,10 +19,8 @@
 FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
 
 
-def _make_radam_single_tensor_non_capturable(optim_cls, kwargs):
-    # Remove this function once https://github.com/pytorch/pytorch/issues/118230 is completed
-    if optim_cls == torch.optim.RAdam and not kwargs.get("foreach", False) and kwargs.get("capturable", False):
-        # Radam does not support capturable single tensor
+def _force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs):
+    if optim_info.only_supports_capturable_on_foreach and not kwargs.get("foreach", False) and kwargs.get("capturable", False):
         kwargs["capturable"] = False
 
 @markDynamoStrictTest
@@ -71,6 +69,9 @@
         for optim_input in optim_inputs:
             if "foreach" in optim_info.supported_impls:
                 optim_input.kwargs["foreach"] = False  # force forloop
+
+            _force_capturable_False_for_unsupported_single_tensor(optim_info, optim_input.kwargs)
+
             if contiguous:
                 weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
                 bias = Parameter(torch.randn((10), device=device, dtype=dtype))
@@ -79,8 +80,6 @@
                 bias = Parameter(torch.randn((10, 2), device=device, dtype=dtype)[..., 0])
             input = torch.randn(5, device=device, dtype=dtype)
 
-            # https://github.com/pytorch/pytorch/issues/118230
-            _make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
             optimizer = optim_cls([weight, bias], **optim_input.kwargs)
 
             def closure():
@@ -109,13 +108,14 @@
     @optims(optim_db, dtypes=[torch.float32])
     def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info):
         optim_cls = optim_info.optim_cls
-        optim_inputs = optim_info.optim_inputs_func(device="cuda")
+        optim_inputs = optim_info.optim_inputs_func(device=device)
         for optim_input in optim_inputs:
             if "foreach" in optim_info.supported_impls:
                 optim_input.kwargs["foreach"] = False  # force forloop
 
-            # https://github.com/pytorch/pytorch/issues/118230
-            _make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
+            if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
+                    and not optim_input.kwargs.get("foreach", False)):
+                continue
 
             weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
             bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
@@ -148,13 +148,19 @@
     def test_complex(self, device, dtype, optim_info):
         optim_cls = optim_info.optim_cls
         # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
-        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
+        # Also skip fused, since our fused kernels do not support complex
+        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
+            device, dtype, optim_info, skip=("differentiable", "fused"))
         for optim_input in all_optim_inputs:
+            if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
+                    and not optim_input.kwargs.get("foreach", False)):
+                continue
+
             # Last param is intentionally real to test that we can mix real and complex
             complex_params = [
-                torch.randn(10, 5, dtype=dtype, requires_grad=True),
-                torch.randn(10, dtype=dtype, requires_grad=True),
-                torch.randn(10, 5, dtype=torch.float32, requires_grad=True),
+                torch.randn(10, 5, device=device, dtype=dtype, requires_grad=True),
+                torch.randn(10, device=device, dtype=dtype, requires_grad=True),
+                torch.randn(10, 5, device=device, dtype=torch.float32, requires_grad=True),
             ]
             real_params = [
                 (
@@ -164,8 +170,7 @@
                 )
                 for param in complex_params
             ]
-            # https://github.com/pytorch/pytorch/issues/118230
-            _make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
+
             complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
             real_optimizer = optim_cls(real_params, **optim_input.kwargs)
             real_steps = []
@@ -234,14 +239,13 @@
         for optim_input in optim_inputs:
             updated_params, state = [], []
             kwargs = deepcopy(optim_input.kwargs)
-            if (kwargs.get("capturable", False) and str(device) == "cpu"):
+            if kwargs.get("capturable", False) and str(device) == "cpu":
                 # capturable is not supported on CPU
                 continue
             for flag_value in (False, True):
                 kwargs[flag] = flag_value
 
-                # https://github.com/pytorch/pytorch/issues/118230
-                _make_radam_single_tensor_non_capturable(optim_cls, kwargs)
+                _force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs)
 
                 input = torch.tensor(
                     [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
@@ -344,7 +348,10 @@
         for optim_input in optim_inputs:
             updated_params, state = [], []
             kwargs = deepcopy(optim_input.kwargs)
-            if kwargs.get("capturable", False) and str(device) == "cpu":
+
+            _force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs)
+
+            if kwargs.get("capturable", False) and str(device) == "cpu" :
                 # capturable is not supported on CPU
                 continue
             for use_impl in (False, True):
@@ -357,8 +364,6 @@
                         p_clone.grad = p.grad.clone().detach()
                         params_clone.append(p_clone)
 
-                # https://github.com/pytorch/pytorch/issues/118230
-                _make_radam_single_tensor_non_capturable(optim_cls, kwargs)
                 optimizer = optim_cls(params_clone, **kwargs)
                 for _ in range(kIterations):
                     optimizer.step()
@@ -393,16 +398,18 @@
         # default dtype is higher prec float64
         old_default_dtype = torch.get_default_dtype()
         for default_dtype in [torch.float64, torch.float16]:
-            torch.set_default_dtype(default_dtype)
-            self._test_derived_optimizers(
-                device,
-                dtype,
-                optim_info,
-                "foreach",
-                reduced_precision=default_dtype == torch.float16,
-                assert_step_dtype=torch.float64 if default_dtype == torch.float64 else torch.float32,
-            )
-            torch.set_default_dtype(old_default_dtype)
+            try:
+                torch.set_default_dtype(default_dtype)
+                self._test_derived_optimizers(
+                    device,
+                    dtype,
+                    optim_info,
+                    "foreach",
+                    reduced_precision=default_dtype == torch.float16,
+                    assert_step_dtype=torch.float64 if default_dtype == torch.float64 else torch.float32,
+                )
+            finally:
+                torch.set_default_dtype(old_default_dtype)
 
 
 
@@ -431,8 +438,7 @@
             for flag_value in (False, True):
                 kwargs["foreach"] = flag_value
 
-                # https://github.com/pytorch/pytorch/issues/118230
-                _make_radam_single_tensor_non_capturable(optim_cls, kwargs)
+                _force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs)
 
                 # The 128 is critical here! Our CUDACachingAllocator allocates in blocks of 512,
                 # meaning any tensor that occupies <512 bytes of memory will allocate a whole
@@ -539,6 +545,11 @@
         # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
         all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
         for optim_input in all_optim_inputs:
+            # See https://github.com/pytorch/pytorch/issues/117836 and #118230
+            if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
+                    and not optim_input.kwargs.get("foreach", False)):
+                continue
+
             weight_kwargs = optim_input.kwargs
             bias_kwargs = deepcopy(optim_input.kwargs)
             bias_kwargs["weight_decay"] = 0.0
@@ -575,6 +586,11 @@
         # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
         all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
         for optim_input in all_optim_inputs:
+            # See https://github.com/pytorch/pytorch/issues/117836 and #118230
+            if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
+                    and not optim_input.kwargs.get("foreach", False)):
+                continue
+
             # optim_input.kwargs will be the param group kwargs, which should have >0 lr
             if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0:
                 optim_input.kwargs["lr"] = 1e-3
@@ -630,10 +646,12 @@
             return torch.tensor([1], device=device, dtype=dtype)
 
         for optim_input in all_optim_inputs:
-            _make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
+            if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
+                    and not optim_input.kwargs.get("foreach", False)):
+                continue
+
             optimizer = optim_cls(params, **optim_input.kwargs)
             optimizer.step(closure)
-            self.assertEqual(old_params, params)
 
 
     @optims(optim_db, dtypes=[torch.float32])
@@ -648,7 +666,10 @@
 
         for optim_input in all_optim_inputs:
             kwargs = optim_input.kwargs
-            _make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
+
+            if (optim_info.only_supports_capturable_on_foreach and kwargs.get("capturable", False)
+                    and not kwargs.get("foreach", False)):
+                continue
 
             # params will decay even if grads are empty if weight_decay != 0,
             # and capturable doesn't work for CPU tensors
@@ -657,7 +678,7 @@
 
             # AdamW params will be updated regardless of grads due to lr, so make lr smaller
             if optim_cls.__name__ == "AdamW":
-                kwargs["lr"] = torch.tensor(1e-4) if isinstance(kwargs.get("lr", 1e-4), torch.Tensor) else 1e-4
+                kwargs["lr"] = torch.tensor(1e-5) if isinstance(kwargs.get("lr", 1e-5), torch.Tensor) else 1e-5
 
             if kwargs.get("differentiable", False):
                 params = [param.clone()]
@@ -684,6 +705,10 @@
         all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
         params = [Parameter(torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)) for _ in range(2)]
         for optim_input in all_optim_inputs:
+            if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
+                    and not optim_input.kwargs.get("foreach", False)):
+                continue
+
             optimizer = optim_cls(params, **optim_input.kwargs)
             optimizer.__repr__()
 
@@ -706,6 +731,10 @@
             return loss
 
         for optim_input in all_optim_inputs:
+            if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
+                    and not optim_input.kwargs.get("foreach", False)):
+                continue
+
             optimizer = optim_cls(params, **optim_input.kwargs)
             closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
 
@@ -749,6 +778,10 @@
         # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
         all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
         for optim_input in all_optim_inputs:
+            if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
+                    and not optim_input.kwargs.get("foreach", False)):
+                continue
+
             torch.manual_seed(1)
             model = torch.nn.Sequential(
                 torch.nn.Conv2d(4, 2, 1, stride=2),
@@ -811,9 +844,8 @@
 
         for optim_input in all_optim_inputs:
             kwargs = optim_input.kwargs
-            # See https://github.com/pytorch/pytorch/issues/117836 for Adamax
-            # See https://github.com/pytorch/pytorch/issues/118230 for RAdam
-            if optim_cls.__name__ in ["Adamax", "RAdam"] and kwargs.get("capturable", False) and not kwargs.get("foreach", False):
+            if (optim_info.only_supports_capturable_on_foreach and kwargs.get("capturable", False)
+                    and not kwargs.get("foreach", False)):
                 continue
 
             optimizer = optim_cls(params, **optim_input.kwargs)
@@ -843,6 +875,10 @@
             return lbfgs_loss if optim_cls.__name__ == "LBFGS" else None
 
         for optim_input in cpu_optim_inputs:
+            if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
+                    and not optim_input.kwargs.get("foreach", False)):
+                continue
+
             params = [Parameter(torch.randn(2, 3, device="cpu", dtype=dtype)) for _ in range(2)]
             for p in params:
                 p.grad = torch.randn_like(p)
@@ -906,6 +942,10 @@
             return {k for k in obj.__dict__ if not k.startswith("_")}
 
         for optim_input in all_optim_inputs:
+            if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
+                    and not optim_input.kwargs.get("foreach", False)):
+                continue
+
             optimizer = optim_cls(params, **optim_input.kwargs)
 
             # Make some state
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index 90e0f78..a2bf3af 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -1568,7 +1568,7 @@
         }
 
         excluded_single_tensor = {
-            radam,  # https://github.com/pytorch/pytorch/issues/117807
+            radam,  # https://github.com/pytorch/pytorch/issues/118230
         }
 
         for opt_mod in optimizer_modules:
diff --git a/torch/optim/adam.py b/torch/optim/adam.py
index 15abef4..386bbc2 100644
--- a/torch/optim/adam.py
+++ b/torch/optim/adam.py
@@ -74,7 +74,10 @@
             for p in group["params"]:
                 p_state = self.state.get(p, [])
                 if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
-                    p_state["step"] = torch.tensor(float(p_state["step"]), dtype=_get_scalar_dtype(is_fused=fused))
+                    step_val = float(p_state["step"])
+                    p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(is_fused=fused), device=p.device)
+                                       if group['capturable'] or group['fused']
+                                       else torch.tensor(step_val, dtype=_get_scalar_dtype()))
 
     def _init_group(
         self,
diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py
index 8f0ee6c..cba664d 100644
--- a/torch/optim/adamax.py
+++ b/torch/optim/adamax.py
@@ -34,6 +34,9 @@
         if not 0.0 <= weight_decay:
             raise ValueError(f"Invalid weight_decay value: {weight_decay}")
 
+        if foreach is False and capturable:
+            raise ValueError("Capturable not supported with single tensor Adamax")
+
         defaults = dict(
             lr=lr,
             betas=betas,
@@ -53,13 +56,12 @@
             group.setdefault("maximize", False)
             group.setdefault("differentiable", False)
             group.setdefault("capturable", False)
-        state_values = list(self.state.values())
-        step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
-            state_values[0]["step"]
-        )
-        if not step_is_tensor:
-            for s in state_values:
-                s["step"] = torch.tensor(float(s["step"]), dtype=_get_scalar_dtype())
+            for p in group["params"]:
+                p_state = self.state.get(p, [])
+                if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
+                    step_val = float(p_state["step"])
+                    p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device) if group['capturable']
+                                       else torch.tensor(step_val, dtype=_get_scalar_dtype()))
 
     def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps):
         has_complex = False
@@ -265,6 +267,8 @@
     capturable: bool,
     has_complex: bool,
 ):
+    if capturable:
+        raise RuntimeError("capturable is not supported for single tensor Adamax (when foreach=False)")
 
     for i, param in enumerate(params):
         grad = grads[i]
@@ -272,6 +276,7 @@
         exp_avg = exp_avgs[i]
         exp_inf = exp_infs[i]
         step_t = state_steps[i]
+
         # update step
         step_t += 1
 
@@ -328,6 +333,11 @@
     if len(params) == 0:
         return
 
+    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
+    if (not torch._utils.is_compiling() and capturable
+            and not all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps))):
+        raise RuntimeError("If capturable=True, params and state_steps must be CUDA tensors.")
+
     grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
     for ((grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps), _) in grouped_tensors.values():
         if has_complex:
diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py
index bbf5fe7..f97e66e 100644
--- a/torch/optim/adamw.py
+++ b/torch/optim/adamw.py
@@ -83,7 +83,10 @@
             for p in group["params"]:
                 p_state = self.state.get(p, [])
                 if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
-                    p_state["step"] = torch.tensor(float(p_state["step"]), dtype=_get_scalar_dtype(is_fused=fused))
+                    step_val = float(p_state["step"])
+                    p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(is_fused=fused), device=p.device)
+                                       if group['capturable'] or group['fused']
+                                       else torch.tensor(step_val, dtype=_get_scalar_dtype()))
 
     def _init_group(
         self,
diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py
index c65411a..355c1bd 100644
--- a/torch/optim/asgd.py
+++ b/torch/optim/asgd.py
@@ -34,7 +34,7 @@
         if not 0.0 <= weight_decay:
             raise ValueError(f"Invalid weight_decay value: {weight_decay}")
 
-        if foreach is False and capturable:
+        if foreach is False and capturable and not is_compiling():
             raise ValueError("Capturable not supported with single tensor ASGD")
 
         defaults = dict(
@@ -57,25 +57,20 @@
             group.setdefault("maximize", False)
             group.setdefault("differentiable", False)
             group.setdefault("capturable", False)
-        state_values = list(self.state.values())
-        step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
-            state_values[0]["step"]
-        )
-        if not step_is_tensor:
-            for s in state_values:
-                s["step"] = torch.tensor(float(s["step"]), dtype=_get_scalar_dtype())
-        eta_is_tensor = (len(state_values) != 0) and torch.is_tensor(
-            state_values[0]["eta"]
-        )
-        if not eta_is_tensor:
-            for s in state_values:
-                s["eta"] = torch.tensor(s["eta"], dtype=_get_scalar_dtype())
-        mu_is_tensor = (len(state_values) != 0) and torch.is_tensor(
-            state_values[0]["mu"]
-        )
-        if not mu_is_tensor:
-            for s in state_values:
-                s["mu"] = torch.tensor(float(s["mu"]), dtype=_get_scalar_dtype())
+            for p in group["params"]:
+                p_state = self.state.get(p, [])
+                if len(p_state) != 0:
+                    if not torch.is_tensor(p_state['step']):
+                        step_val = float(p_state["step"])
+                        p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device)
+                                           if group['capturable'] else torch.tensor(step_val, dtype=_get_scalar_dtype()))
+                    if not torch.is_tensor(p_state["eta"]):
+                        p_state["eta"] = (torch.tensor(p_state["eta"], dtype=_get_scalar_dtype(), device=p.device)
+                                          if group["capturable"] else torch.tensor(p_state["eta"], dtype=_get_scalar_dtype()))
+                    if not torch.is_tensor(p_state["mu"]):
+                        p_state["mu"] = (torch.tensor(p_state["mu"], dtype=_get_scalar_dtype(), device=p.device)
+                                         if group["capturable"] else torch.tensor(p_state["mu"], dtype=_get_scalar_dtype()))
+
 
     def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
         has_complex = False
@@ -206,8 +201,6 @@
     if foreach and not torch.jit.is_scripting():
         func = _multi_tensor_asgd
     else:
-        if capturable and not is_compiling():
-            raise RuntimeError("Capturable not supported with single tensor ASGD")
         func = _single_tensor_asgd
 
     func(
@@ -247,6 +240,9 @@
     capturable: bool,
     has_complex: bool,
 ):
+    if capturable and not is_compiling():
+        raise RuntimeError("capturable is not supported for single tensor ASGD (when foreach=False)")
+
     for i, param in enumerate(params):
         grad = grads[i]
         grad = grad if not maximize else -grad
@@ -304,12 +300,17 @@
     capturable: bool,
     has_complex: bool,
 ):
-
     if len(params) == 0:
         return
 
     assert not differentiable, "_foreach ops don't support autograd"
 
+    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
+    if not torch._utils.is_compiling() and capturable:
+        assert all(p.is_cuda and mu.is_cuda and eta.is_cuda and step.is_cuda
+                   for p, mu, eta, step in zip(params, mus, etas, state_steps)), \
+            "If capturable=True, params, mu_products, and state_steps must be CUDA tensors."
+
     grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
     for ((device, _), ((grouped_params, grouped_grads, grouped_axs, grouped_mus,
          grouped_etas, grouped_state_steps), _)) in grouped_tensors.items():
diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py
index 3526402..f05b6b0 100644
--- a/torch/optim/nadam.py
+++ b/torch/optim/nadam.py
@@ -37,15 +37,18 @@
             group.setdefault('capturable', False)
             group.setdefault('differentiable', False)
             group.setdefault('decoupled_weight_decay', False)
-        state_values = list(self.state.values())
-        step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
-        if not step_is_tensor:
-            for s in state_values:
-                s['step'] = torch.tensor(float(s['step']), dtype=_get_scalar_dtype())
-        mu_product_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['mu_product'])
-        if not mu_product_is_tensor:
-            for s in state_values:
-                s['mu_product'] = torch.tensor(s['mu_product'], dtype=_get_scalar_dtype())
+            for p in group["params"]:
+                p_state = self.state.get(p, [])
+                if len(p_state) != 0:
+                    if not torch.is_tensor(p_state['step']):
+                        step_val = float(p_state["step"])
+                        p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device)
+                                           if group['capturable'] else torch.tensor(step_val, dtype=_get_scalar_dtype()))
+                    if not torch.is_tensor(p_state['mu_product']):
+                        mu_prod_val = p_state["mu_product"]
+                        p_state["mu_product"] = (torch.tensor(mu_prod_val, dtype=_get_scalar_dtype(), device=p.device)
+                                                 if group['capturable'] else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype()))
+
 
     def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps):
         has_complex = False
diff --git a/torch/optim/radam.py b/torch/optim/radam.py
index 4184450..5fe7010 100644
--- a/torch/optim/radam.py
+++ b/torch/optim/radam.py
@@ -44,6 +44,10 @@
             raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
         if not 0.0 <= weight_decay:
             raise ValueError(f"Invalid weight_decay value: {weight_decay}")
+
+        if foreach is False and capturable:
+            raise ValueError("Capturable not supported with single tensor RAdam")
+
         defaults = dict(
             lr=lr,
             betas=betas,
@@ -208,7 +212,7 @@
             decay as in AdamW to obtain RAdamW (default: False)
         {_foreach_doc}
         {_differentiable_doc}
-        {_capturable_doc}
+        {_capturable_doc} For RAdam, capturable is only supported when foreach=True.
 
     .. _On the variance of the adaptive learning rate and beyond:
         https://arxiv.org/abs/1908.03265
@@ -297,7 +301,7 @@
     has_complex: bool,
 ):
     if capturable:
-        raise RuntimeError("capturable is not supported for single tensor radam")
+        raise RuntimeError("capturable is not supported for single tensor RAdam (when foreach=False)")
 
     for i, param in enumerate(params):
         grad = grads[i]
diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py
index e27eda5..6f782f8 100644
--- a/torch/testing/_internal/common_optimizers.py
+++ b/torch/testing/_internal/common_optimizers.py
@@ -114,6 +114,8 @@
         supports_param_groups: bool = True,
         # whether the optimizer supports parameters on multiple devices
         supports_multiple_devices: bool = True,
+        # whether the optimizer ONLY supports capturable on foreach vs. both foreach and forloop
+        only_supports_capturable_on_foreach: bool = False,
         skips=(),  # Indicates which tests to skip
         decorators=None,  # Additional decorators to apply to generated tests
         optim_error_inputs_func=None,  # Function to generate optim inputs that error
@@ -126,6 +128,7 @@
         self.step_requires_closure = step_requires_closure
         self.supports_param_groups = supports_param_groups
         self.supports_multiple_devices = supports_multiple_devices
+        self.only_supports_capturable_on_foreach = only_supports_capturable_on_foreach
         self.decorators = (
             *(decorators if decorators else []),
             *(skips if skips else []),
@@ -262,7 +265,7 @@
 # global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
 
 
-def optim_inputs_func_adadelta(device=None):
+def optim_inputs_func_adadelta(device):
     return [
         OptimizerInput(params=None, kwargs={}, desc="default"),
         OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
@@ -297,7 +300,7 @@
     return error_inputs
 
 
-def optim_inputs_func_adagrad(device=None):
+def optim_inputs_func_adagrad(device):
     return [
         OptimizerInput(params=None, kwargs={}, desc="default"),
         OptimizerInput(
@@ -341,7 +344,7 @@
 
 # TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
 # with all implementation code paths...
-def optim_inputs_func_adam(device=None):
+def optim_inputs_func_adam(device):
     cuda_supported_configs = [
         OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
         OptimizerInput(
@@ -370,7 +373,7 @@
         OptimizerInput(
             params=None, kwargs={"weight_decay": 0.1, "amsgrad": True}, desc="amsgrad"
         ),
-    ] + (cuda_supported_configs if str(device) == "cuda" else [])
+    ] + (cuda_supported_configs if "cuda" in str(device) else [])
 
 
 def optim_error_inputs_func_adam(device, dtype):
@@ -405,7 +408,7 @@
                 error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
             ),
         ]
-    if str(device) == "cuda":
+    if "cuda" in str(device):
         sample_tensor = torch.empty((), device=device, dtype=dtype)
         error_inputs += [
             ErrorOptimizerInput(
@@ -430,7 +433,7 @@
     return error_inputs
 
 
-def optim_inputs_func_adamax(device=None):
+def optim_inputs_func_adamax(device):
     cuda_supported_configs = [
         OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
         OptimizerInput(
@@ -461,11 +464,23 @@
             kwargs={"weight_decay": 0.1, "maximize": True},
             desc="maximize",
         ),
-    ] + (cuda_supported_configs if str(device) == "cuda" else [])
+    ] + (cuda_supported_configs if "cuda" in str(device) else [])
 
 
 def optim_error_inputs_func_adamax(device, dtype):
     error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if "cuda" in str(device):
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(foreach=False, capturable=True),
+                    desc="single tensor capturable not supported",
+                ),
+                error_type=ValueError,
+                error_regex="Capturable not supported with single tensor Adamax",
+            )
+        ]
     if str(device) == "cpu":
         error_inputs += [
             ErrorOptimizerInput(
@@ -481,15 +496,33 @@
     return error_inputs
 
 
-def optim_inputs_func_adamw(device=None):
-    return optim_inputs_func_adam(device=device)
+def optim_inputs_func_adamw(device):
+    return optim_inputs_func_adam(device)
 
 
 def optim_error_inputs_func_adamw(device, dtype):
     return optim_error_inputs_func_adam(device, dtype)
 
 
-def optim_inputs_func_asgd(device=None):
+def optim_inputs_func_asgd(device):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"maximize": True, "capturable": True},
+            desc="maximize, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "capturable": True},
+            desc="weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
+            desc="maximize, weight_decay, capturable",
+        ),
+    ]
     return [
         OptimizerInput(params=None, kwargs={}, desc="default"),
         OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
@@ -501,13 +534,25 @@
         OptimizerInput(
             params=None,
             kwargs={"weight_decay": 0.1, "maximize": True},
-            desc="maximize",
+            desc="maximize, nonzero weight_decay",
         ),
-    ]
+    ] + (cuda_supported_configs if "cuda" in str(device) else [])
 
 
 def optim_error_inputs_func_asgd(device, dtype):
     error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if "cuda" in str(device):
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(foreach=False, capturable=True),
+                    desc="single tensor capturable not supported",
+                ),
+                error_type=ValueError,
+                error_regex="Capturable not supported with single tensor ASGD",
+            )
+        ]
     if str(device) == "cpu":
         error_inputs += [
             ErrorOptimizerInput(
@@ -523,7 +568,7 @@
     return error_inputs
 
 
-def optim_inputs_func_lbfgs(device=None):
+def optim_inputs_func_lbfgs(device):
     return [
         OptimizerInput(params=None, kwargs={}, desc="default"),
         OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
@@ -544,7 +589,7 @@
 
 
 # Weird story bro, NAdam and RAdam do not have maximize.
-def optim_inputs_func_nadam(device=None):
+def optim_inputs_func_nadam(device):
     cuda_supported_configs = [
         OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
         OptimizerInput(
@@ -585,7 +630,7 @@
             },
             desc="decoupled_weight_decay",
         ),
-    ] + (cuda_supported_configs if str(device) == "cuda" else [])
+    ] + (cuda_supported_configs if "cuda" in str(device) else [])
 
 
 def optim_error_inputs_func_nadam(device, dtype):
@@ -653,6 +698,18 @@
 
 def optim_error_inputs_func_radam(device, dtype):
     error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if "cuda" in str(device):
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(foreach=False, capturable=True),
+                    desc="single tensor capturable not supported",
+                ),
+                error_type=ValueError,
+                error_regex="Capturable not supported with single tensor RAdam",
+            ),
+        ]
     if str(device) == "cpu":
         error_inputs += [
             ErrorOptimizerInput(
@@ -677,7 +734,7 @@
     return error_inputs
 
 
-def optim_inputs_func_rmsprop(device=None):
+def optim_inputs_func_rmsprop(device):
     return [
         OptimizerInput(params=None, kwargs={}, desc="default"),
         OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
@@ -724,7 +781,7 @@
     return error_inputs
 
 
-def optim_inputs_func_rprop(device=None):
+def optim_inputs_func_rprop(device):
     return [
         OptimizerInput(params=None, kwargs={}, desc="default"),
         OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"),
@@ -757,7 +814,7 @@
     return error_inputs
 
 
-def optim_inputs_func_sgd(device=None):
+def optim_inputs_func_sgd(device):
     return [
         OptimizerInput(params=None, kwargs={}, desc="default"),
         OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
@@ -802,7 +859,7 @@
     return error_inputs
 
 
-def optim_inputs_func_sparseadam(device=None):
+def optim_inputs_func_sparseadam(device):
     return [
         OptimizerInput(params=None, kwargs={}, desc="default"),
         OptimizerInput(
@@ -879,6 +936,14 @@
     return error_inputs
 
 
+def _get_device_type(device: Union[str, torch.device]) -> str:
+    # Returns the device type as a string, e.g., "cpu" or "cuda"
+    if isinstance(device, torch.device):
+        device = str(device.type)
+    assert isinstance(device, str)
+    return device.split(":")[0]
+
+
 def _get_optim_inputs_including_global_cliquey_kwargs(
     device, dtype, optim_info, skip=()
 ) -> List[OptimizerInput]:
@@ -897,14 +962,20 @@
         x in ["foreach", "fused", "differentiable"] for x in skip
     ), "skip must be a subset of ['foreach', 'fused', 'differentiable']"
 
-    optim_inputs = optim_info.optim_inputs_func(device=device)
+    optim_inputs = optim_info.optim_inputs_func(device)
 
     supported_impls = tuple(
         x
         for x in optim_info.supported_impls
         if x not in skip
-        and (str(device) in _get_fused_kernels_supported_devices() or x != "fused")
-        and (str(device) in _get_foreach_kernels_supported_devices() or x != "foreach")
+        and (
+            _get_device_type(device) in _get_fused_kernels_supported_devices()
+            or x != "fused"
+        )
+        and (
+            _get_device_type(device) in _get_foreach_kernels_supported_devices()
+            or x != "foreach"
+        )
     )
 
     all_optim_inputs = []
@@ -1131,6 +1202,7 @@
         optim_inputs_func=optim_inputs_func_adamax,
         optim_error_inputs_func=optim_error_inputs_func_adamax,
         supported_impls=("foreach", "differentiable"),
+        only_supports_capturable_on_foreach=True,  # Remove this line when #117836 is done!
         skips=(
             DecorateInfo(
                 skipIfMps,  # addcdiv doesn't work for non-contiguous, see #118115
@@ -1197,6 +1269,62 @@
                 "TestOptimRenewed",
                 "test_deepcopy_copies_all_public_attrs",
             ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "cpu fails due to #115607; both devices fail cuz #117836"
+                ),
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
+                ),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
+                ),
+                "TestOptimRenewed",
+                "test_step_is_noop_when_params_have_no_grad",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
+                ),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
+                ),
+                "TestOptimRenewed",
+                "test_param_groups_weight_decay",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
+                ),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
+                ),
+                "TestOptimRenewed",
+                "test_state_dict_with_cuda_params",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
+                ),
+                "TestOptimRenewed",
+                "test_mixed_device_dtype",
+            ),
         ),
     ),
     OptimizerInfo(
@@ -1267,6 +1395,7 @@
         optim_inputs_func=optim_inputs_func_asgd,
         optim_error_inputs_func=optim_error_inputs_func_asgd,
         supported_impls=("foreach", "differentiable"),
+        only_supports_capturable_on_foreach=True,  # Remove this line when #116052 is done!
         skips=(
             DecorateInfo(
                 skipIfTorchDynamo(
@@ -1455,6 +1584,7 @@
         optim_inputs_func=optim_inputs_func_radam,
         optim_error_inputs_func=optim_error_inputs_func_radam,
         supported_impls=("foreach", "differentiable"),
+        only_supports_capturable_on_foreach=True,  # Remove this line when #118230 is done!
         skips=(
             DecorateInfo(
                 skipIfTorchDynamo(
@@ -1545,6 +1675,13 @@
                     "Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
                 ),
                 "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
+                ),
+                "TestOptimRenewed",
                 "test_step_is_noop_when_params_have_no_grad",
             ),
             DecorateInfo(
@@ -1573,13 +1710,6 @@
                     "Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
                 ),
                 "TestOptimRenewed",
-                "test_step_is_noop_for_zero_grads",
-            ),
-            DecorateInfo(
-                skipIfTorchDynamo(
-                    "Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
-                ),
-                "TestOptimRenewed",
                 "test_state_dict_with_cuda_params",
             ),
             DecorateInfo(