|  | # Owner(s): ["oncall: jit"] | 
|  |  | 
|  | import os | 
|  | import sys | 
|  | import torch | 
|  | from torch.testing._internal.jit_utils import JitTestCase, make_global | 
|  | from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED | 
|  | from typing import List, Dict, Tuple, Any, Optional, NamedTuple  # noqa: F401 | 
|  | from torch.testing._internal.common_utils import NoTest | 
|  |  | 
|  | # Make the helper files in test/ importable | 
|  | pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | 
|  | sys.path.append(pytorch_test_dir) | 
|  |  | 
|  | if not _IS_MONKEYTYPE_INSTALLED: | 
|  | print("monkeytype is not installed. Skipping tests for Profile-Directed Typing", file=sys.stderr) | 
|  | JitTestCase = NoTest  # type: ignore[misc, assignment] # noqa: F811 | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | raise RuntimeError( | 
|  | "This test file is not meant to be run directly, use:\n\n" | 
|  | "\tpython test/test_jit.py TESTNAME\n\n" | 
|  | "instead." | 
|  | ) | 
|  |  | 
|  | class TestPDT(JitTestCase): | 
|  | """ | 
|  | A suite of tests for profile directed typing in TorchScript. | 
|  | """ | 
|  | def test_nn_module(self): | 
|  | class TestPDTModel(torch.nn.Module): | 
|  | def forward(self, x) -> Any: | 
|  | if isinstance(x, int): | 
|  | return x + 1 | 
|  | elif isinstance(x, float): | 
|  | return x - 1 | 
|  | else: | 
|  | return x | 
|  |  | 
|  | make_global(TestPDTModel) | 
|  | pdt_model = TestPDTModel() | 
|  | inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ] | 
|  | scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp}) | 
|  | self.assertEqual(scripted_pdt_model(50), pdt_model(50)) | 
|  | self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8)) | 
|  | self.assertTrue(scripted_pdt_model(True), pdt_model(True)) | 
|  |  | 
|  | def test_nested_nn_module_class(self): | 
|  | class NestedPDTInner(torch.nn.Module): | 
|  | def forward(self, x): | 
|  | if isinstance(x, int): | 
|  | return x * 10 | 
|  | return x | 
|  |  | 
|  | class NestedModulePDTWrapper(torch.nn.Module): | 
|  | def __init__(self, inner): | 
|  | super().__init__() | 
|  | self.inner = inner | 
|  |  | 
|  | def forward(self, x): | 
|  | return self.inner(x) | 
|  |  | 
|  | make_global(NestedPDTInner, NestedModulePDTWrapper) | 
|  | inner_pdt_model = NestedPDTInner() | 
|  | wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model) | 
|  | inp: List[Tuple[Any, ...]] = [(20, ), (False, )] | 
|  | scripted_pdt_model = torch.jit.script(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp}) | 
|  | self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30)) | 
|  | self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9)) | 
|  | self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True)) | 
|  |  | 
|  | def test_nested_nn_module_class_with_args(self): | 
|  | class NestedModulePDTInner(torch.nn.Module): | 
|  | def forward(self, x, y): | 
|  | if isinstance(x, int): | 
|  | return x * 10 + y | 
|  | return x | 
|  |  | 
|  | class NestedModulePDTOuter(torch.nn.Module): | 
|  | def __init__(self, inner): | 
|  | super().__init__() | 
|  | self.inner = inner | 
|  |  | 
|  | def forward(self, x): | 
|  | return self.inner(x, 20) | 
|  |  | 
|  | make_global(NestedModulePDTInner, NestedModulePDTOuter) | 
|  | inner_pdt_model = NestedModulePDTInner() | 
|  | outer_pdt_model = NestedModulePDTOuter(inner_pdt_model) | 
|  | inner_input: List[Tuple[Any, ...]] = [(10, 10), (1.9, 20), ] | 
|  | outer_input: List[Tuple[Any, ...]] = [(20, ), (False, )] | 
|  | scripted_pdt_model = torch.jit.script(outer_pdt_model, example_inputs={inner_pdt_model: inner_input, | 
|  | outer_pdt_model: outer_input, }) | 
|  | self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30)) | 
|  | self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9)) | 
|  | self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True)) | 
|  |  | 
|  | def test_nested_function_in_forward(self): | 
|  | class NestedFunctionInForward(torch.nn.Module): | 
|  | def forward(self, x): | 
|  | return self.fun(x) + 10 | 
|  |  | 
|  | def fun(self, x): | 
|  | if isinstance(x, bool): | 
|  | return 0 | 
|  | elif isinstance(x, int): | 
|  | return x + 1 | 
|  | return 0 | 
|  |  | 
|  | make_global(NestedFunctionInForward) | 
|  | pdt_model = NestedFunctionInForward() | 
|  | inp: List[Tuple[Any, ...]] = [(-1, ), (False, )] | 
|  | scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp}) | 
|  | self.assertEqual(scripted_pdt_model(30), pdt_model(30)) | 
|  | self.assertEqual(scripted_pdt_model(True), pdt_model(True)) | 
|  |  | 
|  | def test_nn_module_with_export_function(self): | 
|  | class TestModelWithExport(torch.nn.Module): | 
|  | @torch.jit.export | 
|  | def fn(self, x, y) -> Any: | 
|  | assert not (isinstance(x, bool) and isinstance(y, bool)) | 
|  | if isinstance(x, int) and isinstance(y, int): | 
|  | return x + y | 
|  | elif isinstance(x, float) and isinstance(y, float): | 
|  | return x - y | 
|  | else: | 
|  | return -1 | 
|  |  | 
|  |  | 
|  | make_global(TestModelWithExport) | 
|  | pdt_model = TestModelWithExport() | 
|  | inp: List[Tuple[Any, ...]] = [(20, 10, ), (2.7, 8.9, ), ] | 
|  | scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model.fn: inp}) | 
|  | self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90)) | 
|  | self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2)) | 
|  | self.assertTrue(scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2)) | 
|  |  | 
|  | def test_class_methods(self): | 
|  | class PDTModel: | 
|  | def test_sum(self, a): | 
|  | return sum(a) | 
|  |  | 
|  | make_global(PDTModel) | 
|  | pdt_model = PDTModel() | 
|  | inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ] | 
|  | scripted_pdt_model = torch.jit.script(PDTModel, example_inputs={pdt_model.test_sum: inp}) | 
|  | script_model = scripted_pdt_model() | 
|  | self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], )) | 
|  |  | 
|  | def test_class_with_multiple_methods(self): | 
|  | class PDTModelWithManyMethods: | 
|  | def test_list_to_dict(self, a): | 
|  | new_dictionary: Dict[float, bool] = {} | 
|  | for element in a: | 
|  | new_dictionary[element] = True | 
|  | return new_dictionary | 
|  |  | 
|  | def test_substring(self, a, b): | 
|  | return b in a | 
|  |  | 
|  | make_global(PDTModelWithManyMethods) | 
|  | pdt_model = PDTModelWithManyMethods() | 
|  | list_inp: List[Tuple[Any, ...]] = [([1.2, 2.3, ], ), ] | 
|  | str_inp: List[Tuple[Any, ...]] = [("abc", "b", ), ] | 
|  | scripted_pdt_model = torch.jit.script(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp, | 
|  | pdt_model.test_substring: str_inp}) | 
|  | script_model = scripted_pdt_model() | 
|  | self.assertEqual(script_model.test_list_to_dict([1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict([1.1, 2.2, 3.3, ], )) | 
|  | self.assertEqual(script_model.test_substring("helloworld", "world", ), pdt_model.test_substring("helloworld", "world", )) | 
|  | self.assertEqual(script_model.test_substring("helloworld", "def", ), pdt_model.test_substring("helloworld", "def", )) | 
|  |  | 
|  | def test_multiple_class_with_same_method(self): | 
|  | class PDTModelOne: | 
|  | def test_find(self, a, b): | 
|  | return b in a.keys() | 
|  |  | 
|  | class PDTModelTwo: | 
|  | def test_find(self, a, b): | 
|  | return b in a | 
|  |  | 
|  | make_global(PDTModelOne, PDTModelTwo) | 
|  | pdt_model_one = PDTModelOne() | 
|  | pdt_model_two = PDTModelTwo() | 
|  | dict_inp: List[Tuple[Any, ...]] = [({1.2: True, 2.3: False, }, 1.2), ] | 
|  | list_inp: List[Tuple[Any, ...]] = [(["abc", "b", ], "c"), ] | 
|  | scripted_pdt_model_one = torch.jit.script(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp}) | 
|  | scripted_pdt_model_two = torch.jit.script(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp}) | 
|  |  | 
|  | script_model_one, script_model_two = scripted_pdt_model_one(), scripted_pdt_model_two() | 
|  | self.assertEqual(script_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4), | 
|  | pdt_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4)) | 
|  | self.assertEqual(script_model_two.test_find(["hello", "world", ], "world"), | 
|  | pdt_model_two.test_find(["hello", "world", ], "world")) | 
|  |  | 
|  | def test_pdt(self): | 
|  | def test_sum(a, b): | 
|  | return a + b | 
|  |  | 
|  | make_global(test_sum) | 
|  | scripted_fn_add = torch.jit.script(test_sum, example_inputs=[(3, 4)]) | 
|  | self.assertEqual(scripted_fn_add(10, 2), test_sum(10, 2)) | 
|  |  | 
|  | def test_sub(a, b): | 
|  | return a - b | 
|  |  | 
|  | make_global(test_sub) | 
|  | scripted_fn_sub = torch.jit.script(test_sub, example_inputs=[(3.9, 4.10)]) | 
|  | self.assertEqual(scripted_fn_sub(6.5, 2.9), test_sub(6.5, 2.9)) | 
|  |  | 
|  | def test_mul(a, b): | 
|  | return a * b | 
|  |  | 
|  | make_global(test_mul) | 
|  | scripted_fn_mul = torch.jit.script(test_mul, example_inputs=[(-10, 9)]) | 
|  | self.assertEqual(scripted_fn_mul(-1, 3), test_mul(-1, 3)) | 
|  |  | 
|  | def test_args_complex(real, img): | 
|  | return torch.complex(real, img) | 
|  |  | 
|  | make_global(test_args_complex) | 
|  | scripted_fn_complex = torch.jit.script(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))]) | 
|  | arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4) | 
|  | self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2)) | 
|  |  | 
|  | def test_bool(a): | 
|  | if a: | 
|  | return -1 | 
|  | else: | 
|  | return 0 | 
|  |  | 
|  | make_global(test_bool) | 
|  | scripted_fn_bool = torch.jit.script(test_bool, example_inputs=[(True,)]) | 
|  | self.assertEqual(scripted_fn_bool(True), test_bool(True)) | 
|  |  | 
|  | def test_str(a): | 
|  | if a == "": | 
|  | return False | 
|  | else: | 
|  | return True | 
|  |  | 
|  | make_global(test_str) | 
|  | scripted_fn_str = torch.jit.script(test_str, example_inputs=[("",)]) | 
|  | self.assertEqual(scripted_fn_str("abc"), test_str("abc")) | 
|  |  | 
|  | def test_pdt_list_and_tuple(self): | 
|  | def test_list_and_tuple(a): | 
|  | return sum(a) | 
|  |  | 
|  | make_global(test_list_and_tuple) | 
|  |  | 
|  | scripted_fn_float_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([4.9, 8.9],)]) | 
|  | self.assertEqual(scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6])) | 
|  |  | 
|  | scripted_fn_bool_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([True, False, True],)]) | 
|  | self.assertEqual(scripted_fn_bool_list_input([True, True, True]), test_list_and_tuple([True, True, True])) | 
|  |  | 
|  | scripted_fn_int_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([3, 4, 5], )]) | 
|  | self.assertEqual(scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3])) | 
|  |  | 
|  | scripted_fn_float_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((4.9, 8.9),)]) | 
|  | self.assertEqual(scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6))) | 
|  |  | 
|  | scripted_fn_bool_tuple_input = torch.jit.script(test_list_and_tuple, | 
|  | example_inputs=[((True, False, True),)]) | 
|  | self.assertEqual(scripted_fn_bool_tuple_input((True, True, True)), | 
|  | test_list_and_tuple((True, True, True))) | 
|  |  | 
|  | scripted_fn_int_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((3, 4, 5), )]) | 
|  | self.assertEqual(scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3))) | 
|  |  | 
|  | def test_nested_list_and_tuple(self): | 
|  | def test_nested_list(inp): | 
|  | return [sum(v) for v in inp] | 
|  |  | 
|  | def test_nested_tuple(inp): | 
|  | ans = 0.0 | 
|  | for tup in inp: | 
|  | for val in tup: | 
|  | if val > 0: | 
|  | ans *= val | 
|  | return ans | 
|  |  | 
|  | make_global(test_nested_list, test_nested_tuple) | 
|  |  | 
|  | list_inp = [[1, 2, 3, ], [5, 6, 7, ]] | 
|  | scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ]) | 
|  | inp = [[0, 4, 7, ], [8, 11, ], [6, -1, -20, ]] | 
|  | self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, )) | 
|  |  | 
|  | list_inp = ([1, 2, 3, ], [5, 6, 7, ]) | 
|  | scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ]) | 
|  | inp = ([0, 4, 7, ], [8, 11, ], [6, -1, -20, ]) | 
|  | self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, )) | 
|  |  | 
|  | tup_inp = [(1.0, 2.6, 3.7, ), (5.7, 6.1, 1.7, )] | 
|  | scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ]) | 
|  | inp = [(1.0, 4.1, 7.4, ), (4.8, 1.1, -1.2, ), (6.3, -1.3, -2.0, )] | 
|  | self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, )) | 
|  |  | 
|  | tup_inp = ((True, False, True, ), (False, False, False, )) | 
|  | scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ]) | 
|  | inp = ((True, True, True, ), (False, False, True, )) | 
|  | self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, )) | 
|  |  | 
|  | def test_pdt_dict(self): | 
|  | def test_dict(a): | 
|  | return a['foo'] | 
|  |  | 
|  | def test_dict_int_list(a): | 
|  | return a[1] | 
|  |  | 
|  | make_global(test_dict, test_dict_int_list) | 
|  |  | 
|  | str_bool_inp = {'foo' : True, 'bar': False} | 
|  | scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)]) | 
|  | self.assertEqual(scripted_fn({'foo' : False, 'bar': True}, ), test_dict({'foo' : False, 'bar': True}, )) | 
|  |  | 
|  | str_list_inp = {0 : [True, False], 1: [False, True]} | 
|  | scripted_fn = torch.jit.script(test_dict_int_list, example_inputs=[(str_list_inp,)]) | 
|  | self.assertEqual(scripted_fn({0 : [False, False], 1: [True, True]}, ), | 
|  | test_dict_int_list({0 : [False, False], 1: [True, True]}, )) | 
|  |  | 
|  | def test_any(self): | 
|  | def test_multiple_types(a): | 
|  | assert not isinstance(a, bool) | 
|  | return a | 
|  |  | 
|  | def test_multiple_type_refinement(a): | 
|  | if isinstance(a, bool): | 
|  | return 1 | 
|  | elif isinstance(a, int): | 
|  | return 1 + a | 
|  | elif isinstance(a, float): | 
|  | return 1 + int(a) | 
|  | else: | 
|  | return -1 | 
|  |  | 
|  | make_global(test_multiple_types, test_multiple_type_refinement) | 
|  |  | 
|  | scripted_fn = torch.jit.script(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )]) | 
|  | self.assertEqual(scripted_fn(10), test_multiple_types(10)) | 
|  | self.assertEqual(scripted_fn("def"), test_multiple_types("def")) | 
|  | self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999)) | 
|  | self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14])) | 
|  |  | 
|  | scripted_fn = torch.jit.script(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,), | 
|  | ([3, 4, 5],), (True, ), ({"a": True}, ), ]) | 
|  | self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10)) | 
|  | self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def")) | 
|  | self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999)) | 
|  | self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14])) | 
|  | self.assertEqual(scripted_fn(False), test_multiple_type_refinement(False)) | 
|  | self.assertEqual(scripted_fn({"abc" : True, "def": False}), test_multiple_type_refinement({"abc" : True, "def": False})) | 
|  |  | 
|  | def test_class_as_profiled_types(self): | 
|  | class UserDefinedClass: | 
|  | def fn(self, b) -> Any: | 
|  | assert b is not None | 
|  | if isinstance(b, int): | 
|  | return b if b > 0 else -1 | 
|  | elif isinstance(b, float): | 
|  | return b if b > 0.0 else -1.0 | 
|  | return 0 | 
|  |  | 
|  | def test_model(a, m): | 
|  | assert not isinstance(a, bool) | 
|  | return m.fn(a) | 
|  |  | 
|  | make_global(UserDefinedClass, test_model) | 
|  |  | 
|  | user_class = UserDefinedClass() | 
|  | scripted_fn = torch.jit.script(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ]) | 
|  | self.assertEqual(scripted_fn(100, user_class, ), test_model(100, user_class)) | 
|  | self.assertEqual(scripted_fn(1.9, user_class, ), test_model(1.9, user_class)) | 
|  |  | 
|  | def test_class_with_args_as_profiled_types(self): | 
|  | class ClassWithArgs: | 
|  | def __init__(self, a: bool): | 
|  | self.a = a | 
|  |  | 
|  | def fn(self, b): | 
|  | if self.a: | 
|  | return b | 
|  | else: | 
|  | return -1 | 
|  |  | 
|  | def test_model_with_args(a, m): | 
|  | assert not isinstance(a, bool) | 
|  | return m.fn(a) | 
|  |  | 
|  | make_global(ClassWithArgs, test_model_with_args) | 
|  |  | 
|  | user_class = ClassWithArgs(False) | 
|  | scripted_fn = torch.jit.script(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ]) | 
|  | self.assertEqual(scripted_fn(100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True))) | 
|  |  | 
|  | def test_nn_parameter_as_arg(self): | 
|  | class TestNNParameter(torch.nn.Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | self.inp = torch.nn.Parameter(torch.ones(2, 3)) | 
|  |  | 
|  | def add_nn_parameter_with_int(self, x, y): | 
|  | return torch.add(x, y) | 
|  |  | 
|  | def forward(self, y): | 
|  | return self.add_nn_parameter_with_int(self.inp, y) | 
|  |  | 
|  | make_global(TestNNParameter) | 
|  | pdt_model = TestNNParameter() | 
|  | scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [(10, ), ], }) | 
|  | self.assertEqual(scripted_fn(20), pdt_model(20)) | 
|  |  | 
|  | def test_fx_tracing_with_typing(self): | 
|  | class FXModelOutput(NamedTuple): | 
|  | result: List[int] | 
|  |  | 
|  | class FXModel(torch.nn.Module): | 
|  | def forward(self, a) -> FXModelOutput: | 
|  | result = FXModelOutput(result=a) | 
|  | return result | 
|  |  | 
|  | make_global(FXModel, FXModelOutput) | 
|  | pdt_model = FXModel() | 
|  | scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) | 
|  | self.assertEqual(scripted_fn([20]), pdt_model([20])) | 
|  |  | 
|  | def test_nonetype_as_optional_of_type(self): | 
|  | def test_none(a) -> Any: | 
|  | if a is None: | 
|  | return 0 | 
|  | else: | 
|  | return a + torch.ones(1) | 
|  |  | 
|  | make_global(test_none) | 
|  |  | 
|  | scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10.6, )]) | 
|  | self.assertEqual(scripted_fn(30.9, ), test_none(30.9, )) | 
|  |  | 
|  | scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10, )]) | 
|  | self.assertEqual(scripted_fn(2, ), test_none(2, )) | 
|  |  | 
|  | scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (torch.Tensor(1), )]) | 
|  | self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), )) |