[Dynamo] Support typing.Union and typing.Optional (#98384)
Fixes #98265
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98384
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index fdf7a74..8740cf8 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -1577,6 +1577,20 @@
res = opt_fn(x)
self.assertTrue(same(ref, res))
+ def test_typing_union_and_optional(self):
+ def fn(x):
+ a = torch.jit.annotate(typing.Dict[str, typing.Optional[torch.Tensor]], {})
+ b = torch.jit.annotate(
+ typing.Dict[str, typing.Union[torch.Tensor, None]], {}
+ )
+ return a, b, x + 1
+
+ x = torch.randn(3)
+ ref = fn(x)
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ res = opt_fn(x)
+ self.assertTrue(same(ref, res))
+
def test_optimize_on_module(self):
class MockModule(torch.nn.Module):
def __init__(self):
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index 4c5fbda..77bb927 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -382,7 +382,9 @@
if sys.version_info < (3, 9):
return isinstance(value, typing._GenericAlias)
else:
- return isinstance(value, typing._SpecialGenericAlias)
+ return isinstance(
+ value, (typing._SpecialGenericAlias, typing._UnionGenericAlias)
+ )
def is_numpy_int_type(value):