`nn.ModuleList.__getitem__` overloads (#132834)
Overloads so that you can get more specific type info based on how you are indexing.
```python
from torch import nn
module_list = nn.ModuleList(32 * [nn.Linear(2, 2)])
# before:
reveal_type(module_list[0]) # Type of "module_list[0]" is "Module | ModuleList"
reveal_type(module_list[:1]) # Type of "module_list[: 1]" is "Module | ModuleList"
# now:
reveal_type(module_list[0]) # Type of "module_list[0]" is "Module"
reveal_type(module_list[:1]) # Type of "module_list[: 1]" is "ModuleList"
```
Co-authored-by: Skylion007 <Skylion007@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132834
Approved by: https://github.com/Skylion007, https://github.com/albanD
diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py
index 585f4ef..61074a9 100644
--- a/torch/nn/modules/container.py
+++ b/torch/nn/modules/container.py
@@ -318,6 +318,14 @@
idx += len(self)
return str(idx)
+ @overload
+ def __getitem__(self, idx: slice) -> "ModuleList":
+ ...
+
+ @overload
+ def __getitem__(self, idx: int) -> Module:
+ ...
+
@_copy_to_script_wrapper
def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]:
if isinstance(idx, slice):