Fix public API tests (#131386)

This PR fixes a bug in `test_correct_module_names` introduced in #130497. It also addresses post-fix test failures in:
* `torch/ao/quantization/__init__.py` - set the correct `__module__` for several public API helpers
* `torch/library.py` - add `register_vmap` to `__all__`
* `torch/nn/attention/flex_attention.py` - make `round_up_to_multiple` private by prepending an underscore
* `torch/storage.py` - introduce `__all__` to avoid `Self` being re-exported as a public API
* `torch/distributed/pipelining/schedules.py` - add `ZeroBubbleAlgorithm` to `__all__`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131386
Approved by: https://github.com/albanD
diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py
index 25b4afd..806a9bb 100644
--- a/test/test_public_bindings.py
+++ b/test/test_public_bindings.py
@@ -594,7 +594,7 @@
                         )
 
         for mod in pkgutil.walk_packages(torch.__path__, "torch."):
-            mod = mod.name
+            modname = mod.name
             test_module(modname)
         test_module("torch")
 
diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py
index 1f62f7f..adf70f7 100644
--- a/torch/ao/quantization/__init__.py
+++ b/torch/ao/quantization/__init__.py
@@ -30,8 +30,16 @@
 from .stubs import *  # noqa: F403
 
 
+# ensure __module__ is set correctly for public APIs
 ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
 ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
+for _f in [
+    compare_results,
+    extract_results_from_loggers,
+    generate_numeric_debug_handle,
+    prepare_for_propagation_comparison,
+]:
+    _f.__module__ = "torch.ao.quantization"
 
 __all__ = [
     "DeQuantStub",
diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py
index b887a4d..f974fbf 100644
--- a/torch/distributed/pipelining/schedules.py
+++ b/torch/distributed/pipelining/schedules.py
@@ -27,6 +27,7 @@
     "ScheduleGPipe",
     "ScheduleInterleaved1F1B",
     "ScheduleLoopedBFS",
+    "ZeroBubbleAlgorithm",
 ]
 
 logger = logging.getLogger(__name__)
diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py
index 70d1c84..76cadd3 100644
--- a/torch/fx/experimental/graph_gradual_typechecker.py
+++ b/torch/fx/experimental/graph_gradual_typechecker.py
@@ -3,8 +3,8 @@
 import torch
 import operator
 from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise
-from typing import Callable, Dict, TypeVar
-from typing_extensions import ParamSpec
+from typing import Callable, Dict, TypeVar as _TypeVar
+from typing_extensions import ParamSpec as _ParamSpec
 from torch.fx.node import Target, Node
 from torch.nn.modules.batchnorm import BatchNorm2d
 from torch.nn.modules.conv import Conv2d
@@ -15,8 +15,8 @@
 
 import sympy
 
-_T = TypeVar("_T")
-_P = ParamSpec("_P")
+_T = _TypeVar("_T")
+_P = _ParamSpec("_P")
 
 _INFERENCE_RULES: Dict[Target, Callable] = {}
 _REFINEMENT_RULES: Dict[Target, Callable] = {}
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
index a90939f..cf46399 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
@@ -2,8 +2,8 @@
 import torch
 import operator
 import warnings
-from typing import Callable, Dict, Iterable, TypeVar
-from typing_extensions import ParamSpec
+from typing import Callable, Dict, Iterable, TypeVar as _TypeVar
+from typing_extensions import ParamSpec as _ParamSpec
 
 from torch.fx._symbolic_trace import _assert_is_none
 from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \
@@ -19,8 +19,8 @@
 from torch.nn.modules.conv import Conv2d
 from torch.nn.modules.batchnorm import BatchNorm2d
 
-_T = TypeVar("_T")
-_P = ParamSpec("_P")
+_T = _TypeVar("_T")
+_P = _ParamSpec("_P")
 
 _INFERENCE_RULES: Dict[Target, Callable] = {}
 
diff --git a/torch/library.py b/torch/library.py
index 960967f..abf7db5 100644
--- a/torch/library.py
+++ b/torch/library.py
@@ -43,6 +43,7 @@
     "impl_abstract",
     "register_fake",
     "register_torch_dispatch",
+    "register_vmap",
     "get_ctx",
     "custom_op",
     "infer_schema",
diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py
index bc7c83f..a1a2df8 100644
--- a/torch/nn/attention/flex_attention.py
+++ b/torch/nn/attention/flex_attention.py
@@ -450,7 +450,7 @@
     return x
 
 
-def round_up_to_multiple(x, multiple):
+def _round_up_to_multiple(x, multiple):
     return (x + multiple - 1) // multiple * multiple
 
 
@@ -684,8 +684,8 @@
         mod_type == _ModificationType.MASK_MOD
     ), f"create-block_mask requires a mask_mod function! Got {mask_mod}"
     inner_func = _create_block_mask_inner
-    Q_LEN = Q_LEN if Q_LEN < 128 else round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
-    KV_LEN = round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)
+    Q_LEN = Q_LEN if Q_LEN < 128 else _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
+    KV_LEN = _round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)
     if _compile:
         inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
     with TransformGetItemToIndex():
@@ -702,8 +702,8 @@
     of the query and key tensors.
     """
     device = query.device
-    kv_len = round_up_to_multiple(key.size()[-2], 128)
-    q_len = round_up_to_multiple(query.size()[-2], 128)
+    kv_len = _round_up_to_multiple(key.size()[-2], 128)
+    q_len = _round_up_to_multiple(query.size()[-2], 128)
     return BlockMask(
         kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device),
         kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device),
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index 234341e..eb853d3 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -7,8 +7,8 @@
 import sys
 import typing
 import warnings
-from typing import Any, Callable, Literal, NoReturn, Sequence, TypeVar
-from typing_extensions import Concatenate, ParamSpec
+from typing import Any, Callable, Literal, NoReturn, Sequence, TypeVar as _TypeVar
+from typing_extensions import Concatenate as _Concatenate, ParamSpec as _ParamSpec
 
 import torch
 import torch._C._onnx as _C_onnx
@@ -22,9 +22,9 @@
 if typing.TYPE_CHECKING:
     from torch.types import Number
 
-_T = TypeVar("_T")
-_U = TypeVar("_U")
-_P = ParamSpec("_P")
+_T = _TypeVar("_T")
+_U = _TypeVar("_U")
+_P = _ParamSpec("_P")
 
 # ---------------------------------------------------------------------------------
 # Helper functions
@@ -204,7 +204,7 @@
 
 def parse_args(
     *arg_descriptors: _ValueDescriptor,
-) -> Callable[[Callable[Concatenate[_U, _P], _T]], Callable[Concatenate[_U, _P], _T]]:
+) -> Callable[[Callable[_Concatenate[_U, _P], _T]], Callable[_Concatenate[_U, _P], _T]]:
     """A decorator which converts args from torch._C.Value to built-in types.
 
     For example:
@@ -233,8 +233,8 @@
     """
 
     def decorator(
-        fn: Callable[Concatenate[_U, _P], _T]
-    ) -> Callable[Concatenate[_U, _P], _T]:
+        fn: Callable[_Concatenate[_U, _P], _T]
+    ) -> Callable[_Concatenate[_U, _P], _T]:
         fn._arg_descriptors = arg_descriptors  # type: ignore[attr-defined]
 
         @functools.wraps(fn)
diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py
index c0ec679..49c0739 100644
--- a/torch/optim/__init__.py
+++ b/torch/optim/__init__.py
@@ -23,6 +23,8 @@
 from torch.optim.sgd import SGD
 from torch.optim.sparse_adam import SparseAdam
 
+Adafactor.__module__ = "torch.optim"
+
 
 del adadelta  # type: ignore[name-defined] # noqa: F821
 del adagrad  # type: ignore[name-defined] # noqa: F821
diff --git a/torch/storage.py b/torch/storage.py
index 59023a7..b6ba608 100644
--- a/torch/storage.py
+++ b/torch/storage.py
@@ -16,6 +16,9 @@
 from torch.types import _bool, _int, Storage
 
 
+__all__ = ["TypedStorage", "UntypedStorage"]
+
+
 try:
     import numpy as np