blob: 5d810ee769758e7007a41f9c0cc054c236b22ea3 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import logging
import unittest
import torch
import torch._dynamo
import torch._dynamo.config
import torch._dynamo.test_case
from torch._dynamo.comptime import comptime
from torch._dynamo.exc import Unsupported
from torch.testing._internal.common_device_type import skipIf
from torch.testing._internal.common_utils import IS_FBCODE, munge_exc, TEST_Z3
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
class ExcTests(LoggingTestCase):
maxDiff = None
def test_unsupported_real_stack(self):
# exercise Unsupported constructor and augment_exc_message
def fn002(x):
torch._dynamo.graph_break()
def fn001(x):
x = x + 1
fn002(x)
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn001, backend="eager", fullgraph=True)(
torch.randn(1)
),
"""\
'skip function graph_break in file _dynamo/decorators.py'
from user code:
File "test_exc.py", line N, in fn001
fn002(x)
File "test_exc.py", line N, in fn002
torch._dynamo.graph_break()""",
)
@torch._dynamo.config.patch(verbose=True, suppress_errors=True)
@make_logging_test()
@unittest.skipIf(IS_FBCODE, "stack trace slightly different in fbcode")
def test_internal_error_suppress_errors(self, records):
def fn001(x):
def f(ctx):
raise AssertionError()
comptime(f)
torch.compile(fn001, backend="eager")(torch.randn(1))
record = self.getRecord(records, "WON'T CONVERT")
self.assertExpectedInline(
munge_exc(record.getMessage()),
"""\
WON'T CONVERT fn001 test_exc.py line N
========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
File "test_exc.py", line N, in f
raise AssertionError()
AssertionError:
from user code:
File "test_exc.py", line N, in fn001
comptime(f)
========== The above exception occurred while processing the following code ==========
File "test_exc.py", line N, in test_internal_error_suppress_errors
torch.compile(fn001, backend="eager")(torch.randn(1))
File "test_exc.py", line N, in fn001
comptime(f)
==========""",
)
@make_logging_test()
def test_not_implemented_error(self, records):
def fn001(x):
def f(ctx):
raise NotImplementedError()
# Ensure graph break is not possible
for i in range(3):
comptime(f)
torch.compile(fn001, backend="eager")(torch.randn(1))
record = self.getRecord(records, "WON'T CONVERT")
self.assertExpectedInline(
munge_exc(record.getMessage()),
"""\
WON'T CONVERT fn001 test_exc.py line N
due to:
Traceback (most recent call last):
File "test_exc.py", line N, in f
raise NotImplementedError()
torch._dynamo.exc.InternalTorchDynamoError:
from user code:
File "test_exc.py", line N, in fn001
comptime(f)""",
)
@unittest.expectedFailure
@torch._dynamo.config.patch(inject_BUILD_SET_unimplemented_TESTING_ONLY=True)
@make_logging_test(dynamo=logging.DEBUG)
def test_unsupported_error(self, records):
def fn001(x):
return {1, 2}
torch.compile(fn001, backend="eager")(torch.randn(1))
# TODO: There is no graph break log! This is because the graph break
# logging is not in a centralized location; unsupported
# instruction bypasses it
self.getRecord(records, "Graph break:")
@torch._dynamo.config.patch(suppress_errors=False)
def test_internal_error_no_suppress(self):
def fn001(x):
# NB: avoid decorator, as 3.11 changed the line number attributed
# in this situation
def f(ctx):
raise AssertionError()
comptime(f)
# NB: OK for user code to be truncated here, because the regular
# exception backtrace has the rest of the crumbs
self.assertExpectedInlineMunged(
AssertionError,
lambda: torch.compile(fn001, backend="eager")(torch.randn(1)),
"""\
from user code:
File "test_exc.py", line N, in fn001
comptime(f)""",
)
@make_logging_test(graph_breaks=True)
def test_graph_break_log(self, records):
def fn002(x):
x = x + 1
torch._dynamo.graph_break()
x = x + 1
return x
def fn001(x):
return fn002(x)
torch.compile(fn001, backend="eager")(torch.randn(1))
record = self.getRecord(records, "Graph break:")
# TODO: This should also report the enclosing frames; need to plumb
# frame object to it
self.assertExpectedInline(
munge_exc(record.getMessage()),
"""\
Graph break: from user code at:
File "test_exc.py", line N, in fn001
return fn002(x)
File "test_exc.py", line N, in fn002
torch._dynamo.graph_break()
""", # noqa: B950
)
@torch._dynamo.config.patch(suppress_errors=False)
def test_backend_suppress_line(self):
def fn001(x):
x = torch.relu(x)
return x + 1
# Do NOT let this get attributed to x + 1
self.assertExpectedInlineMunged(
torch._dynamo.exc.BackendCompilerFailed,
lambda: torch.compile(fn001, backend="relu_compile_error_TESTING_ONLY")(
torch.randn(1)
),
"""\
backend='relu_compile_error_TESTING_ONLY' raised:
ReluCompileError:""",
)
@skipIf(not TEST_Z3, "z3 not installed")
@torch._dynamo.config.patch(
assume_static_by_default=False,
suppress_errors=False,
)
@torch.fx.experimental._config.patch(
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
translation_validation=True,
translation_validation_no_bisect=True,
)
def test_trigger_on_error(self):
from torch.fx.experimental.validator import ValidationException
@torch.compile
def fn(x, shape):
return x.split(shape)
self.assertExpectedInlineMunged(
ValidationException,
lambda: fn(torch.randn(20), (5, 10, 5)),
"""\
translation validation failed.
Model:
==> L['shape'][0]: 0
==> L['shape'][1]: 0
==> L['shape'][2]: 0
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: 0
==> s2: 0
==> s3: 0
Assertions:
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 1)
==> (True)
Target Expressions:
==> (<= 0 s1)
==> (<= 0 s2)
==> (<= 0 s3)
==> (<= 2 s0)
==> (== 0 L['shape'][0])
==> (== 0 L['shape'][1])
==> (== 0 L['shape'][2])
==> (== 0 L['x'].storage_offset())
==> (== 0 s1)
==> (== 0 s2)
==> (== 0 s3)
==> (== 1 L['x'].stride()[0])
==> (== L['x'].size()[0] s0)
==> (> s0 0)
==> (>= 0 s1)
==> (>= 0 s2)
==> (>= 0 s3)
==> (>= 9223372036854775806 s0)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
)
@skipIf(not TEST_Z3, "z3 not installed")
@torch._dynamo.config.patch(
assume_static_by_default=False,
suppress_errors=False,
)
@torch.fx.experimental._config.patch(
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
translation_validation=True,
)
def test_trigger_bisect_on_error(self):
from torch.fx.experimental.validator import BisectValidationException
@torch.compile
def fn(x, shape):
return x.split(shape)
self.assertExpectedInlineMunged(
BisectValidationException,
lambda: fn(torch.randn(20), (5, 10, 5)),
"""\
translation validation failed when evaluating: Eq(s1 + s2 + s3, s0)
Failure occurred while running node:
%split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})
Model:
==> L['shape'][0]: 1
==> L['shape'][1]: 1
==> L['shape'][2]: 2
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: 1
==> s2: 1
==> s3: 2
Assertions:
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 1)
Target Expressions:
==> (!= (+ s1 s2 s3) s0)
==> (<= 0 s1)
==> (<= 0 s2)
==> (<= 0 s3)
==> (<= 2 s0)
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 0)
==> (>= 9223372036854775806 s0)
==> (>= 9223372036854775807 s1)
==> (>= 9223372036854775807 s2)
==> (>= 9223372036854775807 s3)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()