Fix wrong class inheritance in pyi (#116404)
As the title stated.
https://github.com/pytorch/pytorch/blob/f6dfbffb3bb46ada6fe66b5da4f989f9d4d69b3c/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L153
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116404
Approved by: https://github.com/ezyang, https://github.com/wconstab
diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi
index 8fc44f1..dc6787d 100644
--- a/torch/_C/_distributed_c10d.pyi
+++ b/torch/_C/_distributed_c10d.pyi
@@ -251,6 +251,19 @@
@staticmethod
def unbox(obj: ScriptObject) -> Work: ...
+class Backend:
+ def __init__(
+ self,
+ rank: int,
+ size: int,
+ ): ...
+ @property
+ def supports_splitting(self) -> bool: ...
+ def rank(self) -> int: ...
+ def size(self) -> int: ...
+ def eager_connect_single_device(self, device: Optional[torch.device]) -> None: ...
+ def _set_sequence_number_for_group(self) -> None: ...
+
class ProcessGroup:
class Options:
def __init__(self, backend: str, timeout: timedelta = ...): ...
@@ -461,7 +474,7 @@
self,
device: torch.device,
backend_type: BackendType,
- backend: Optional[ProcessGroup],
+ backend: Optional[Backend],
) -> None: ...
def _set_group_name(self, name: str) -> None: ...
def name(self) -> str: ...
@@ -479,7 +492,7 @@
process_groups: List[ProcessGroup],
) -> ProcessGroupRoundRobin: ...
-class ProcessGroupGloo(ProcessGroup):
+class ProcessGroupGloo(Backend):
class Device: ...
class Options: ...
@@ -496,11 +509,11 @@
def create_default_device() -> Device: ...
def _set_default_timeout(self, timeout) -> None: ...
-class _ProcessGroupWrapper(ProcessGroup):
- def __init__(self, pg: ProcessGroup, gloo_pg: ProcessGroupGloo): ...
+class _ProcessGroupWrapper(Backend):
+ def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo): ...
wrapped_pg: Backend
-class ProcessGroupNCCL(ProcessGroup):
+class ProcessGroupNCCL(Backend):
class Options:
def __init__(self, timeout: Optional[timedelta] = None): ...
@property
@@ -525,7 +538,7 @@
def _group_end(self) -> None: ...
def _set_default_timeout(self, timeout) -> None: ...
-class ProcessGroupUCC(ProcessGroup):
+class ProcessGroupUCC(Backend):
def __init__(
self,
store: Store,
@@ -534,7 +547,7 @@
timeout: timedelta,
): ...
-class ProcessGroupMPI(ProcessGroup):
+class ProcessGroupMPI(Backend):
def __init__(
self,
rank: int,
@@ -563,13 +576,3 @@
logger: Optional[Logger],
): ...
def _make_nccl_premul_sum(factor: Union[float, List[Tensor]]) -> ReduceOp: ...
-
-class Backend:
- def __init__(
- self,
- rank: int,
- size: int,
- ): ...
- @property
- def supports_splitting(self) -> bool: ...
- def eager_connect_single_device(self, device: Optional[torch.device]) -> None: ...
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 41a12a6..2aa9ea4 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -1380,7 +1380,7 @@
if device_id:
pg.bound_device_id = device_id
backend_config = BackendConfig(backend)
- backend_class: ProcessGroup
+ backend_class: torch._C._distributed_c10d.Backend
for device, backend_str in backend_config.get_device_backend_map().items():
# Use the group name as prefix in the default store, such that
# a single store can be reused by multiple groups.
@@ -1473,7 +1473,7 @@
# TODO: This defaults to the old behavior for PythonProcessGroups which overwrites the
# ProcessGroup instance
if issubclass(type(backend_class), ProcessGroup):
- pg = backend_class
+ pg = backend_class # type: ignore[assignment]
break
# Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set
@@ -3629,7 +3629,7 @@
def _create_process_group_wrapper(
- wrapped_pg: ProcessGroup,
+ wrapped_pg: torch._C._distributed_c10d.Backend,
store_prefix: str,
store: Store,
rank: int,