Fix handling of torch.return_types in dynamo (#120826)
Handle quasi-namedtuples as a special case in dynamo
Fixes #120651
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120826
Approved by: https://github.com/anijain2305
diff --git a/test/dynamo_expected_failures/TestNamedTupleAPI.test_namedtuple_return b/test/dynamo_expected_failures/TestNamedTupleAPI.test_namedtuple_return
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestNamedTupleAPI.test_namedtuple_return
+++ /dev/null
diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py
index 91469945..ad96886 100644
--- a/torch/_dynamo/variables/user_defined.py
+++ b/torch/_dynamo/variables/user_defined.py
@@ -307,7 +307,12 @@
elif is_namedtuple_cls(self.value):
fields = namedtuple_fields(self.value)
- field_defaults = self.value._field_defaults
+ # check if this a quasi-namedtuple or a real one
+ if self.value.__module__ == "torch.return_types":
+ # create pseudo-defaults from values of the quasi-namedtuple
+ field_defaults = dict(zip(fields, args[0].items))
+ else:
+ field_defaults = self.value._field_defaults
items = list(args)
items.extend([None] * (len(fields) - len(items)))