Remove special handling of step with closure (#123620)
Implements https://github.com/pytorch/pytorch/issues/123479
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123620
Approved by: https://github.com/anijain2305
ghstack dependencies: #123496, #123497, #123551, #123552, #123618
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 51974d3..4f52367 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -2151,8 +2151,6 @@
if k in f_locals
}
- self._throw_if_unsupported_optimizer_step()
-
self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
if export:
# export gets confused if we never realize unused inputs
@@ -2166,13 +2164,6 @@
if name in f_locals:
self._freevars_ids[name] = id(f_locals[name])
- def _throw_if_unsupported_optimizer_step(self):
- from .variables import OptimizerVariable
-
- OptimizerVariable.throw_if_unsupported_step(
- self.symbolic_locals, self.code_options["co_name"]
- )
-
def _throw_if_in_functorch(self):
# Fallback to eager in case of a graph break inside vmap
eager = torch._dynamo.lookup_backend("eager")
diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py
index d925188..62120b1 100644
--- a/torch/_dynamo/variables/optimizer.py
+++ b/torch/_dynamo/variables/optimizer.py
@@ -6,8 +6,6 @@
import torch
from torch.utils._pytree import tree_map_only
-from ..exc import unimplemented, Unsupported
-
from ..guards import GuardBuilder, install_guard
from ..source import (
AttrSource,
@@ -42,23 +40,6 @@
*UserDefinedObjectVariable._nonvar_fields,
}
- @classmethod
- def throw_if_unsupported_step(cls, symbolic_locals, f_name):
- """
- We don't support calling the step with closure argument, so graph break if
- if that's the case.
- """
- if (
- "closure" in symbolic_locals
- and not isinstance(symbolic_locals["closure"], ConstantVariable)
- and "self" in symbolic_locals
- and isinstance(symbolic_locals["self"], OptimizerVariable)
- and f_name == "step"
- ):
- unimplemented(
- "Optimizer step with closure not supported by torch.compile()"
- )
-
def __init__(
self,
value,
@@ -89,6 +70,7 @@
"""This is an optimization to avoid tracing the very slow initialization of the optimizer"""
if name == "_init_group":
try:
+ self.graph_break_if_pending_mutation(tx)
self.move_step_if_cpu()
py_args, py_kwargs = self.get_python_args(*args, **kwargs)
ret_val = self.value._init_group(*py_args, **py_kwargs)
@@ -109,17 +91,6 @@
# trace normally if we can't map args or install guards correctly
pass
- if name == "step":
- if (
- "closure" in kwargs
- and not isinstance(kwargs["closure"], ConstantVariable)
- or len(args) == 1
- and not isinstance(args[0], ConstantVariable)
- ):
- raise Unsupported(
- "Optimizer step with closure not supported by torch.compile()"
- )
-
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
@@ -134,6 +105,21 @@
return super().var_getattr(tx, name)
+ def graph_break_if_pending_mutation(self, tx):
+ # If there are pending mutations on a parameter (due to using closure)
+ # then we need to graph break to allow the python version of the parameter
+ # to update, so that running _init_group will initialize the states with
+ # the correct values
+ for g in self.value.param_groups:
+ for p in g["params"]:
+ side_effects = tx.output.side_effects
+ if side_effects.has_pending_mutation(
+ side_effects.id_to_variable.get(id(p), None)
+ ):
+ from ..exc import Unsupported
+
+ raise Unsupported("Pending mutation on parameter")
+
def _set_capturable(self, tx):
from . import LazyVariableTracker
from .builder import VariableBuilder