Jane Xu | 71b7182 | 2021-10-26 07:43:01 -0700 | [diff] [blame] | 1 | # Owner(s): ["oncall: package/deploy"] |
| 2 | |
Shunting Zhang | 0d7036f | 2021-09-28 19:20:42 -0700 | [diff] [blame] | 3 | import textwrap |
| 4 | import types |
| 5 | |
| 6 | from torch.utils._freeze import Freezer, PATH_MARKER |
| 7 | from torch.testing._internal.common_utils import run_tests, TestCase |
| 8 | |
| 9 | |
| 10 | class TestFreezer(TestCase): |
| 11 | """Tests the freeze.py script""" |
| 12 | |
| 13 | def test_compile_string(self): |
| 14 | freezer = Freezer(True) |
| 15 | code_str = textwrap.dedent( |
| 16 | """ |
| 17 | class MyCls: |
| 18 | def __init__(self): |
| 19 | pass |
| 20 | """ |
| 21 | ) |
| 22 | co = freezer.compile_string(code_str) |
| 23 | num_co = 0 |
| 24 | |
| 25 | def verify_filename(co: types.CodeType): |
| 26 | nonlocal num_co |
| 27 | |
| 28 | if not isinstance(co, types.CodeType): |
| 29 | return |
| 30 | |
| 31 | self.assertEqual(PATH_MARKER, co.co_filename) |
| 32 | num_co += 1 |
| 33 | |
| 34 | for nested_co in co.co_consts: |
| 35 | verify_filename(nested_co) |
| 36 | |
| 37 | verify_filename(co) |
| 38 | # there is at least one nested code object besides the top level one |
| 39 | self.assertTrue(num_co >= 2) |
| 40 | |
| 41 | |
| 42 | if __name__ == "__main__": |
| 43 | run_tests() |