| # Owner(s): ["module: dynamo"] |
| |
| import torch |
| import torch._dynamo.config |
| |
| import torch._dynamo.test_case |
| import torch._functorch.config |
| import torch.utils.checkpoint |
| |
| |
| class ExceptionTests(torch._dynamo.test_case.TestCase): |
| def test_exception(self): |
| def fn(x): |
| x = torch.cos(x) |
| try: |
| x = torch.sin(x) |
| raise NotImplementedError |
| except Exception: |
| x = torch.sigmoid(x) |
| |
| return x |
| |
| x = torch.randn(4) |
| ref = fn(x) |
| opt_fn = torch.compile(fn, backend="eager", fullgraph=True) |
| res = opt_fn(x) |
| self.assertEqual(ref, res) |
| |
| def test_exception2(self): |
| def fn(x): |
| x = torch.cos(x) |
| try: |
| x = torch.sin(x) |
| raise NotImplementedError |
| except (NotImplementedError, AttributeError) as e: |
| x = torch.sigmoid(x) |
| |
| return x |
| |
| x = torch.randn(4) |
| ref = fn(x) |
| opt_fn = torch.compile(fn, backend="eager", fullgraph=True) |
| res = opt_fn(x) |
| self.assertEqual(ref, res) |
| |
| def test_exception3(self): |
| def fn(x): |
| x = torch.cos(x) |
| try: |
| x = torch.sin(x) |
| raise NotImplementedError("Not implemented") |
| except AssertionError: |
| x = torch.sigmoid(x) |
| except NotImplementedError: |
| x = torch.cos(x) |
| finally: |
| x = torch.cos(x) |
| |
| return x |
| |
| x = torch.randn(4) |
| ref = fn(x) |
| opt_fn = torch.compile(fn, backend="eager", fullgraph=True) |
| res = opt_fn(x) |
| self.assertEqual(ref, res) |
| |
| def test_exception_with_another_exception(self): |
| def fn(x): |
| x = torch.cos(x) |
| try: |
| x = torch.sin(x) |
| raise NotImplementedError("Not implemented") |
| except NotImplementedError as e: |
| x = torch.sigmoid(x) |
| try: |
| x = torch.cos(x) |
| raise AssertionError |
| except AssertionError: |
| x = torch.cos(x) |
| |
| x = torch.randn(4) |
| ref = fn(x) |
| opt_fn = torch.compile(fn, backend="eager", fullgraph=True) |
| res = opt_fn(x) |
| self.assertEqual(ref, res) |
| |
| def test_exception_else(self): |
| def gn(x): |
| return torch.cos(x) |
| |
| def fn(x): |
| x = torch.cos(x) |
| try: |
| x = torch.sin(x) |
| x = gn(x) |
| except Exception: |
| x = torch.sigmoid(x) |
| else: |
| x = torch.cos(x) |
| |
| return x |
| |
| x = torch.randn(4) |
| ref = fn(x) |
| opt_fn = torch.compile(fn, backend="eager", fullgraph=True) |
| res = opt_fn(x) |
| self.assertEqual(ref, res) |
| |
| # TODO(anijain2305) - does not work with fullgraph=True |
| def test_exception_with_another_exception2(self): |
| def gn(x): |
| try: |
| x = torch.cos(x) |
| raise NotImplementedError("Not implemented") |
| except NotImplementedError as e: |
| x = torch.sigmoid(x) |
| raise |
| |
| def fn(x): |
| try: |
| x = torch.cos(x) |
| gn(x) |
| except Exception: |
| pass |
| return x |
| |
| x = torch.randn(4) |
| ref = fn(x) |
| # Cant use fullgraph=True because RERAISE is not supported |
| opt_fn = torch.compile(fn, backend="eager") |
| res = opt_fn(x) |
| |
| # TODO(anijain2305) - does not work with fullgraph=True |
| def test_exception_with_ctx_manager(self): |
| def fn(x): |
| x = torch.cos(x) |
| try: |
| with torch.no_grad(): |
| x = torch.sin(x) |
| raise NotImplementedError("Not implemented") |
| except NotImplementedError as e: |
| x = torch.sigmoid(x) |
| return x |
| |
| x = torch.randn(4) |
| ref = fn(x) |
| # Cant use fullgraph=True because WITH_EXCEPT_START is not supported |
| opt_fn = torch.compile(fn, backend="eager") |
| res = opt_fn(x) |
| self.assertEqual(ref, res) |
| |
| def test_exception_raised_from_child(self): |
| def gn(): |
| raise NotImplementedError("foo") |
| |
| def fn(x): |
| x = torch.cos(x) |
| try: |
| x = torch.sin(x) |
| gn() |
| x = torch.sin(x) |
| except Exception: |
| x = torch.sigmoid(x) |
| |
| return x |
| |
| x = torch.randn(4) |
| ref = fn(x) |
| opt_fn = torch.compile(fn, backend="eager", fullgraph=True) |
| res = opt_fn(x) |
| self.assertEqual(ref, res) |
| |
| def test_nn_module_getattr(self): |
| class A: |
| def __init__(self): |
| self._b = 20 |
| |
| def __getattr__(self, name): |
| fixed_name = "_" + name |
| if fixed_name in self.__dict__: |
| return self.__dict__[fixed_name] |
| raise AttributeError(f"{name} absent") |
| |
| class B(A): |
| def __init__(self): |
| self.a = 10 |
| |
| def __getattr__(self, name): |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| return 30 |
| |
| obj = B() |
| |
| def fn(x): |
| return x * obj.a * obj.b * obj.c |
| |
| x = torch.ones(4) |
| ref = fn(x) |
| print(ref) |
| opt_fn = torch.compile(fn, backend="eager", fullgraph=True) |
| res = opt_fn(x) |
| self.assertEqual(ref, res) |
| |
| @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) |
| def test_custom_getattr_on_module_exception(self): |
| class Foo(torch.nn.Module): |
| def __init__(self, a=3): |
| super().__init__() |
| self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2)) |
| |
| def __getattr__(self, name): |
| try: |
| return super().__getattr__(name) # defer to nn.Module's logic |
| except AttributeError: |
| if name == "a_copy": |
| return self.a |
| raise |
| |
| def forward(self, x): |
| return x * self.a * self.a_copy |
| |
| mod = Foo() |
| opt_mod = torch.compile(mod, backend="eager", fullgraph=True) |
| |
| x = torch.ones(4) |
| self.assertEqual(mod(x), opt_mod(x)) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |