| # Owner(s): ["oncall: jit"] |
| |
| import io |
| import os |
| import sys |
| import warnings |
| from contextlib import redirect_stderr |
| |
| import torch |
| from torch.testing import FileCheck |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| if __name__ == "__main__": |
| raise RuntimeError( |
| "This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead." |
| ) |
| |
| |
| class TestWarn(JitTestCase): |
| def test_warn(self): |
| @torch.jit.script |
| def fn(): |
| warnings.warn("I am warning you") |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| fn() |
| |
| FileCheck().check_count( |
| str="UserWarning: I am warning you", count=1, exactly=True |
| ).run(f.getvalue()) |
| |
| def test_warn_only_once(self): |
| @torch.jit.script |
| def fn(): |
| for _ in range(10): |
| warnings.warn("I am warning you") |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| fn() |
| |
| FileCheck().check_count( |
| str="UserWarning: I am warning you", count=1, exactly=True |
| ).run(f.getvalue()) |
| |
| def test_warn_only_once_in_loop_func(self): |
| def w(): |
| warnings.warn("I am warning you") |
| |
| @torch.jit.script |
| def fn(): |
| for _ in range(10): |
| w() |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| fn() |
| |
| FileCheck().check_count( |
| str="UserWarning: I am warning you", count=1, exactly=True |
| ).run(f.getvalue()) |
| |
| def test_warn_once_per_func(self): |
| def w1(): |
| warnings.warn("I am warning you") |
| |
| def w2(): |
| warnings.warn("I am warning you") |
| |
| @torch.jit.script |
| def fn(): |
| w1() |
| w2() |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| fn() |
| |
| FileCheck().check_count( |
| str="UserWarning: I am warning you", count=2, exactly=True |
| ).run(f.getvalue()) |
| |
| def test_warn_once_per_func_in_loop(self): |
| def w1(): |
| warnings.warn("I am warning you") |
| |
| def w2(): |
| warnings.warn("I am warning you") |
| |
| @torch.jit.script |
| def fn(): |
| for _ in range(10): |
| w1() |
| w2() |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| fn() |
| |
| FileCheck().check_count( |
| str="UserWarning: I am warning you", count=2, exactly=True |
| ).run(f.getvalue()) |
| |
| def test_warn_multiple_calls_multiple_warnings(self): |
| @torch.jit.script |
| def fn(): |
| warnings.warn("I am warning you") |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| fn() |
| fn() |
| |
| FileCheck().check_count( |
| str="UserWarning: I am warning you", count=2, exactly=True |
| ).run(f.getvalue()) |
| |
| def test_warn_multiple_calls_same_func_diff_stack(self): |
| def warn(caller: str): |
| warnings.warn("I am warning you from " + caller) |
| |
| @torch.jit.script |
| def foo(): |
| warn("foo") |
| |
| @torch.jit.script |
| def bar(): |
| warn("bar") |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| foo() |
| bar() |
| |
| FileCheck().check_count( |
| str="UserWarning: I am warning you from foo", count=1, exactly=True |
| ).check_count( |
| str="UserWarning: I am warning you from bar", count=1, exactly=True |
| ).run( |
| f.getvalue() |
| ) |