Fix public API tests (#131386)
This PR fixes a bug in `test_correct_module_names` introduced in #130497. It also addresses post-fix test failures in:
* `torch/ao/quantization/__init__.py` - set the correct `__module__` for several public API helpers
* `torch/library.py` - add `register_vmap` to `__all__`
* `torch/nn/attention/flex_attention.py` - make `round_up_to_multiple` private by prepending an underscore
* `torch/storage.py` - introduce `__all__` to avoid `Self` being re-exported as a public API
* `torch/distributed/pipelining/schedules.py` - add `ZeroBubbleAlgorithm` to `__all__`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131386
Approved by: https://github.com/albanD
diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py
index 25b4afd..806a9bb 100644
--- a/test/test_public_bindings.py
+++ b/test/test_public_bindings.py
@@ -594,7 +594,7 @@
)
for mod in pkgutil.walk_packages(torch.__path__, "torch."):
- mod = mod.name
+ modname = mod.name
test_module(modname)
test_module("torch")
diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py
index 1f62f7f..adf70f7 100644
--- a/torch/ao/quantization/__init__.py
+++ b/torch/ao/quantization/__init__.py
@@ -30,8 +30,16 @@
from .stubs import * # noqa: F403
+# ensure __module__ is set correctly for public APIs
ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
+for _f in [
+ compare_results,
+ extract_results_from_loggers,
+ generate_numeric_debug_handle,
+ prepare_for_propagation_comparison,
+]:
+ _f.__module__ = "torch.ao.quantization"
__all__ = [
"DeQuantStub",
diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py
index b887a4d..f974fbf 100644
--- a/torch/distributed/pipelining/schedules.py
+++ b/torch/distributed/pipelining/schedules.py
@@ -27,6 +27,7 @@
"ScheduleGPipe",
"ScheduleInterleaved1F1B",
"ScheduleLoopedBFS",
+ "ZeroBubbleAlgorithm",
]
logger = logging.getLogger(__name__)
diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py
index 70d1c84..76cadd3 100644
--- a/torch/fx/experimental/graph_gradual_typechecker.py
+++ b/torch/fx/experimental/graph_gradual_typechecker.py
@@ -3,8 +3,8 @@
import torch
import operator
from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise
-from typing import Callable, Dict, TypeVar
-from typing_extensions import ParamSpec
+from typing import Callable, Dict, TypeVar as _TypeVar
+from typing_extensions import ParamSpec as _ParamSpec
from torch.fx.node import Target, Node
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.conv import Conv2d
@@ -15,8 +15,8 @@
import sympy
-_T = TypeVar("_T")
-_P = ParamSpec("_P")
+_T = _TypeVar("_T")
+_P = _ParamSpec("_P")
_INFERENCE_RULES: Dict[Target, Callable] = {}
_REFINEMENT_RULES: Dict[Target, Callable] = {}
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
index a90939f..cf46399 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
@@ -2,8 +2,8 @@
import torch
import operator
import warnings
-from typing import Callable, Dict, Iterable, TypeVar
-from typing_extensions import ParamSpec
+from typing import Callable, Dict, Iterable, TypeVar as _TypeVar
+from typing_extensions import ParamSpec as _ParamSpec
from torch.fx._symbolic_trace import _assert_is_none
from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \
@@ -19,8 +19,8 @@
from torch.nn.modules.conv import Conv2d
from torch.nn.modules.batchnorm import BatchNorm2d
-_T = TypeVar("_T")
-_P = ParamSpec("_P")
+_T = _TypeVar("_T")
+_P = _ParamSpec("_P")
_INFERENCE_RULES: Dict[Target, Callable] = {}
diff --git a/torch/library.py b/torch/library.py
index 960967f..abf7db5 100644
--- a/torch/library.py
+++ b/torch/library.py
@@ -43,6 +43,7 @@
"impl_abstract",
"register_fake",
"register_torch_dispatch",
+ "register_vmap",
"get_ctx",
"custom_op",
"infer_schema",
diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py
index bc7c83f..a1a2df8 100644
--- a/torch/nn/attention/flex_attention.py
+++ b/torch/nn/attention/flex_attention.py
@@ -450,7 +450,7 @@
return x
-def round_up_to_multiple(x, multiple):
+def _round_up_to_multiple(x, multiple):
return (x + multiple - 1) // multiple * multiple
@@ -684,8 +684,8 @@
mod_type == _ModificationType.MASK_MOD
), f"create-block_mask requires a mask_mod function! Got {mask_mod}"
inner_func = _create_block_mask_inner
- Q_LEN = Q_LEN if Q_LEN < 128 else round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
- KV_LEN = round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)
+ Q_LEN = Q_LEN if Q_LEN < 128 else _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
+ KV_LEN = _round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)
if _compile:
inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
with TransformGetItemToIndex():
@@ -702,8 +702,8 @@
of the query and key tensors.
"""
device = query.device
- kv_len = round_up_to_multiple(key.size()[-2], 128)
- q_len = round_up_to_multiple(query.size()[-2], 128)
+ kv_len = _round_up_to_multiple(key.size()[-2], 128)
+ q_len = _round_up_to_multiple(query.size()[-2], 128)
return BlockMask(
kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device),
kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device),
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index 234341e..eb853d3 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -7,8 +7,8 @@
import sys
import typing
import warnings
-from typing import Any, Callable, Literal, NoReturn, Sequence, TypeVar
-from typing_extensions import Concatenate, ParamSpec
+from typing import Any, Callable, Literal, NoReturn, Sequence, TypeVar as _TypeVar
+from typing_extensions import Concatenate as _Concatenate, ParamSpec as _ParamSpec
import torch
import torch._C._onnx as _C_onnx
@@ -22,9 +22,9 @@
if typing.TYPE_CHECKING:
from torch.types import Number
-_T = TypeVar("_T")
-_U = TypeVar("_U")
-_P = ParamSpec("_P")
+_T = _TypeVar("_T")
+_U = _TypeVar("_U")
+_P = _ParamSpec("_P")
# ---------------------------------------------------------------------------------
# Helper functions
@@ -204,7 +204,7 @@
def parse_args(
*arg_descriptors: _ValueDescriptor,
-) -> Callable[[Callable[Concatenate[_U, _P], _T]], Callable[Concatenate[_U, _P], _T]]:
+) -> Callable[[Callable[_Concatenate[_U, _P], _T]], Callable[_Concatenate[_U, _P], _T]]:
"""A decorator which converts args from torch._C.Value to built-in types.
For example:
@@ -233,8 +233,8 @@
"""
def decorator(
- fn: Callable[Concatenate[_U, _P], _T]
- ) -> Callable[Concatenate[_U, _P], _T]:
+ fn: Callable[_Concatenate[_U, _P], _T]
+ ) -> Callable[_Concatenate[_U, _P], _T]:
fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined]
@functools.wraps(fn)
diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py
index c0ec679..49c0739 100644
--- a/torch/optim/__init__.py
+++ b/torch/optim/__init__.py
@@ -23,6 +23,8 @@
from torch.optim.sgd import SGD
from torch.optim.sparse_adam import SparseAdam
+Adafactor.__module__ = "torch.optim"
+
del adadelta # type: ignore[name-defined] # noqa: F821
del adagrad # type: ignore[name-defined] # noqa: F821
diff --git a/torch/storage.py b/torch/storage.py
index 59023a7..b6ba608 100644
--- a/torch/storage.py
+++ b/torch/storage.py
@@ -16,6 +16,9 @@
from torch.types import _bool, _int, Storage
+__all__ = ["TypedStorage", "UntypedStorage"]
+
+
try:
import numpy as np