Make ModuleList derive from Sequence[T] and type it appropriately (#89135)

I see https://github.com/pytorch/pytorch/issues/53103 says this might be problematic, but I'm a bit confused at this point, because it looks like ModuleList does in fact already adhere to the Sequence API

The big win here is that for homogenous ModuleLists, you now get typing for individual members, e.g.
`ModuleList([Linear(), Linear(), Linear()])[1]` properly has type `Linear`

If this looks good, I can do a followup PR to do similarly for `ModuleDict` and `Parameter[List,Dict]`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89135
Approved by: https://github.com/albanD
diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py
index 079a878..6281531 100644
--- a/torch/nn/modules/container.py
+++ b/torch/nn/modules/container.py
@@ -8,7 +8,7 @@
 from ..parameter import Parameter
 from torch._jit_internal import _copy_to_script_wrapper
 
-from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
+from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Sequence, Tuple, TypeVar, Union
 
 __all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict']
 
@@ -247,7 +247,10 @@
         return self
 
 
-class ModuleList(Module):
+# `Sequence` but not a `MutableSequence` since as currently implemented, the signatures
+# of the mutable methods are incompatible, e.g. `append` returns `self` and `pop` takes
+# a slice.
+class ModuleList(Module, Sequence[T]):
     r"""Holds submodules in a list.
 
     :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
@@ -271,9 +274,9 @@
                 return x
     """
 
-    _modules: Dict[str, Module]  # type: ignore[assignment]
+    _modules: Dict[str, T]  # type: ignore[assignment]
 
-    def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
+    def __init__(self, modules: Optional[Iterable[T]] = None) -> None:
         super(ModuleList, self).__init__()
         if modules is not None:
             self += modules
@@ -287,14 +290,22 @@
             idx += len(self)
         return str(idx)
 
+    @overload
+    def __getitem__(self, idx: int) -> T:
+        ...
+
+    @overload
+    def __getitem__(self, idx: slice) -> 'ModuleList[T]':
+        ...
+
     @_copy_to_script_wrapper
-    def __getitem__(self, idx: Union[int, slice]) -> Union[Module, 'ModuleList']:
+    def __getitem__(self, idx: Union[int, slice]) -> Union[T, 'ModuleList[T]']:
         if isinstance(idx, slice):
             return self.__class__(list(self._modules.values())[idx])
         else:
             return self._modules[self._get_abs_string_index(idx)]
 
-    def __setitem__(self, idx: int, module: Module) -> None:
+    def __setitem__(self, idx: int, module: T) -> None:
         idx = self._get_abs_string_index(idx)
         return setattr(self, str(idx), module)
 
@@ -313,14 +324,14 @@
         return len(self._modules)
 
     @_copy_to_script_wrapper
-    def __iter__(self) -> Iterator[Module]:
+    def __iter__(self) -> Iterator[T]:
         return iter(self._modules.values())
 
-    def __iadd__(self, modules: Iterable[Module]) -> 'ModuleList':
+    def __iadd__(self, modules: Iterable[T]) -> 'ModuleList[T]':
         return self.extend(modules)
 
-    def __add__(self, other: Iterable[Module]) -> 'ModuleList':
-        combined = ModuleList()
+    def __add__(self, other: Iterable[T]) -> 'ModuleList[T]':
+        combined: ModuleList[T] = ModuleList()
         for i, module in enumerate(chain(self, other)):
             combined.add_module(str(i), module)
         return combined
@@ -363,7 +374,7 @@
         keys = [key for key in keys if not key.isdigit()]
         return keys
 
-    def insert(self, index: int, module: Module) -> None:
+    def insert(self, index: int, module: T) -> None:
         r"""Insert a given module before a given index in the list.
 
         Args:
@@ -374,7 +385,7 @@
             self._modules[str(i)] = self._modules[str(i - 1)]
         self._modules[str(index)] = module
 
-    def append(self, module: Module) -> 'ModuleList':
+    def append(self, module: T) -> 'ModuleList[T]':
         r"""Appends a given module to the end of the list.
 
         Args:
@@ -383,12 +394,20 @@
         self.add_module(str(len(self)), module)
         return self
 
-    def pop(self, key: Union[int, slice]) -> Module:
+    @overload
+    def pop(self, key: int) -> T:
+        ...
+
+    @overload
+    def pop(self, key: slice) -> 'ModuleList[T]':
+        ...
+
+    def pop(self, key: Union[int, slice]) -> Union[T, 'ModuleList[T]']:
         v = self[key]
         del self[key]
         return v
 
-    def extend(self, modules: Iterable[Module]) -> 'ModuleList':
+    def extend(self, modules: Iterable[T]) -> 'ModuleList[T]':
         r"""Appends modules from a Python iterable to the end of the list.
 
         Args: