Fix wrong class inheritance in pyi (#116404)

As the title stated.

https://github.com/pytorch/pytorch/blob/f6dfbffb3bb46ada6fe66b5da4f989f9d4d69b3c/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L153

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116404
Approved by: https://github.com/ezyang, https://github.com/wconstab
diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi
index 8fc44f1..dc6787d 100644
--- a/torch/_C/_distributed_c10d.pyi
+++ b/torch/_C/_distributed_c10d.pyi
@@ -251,6 +251,19 @@
     @staticmethod
     def unbox(obj: ScriptObject) -> Work: ...
 
+class Backend:
+    def __init__(
+        self,
+        rank: int,
+        size: int,
+    ): ...
+    @property
+    def supports_splitting(self) -> bool: ...
+    def rank(self) -> int: ...
+    def size(self) -> int: ...
+    def eager_connect_single_device(self, device: Optional[torch.device]) -> None: ...
+    def _set_sequence_number_for_group(self) -> None: ...
+
 class ProcessGroup:
     class Options:
         def __init__(self, backend: str, timeout: timedelta = ...): ...
@@ -461,7 +474,7 @@
         self,
         device: torch.device,
         backend_type: BackendType,
-        backend: Optional[ProcessGroup],
+        backend: Optional[Backend],
     ) -> None: ...
     def _set_group_name(self, name: str) -> None: ...
     def name(self) -> str: ...
@@ -479,7 +492,7 @@
     process_groups: List[ProcessGroup],
 ) -> ProcessGroupRoundRobin: ...
 
-class ProcessGroupGloo(ProcessGroup):
+class ProcessGroupGloo(Backend):
     class Device: ...
     class Options: ...
 
@@ -496,11 +509,11 @@
     def create_default_device() -> Device: ...
     def _set_default_timeout(self, timeout) -> None: ...
 
-class _ProcessGroupWrapper(ProcessGroup):
-    def __init__(self, pg: ProcessGroup, gloo_pg: ProcessGroupGloo): ...
+class _ProcessGroupWrapper(Backend):
+    def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo): ...
     wrapped_pg: Backend
 
-class ProcessGroupNCCL(ProcessGroup):
+class ProcessGroupNCCL(Backend):
     class Options:
         def __init__(self, timeout: Optional[timedelta] = None): ...
         @property
@@ -525,7 +538,7 @@
     def _group_end(self) -> None: ...
     def _set_default_timeout(self, timeout) -> None: ...
 
-class ProcessGroupUCC(ProcessGroup):
+class ProcessGroupUCC(Backend):
     def __init__(
         self,
         store: Store,
@@ -534,7 +547,7 @@
         timeout: timedelta,
     ): ...
 
-class ProcessGroupMPI(ProcessGroup):
+class ProcessGroupMPI(Backend):
     def __init__(
         self,
         rank: int,
@@ -563,13 +576,3 @@
     logger: Optional[Logger],
 ): ...
 def _make_nccl_premul_sum(factor: Union[float, List[Tensor]]) -> ReduceOp: ...
-
-class Backend:
-    def __init__(
-        self,
-        rank: int,
-        size: int,
-    ): ...
-    @property
-    def supports_splitting(self) -> bool: ...
-    def eager_connect_single_device(self, device: Optional[torch.device]) -> None: ...
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 41a12a6..2aa9ea4 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -1380,7 +1380,7 @@
     if device_id:
         pg.bound_device_id = device_id
     backend_config = BackendConfig(backend)
-    backend_class: ProcessGroup
+    backend_class: torch._C._distributed_c10d.Backend
     for device, backend_str in backend_config.get_device_backend_map().items():
         # Use the group name as prefix in the default store, such that
         # a single store can be reused by multiple groups.
@@ -1473,7 +1473,7 @@
         # TODO: This defaults to the old behavior for PythonProcessGroups which overwrites the
         # ProcessGroup instance
         if issubclass(type(backend_class), ProcessGroup):
-            pg = backend_class
+            pg = backend_class  # type: ignore[assignment]
             break
 
         # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set
@@ -3629,7 +3629,7 @@
 
 
 def _create_process_group_wrapper(
-    wrapped_pg: ProcessGroup,
+    wrapped_pg: torch._C._distributed_c10d.Backend,
     store_prefix: str,
     store: Store,
     rank: int,