[BE] typing for decorators - optim/optimizer (#131583)
See #131429
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131583
Approved by: https://github.com/janeyx99
ghstack dependencies: #131568, #131569, #131570, #131571, #131572, #131573, #131574, #131575, #131576, #131577, #131578, #131579, #131580, #131581, #131582
diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py
index fb719b3..bb2335b 100644
--- a/torch/optim/adadelta.py
+++ b/torch/optim/adadelta.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional, Union
diff --git a/torch/optim/adam.py b/torch/optim/adam.py
index 6610839..235880f 100644
--- a/torch/optim/adam.py
+++ b/torch/optim/adam.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union
diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py
index b74c9e0..639a7f5 100644
--- a/torch/optim/adamax.py
+++ b/torch/optim/adamax.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union
diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py
index 0687d13..21b90fe 100644
--- a/torch/optim/adamw.py
+++ b/torch/optim/adamw.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import cast, List, Optional, Tuple, Union
diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py
index 163dbc0..025fb50 100644
--- a/torch/optim/asgd.py
+++ b/torch/optim/asgd.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union
diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py
index 6ec27df..4de66ef 100644
--- a/torch/optim/nadam.py
+++ b/torch/optim/nadam.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
r"""Implementation for the NAdam algorithm."""
from typing import cast, List, Optional, Tuple, Union
diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py
index 6cbc18b..217fa60 100644
--- a/torch/optim/optimizer.py
+++ b/torch/optim/optimizer.py
@@ -17,6 +17,7 @@
List,
Optional,
overload,
+ Sequence,
Set,
Tuple,
TypeVar,
@@ -35,6 +36,9 @@
)
from torch.utils.hooks import RemovableHandle
+_T = TypeVar("_T")
+_P = ParamSpec("_P")
+
Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]
@@ -112,7 +116,9 @@
return x
-def _disable_dynamo_if_unsupported(single_tensor_fn=None):
+def _disable_dynamo_if_unsupported(
+ single_tensor_fn: Optional[Callable[..., object]] = None
+) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
# workaround for torchscript BC
# it requires all called functions to be in the
# global environment at the site at which the
@@ -120,7 +126,7 @@
if single_tensor_fn:
globals()[single_tensor_fn.__name__] = single_tensor_fn
- def wrapper(func):
+ def wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
import inspect
disabled_func = torch._disable_dynamo(func)
@@ -137,15 +143,18 @@
# but this only occurs in the rare case that the user explicitly deletes
# the capturable flag. If capturable=True, this is not a problem.
@functools.wraps(func)
- def maybe_fallback(*args, **kwargs):
+ def maybe_fallback(*args: _P.args, **kwargs: _P.kwargs):
if is_compiling() and (
not kwargs.get("capturable", False)
and has_state_steps
- and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda)
+ and (
+ isinstance(arg := args[state_steps_ind], Sequence)
+ and arg[0].is_cuda
+ )
or (
"state_steps" in kwargs
- and kwargs["state_steps"]
- and kwargs["state_steps"][0].is_cuda
+ and isinstance(arg := kwargs["state_steps"], Sequence)
+ and arg[0].is_cuda
)
):
return disabled_func(*args, **kwargs)
@@ -300,7 +309,6 @@
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
-_P = ParamSpec("_P")
R = TypeVar("R")
T = TypeVar("T")
diff --git a/torch/optim/radam.py b/torch/optim/radam.py
index 6ec7ca2..20a06be 100644
--- a/torch/optim/radam.py
+++ b/torch/optim/radam.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
r"""Implementation for the RAdam algorithm."""
from typing import cast, List, Optional, Tuple, Union
diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py
index 860d1f6..5f929c0 100644
--- a/torch/optim/rmsprop.py
+++ b/torch/optim/rmsprop.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
r"""Implementation for the RMSprop algorithm."""
from typing import List, Optional, Union
diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py
index 395479b..bba6d53 100644
--- a/torch/optim/rprop.py
+++ b/torch/optim/rprop.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
r"""Implementation for the Resilient backpropagation."""
from typing import List, Optional, Tuple, Union