| from collections import namedtuple |
| from torch.testing._internal.common_utils import run_tests |
| from torch.testing._internal.jit_utils import JitTestCase |
| from torch.testing import FileCheck |
| from torch import jit |
| from typing import NamedTuple, List, Optional, Dict, Tuple, Any |
| from jit.test_module_interface import TestModuleInterface # noqa: F401 |
| import unittest |
| import sys |
| import torch |
| import torch.testing._internal.jit_utils |
| import torch.nn as nn |
| import types |
| |
| class TestScriptPy3(JitTestCase): |
| def test_joined_str(self): |
| def func(x): |
| hello, test = "Hello", "test" |
| print(f"{hello + ' ' + test}, I'm a {test}") # noqa E999 |
| print(f"format blank") # noqa F541 |
| hi = 'hi' |
| print(f"stuff before {hi}") |
| print(f"{hi} stuff after") |
| return x + 1 |
| |
| x = torch.arange(4., requires_grad=True) |
| # TODO: Add support for f-strings in string parser frontend |
| # self.checkScript(func, [x], optimize=True, capture_output=True) |
| |
| with self.capture_stdout() as captured: |
| out = func(x) |
| |
| scripted = torch.jit.script(func) |
| with self.capture_stdout() as captured_script: |
| out_script = func(x) |
| |
| self.assertEqual(out, out_script) |
| self.assertEqual(captured, captured_script) |
| |
| @unittest.skipIf(sys.version_info[:2] < (3, 7), "`dataclasses` module not present on < 3.7") |
| def test_dataclass_error(self): |
| from dataclasses import dataclass |
| |
| @dataclass |
| class NormalizationInfo(object): |
| mean: float = 0.0 |
| |
| def compute(self, total_rows): |
| return self.mean |
| |
| def fn(): |
| return NormalizationInfo(1, 2, 3, 4, 5) |
| |
| with self.assertRaisesRegex(OSError, "NormalizationInfo"): |
| torch.jit.script(fn) |
| |
| def test_optional_dict_construct(self): |
| class M(torch.nn.Module): |
| def use(self, buffer: Dict[str, Optional[torch.Tensor]]): |
| return buffer["prev_key"] |
| |
| def forward(self, x): |
| prev_key = torch.rand(2, 3) |
| next_key = torch.rand(2, 3) |
| saved_state: Dict[str, Optional[torch.Tensor]] = { |
| "prev_key": prev_key, |
| "next_key": next_key, |
| } |
| |
| return self.use(saved_state) |
| |
| self.checkModule(M(), (torch.rand(2, 2),)) |
| |
| def test_kwarg_support(self): |
| with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"): |
| class M(torch.nn.Module): |
| def forward(self, *, n_tokens: int, device_name: str = 2): |
| pass |
| torch.jit.script(M()) |
| |
| class M(torch.nn.Module): |
| def forward(self, *, n_tokens: int, device_name: str): |
| return n_tokens, device_name |
| |
| sm = torch.jit.script(M()) |
| |
| with self.assertRaisesRegex(RuntimeError, "missing value for argument 'n_tokens'"): |
| sm() |
| |
| input = (3, 'hello') |
| self.assertEqual(sm(*input), input) |
| |
| def test_named_tuple(self): |
| class FeatureVector(NamedTuple): |
| float_features: float |
| sequence_features: List[float] |
| time_since_first: float |
| |
| @torch.jit.script |
| def foo(x) -> float: |
| fv = FeatureVector(3.0, [3.0], 3.0) # noqa |
| rv = fv.float_features |
| for val in fv.sequence_features: |
| rv += val |
| rv *= fv.time_since_first |
| return rv |
| |
| self.assertEqual(foo(torch.rand(3, 4)), 18.0) |
| |
| def test_named_tuple_constant(self): |
| class Tup(NamedTuple): |
| a: int |
| b: int |
| |
| @torch.jit.script |
| def foo(): |
| return Tup(1, 2) |
| |
| self.assertEqual(foo(), Tup(1, 2)) |
| |
| def test_dict_preserves_order(self): |
| def dict_ordering(): |
| a : Dict[int, int] = {} |
| for i in range(1000): |
| a[i] = i + 1 |
| return a |
| |
| self.checkScript(dict_ordering, ()) |
| di = torch.jit.script(dict_ordering)() |
| res = list(di.items()) |
| for i in range(1000): |
| key, value = res[i] |
| self.assertTrue(key == i and value == i + 1) |
| |
| def test_list_unification_hint(self): |
| with self.assertRaisesRegex(RuntimeError, "Expected a List type hint"): |
| @torch.jit.script |
| def x(): |
| b : int = [2, 3] |
| return b |
| |
| def test_return_named_tuple(self): |
| class FeatureVector(NamedTuple): |
| float_features: float |
| sequence_features: List[float] |
| time_since_first: float |
| |
| @torch.jit.script |
| def foo(x): |
| fv = FeatureVector(3.0, [3.0], 3.0) |
| return fv |
| |
| out = foo(torch.rand(3, 4)) |
| out = foo(torch.rand(3, 4)) |
| self.assertEqual(out.float_features, 3.0) |
| self.assertEqual(out.sequence_features, [3.0]) |
| self.assertEqual(out.time_since_first, 3.0) |
| |
| def test_named_tuple_as_attr(self): |
| class Config(NamedTuple): |
| size: int |
| |
| class MyMod(nn.Module): |
| configs: Dict[int, Config] |
| |
| def __init__(self, configs): |
| super().__init__() |
| self.configs = configs |
| |
| def forward(self, x): |
| for _id, config in self.configs.items(): |
| x += config.size |
| return x |
| |
| s = torch.jit.script(MyMod({0: Config(size=16)})) |
| |
| def test_types_as_values(self): |
| def fn(m: torch.Tensor) -> torch.device: |
| return m.device |
| |
| self.checkScript(fn, [torch.randn(2, 2)]) |
| |
| GG = namedtuple('GG', ['f', 'g']) |
| |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| @torch.jit.ignore |
| def foo(self, x, z): |
| # type: (Tensor, Tensor) -> Tuple[GG, GG] |
| return GG(x, z), GG(x, z) |
| |
| def forward(self, x, z): |
| return self.foo(x, z) |
| |
| foo = torch.jit.script(Foo()) |
| y = foo(torch.randn(2, 2), torch.randn(2, 2)) |
| |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| @torch.jit.ignore |
| def foo(self, x, z) -> Tuple[GG, GG]: |
| return GG(x, z) |
| |
| def forward(self, x, z): |
| return self.foo(x, z) |
| |
| foo = torch.jit.script(Foo()) |
| y = foo(torch.randn(2, 2), torch.randn(2, 2)) |
| |
| |
| def test_named_tuple_resolution(self): |
| class TheType(NamedTuple): |
| t: int |
| |
| class MyModule(types.ModuleType): |
| def __init__(self): |
| super(MyModule, self).__init__('MyModule') |
| |
| def __getattr__(self, attr): |
| return TheType |
| |
| some_module = MyModule() |
| |
| def fn() -> some_module.Type: |
| return some_module.Type(1) |
| |
| self.checkScript(fn, []) |
| |
| def test_ignore_with_types(self): |
| @torch.jit.ignore |
| def fn(x: Dict[str, Optional[torch.Tensor]]): |
| return x + 10 |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| |
| def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor: |
| self.dropout_modality(in_batch) |
| fn(in_batch) |
| return torch.tensor(1) |
| |
| @torch.jit.ignore |
| def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]: |
| return in_batch |
| |
| sm = torch.jit.script(M()) |
| FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph)) |
| |
| def test_python_callable(self): |
| class MyPythonClass(object): |
| @torch.jit.ignore |
| def __call__(self, *args) -> str: |
| return str(type(args[0])) |
| |
| the_class = MyPythonClass() |
| |
| @torch.jit.script |
| def fn(x): |
| return the_class(x) |
| |
| # This doesn't involve the string frontend, so don't use checkScript |
| x = torch.ones(2) |
| self.assertEqual(fn(x), the_class(x)) |
| |
| def test_bad_types(self): |
| @torch.jit.ignore |
| def fn(my_arg): |
| return my_arg + 10 |
| |
| with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"): |
| @torch.jit.script |
| def other_fn(x): |
| return fn('2') |
| |
| def test_named_tuple_slice_unpack(self): |
| class MyCoolNamedTuple(NamedTuple): |
| a : int |
| b : float |
| c : List[int] |
| |
| @torch.jit.script |
| def foo(a : int, b : float, c : List[int]): |
| tup = MyCoolNamedTuple(a, b, c) # noqa |
| my_a, my_b, my_c = tup |
| return tup[:1], my_a, my_c |
| |
| self.assertEqual(foo(3, 3.5, [6]), ((3,), 3, [6])) |
| |
| def test_named_tuple_lower(self): |
| class MyCoolNamedTuple(NamedTuple): |
| a : int |
| b : float |
| c : List[int] |
| |
| @torch.jit.script |
| def foo(a : int): |
| tup = MyCoolNamedTuple(a, 3.14, [9]) # noqa |
| return tup |
| |
| FileCheck().check('TupleConstruct').run(foo.graph) |
| torch._C._jit_pass_lower_all_tuples(foo.graph) |
| FileCheck().check_not('TupleConstruct').run(foo.graph) |
| |
| def test_named_tuple_type_annotation(self): |
| global MyCoolNamedTuple # see [local resolution in python] |
| |
| class MyCoolNamedTuple(NamedTuple): |
| a : int |
| b : float |
| c : List[int] |
| |
| @torch.jit.script |
| def foo(x : MyCoolNamedTuple) -> MyCoolNamedTuple: |
| return x |
| |
| mnt = MyCoolNamedTuple(42, 420.0, [666]) |
| self.assertEqual(foo(mnt), mnt) |
| |
| def test_named_tuple_wrong_types(self): |
| class MyCoolNamedTuple(NamedTuple): |
| a : int |
| b : float |
| c : List[int] |
| |
| with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int' for argument 'a'" |
| " but instead found type 'str'"): |
| @torch.jit.script |
| def foo(): |
| tup = MyCoolNamedTuple('foo', 'bar', 'baz') # noqa |
| return tup |
| |
| def test_named_tuple_kwarg_construct(self): |
| class MyCoolNamedTuple(NamedTuple): |
| a : int |
| b : float |
| c : List[int] |
| |
| @torch.jit.script |
| def foo(): |
| tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9) # noqa |
| return tup |
| |
| tup = foo() |
| self.assertEqual(tup.a, 9) |
| self.assertEqual(tup.b, 3.5) |
| self.assertEqual(tup.c, [1, 2, 3]) |
| |
| def test_named_tuple_default_error(self): |
| class MyCoolNamedTuple(NamedTuple): |
| a : int |
| b : float |
| c : List[int] = [3, 4, 5] |
| |
| with self.assertRaisesRegex(RuntimeError, 'Default values are currently not supported'): |
| @torch.jit.script |
| def foo(): |
| tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9) # noqa |
| return tup |
| |
| @unittest.skipIf(True, "broken while these tests were not in CI") |
| def test_named_tuple_serialization(self): |
| class MyCoolNamedTuple(NamedTuple): |
| a : int |
| b : float |
| c : List[int] |
| |
| class MyMod(torch.jit.ScriptModule): |
| @torch.jit.script_method |
| def forward(self): |
| return MyCoolNamedTuple(3, 3.5, [3, 4, 5]) |
| |
| mm = MyMod() |
| mm.save('foo.zip') |
| torch.testing._internal.jit_utils.clear_class_registry() |
| loaded = torch.jit.load('foo.zip') |
| |
| out = mm() |
| out_loaded = loaded() |
| |
| for name in ['a', 'b', 'c']: |
| self.assertEqual(getattr(out_loaded, name), getattr(out, name)) |
| |
| def test_type_annotate_py3(self): |
| def fn(): |
| a : List[int] = [] |
| b : torch.Tensor = torch.ones(2, 2) |
| c : Optional[torch.Tensor] = None |
| d : Optional[torch.Tensor] = torch.ones(3, 4) |
| for _ in range(10): |
| a.append(4) |
| c = torch.ones(2, 2) |
| d = None |
| return a, b, c, d |
| |
| self.checkScript(fn, ()) |
| |
| def wrong_type(): |
| wrong : List[int] = [0.5] |
| return wrong |
| |
| with self.assertRaisesRegex(RuntimeError, "Lists must contain only a single type"): |
| torch.jit.script(wrong_type) |
| |
| def test_subexpression_List_Future(self): |
| |
| @torch.jit.script |
| def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]: |
| return x[0] |
| |
| FileCheck().check('Future[int]').check('Future[int]').run(fn.graph) |
| |
| def test_subexpression_Future_annotate(self): |
| @torch.jit.script |
| def fn() -> torch.jit.Future[int]: |
| x: List[torch.jit.Future[int]] = [] |
| return x[0] |
| |
| FileCheck().check("Future[int][]").run(fn.graph) |
| |
| def test_future_isinstance(self): |
| @torch.jit.script |
| def fn(x: Any) -> torch.jit.Future[int]: |
| assert isinstance(x, jit.Future[int]) |
| return x |
| |
| FileCheck().check("Future[int]").run(fn.graph) |
| |
| def test_str_refine_any(self): |
| def forward(x: Any) -> str: |
| if isinstance(x, str): |
| return x |
| return "foo" |
| forward = torch.jit.script(forward) |
| self.assertEqual(forward(1), "foo") |
| self.assertEqual(forward("bar"), "bar") |
| |
| def test_subexpression_Tuple_int_int_Future(self): |
| |
| @torch.jit.script |
| def fn(x: Tuple[int, int, torch.jit.Future[int]]) -> Tuple[int, torch.jit.Future[int]]: |
| return x[0], x[2] |
| |
| FileCheck().check('(int, int, Future[int])').check('(int, Future[int])').run(fn.graph) |
| |
| def test_subexpression_Dict_int_Future(self): |
| |
| @torch.jit.script |
| def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]: |
| return x[y] |
| |
| FileCheck().check('Dict(int, Future(int))').check('Future[int]').run(fn.graph) |
| |
| def test_subexpression_Optional(self): |
| |
| @torch.jit.script |
| def fn(x: Optional[Dict[int, torch.jit.Future[int]]]) -> Optional[torch.jit.Future[int]]: |
| if x is not None: |
| return x[0] |
| else: |
| return None |
| |
| FileCheck().check('Dict(int, Future(int))?').run(fn.graph) |
| |
| def test_unimported_type_resolution(self): |
| # verify fallback from the python resolver to the c++ resolver |
| |
| @ torch.jit.script |
| def fn(x): |
| # type: (number) -> number |
| return x + 1 |
| |
| FileCheck().check('Scalar').run(fn.graph) |
| |
| def test_parser_bug(self): |
| def parser_bug(o: Optional[torch.Tensor]): |
| pass |
| |
| def test_mismatched_annotation(self): |
| with self.assertRaisesRegex(RuntimeError, 'annotated with type'): |
| @torch.jit.script |
| def foo(): |
| x : str = 4 |
| return x |
| |
| def test_reannotate(self): |
| with self.assertRaisesRegex(RuntimeError, 'declare and annotate'): |
| @torch.jit.script |
| def foo(): |
| x = 5 |
| if True: |
| x : Optional[int] = 7 |
| |
| def test_module_inplace_construct(self): |
| class M(nn.Module): |
| def __init__(self, start: int): |
| super().__init__() |
| self.linear = nn.Linear(3, 3) |
| self.attribute = start |
| self.parameter = nn.Parameter(torch.tensor(3, dtype=torch.float)) |
| |
| def method(self) -> int: |
| return self.attribute |
| |
| @torch.jit.unused |
| def unused_method(self): |
| return self.attribute + self.attribute |
| |
| def forward(self, x): |
| return self.linear(self.linear(x)) |
| |
| |
| class N(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(4, 4) |
| |
| @torch.jit.ignore |
| def ignored_method(self, x): |
| return x |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| m = torch.jit.script(M(3)) |
| n = torch.jit.script(N()) |
| |
| n._reconstruct(m._c) |
| |
| inp = torch.rand((3)) |
| |
| # Check that both modules produce the same output. |
| with torch.no_grad(): |
| m_out = m(inp) |
| n_out = n(inp) |
| self.assertEqual(m_out, n_out) |
| |
| # Check that ignored method is still intact. |
| self.assertEqual(inp, n.ignored_method(inp)) |
| |
| def test_if_returning_any(self): |
| """ |
| Check that an if statement can return different |
| types early from each branch when the return |
| type of the function is Any. |
| """ |
| def if_function(inp: torch.Tensor) -> Any: |
| if inp.shape[0] == 1: |
| return inp * inp |
| else: |
| return "str" |
| |
| self.checkScript(if_function, (torch.randn(5),)) |
| |
| def test_export_opnames_interface(self): |
| global OneTwoModule |
| |
| @torch.jit.interface |
| class OneTwoModule(nn.Module): |
| def one(self, x, y): |
| # type: (Tensor, Tensor) -> Tensor |
| pass |
| |
| def two(self, x): |
| # type: (Tensor) -> Tensor |
| pass |
| |
| def forward(self, x): |
| # type: (Tensor) -> Tensor |
| pass |
| |
| class FooMod(nn.Module): |
| def one(self, x, y): |
| # type: (Tensor, Tensor) -> Tensor |
| return x + y |
| |
| def two(self, x): |
| # type: (Tensor) -> Tensor |
| return 2 * x |
| |
| def forward(self, x): |
| # type: (Tensor) -> Tensor |
| return self.one(self.two(x), x) |
| |
| class BarMod(nn.Module): |
| def one(self, x, y): |
| # type: (Tensor, Tensor) -> Tensor |
| return x * y |
| |
| def two(self, x): |
| # type: (Tensor) -> Tensor |
| return 2 / x |
| |
| def forward(self, x): |
| # type: (Tensor) -> Tensor |
| return self.two(self.one(x, x)) |
| |
| class M(nn.Module): |
| sub : OneTwoModule |
| |
| def __init__(self): |
| super(M, self).__init__() |
| self.sub = BarMod() |
| |
| def forward(self, x): |
| # type: (Tensor) -> Tensor |
| return self.sub.forward(x) |
| |
| def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): |
| return mod_list[0].forward(x) + mod_list[1].forward(x) |
| |
| scripted_M_mod = torch.jit.script(M()) |
| # Temporarily test empty output because lite interpreter does not support interface call |
| # Replace it with the issubset call when interface call is supported. |
| self.assertTrue(len(torch.jit.export_opnames(scripted_M_mod)) == 0) |
| # self.assertTrue(set(['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal']).issubset( |
| # set(torch.jit.export_opnames(scripted_M_mod)))) |
| |
| scripted_M_mod.sub = torch.jit.script(FooMod()) |
| self.assertTrue(len(torch.jit.export_opnames(scripted_M_mod)) == 0) |
| # self.assertTrue(set(['aten::add.Tensor', 'aten::mul.Scalar']).issubset( |
| # set(torch.jit.export_opnames(scripted_M_mod)))) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |