| # Owner(s): ["module: dynamo"] |
| |
| import logging |
| import unittest |
| |
| import torch |
| import torch._dynamo |
| import torch._dynamo.config |
| import torch._dynamo.test_case |
| from torch._dynamo.comptime import comptime |
| from torch._dynamo.exc import Unsupported |
| from torch.testing._internal.common_device_type import skipIf |
| from torch.testing._internal.common_utils import IS_FBCODE, munge_exc, TEST_Z3 |
| from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test |
| |
| |
| class ExcTests(LoggingTestCase): |
| maxDiff = None |
| |
| def test_unsupported_real_stack(self): |
| # exercise Unsupported constructor and augment_exc_message |
| def fn002(x): |
| torch._dynamo.graph_break() |
| |
| def fn001(x): |
| x = x + 1 |
| fn002(x) |
| |
| self.assertExpectedInlineMunged( |
| Unsupported, |
| lambda: torch.compile(fn001, backend="eager", fullgraph=True)( |
| torch.randn(1) |
| ), |
| """\ |
| 'skip function graph_break in file _dynamo/decorators.py' |
| |
| from user code: |
| File "test_exc.py", line N, in fn001 |
| fn002(x) |
| File "test_exc.py", line N, in fn002 |
| torch._dynamo.graph_break()""", |
| ) |
| |
| @torch._dynamo.config.patch(verbose=True, suppress_errors=True) |
| @make_logging_test() |
| @unittest.skipIf(IS_FBCODE, "stack trace slightly different in fbcode") |
| def test_internal_error_suppress_errors(self, records): |
| def fn001(x): |
| def f(ctx): |
| raise AssertionError() |
| |
| comptime(f) |
| |
| torch.compile(fn001, backend="eager")(torch.randn(1)) |
| |
| record = self.getRecord(records, "WON'T CONVERT") |
| |
| self.assertExpectedInline( |
| munge_exc(record.getMessage()), |
| """\ |
| WON'T CONVERT fn001 test_exc.py line N |
| ========== TorchDynamo Stack Trace ========== |
| Traceback (most recent call last): |
| File "test_exc.py", line N, in f |
| raise AssertionError() |
| AssertionError: |
| |
| from user code: |
| File "test_exc.py", line N, in fn001 |
| comptime(f) |
| |
| |
| ========== The above exception occurred while processing the following code ========== |
| |
| File "test_exc.py", line N, in test_internal_error_suppress_errors |
| torch.compile(fn001, backend="eager")(torch.randn(1)) |
| File "test_exc.py", line N, in fn001 |
| comptime(f) |
| |
| ==========""", |
| ) |
| |
| @make_logging_test() |
| def test_not_implemented_error(self, records): |
| def fn001(x): |
| def f(ctx): |
| raise NotImplementedError() |
| |
| # Ensure graph break is not possible |
| for i in range(3): |
| comptime(f) |
| |
| torch.compile(fn001, backend="eager")(torch.randn(1)) |
| |
| record = self.getRecord(records, "WON'T CONVERT") |
| |
| self.assertExpectedInline( |
| munge_exc(record.getMessage()), |
| """\ |
| WON'T CONVERT fn001 test_exc.py line N |
| due to: |
| Traceback (most recent call last): |
| File "test_exc.py", line N, in f |
| raise NotImplementedError() |
| torch._dynamo.exc.InternalTorchDynamoError: |
| |
| from user code: |
| File "test_exc.py", line N, in fn001 |
| comptime(f)""", |
| ) |
| |
| @unittest.expectedFailure |
| @torch._dynamo.config.patch(inject_BUILD_SET_unimplemented_TESTING_ONLY=True) |
| @make_logging_test(dynamo=logging.DEBUG) |
| def test_unsupported_error(self, records): |
| def fn001(x): |
| return {1, 2} |
| |
| torch.compile(fn001, backend="eager")(torch.randn(1)) |
| |
| # TODO: There is no graph break log! This is because the graph break |
| # logging is not in a centralized location; unsupported |
| # instruction bypasses it |
| self.getRecord(records, "Graph break:") |
| |
| @torch._dynamo.config.patch(suppress_errors=False) |
| def test_internal_error_no_suppress(self): |
| def fn001(x): |
| # NB: avoid decorator, as 3.11 changed the line number attributed |
| # in this situation |
| def f(ctx): |
| raise AssertionError() |
| |
| comptime(f) |
| |
| # NB: OK for user code to be truncated here, because the regular |
| # exception backtrace has the rest of the crumbs |
| self.assertExpectedInlineMunged( |
| AssertionError, |
| lambda: torch.compile(fn001, backend="eager")(torch.randn(1)), |
| """\ |
| |
| |
| from user code: |
| File "test_exc.py", line N, in fn001 |
| comptime(f)""", |
| ) |
| |
| @make_logging_test(graph_breaks=True) |
| def test_graph_break_log(self, records): |
| def fn002(x): |
| x = x + 1 |
| torch._dynamo.graph_break() |
| x = x + 1 |
| return x |
| |
| def fn001(x): |
| return fn002(x) |
| |
| torch.compile(fn001, backend="eager")(torch.randn(1)) |
| |
| record = self.getRecord(records, "Graph break:") |
| |
| # TODO: This should also report the enclosing frames; need to plumb |
| # frame object to it |
| self.assertExpectedInline( |
| munge_exc(record.getMessage()), |
| """\ |
| Graph break: from user code at: |
| File "test_exc.py", line N, in fn001 |
| return fn002(x) |
| File "test_exc.py", line N, in fn002 |
| torch._dynamo.graph_break() |
| """, # noqa: B950 |
| ) |
| |
| @torch._dynamo.config.patch(suppress_errors=False) |
| def test_backend_suppress_line(self): |
| def fn001(x): |
| x = torch.relu(x) |
| return x + 1 |
| |
| # Do NOT let this get attributed to x + 1 |
| self.assertExpectedInlineMunged( |
| torch._dynamo.exc.BackendCompilerFailed, |
| lambda: torch.compile(fn001, backend="relu_compile_error_TESTING_ONLY")( |
| torch.randn(1) |
| ), |
| """\ |
| backend='relu_compile_error_TESTING_ONLY' raised: |
| ReluCompileError:""", |
| ) |
| |
| @skipIf(not TEST_Z3, "z3 not installed") |
| @torch._dynamo.config.patch( |
| assume_static_by_default=False, |
| suppress_errors=False, |
| ) |
| @torch.fx.experimental._config.patch( |
| inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True, |
| translation_validation=True, |
| translation_validation_no_bisect=True, |
| ) |
| def test_trigger_on_error(self): |
| from torch.fx.experimental.validator import ValidationException |
| |
| @torch.compile |
| def fn(x, shape): |
| return x.split(shape) |
| |
| self.assertExpectedInlineMunged( |
| ValidationException, |
| lambda: fn(torch.randn(20), (5, 10, 5)), |
| """\ |
| translation validation failed. |
| |
| Model: |
| ==> L['shape'][0]: 0 |
| ==> L['shape'][1]: 0 |
| ==> L['shape'][2]: 0 |
| ==> L['x'].size()[0]: 3 |
| ==> L['x'].storage_offset(): 0 |
| ==> L['x'].stride()[0]: 1 |
| ==> s0: 3 |
| ==> s1: 0 |
| ==> s2: 0 |
| ==> s3: 0 |
| |
| Assertions: |
| ==> (== 0 L['x'].storage_offset()) |
| ==> (== 1 L['x'].stride()[0]) |
| ==> (== L['shape'][0] s1) |
| ==> (== L['shape'][1] s2) |
| ==> (== L['shape'][2] s3) |
| ==> (== L['x'].size()[0] s0) |
| ==> (> s0 1) |
| ==> (True) |
| |
| Target Expressions: |
| ==> (<= 0 s1) |
| ==> (<= 0 s2) |
| ==> (<= 0 s3) |
| ==> (<= 2 s0) |
| ==> (== 0 L['shape'][0]) |
| ==> (== 0 L['shape'][1]) |
| ==> (== 0 L['shape'][2]) |
| ==> (== 0 L['x'].storage_offset()) |
| ==> (== 0 s1) |
| ==> (== 0 s2) |
| ==> (== 0 s3) |
| ==> (== 1 L['x'].stride()[0]) |
| ==> (== L['x'].size()[0] s0) |
| ==> (> s0 0) |
| ==> (>= 0 s1) |
| ==> (>= 0 s2) |
| ==> (>= 0 s3) |
| ==> (>= 9223372036854775806 s0) |
| |
| Failed Source Expressions: |
| ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", |
| ) |
| |
| @skipIf(not TEST_Z3, "z3 not installed") |
| @torch._dynamo.config.patch( |
| assume_static_by_default=False, |
| suppress_errors=False, |
| ) |
| @torch.fx.experimental._config.patch( |
| inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True, |
| translation_validation=True, |
| ) |
| def test_trigger_bisect_on_error(self): |
| from torch.fx.experimental.validator import BisectValidationException |
| |
| @torch.compile |
| def fn(x, shape): |
| return x.split(shape) |
| |
| self.assertExpectedInlineMunged( |
| BisectValidationException, |
| lambda: fn(torch.randn(20), (5, 10, 5)), |
| """\ |
| translation validation failed when evaluating: Eq(s1 + s2 + s3, s0) |
| |
| Failure occurred while running node: |
| %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {}) |
| |
| Model: |
| ==> L['shape'][0]: 1 |
| ==> L['shape'][1]: 1 |
| ==> L['shape'][2]: 2 |
| ==> L['x'].size()[0]: 3 |
| ==> L['x'].storage_offset(): 0 |
| ==> L['x'].stride()[0]: 1 |
| ==> s0: 3 |
| ==> s1: 1 |
| ==> s2: 1 |
| ==> s3: 2 |
| |
| Assertions: |
| ==> (== 0 L['x'].storage_offset()) |
| ==> (== 1 L['x'].stride()[0]) |
| ==> (== L['shape'][0] s1) |
| ==> (== L['shape'][1] s2) |
| ==> (== L['shape'][2] s3) |
| ==> (== L['x'].size()[0] s0) |
| ==> (> s0 1) |
| |
| Target Expressions: |
| ==> (!= (+ s1 s2 s3) s0) |
| ==> (<= 0 s1) |
| ==> (<= 0 s2) |
| ==> (<= 0 s3) |
| ==> (<= 2 s0) |
| ==> (== 0 L['x'].storage_offset()) |
| ==> (== 1 L['x'].stride()[0]) |
| ==> (== L['shape'][0] s1) |
| ==> (== L['shape'][1] s2) |
| ==> (== L['shape'][2] s3) |
| ==> (== L['x'].size()[0] s0) |
| ==> (> s0 0) |
| ==> (>= 9223372036854775806 s0) |
| ==> (>= 9223372036854775807 s1) |
| ==> (>= 9223372036854775807 s2) |
| ==> (>= 9223372036854775807 s3) |
| |
| Failed Source Expressions: |
| ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", |
| ) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |