blob: 1cf31f9edc36255e00de5013f2f9f61c06e63747 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.config
import torch._dynamo.test_case
import torch._functorch.config
import torch.utils.checkpoint
class ExceptionTests(torch._dynamo.test_case.TestCase):
def test_exception(self):
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
raise NotImplementedError
except Exception:
x = torch.sigmoid(x)
return x
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_exception2(self):
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
raise NotImplementedError
except (NotImplementedError, AttributeError) as e:
x = torch.sigmoid(x)
return x
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_exception3(self):
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
raise NotImplementedError("Not implemented")
except AssertionError:
x = torch.sigmoid(x)
except NotImplementedError:
x = torch.cos(x)
finally:
x = torch.cos(x)
return x
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_exception_with_another_exception(self):
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
raise NotImplementedError("Not implemented")
except NotImplementedError as e:
x = torch.sigmoid(x)
try:
x = torch.cos(x)
raise AssertionError
except AssertionError:
x = torch.cos(x)
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_exception_else(self):
def gn(x):
return torch.cos(x)
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
x = gn(x)
except Exception:
x = torch.sigmoid(x)
else:
x = torch.cos(x)
return x
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
# TODO(anijain2305) - does not work with fullgraph=True
def test_exception_with_another_exception2(self):
def gn(x):
try:
x = torch.cos(x)
raise NotImplementedError("Not implemented")
except NotImplementedError as e:
x = torch.sigmoid(x)
raise
def fn(x):
try:
x = torch.cos(x)
gn(x)
except Exception:
pass
return x
x = torch.randn(4)
ref = fn(x)
# Cant use fullgraph=True because RERAISE is not supported
opt_fn = torch.compile(fn, backend="eager")
res = opt_fn(x)
# TODO(anijain2305) - does not work with fullgraph=True
def test_exception_with_ctx_manager(self):
def fn(x):
x = torch.cos(x)
try:
with torch.no_grad():
x = torch.sin(x)
raise NotImplementedError("Not implemented")
except NotImplementedError as e:
x = torch.sigmoid(x)
return x
x = torch.randn(4)
ref = fn(x)
# Cant use fullgraph=True because WITH_EXCEPT_START is not supported
opt_fn = torch.compile(fn, backend="eager")
res = opt_fn(x)
self.assertEqual(ref, res)
def test_exception_raised_from_child(self):
def gn():
raise NotImplementedError("foo")
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
gn()
x = torch.sin(x)
except Exception:
x = torch.sigmoid(x)
return x
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_nn_module_getattr(self):
class A:
def __init__(self):
self._b = 20
def __getattr__(self, name):
fixed_name = "_" + name
if fixed_name in self.__dict__:
return self.__dict__[fixed_name]
raise AttributeError(f"{name} absent")
class B(A):
def __init__(self):
self.a = 10
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return 30
obj = B()
def fn(x):
return x * obj.a * obj.b * obj.c
x = torch.ones(4)
ref = fn(x)
print(ref)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_custom_getattr_on_module_exception(self):
class Foo(torch.nn.Module):
def __init__(self, a=3):
super().__init__()
self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2))
def __getattr__(self, name):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "a_copy":
return self.a
raise
def forward(self, x):
return x * self.a * self.a_copy
mod = Foo()
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
x = torch.ones(4)
self.assertEqual(mod(x), opt_mod(x))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()