[dynamo] Support dict unpack of MutableMapping objects (#131961)

Fixes https://github.com/pytorch/pytorch/issues/128067

The basic functionality was alredy introduced earlier. This just ensures
that we support UserDefinedObjectVariable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131961
Approved by: https://github.com/williamwen42, https://github.com/mlazos, https://github.com/yanboliang
ghstack dependencies: #131827, #131956
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index cef416b..ebc76bb 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -1059,6 +1059,51 @@
         res = opt_fn(x)
         self.assertEqual(ref, res)
 
+    def test_unpack_mutable_map(self):
+        from collections.abc import MutableMapping
+
+        class TensorDict(MutableMapping):
+            def __init__(self):
+                self._dict = {}
+
+            def add(self, key, value):
+                self._dict[key] = value
+
+            def items(self):
+                return self._dict.items()
+
+            def __delitem__(self, key):
+                del self._dict[key]
+
+            def __getitem__(self, key):
+                return self._dict[key]
+
+            def __iter__(self):
+                return iter(self._dict)
+
+            def __len__(self):
+                return len(self._dict)
+
+            def __setitem__(self, key, value):
+                self._dict[key] = value
+
+        tensor_dict = TensorDict()
+        tensor_dict.add("a", torch.ones(4) * 2)
+
+        def gn(x, a=1):
+            return x * a
+
+        def fn(x):
+            return gn(x, **tensor_dict)
+
+        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
+
+        x = torch.randn(4)
+
+        ref = fn(x)
+        res = opt_fn(x)
+        self.assertEqual(ref, res)
+
     def _test_default_dict_helper(self, factory):
         dd = collections.defaultdict(factory)
         param = torch.nn.Parameter(torch.ones([2, 2]))
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index c1bc2ae..4cc8f01 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -1590,6 +1590,10 @@
         ) and argsvars.has_force_unpack_var_sequence(self):
             argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
 
+        # Unpack for cases like fn(**obj) where obj is a map
+        if isinstance(kwargsvars, UserDefinedObjectVariable):
+            kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars)  # type: ignore[arg-type]
+
         if not isinstance(argsvars, BaseListVariable) or not isinstance(
             kwargsvars, ConstDictVariable
         ):
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index 30dd19a..1b2af4d 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -1337,9 +1337,10 @@
                 # This is applicable for user defined objects which seem like dict, but are not really dicts. For
                 # example, TensorDict derives from MutableMapping. For such cases, we can directly inline the .items
                 # method and create a new dict.
-                out = tx.inline_user_function_return(
-                    arg.var_getattr(tx, "items"), args, kwargs
-                )
+                func_var = arg.var_getattr(tx, "items")
+                if not isinstance(func_var, variables.UserFunctionVariable):
+                    unimplemented(f"{user_cls.__name__}.items(): {args} {kwargs}")
+                out = tx.inline_user_function_return(func_var, args, kwargs)
                 if isinstance(out, ConstDictVariable):
                     return out
                 return BuiltinVariable(user_cls).call_custom_dict(tx, user_cls, out)
diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py
index 24b4140..278e155 100644
--- a/torch/_dynamo/variables/dicts.py
+++ b/torch/_dynamo/variables/dicts.py
@@ -239,6 +239,7 @@
             ListIteratorVariable,
             ListVariable,
             TupleVariable,
+            UserDefinedObjectVariable,
         )
 
         Hashable = ConstDictVariable._HashableTracker
@@ -304,6 +305,7 @@
                     TupleVariable,
                     ListIteratorVariable,
                     variables.IteratorVariable,
+                    UserDefinedObjectVariable,
                 ),
             )
             and self.mutable_local