[BE] typing for decorators - optim/optimizer (#131583)

See #131429
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131583
Approved by: https://github.com/janeyx99
ghstack dependencies: #131568, #131569, #131570, #131571, #131572, #131573, #131574, #131575, #131576, #131577, #131578, #131579, #131580, #131581, #131582
diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py
index fb719b3..bb2335b 100644
--- a/torch/optim/adadelta.py
+++ b/torch/optim/adadelta.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
 # mypy: allow-untyped-defs
 from typing import Any, Dict, List, Optional, Union
 
diff --git a/torch/optim/adam.py b/torch/optim/adam.py
index 6610839..235880f 100644
--- a/torch/optim/adam.py
+++ b/torch/optim/adam.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
 # mypy: allow-untyped-defs
 from typing import List, Optional, Tuple, Union
 
diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py
index b74c9e0..639a7f5 100644
--- a/torch/optim/adamax.py
+++ b/torch/optim/adamax.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
 # mypy: allow-untyped-defs
 from typing import List, Optional, Tuple, Union
 
diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py
index 0687d13..21b90fe 100644
--- a/torch/optim/adamw.py
+++ b/torch/optim/adamw.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
 # mypy: allow-untyped-defs
 from typing import cast, List, Optional, Tuple, Union
 
diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py
index 163dbc0..025fb50 100644
--- a/torch/optim/asgd.py
+++ b/torch/optim/asgd.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
 # mypy: allow-untyped-defs
 from typing import List, Optional, Tuple, Union
 
diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py
index 6ec27df..4de66ef 100644
--- a/torch/optim/nadam.py
+++ b/torch/optim/nadam.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
 # mypy: allow-untyped-defs
 r"""Implementation for the NAdam algorithm."""
 from typing import cast, List, Optional, Tuple, Union
diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py
index 6cbc18b..217fa60 100644
--- a/torch/optim/optimizer.py
+++ b/torch/optim/optimizer.py
@@ -17,6 +17,7 @@
     List,
     Optional,
     overload,
+    Sequence,
     Set,
     Tuple,
     TypeVar,
@@ -35,6 +36,9 @@
 )
 from torch.utils.hooks import RemovableHandle
 
+_T = TypeVar("_T")
+_P = ParamSpec("_P")
+
 Args: TypeAlias = Tuple[Any, ...]
 Kwargs: TypeAlias = Dict[str, Any]
 StateDict: TypeAlias = Dict[str, Any]
@@ -112,7 +116,9 @@
         return x
 
 
-def _disable_dynamo_if_unsupported(single_tensor_fn=None):
+def _disable_dynamo_if_unsupported(
+    single_tensor_fn: Optional[Callable[..., object]] = None
+) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
     # workaround for torchscript BC
     # it requires all called functions to be in the
     # global environment at the site at which the
@@ -120,7 +126,7 @@
     if single_tensor_fn:
         globals()[single_tensor_fn.__name__] = single_tensor_fn
 
-    def wrapper(func):
+    def wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
         import inspect
 
         disabled_func = torch._disable_dynamo(func)
@@ -137,15 +143,18 @@
         # but this only occurs in the rare case that the user explicitly deletes
         # the capturable flag. If capturable=True, this is not a problem.
         @functools.wraps(func)
-        def maybe_fallback(*args, **kwargs):
+        def maybe_fallback(*args: _P.args, **kwargs: _P.kwargs):
             if is_compiling() and (
                 not kwargs.get("capturable", False)
                 and has_state_steps
-                and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda)
+                and (
+                    isinstance(arg := args[state_steps_ind], Sequence)
+                    and arg[0].is_cuda
+                )
                 or (
                     "state_steps" in kwargs
-                    and kwargs["state_steps"]
-                    and kwargs["state_steps"][0].is_cuda
+                    and isinstance(arg := kwargs["state_steps"], Sequence)
+                    and arg[0].is_cuda
                 )
             ):
                 return disabled_func(*args, **kwargs)
@@ -300,7 +309,6 @@
 
 ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
 
-_P = ParamSpec("_P")
 R = TypeVar("R")
 T = TypeVar("T")
 
diff --git a/torch/optim/radam.py b/torch/optim/radam.py
index 6ec7ca2..20a06be 100644
--- a/torch/optim/radam.py
+++ b/torch/optim/radam.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
 # mypy: allow-untyped-defs
 r"""Implementation for the RAdam algorithm."""
 from typing import cast, List, Optional, Tuple, Union
diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py
index 860d1f6..5f929c0 100644
--- a/torch/optim/rmsprop.py
+++ b/torch/optim/rmsprop.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
 # mypy: allow-untyped-defs
 r"""Implementation for the RMSprop algorithm."""
 from typing import List, Optional, Union
diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py
index 395479b..bba6d53 100644
--- a/torch/optim/rprop.py
+++ b/torch/optim/rprop.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
 # mypy: allow-untyped-defs
 r"""Implementation for the Resilient backpropagation."""
 from typing import List, Optional, Tuple, Union