[dynamo] Support dict unpack of MutableMapping objects (#131961)
Fixes https://github.com/pytorch/pytorch/issues/128067
The basic functionality was alredy introduced earlier. This just ensures
that we support UserDefinedObjectVariable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131961
Approved by: https://github.com/williamwen42, https://github.com/mlazos, https://github.com/yanboliang
ghstack dependencies: #131827, #131956
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index cef416b..ebc76bb 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -1059,6 +1059,51 @@
res = opt_fn(x)
self.assertEqual(ref, res)
+ def test_unpack_mutable_map(self):
+ from collections.abc import MutableMapping
+
+ class TensorDict(MutableMapping):
+ def __init__(self):
+ self._dict = {}
+
+ def add(self, key, value):
+ self._dict[key] = value
+
+ def items(self):
+ return self._dict.items()
+
+ def __delitem__(self, key):
+ del self._dict[key]
+
+ def __getitem__(self, key):
+ return self._dict[key]
+
+ def __iter__(self):
+ return iter(self._dict)
+
+ def __len__(self):
+ return len(self._dict)
+
+ def __setitem__(self, key, value):
+ self._dict[key] = value
+
+ tensor_dict = TensorDict()
+ tensor_dict.add("a", torch.ones(4) * 2)
+
+ def gn(x, a=1):
+ return x * a
+
+ def fn(x):
+ return gn(x, **tensor_dict)
+
+ opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
+
+ x = torch.randn(4)
+
+ ref = fn(x)
+ res = opt_fn(x)
+ self.assertEqual(ref, res)
+
def _test_default_dict_helper(self, factory):
dd = collections.defaultdict(factory)
param = torch.nn.Parameter(torch.ones([2, 2]))
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index c1bc2ae..4cc8f01 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -1590,6 +1590,10 @@
) and argsvars.has_force_unpack_var_sequence(self):
argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
+ # Unpack for cases like fn(**obj) where obj is a map
+ if isinstance(kwargsvars, UserDefinedObjectVariable):
+ kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type]
+
if not isinstance(argsvars, BaseListVariable) or not isinstance(
kwargsvars, ConstDictVariable
):
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index 30dd19a..1b2af4d 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -1337,9 +1337,10 @@
# This is applicable for user defined objects which seem like dict, but are not really dicts. For
# example, TensorDict derives from MutableMapping. For such cases, we can directly inline the .items
# method and create a new dict.
- out = tx.inline_user_function_return(
- arg.var_getattr(tx, "items"), args, kwargs
- )
+ func_var = arg.var_getattr(tx, "items")
+ if not isinstance(func_var, variables.UserFunctionVariable):
+ unimplemented(f"{user_cls.__name__}.items(): {args} {kwargs}")
+ out = tx.inline_user_function_return(func_var, args, kwargs)
if isinstance(out, ConstDictVariable):
return out
return BuiltinVariable(user_cls).call_custom_dict(tx, user_cls, out)
diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py
index 24b4140..278e155 100644
--- a/torch/_dynamo/variables/dicts.py
+++ b/torch/_dynamo/variables/dicts.py
@@ -239,6 +239,7 @@
ListIteratorVariable,
ListVariable,
TupleVariable,
+ UserDefinedObjectVariable,
)
Hashable = ConstDictVariable._HashableTracker
@@ -304,6 +305,7 @@
TupleVariable,
ListIteratorVariable,
variables.IteratorVariable,
+ UserDefinedObjectVariable,
),
)
and self.mutable_local