Remove special handling of step with closure (#123620)

Implements https://github.com/pytorch/pytorch/issues/123479

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123620
Approved by: https://github.com/anijain2305
ghstack dependencies: #123496, #123497, #123551, #123552, #123618
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 51974d3..4f52367 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -2151,8 +2151,6 @@
                 if k in f_locals
             }
 
-            self._throw_if_unsupported_optimizer_step()
-
             self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
             if export:
                 # export gets confused if we never realize unused inputs
@@ -2166,13 +2164,6 @@
                 if name in f_locals:
                     self._freevars_ids[name] = id(f_locals[name])
 
-    def _throw_if_unsupported_optimizer_step(self):
-        from .variables import OptimizerVariable
-
-        OptimizerVariable.throw_if_unsupported_step(
-            self.symbolic_locals, self.code_options["co_name"]
-        )
-
     def _throw_if_in_functorch(self):
         # Fallback to eager in case of a graph break inside vmap
         eager = torch._dynamo.lookup_backend("eager")
diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py
index d925188..62120b1 100644
--- a/torch/_dynamo/variables/optimizer.py
+++ b/torch/_dynamo/variables/optimizer.py
@@ -6,8 +6,6 @@
 import torch
 from torch.utils._pytree import tree_map_only
 
-from ..exc import unimplemented, Unsupported
-
 from ..guards import GuardBuilder, install_guard
 from ..source import (
     AttrSource,
@@ -42,23 +40,6 @@
         *UserDefinedObjectVariable._nonvar_fields,
     }
 
-    @classmethod
-    def throw_if_unsupported_step(cls, symbolic_locals, f_name):
-        """
-        We don't support calling the step with closure argument, so graph break if
-        if that's the case.
-        """
-        if (
-            "closure" in symbolic_locals
-            and not isinstance(symbolic_locals["closure"], ConstantVariable)
-            and "self" in symbolic_locals
-            and isinstance(symbolic_locals["self"], OptimizerVariable)
-            and f_name == "step"
-        ):
-            unimplemented(
-                "Optimizer step with closure not supported by torch.compile()"
-            )
-
     def __init__(
         self,
         value,
@@ -89,6 +70,7 @@
         """This is an optimization to avoid tracing the very slow initialization of the optimizer"""
         if name == "_init_group":
             try:
+                self.graph_break_if_pending_mutation(tx)
                 self.move_step_if_cpu()
                 py_args, py_kwargs = self.get_python_args(*args, **kwargs)
                 ret_val = self.value._init_group(*py_args, **py_kwargs)
@@ -109,17 +91,6 @@
                 # trace normally if we can't map args or install guards correctly
                 pass
 
-        if name == "step":
-            if (
-                "closure" in kwargs
-                and not isinstance(kwargs["closure"], ConstantVariable)
-                or len(args) == 1
-                and not isinstance(args[0], ConstantVariable)
-            ):
-                raise Unsupported(
-                    "Optimizer step with closure not supported by torch.compile()"
-                )
-
         return super().call_method(tx, name, args, kwargs)
 
     def var_getattr(self, tx, name):
@@ -134,6 +105,21 @@
 
         return super().var_getattr(tx, name)
 
+    def graph_break_if_pending_mutation(self, tx):
+        # If there are pending mutations on a parameter (due to using closure)
+        # then we need to graph break to allow the python version of the parameter
+        # to update, so that running _init_group will initialize the states with
+        # the correct values
+        for g in self.value.param_groups:
+            for p in g["params"]:
+                side_effects = tx.output.side_effects
+                if side_effects.has_pending_mutation(
+                    side_effects.id_to_variable.get(id(p), None)
+                ):
+                    from ..exc import Unsupported
+
+                    raise Unsupported("Pending mutation on parameter")
+
     def _set_capturable(self, tx):
         from . import LazyVariableTracker
         from .builder import VariableBuilder