blob: a6e85dbccf01c36fb8c729dee25a0dec0a51a740 [file] [log] [blame]
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing import FileCheck
from typing import NamedTuple, List, Optional, Any, Dict
from jit.test_module_interface import TestModuleInterface # noqa: F401
import unittest
import sys
import torch
import torch.testing._internal.jit_utils
import torch.nn as nn
class TestScriptPy3(JitTestCase):
def test_joined_str(self):
def func(x):
hello, test = "Hello", "test"
print(f"{hello + ' ' + test}, I'm a {test}") # noqa E999
print(f"format blank")
hi = 'hi'
print(f"stuff before {hi}")
print(f"{hi} stuff after")
return x + 1
x = torch.arange(4., requires_grad=True)
# TODO: Add support for f-strings in string parser frontend
# self.checkScript(func, [x], optimize=True, capture_output=True)
with self.capture_stdout() as captured:
out = func(x)
scripted = torch.jit.script(func)
with self.capture_stdout() as captured_script:
out_script = func(x)
self.assertAlmostEqual(out, out_script)
self.assertEqual(captured, captured_script)
def test_optional_dict_construct(self):
class M(torch.nn.Module):
def use(self, buffer: Dict[str, Optional[torch.Tensor]]):
return buffer["prev_key"]
def forward(self, x):
prev_key = torch.rand(2, 3)
next_key = torch.rand(2, 3)
saved_state: Dict[str, Optional[torch.Tensor]] = {
"prev_key": prev_key,
"next_key": next_key,
}
return self.use(saved_state)
self.checkModule(M(), (torch.rand(2, 2),))
def test_kwarg_support(self):
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
class M(torch.nn.Module):
def forward(self, *, n_tokens: int, device_name: str = 2):
pass
torch.jit.script(M())
class M(torch.nn.Module):
def forward(self, *, n_tokens: int, device_name: str):
return n_tokens, device_name
sm = torch.jit.script(M())
with self.assertRaisesRegex(RuntimeError, "missing value for argument 'n_tokens'"):
sm()
input = (3, 'hello')
self.assertEqual(sm(*input), input)
def test_named_tuple(self):
class FeatureVector(NamedTuple):
float_features: float
sequence_features: List[float]
time_since_first: float
@torch.jit.script
def foo(x) -> float:
fv = FeatureVector(3.0, [3.0], 3.0) # noqa
rv = fv.float_features
for val in fv.sequence_features:
rv += val
rv *= fv.time_since_first
return rv
self.assertEqual(foo(torch.rand(3, 4)), 18.0)
def test_named_tuple_constant(self):
class Tup(NamedTuple):
a: int
b: int
@torch.jit.script
def foo():
return Tup(1, 2)
self.assertEqual(foo(), Tup(1, 2))
@unittest.skipIf(sys.version_info[0] < 3 and sys.version_info[1] < 6, "dict not ordered")
def test_dict_preserves_order(self):
def dict_ordering():
a : Dict[int, int] = {}
for i in range(1000):
a[i] = i + 1
return a
self.checkScript(dict_ordering, ())
di = torch.jit.script(dict_ordering)()
res = list(di.items())
for i in range(1000):
key, value = res[i]
self.assertTrue(key == i and value == i + 1)
def test_list_unification_hint(self):
with self.assertRaisesRegex(RuntimeError, "Expected a List type hint"):
@torch.jit.script
def x():
b : int = [2, 3]
return b
def test_return_named_tuple(self):
class FeatureVector(NamedTuple):
float_features: float
sequence_features: List[float]
time_since_first: float
@torch.jit.script
def foo(x):
fv = FeatureVector(3.0, [3.0], 3.0)
return fv
out = foo(torch.rand(3, 4))
out = foo(torch.rand(3, 4))
self.assertEqual(out.float_features, 3.0)
self.assertEqual(out.sequence_features, [3.0])
self.assertEqual(out.time_since_first, 3.0)
def test_ignore_with_types(self):
@torch.jit.ignore
def fn(x: Dict[str, Optional[torch.Tensor]]):
return x + 10
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor:
self.dropout_modality(in_batch)
fn(in_batch)
return torch.tensor(1)
@torch.jit.ignore
def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]:
return in_batch
sm = torch.jit.script(M())
FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph))
def test_python_callable(self):
class MyPythonClass(object):
@torch.jit.ignore
def __call__(self, *args) -> str:
return str(type(args[0]))
the_class = MyPythonClass()
@torch.jit.script
def fn(x):
return the_class(x)
# This doesn't involve the string frontend, so don't use checkScript
x = torch.ones(2)
self.assertEqual(fn(x), the_class(x))
def test_bad_types(self):
@torch.jit.ignore
def fn(my_arg):
return my_arg + 10
with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"):
@torch.jit.script
def other_fn(x):
return fn('2')
def test_named_tuple_slice_unpack(self):
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int]
@torch.jit.script
def foo(a : int, b : float, c : List[int]):
tup = MyCoolNamedTuple(a, b, c) # noqa
my_a, my_b, my_c = tup
return tup[:1], my_a, my_c
self.assertEqual(foo(3, 3.5, [6]), ((3,), 3, [6]))
def test_named_tuple_lower(self):
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int]
@torch.jit.script
def foo(a : int):
tup = MyCoolNamedTuple(a, 3.14, [9]) # noqa
return tup
FileCheck().check('TupleConstruct').run(foo.graph)
torch._C._jit_pass_lower_all_tuples(foo.graph)
FileCheck().check_not('TupleConstruct').run(foo.graph)
def test_named_tuple_type_annotation(self):
global MyCoolNamedTuple # see [local resolution in python]
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int]
@torch.jit.script
def foo(x : MyCoolNamedTuple) -> MyCoolNamedTuple:
return x
mnt = MyCoolNamedTuple(42, 420.0, [666])
self.assertEqual(foo(mnt), mnt)
def test_named_tuple_wrong_types(self):
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int]
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int' for argument 'a'"
" but instead found type 'str'"):
@torch.jit.script
def foo():
tup = MyCoolNamedTuple('foo', 'bar', 'baz') # noqa
return tup
def test_named_tuple_kwarg_construct(self):
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int]
@torch.jit.script
def foo():
tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9) # noqa
return tup
tup = foo()
self.assertEqual(tup.a, 9)
self.assertEqual(tup.b, 3.5)
self.assertEqual(tup.c, [1, 2, 3])
def test_named_tuple_default_error(self):
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int] = [3, 4, 5]
with self.assertRaisesRegex(RuntimeError, 'Default values are currently not supported'):
@torch.jit.script
def foo():
tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9) # noqa
return tup
@unittest.skipIf(True, "broken while these tests were not in CI")
def test_named_tuple_serialization(self):
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int]
class MyMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self):
return MyCoolNamedTuple(3, 3.5, [3, 4, 5])
mm = MyMod()
mm.save('foo.zip')
torch.testing._internal.jit_utils.clear_class_registry()
loaded = torch.jit.load('foo.zip')
out = mm()
out_loaded = loaded()
for name in ['a', 'b', 'c']:
self.assertEqual(getattr(out_loaded, name), getattr(out, name))
def test_type_annotate_py3(self):
def fn():
a : List[int] = []
b : torch.Tensor = torch.ones(2, 2)
c : Optional[torch.Tensor] = None
d : Optional[torch.Tensor] = torch.ones(3, 4)
for _ in range(10):
a.append(4)
c = torch.ones(2, 2)
d = None
return a, b, c, d
self.checkScript(fn, ())
def wrong_type():
wrong : List[int] = [0.5]
return wrong
with self.assertRaisesRegex(RuntimeError, "Lists must contain only a single type"):
torch.jit.script(wrong_type)
def test_parser_bug(self):
def parser_bug(o: Optional[torch.Tensor]):
pass
def test_mismatched_annotation(self):
with self.assertRaisesRegex(RuntimeError, 'annotated with type'):
@torch.jit.script
def foo():
x : str = 4
return x
def test_reannotate(self):
with self.assertRaisesRegex(RuntimeError, 'declare and annotate'):
@torch.jit.script
def foo():
x = 5
if True:
x : Optional[int] = 7
def test_any_in_class_fails(self):
class MyCoolNamedTuple(NamedTuple):
a : Any
b : float
c : List[int]
with self.assertRaisesRegex(RuntimeError, "contains an Any"):
@torch.jit.script
def foo():
return MyCoolNamedTuple(4, 5.5, [3])
print(foo.graph)
def test_export_opnames_interface(self):
global OneTwoModule
@torch.jit.interface
class OneTwoModule(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
pass
def two(self, x):
# type: (Tensor) -> Tensor
pass
def forward(self, x):
# type: (Tensor) -> Tensor
pass
class FooMod(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
return x + y
def two(self, x):
# type: (Tensor) -> Tensor
return 2 * x
def forward(self, x):
# type: (Tensor) -> Tensor
return self.one(self.two(x), x)
class BarMod(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
return x * y
def two(self, x):
# type: (Tensor) -> Tensor
return 2 / x
def forward(self, x):
# type: (Tensor) -> Tensor
return self.two(self.one(x, x))
class M(nn.Module):
sub : OneTwoModule
def __init__(self):
super(M, self).__init__()
self.sub = BarMod()
def forward(self, x):
# type: (Tensor) -> Tensor
return self.sub.forward(x)
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
return mod_list[0].forward(x) + mod_list[1].forward(x)
scripted_M_mod = torch.jit.script(M())
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'])
scripted_M_mod.sub = torch.jit.script(FooMod())
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
['aten::add.Tensor', 'aten::mul.Scalar'])
if __name__ == '__main__':
run_tests()