Revert "Add __all__ for torch.distributed and fx modules (#80460)"
This reverts commit 5d40c3d5c8225f3eec5e30693c9287342798854a.
Reverted https://github.com/pytorch/pytorch/pull/80460 on behalf of https://github.com/malfet due to Broke MacOS testing, see https://github.com/pytorch/pytorch/runs/7105579664?check_suite_focus=true
diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json
index 8744d43..7a2133b 100644
--- a/test/allowlist_for_publicAPI.json
+++ b/test/allowlist_for_publicAPI.json
@@ -548,6 +548,9 @@
"List",
"Template"
],
+ "torch.distributed.elastic.utils.data.elastic_distributed_sampler": [
+ "DistributedSampler"
+ ],
"torch.distributed.elastic.utils.logging": [
"Optional",
"get_log_level"
@@ -556,6 +559,22 @@
"List",
"timedelta"
],
+ "torch.distributed.fsdp.flatten_params_wrapper": [
+ "Any",
+ "Dict",
+ "Generator",
+ "Iterator",
+ "List",
+ "NamedTuple",
+ "Optional",
+ "ParamOffset",
+ "Sequence",
+ "SharedParamInfo",
+ "Tensor",
+ "Tuple",
+ "Union",
+ "accumulate"
+ ],
"torch.distributed.fsdp.utils": [
"Any",
"Callable",
@@ -692,6 +711,19 @@
"Copy",
"Wait"
],
+ "torch.distributed.pipeline.sync.dependency": [
+ "Fork",
+ "Join",
+ "fork",
+ "join"
+ ],
+ "torch.distributed.pipeline.sync.microbatch": [
+ "Batch",
+ "NoChunk",
+ "check",
+ "gather",
+ "scatter"
+ ],
"torch.distributed.pipeline.sync.phony": [
"get_phony"
],
@@ -975,6 +1007,22 @@
"dataclass",
"map_arg"
],
+ "torch.fx.graph_module": [
+ "Any",
+ "Dict",
+ "Graph",
+ "Importer",
+ "List",
+ "Optional",
+ "PackageExporter",
+ "PackageImporter",
+ "Path",
+ "PythonCode",
+ "Set",
+ "Type",
+ "Union",
+ "compatibility"
+ ],
"torch.fx.immutable_collections": [
"Any",
"Context",
@@ -1040,9 +1088,46 @@
"map_aggregate",
"map_arg"
],
+ "torch.fx.passes.param_fetch": [
+ "Any",
+ "Callable",
+ "Dict",
+ "GraphModule",
+ "List",
+ "Tuple",
+ "Type",
+ "compatibility"
+ ],
+ "torch.fx.passes.shape_prop": [
+ "Any",
+ "Dict",
+ "NamedTuple",
+ "Node",
+ "Optional",
+ "Tuple",
+ "compatibility",
+ "map_aggregate"
+ ],
+ "torch.fx.passes.split_module": [
+ "Any",
+ "Callable",
+ "Dict",
+ "GraphModule",
+ "List",
+ "Optional",
+ "compatibility"
+ ],
"torch.fx.proxy": [
"assert_fn"
],
+ "torch.hub": [
+ "HTTPError",
+ "Path",
+ "Request",
+ "tqdm",
+ "urlopen",
+ "urlparse"
+ ],
"torch.jit": [
"Attribute",
"Final",
diff --git a/torch/distributed/elastic/multiprocessing/errors/__init__.py b/torch/distributed/elastic/multiprocessing/errors/__init__.py
index 5413c46..be955d2 100644
--- a/torch/distributed/elastic/multiprocessing/errors/__init__.py
+++ b/torch/distributed/elastic/multiprocessing/errors/__init__.py
@@ -65,6 +65,7 @@
from .error_handler import ErrorHandler # noqa: F401
from .handlers import get_error_handler # noqa: F401
+
log = get_logger()
@@ -75,6 +76,7 @@
T = TypeVar("T")
+
@dataclass
class ProcessFailure:
"""
diff --git a/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py b/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py
index 597d16b..dcf20dc 100644
--- a/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py
+++ b/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py
@@ -11,7 +11,6 @@
import torch
from torch.utils.data.distributed import DistributedSampler
-__all__ = ['ElasticDistributedSampler']
class ElasticDistributedSampler(DistributedSampler):
"""
diff --git a/torch/distributed/fsdp/flatten_params_wrapper.py b/torch/distributed/fsdp/flatten_params_wrapper.py
index 4292d48..97e086f 100644
--- a/torch/distributed/fsdp/flatten_params_wrapper.py
+++ b/torch/distributed/fsdp/flatten_params_wrapper.py
@@ -32,7 +32,6 @@
FLAT_PARAM = "flat_param"
FPW_MODULE = "_fpw_module"
-__all__ = ['ParamInfo', 'ShardMetadata', 'FlatParameter', 'FlattenParamsWrapper']
def _post_state_dict_hook(
module: nn.Module, state_dict: Dict[str, Any], prefix: str, *args: Any
diff --git a/torch/distributed/pipeline/sync/dependency.py b/torch/distributed/pipeline/sync/dependency.py
index c27b577..a5a7ba5 100644
--- a/torch/distributed/pipeline/sync/dependency.py
+++ b/torch/distributed/pipeline/sync/dependency.py
@@ -5,14 +5,14 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Arbitrary dependency between two autograd lanes."""
-from typing import Tuple
+from typing import List, Tuple
import torch
from torch import Tensor
from .phony import get_phony
-__all__ = ['fork', 'Fork', 'join', 'Join']
+__all__: List[str] = []
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
diff --git a/torch/distributed/pipeline/sync/microbatch.py b/torch/distributed/pipeline/sync/microbatch.py
index 0a17936..3612332 100644
--- a/torch/distributed/pipeline/sync/microbatch.py
+++ b/torch/distributed/pipeline/sync/microbatch.py
@@ -12,7 +12,7 @@
from torch import Tensor
import torch.cuda.comm
-__all__ = ['NoChunk', 'Batch', 'check', 'scatter', 'gather']
+__all__: List[str] = []
Tensors = Sequence[Tensor]
diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py
index a9d0e43..8912518 100644
--- a/torch/fx/graph_module.py
+++ b/torch/fx/graph_module.py
@@ -16,8 +16,6 @@
import os
import warnings
-__all__ = ['reduce_graph_module', 'reduce_package_graph_module', 'reduce_deploy_graph_module', 'GraphModule']
-
# Normal exec loses the source code, however we can work with
# the linecache module to recover it.
# Using _exec_with_source will add it to our local cache
diff --git a/torch/fx/passes/param_fetch.py b/torch/fx/passes/param_fetch.py
index 5979e29..41d7599 100644
--- a/torch/fx/passes/param_fetch.py
+++ b/torch/fx/passes/param_fetch.py
@@ -5,7 +5,6 @@
from torch.fx._compatibility import compatibility
-__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes']
# Matching method matches the attribute name of current version to the attribute name of `target_version`
@compatibility(is_backward_compatible=False)
diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py
index 9c3a036..f7feadd 100644
--- a/torch/fx/passes/shape_prop.py
+++ b/torch/fx/passes/shape_prop.py
@@ -6,7 +6,6 @@
from typing import Any, Tuple, NamedTuple, Optional, Dict
from torch.fx._compatibility import compatibility
-__all__ = ['TensorMetadata', 'ShapeProp']
@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):
diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py
index 1795de3..1bd5918 100644
--- a/torch/fx/passes/split_module.py
+++ b/torch/fx/passes/split_module.py
@@ -4,8 +4,6 @@
from torch.fx._compatibility import compatibility
import inspect
-__all__ = ['Partition', 'split_module']
-
@compatibility(is_backward_compatible=True)
class Partition:
def __init__(self, name: str):
diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py
index fb065c4..e030896 100644
--- a/torch/utils/checkpoint.py
+++ b/torch/utils/checkpoint.py
@@ -2,6 +2,7 @@
import warnings
from typing import Any, Dict, Iterable, List, Optional, Tuple
+
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
if isinstance(inputs, tuple):
out = []