blob: d9b364443eb13246c8d039a1d33bd4892ed1e11e [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.config
import torch._dynamo.test_case
import torch._functorch.config
import torch.utils.checkpoint
class ExceptionTests(torch._dynamo.test_case.TestCase):
def test_exception(self):
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
raise NotImplementedError
except Exception:
x = torch.sigmoid(x)
return x
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_exception2(self):
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
raise NotImplementedError
except (NotImplementedError, AttributeError) as e:
x = torch.sigmoid(x)
return x
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_exception3(self):
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
raise NotImplementedError("Not implemented")
except AssertionError:
x = torch.sigmoid(x)
except NotImplementedError:
x = torch.cos(x)
finally:
x = torch.cos(x)
return x
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_exception4(self):
def fn(x):
for i in range(10):
if i == 5:
return x
try:
x = torch.sin(x)
raise NotImplementedError
except Exception:
x = torch.sigmoid(x)
return x
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_exception_with_another_exception(self):
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
raise NotImplementedError("Not implemented")
except NotImplementedError as e:
x = torch.sigmoid(x)
try:
x = torch.cos(x)
raise AssertionError
except AssertionError:
x = torch.cos(x)
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_exception_else(self):
def gn(x):
return torch.cos(x)
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
x = gn(x)
except Exception:
x = torch.sigmoid(x)
else:
x = torch.cos(x)
return x
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
# TODO(anijain2305) - does not work with fullgraph=True
def test_exception_with_another_exception2(self):
def gn(x):
try:
x = torch.cos(x)
raise NotImplementedError("Not implemented")
except NotImplementedError as e:
x = torch.sigmoid(x)
raise
def fn(x):
try:
x = torch.cos(x)
gn(x)
except Exception:
pass
return x
x = torch.randn(4)
ref = fn(x)
# Cant use fullgraph=True because RERAISE is not supported
opt_fn = torch.compile(fn, backend="eager")
res = opt_fn(x)
# TODO(anijain2305) - does not work with fullgraph=True
def test_exception_with_ctx_manager(self):
def fn(x):
x = torch.cos(x)
try:
with torch.no_grad():
x = torch.sin(x)
raise NotImplementedError("Not implemented")
except NotImplementedError as e:
x = torch.sigmoid(x)
return x
x = torch.randn(4)
ref = fn(x)
# Cant use fullgraph=True because WITH_EXCEPT_START is not supported
opt_fn = torch.compile(fn, backend="eager")
res = opt_fn(x)
self.assertEqual(ref, res)
def test_exception_raised_from_child(self):
def gn():
raise NotImplementedError("foo")
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
gn()
x = torch.sin(x)
except Exception:
x = torch.sigmoid(x)
return x
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_dynamo_undo_kw_names(self):
def g(x, k=None):
if k:
raise TypeError("error")
return x.sin()
def fn(x):
d = {"a": x}
try:
g(x, k=True)
except Exception:
y = 0
for _, b in d.items(): # noqa: PERF102
y += b.sum()
return y
x = torch.randn(2, 3)
expected = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
got = opt_fn(x)
self.assertEqual(expected, got)
def test_nn_module_getattr(self):
class A:
def __init__(self) -> None:
self._b = 20
def __getattr__(self, name):
fixed_name = "_" + name
if fixed_name in self.__dict__:
return self.__dict__[fixed_name]
raise AttributeError(f"{name} absent")
class B(A):
def __init__(self) -> None:
self.a = 10
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return 30
obj = B()
def fn(x):
return x * obj.a * obj.b * obj.c
x = torch.ones(4)
ref = fn(x)
print(ref)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_custom_getattr_on_module_exception(self):
class Foo(torch.nn.Module):
def __init__(self, a=3):
super().__init__()
self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2))
def __getattr__(self, name):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "a_copy":
return self.a
raise
def forward(self, x):
return x * self.a * self.a_copy
mod = Foo()
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
x = torch.ones(4)
self.assertEqual(mod(x), opt_mod(x))
def test_stop_iteration(self):
def zip_longest(*iterables, fillvalue=None):
# Get the iterators for each iterable
iterators = [iter(it) for it in iterables]
result = []
while True:
for it in iterators:
try:
value = next(it)
except StopIteration:
result.append(fillvalue)
return result
result.append(value)
def fn(x, y):
torch.cos(torch.randn(4))
return tuple(zip_longest(x, y))
x = [1, 2, 3, 4]
y = [10, 11, 12]
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
def test_key_error(self):
def fn(x, d):
try:
a = d["b"]
except KeyError:
a = 2
return x * a
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
d = {"a": 1}
ref = fn(x, d)
res = opt_fn(x, d)
self.assertEqual(ref, res)
def test_atrribute_error(self):
class Mock:
def __init__(self):
self.a = 1
mock = Mock()
def fn(x):
try:
c = 2
mock.b
except AttributeError:
c = 3
return torch.sin(x) * c
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(4)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()