| # Owner(s): ["oncall: jit"] |
| |
| import os |
| import sys |
| import unittest |
| import warnings |
| from typing import Dict, List, Optional |
| |
| import torch |
| |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| |
| if __name__ == "__main__": |
| raise RuntimeError( |
| "This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead." |
| ) |
| |
| |
| class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): |
| # NB: There are no tests for `Tuple` or `NamedTuple` here. In fact, |
| # reassigning a non-empty Tuple to an attribute previously typed |
| # as containing an empty Tuple SHOULD fail. See note in `_check.py` |
| |
| def test_annotated_falsy_base_type(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x: int = 0 |
| |
| def forward(self, x: int): |
| self.x = x |
| return 1 |
| |
| with warnings.catch_warnings(record=True) as w: |
| self.checkModule(M(), (1,)) |
| assert len(w) == 0 |
| |
| def test_annotated_nonempty_container(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x: List[int] = [1, 2, 3] |
| |
| def forward(self, x: List[int]): |
| self.x = x |
| return 1 |
| |
| with warnings.catch_warnings(record=True) as w: |
| self.checkModule(M(), ([1, 2, 3],)) |
| assert len(w) == 0 |
| |
| def test_annotated_empty_tensor(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x: torch.Tensor = torch.empty(0) |
| |
| def forward(self, x: torch.Tensor): |
| self.x = x |
| return self.x |
| |
| with warnings.catch_warnings(record=True) as w: |
| self.checkModule(M(), (torch.rand(2, 3),)) |
| assert len(w) == 0 |
| |
| def test_annotated_with_jit_attribute(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x = torch.jit.Attribute([], List[int]) |
| |
| def forward(self, x: List[int]): |
| self.x = x |
| return self.x |
| |
| with warnings.catch_warnings(record=True) as w: |
| self.checkModule(M(), ([1, 2, 3],)) |
| assert len(w) == 0 |
| |
| def test_annotated_class_level_annotation_only(self): |
| class M(torch.nn.Module): |
| x: List[int] |
| |
| def __init__(self) -> None: |
| super().__init__() |
| self.x = [] |
| |
| def forward(self, y: List[int]): |
| self.x = y |
| return self.x |
| |
| with warnings.catch_warnings(record=True) as w: |
| self.checkModule(M(), ([1, 2, 3],)) |
| assert len(w) == 0 |
| |
| def test_annotated_class_level_annotation_and_init_annotation(self): |
| class M(torch.nn.Module): |
| x: List[int] |
| |
| def __init__(self) -> None: |
| super().__init__() |
| self.x: List[int] = [] |
| |
| def forward(self, y: List[int]): |
| self.x = y |
| return self.x |
| |
| with warnings.catch_warnings(record=True) as w: |
| self.checkModule(M(), ([1, 2, 3],)) |
| assert len(w) == 0 |
| |
| def test_annotated_class_level_jit_annotation(self): |
| class M(torch.nn.Module): |
| x: List[int] |
| |
| def __init__(self) -> None: |
| super().__init__() |
| self.x: List[int] = torch.jit.annotate(List[int], []) |
| |
| def forward(self, y: List[int]): |
| self.x = y |
| return self.x |
| |
| with warnings.catch_warnings(record=True) as w: |
| self.checkModule(M(), ([1, 2, 3],)) |
| assert len(w) == 0 |
| |
| def test_annotated_empty_list(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x: List[int] = [] |
| |
| def forward(self, x: List[int]): |
| self.x = x |
| return 1 |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, "Tried to set nonexistent attribute", "self.x = x" |
| ): |
| with self.assertWarnsRegex( |
| UserWarning, |
| "doesn't support " |
| "instance-level annotations on " |
| "empty non-base types", |
| ): |
| torch.jit.script(M()) |
| |
| @unittest.skipIf( |
| sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)" |
| ) |
| def test_annotated_empty_list_lowercase(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x: list[int] = [] |
| |
| def forward(self, x: list[int]): |
| self.x = x |
| return 1 |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, "Tried to set nonexistent attribute", "self.x = x" |
| ): |
| with self.assertWarnsRegex( |
| UserWarning, |
| "doesn't support " |
| "instance-level annotations on " |
| "empty non-base types", |
| ): |
| torch.jit.script(M()) |
| |
| def test_annotated_empty_dict(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x: Dict[str, int] = {} |
| |
| def forward(self, x: Dict[str, int]): |
| self.x = x |
| return 1 |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, "Tried to set nonexistent attribute", "self.x = x" |
| ): |
| with self.assertWarnsRegex( |
| UserWarning, |
| "doesn't support " |
| "instance-level annotations on " |
| "empty non-base types", |
| ): |
| torch.jit.script(M()) |
| |
| @unittest.skipIf( |
| sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)" |
| ) |
| def test_annotated_empty_dict_lowercase(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x: dict[str, int] = {} |
| |
| def forward(self, x: dict[str, int]): |
| self.x = x |
| return 1 |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, "Tried to set nonexistent attribute", "self.x = x" |
| ): |
| with self.assertWarnsRegex( |
| UserWarning, |
| "doesn't support " |
| "instance-level annotations on " |
| "empty non-base types", |
| ): |
| torch.jit.script(M()) |
| |
| def test_annotated_empty_optional(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x: Optional[str] = None |
| |
| def forward(self, x: Optional[str]): |
| self.x = x |
| return 1 |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, "Wrong type for attribute assignment", "self.x = x" |
| ): |
| with self.assertWarnsRegex( |
| UserWarning, |
| "doesn't support " |
| "instance-level annotations on " |
| "empty non-base types", |
| ): |
| torch.jit.script(M()) |
| |
| def test_annotated_with_jit_empty_list(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x = torch.jit.annotate(List[int], []) |
| |
| def forward(self, x: List[int]): |
| self.x = x |
| return 1 |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, "Tried to set nonexistent attribute", "self.x = x" |
| ): |
| with self.assertWarnsRegex( |
| UserWarning, |
| "doesn't support " |
| "instance-level annotations on " |
| "empty non-base types", |
| ): |
| torch.jit.script(M()) |
| |
| @unittest.skipIf( |
| sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)" |
| ) |
| def test_annotated_with_jit_empty_list_lowercase(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x = torch.jit.annotate(list[int], []) |
| |
| def forward(self, x: list[int]): |
| self.x = x |
| return 1 |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, "Tried to set nonexistent attribute", "self.x = x" |
| ): |
| with self.assertWarnsRegex( |
| UserWarning, |
| "doesn't support " |
| "instance-level annotations on " |
| "empty non-base types", |
| ): |
| torch.jit.script(M()) |
| |
| def test_annotated_with_jit_empty_dict(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x = torch.jit.annotate(Dict[str, int], {}) |
| |
| def forward(self, x: Dict[str, int]): |
| self.x = x |
| return 1 |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, "Tried to set nonexistent attribute", "self.x = x" |
| ): |
| with self.assertWarnsRegex( |
| UserWarning, |
| "doesn't support " |
| "instance-level annotations on " |
| "empty non-base types", |
| ): |
| torch.jit.script(M()) |
| |
| @unittest.skipIf( |
| sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)" |
| ) |
| def test_annotated_with_jit_empty_dict_lowercase(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x = torch.jit.annotate(dict[str, int], {}) |
| |
| def forward(self, x: dict[str, int]): |
| self.x = x |
| return 1 |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, "Tried to set nonexistent attribute", "self.x = x" |
| ): |
| with self.assertWarnsRegex( |
| UserWarning, |
| "doesn't support " |
| "instance-level annotations on " |
| "empty non-base types", |
| ): |
| torch.jit.script(M()) |
| |
| def test_annotated_with_jit_empty_optional(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x = torch.jit.annotate(Optional[str], None) |
| |
| def forward(self, x: Optional[str]): |
| self.x = x |
| return 1 |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, "Wrong type for attribute assignment", "self.x = x" |
| ): |
| with self.assertWarnsRegex( |
| UserWarning, |
| "doesn't support " |
| "instance-level annotations on " |
| "empty non-base types", |
| ): |
| torch.jit.script(M()) |
| |
| def test_annotated_with_torch_jit_import(self): |
| from torch import jit |
| |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x = jit.annotate(Optional[str], None) |
| |
| def forward(self, x: Optional[str]): |
| self.x = x |
| return 1 |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, "Wrong type for attribute assignment", "self.x = x" |
| ): |
| with self.assertWarnsRegex( |
| UserWarning, |
| "doesn't support " |
| "instance-level annotations on " |
| "empty non-base types", |
| ): |
| torch.jit.script(M()) |