[dynamo][reland][inline-inbuilt-nn-modules] Mark attributes of nn mod… (#133714)
Relands https://github.com/pytorch/pytorch/pull/132539
Relands https://github.com/pytorch/pytorch/pull/132736
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133714
Approved by: https://github.com/jansel
diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py
index a4de786..6e63389 100644
--- a/test/dynamo/test_modules.py
+++ b/test/dynamo/test_modules.py
@@ -2772,6 +2772,49 @@
self.assertEqual(num_compiles, 1)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
+ def test_mark_static_nn_module_tensor(self):
+ # This test verifies that dynamo will mark
+ # the nn module tensor attributes as static
+ num_compiles = 0
+
+ def debug_compiler(gm, _):
+ nonlocal num_compiles
+ num_compiles += 1
+
+ input_nodes = [
+ n
+ for n in gm.graph.nodes
+ if n.op == "placeholder" and n.name == "l_mod_buf"
+ ]
+
+ self.assertGreater(len(input_nodes), 0)
+ for input_node in input_nodes:
+ self.assertEqual(
+ input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
+ "unguarded",
+ )
+
+ return gm
+
+ class TestModule(torch.nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.buf = torch.ones(2, 2)
+
+ def forward(self, x):
+ return self.buf * x
+
+ mod = TestModule()
+
+ @torch._dynamo.optimize(backend=debug_compiler)
+ def fn(x):
+ return x * mod(x)
+
+ inp = torch.ones(2)
+ fn(inp)
+ self.assertEqual(num_compiles, 1)
+
+ @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
@torch._inductor.config.patch("freezing", True)
@torch.no_grad()
def test_mark_static_with_freezing(self):
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index d53e049..a94a0aa 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -5549,6 +5549,64 @@
self.assertTrue(cnt.frame_count <= 2)
+ @torch._dynamo.config.patch(guard_nn_modules=False)
+ @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
+ def test_inlining_cornercase(self):
+ """
+ nn.Modules can be mapped to either NNModuleVariable or UnspecializedNNModuleVariable. For NNModuleVariable, the
+ tensor attributes become part of the Dynamo graph. For unspecialized, they are lifted as inputs.
+
+ But there is a cornercase. Suppose you have NNModuleVariable with a submodule that is
+ UnspecializedNNModuleVariable. Today, Dynamo will still consider the submodule as specialized (courtesy of
+ guard.source().is_nn_module()). In retrospect, this is a mistake but there are dependencies of export and also
+ cudagraphs which make it harder to fix the corner case right away. The long term solution is
+ inline_inbuilt_nn_modules anyways, so we might have to live with this cornercase in the short term.
+
+ We are starting to annotate the source of each nn module more precisely - NNModuleVariable attribute is marked
+ as NNModuleSource, UnspecilaizedNNModuleVariable attribute is marked as UnspecializedNNModuleSource. But this
+ changes the behavior for the cornercase. And fails some tests which have unfortunately relied on this behavior.
+
+
+ To solve this, we tag the source only when inline_inbuilt_nn_module flag is turned on.
+
+ In this test, we purposely turn the flag off, testing that the tagging is disabled.
+ """
+
+ class SubMod(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(1, 1)
+ self.a = torch.randn(1, 1)
+ self.counter = 0
+ self.multipliers = [2.2, 3.3]
+
+ def forward(self, x):
+ self.counter += 1
+ return (
+ self.linear(x) * self.a * self.multipliers[0] * self.multipliers[1]
+ )
+
+ class Mod(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.submod = SubMod()
+
+ def forward(self, x):
+ return self.submod(x)
+
+ mod = Mod()
+ opt_mod = torch.compile(mod, backend="eager")
+
+ x = torch.randn(1, 1)
+ ref = mod(x)
+ res = opt_mod(x)
+
+ mod.submod.multipliers = [3.3, 4.4]
+ # Since guard_nn_modules is False, this will not recompile
+ with torch._dynamo.config.patch(error_on_recompile=True):
+ ref = mod(x)
+ res = opt_mod(x)
+
def test_optimized_module_training(self):
mod = torch.nn.Linear(3, 3)
mod.eval()
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index 8f11f2e..397bb28 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -99,6 +99,7 @@
TypeSource,
UnspecializedBuiltinNNModuleSource,
UnspecializedNNModuleSource,
+ UnspecializedParamBufferSource,
WeakRefCallSource,
)
from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401
@@ -875,7 +876,7 @@
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
- elif istype(source, AttrSource):
+ elif istype(source, (AttrSource, UnspecializedParamBufferSource)):
assert base_guard_manager # to make mypy happy
if (
@@ -1930,7 +1931,7 @@
#
assert guard.source is not None
static, reason = tensor_always_has_static_shape(
- value, is_tensor=True, guard_source=guard.source
+ value, is_tensor=True, tensor_source=guard.originating_source
)
if not static:
diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py
index 725aecf..a2850de 100644
--- a/torch/_dynamo/source.py
+++ b/torch/_dynamo/source.py
@@ -236,6 +236,12 @@
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
+# Special AttrSource to differentiate module._buffers or module._parameters
+@dataclasses.dataclass(frozen=True)
+class UnspecializedParamBufferSource(AttrSource):
+ pass
+
+
# This source is intended to be used in places where a source is needed but it is expected
# that the symbol will be simplified out later on. Symbols with ephemeral sources are
# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
@@ -679,6 +685,14 @@
return True
+def is_from_unspecialized_param_buffer_source(source: Source):
+ if isinstance(source, UnspecializedParamBufferSource):
+ return True
+ if isinstance(source, ChainedSource):
+ return is_from_unspecialized_param_buffer_source(source.base)
+ return False
+
+
def is_from_flatten_script_object_source(source: Source):
if isinstance(source, FlattenScriptObjectSource):
return True
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index daf1979..9bac71a 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -62,7 +62,7 @@
import torch.utils._pytree as pytree
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
-from torch._guards import TracingContext
+from torch._guards import Source, TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import log_compilation_event
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
@@ -2248,7 +2248,7 @@
def tensor_always_has_static_shape(
tensor: Union[torch.Tensor, Any],
is_tensor: bool,
- guard_source: torch._guards.GuardSource,
+ tensor_source: Source,
) -> Tuple[bool, Optional[TensorStaticReason]]:
"""
Given a tensor, source, and is_tensor flag, determine if a shape should be static.
@@ -2261,12 +2261,20 @@
Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape.
The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed.
"""
+ from .source import is_from_unspecialized_param_buffer_source
+
if (
- guard_source.is_specialized_nn_module()
- and config.force_nn_module_property_static_shapes
- ):
+ tensor_source.guard_source().is_specialized_nn_module()
+ # Marking the tensor attributes of nn modules static to keep the behavior same as before
+ # inline_inbuilt_nn_module flag was introduced.
+ or tensor_source.guard_source().is_unspecialized_nn_module()
+ ) and config.force_nn_module_property_static_shapes:
return True, TensorStaticReason.NN_MODULE_PROPERTY
- if type(tensor) is torch.nn.Parameter and config.force_parameter_static_shapes:
+
+ if (
+ type(tensor) is torch.nn.Parameter
+ or is_from_unspecialized_param_buffer_source(tensor_source)
+ ) and config.force_parameter_static_shapes:
return True, TensorStaticReason.PARAMETER
if not is_tensor:
return True, TensorStaticReason.NOT_TENSOR
diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py
index 08a0cfa..d8f9714 100644
--- a/torch/_dynamo/variables/__init__.py
+++ b/torch/_dynamo/variables/__init__.py
@@ -82,7 +82,11 @@
TypingVariable,
UnknownVariable,
)
-from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
+from .nn_module import (
+ NNModuleVariable,
+ UnspecializedBuiltinNNModuleVariable,
+ UnspecializedNNModuleVariable,
+)
from .optimizer import OptimizerVariable
from .sdpa import SDPAParamsVariable
from .tensor import (
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 72263d5..c18285c 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -170,7 +170,11 @@
TorchVersionVariable,
TypingVariable,
)
-from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
+from .nn_module import (
+ FSDPManagedNNModuleVariable,
+ UnspecializedBuiltinNNModuleVariable,
+ UnspecializedNNModuleVariable,
+)
from .optimizer import OptimizerVariable
from .script_object import TorchScriptObjectVariable
from .sdpa import SDPAParamsVariable
@@ -1312,7 +1316,11 @@
# this will get cleaned up once compile ends
self.tx.output.nn_modules[self.name] = value
- result = UnspecializedNNModuleVariable(value, source=self.source)
+ if value.__module__.startswith(("torch.nn.", "torch.ao.")):
+ result = UnspecializedBuiltinNNModuleVariable(value, source=self.source)
+ else:
+ result = UnspecializedNNModuleVariable(value, source=self.source)
+
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__
return result
@@ -1341,6 +1349,7 @@
# specialized (as we don't expect users to be changing the
# NN modules on the fly)
or self.source.guard_source().is_specialized_nn_module()
+ or self.source.guard_source().is_unspecialized_builtin_nn_module()
or is_from_defaults(self.source)
or is_cell_contents(self.source)
# TODO: Delete this condition when rollout is done. NB: this
@@ -1381,7 +1390,12 @@
if (
config.inline_inbuilt_nn_modules
and not is_static_input
- and isinstance(value, torch.nn.Parameter)
+ and (
+ isinstance(value, torch.nn.Parameter)
+ # mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior
+ # compatible with previous behavior.
+ or (source and source.guard_source().is_unspecialized_nn_module())
+ )
):
self.mark_static_input(value, guard=False)
@@ -2574,7 +2588,9 @@
):
assert source is not None
static_shapes, reason = tensor_always_has_static_shape(
- e, is_tensor, guard_source=source.guard_source()
+ e,
+ is_tensor,
+ tensor_source=source,
)
if not parent_context:
diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py
index 4a484b2..3bd845d 100644
--- a/torch/_dynamo/variables/nn_module.py
+++ b/torch/_dynamo/variables/nn_module.py
@@ -5,7 +5,7 @@
import itertools
import types
from contextlib import contextmanager, nullcontext
-from typing import Any, Dict, List, TYPE_CHECKING
+from typing import Dict, List, TYPE_CHECKING
import torch.nn
@@ -24,6 +24,7 @@
FSDPNNModuleSource,
GetItemSource,
NNModuleSource,
+ UnspecializedBuiltinNNModuleSource,
UnspecializedNNModuleSource,
)
from ..utils import (
@@ -800,6 +801,11 @@
# nn_module_stack_source appropriately to resemble mod.linear.
self.nn_module_stack_source = self.source
+ def _wrap_source(self, attr_source):
+ if not isinstance(attr_source, UnspecializedNNModuleSource):
+ return UnspecializedNNModuleSource(attr_source)
+ return attr_source
+
def get_nn_module_stack_source(self):
return self.nn_module_stack_source or self.source
@@ -1131,6 +1137,17 @@
return out
+class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
+ """
+ Differentiates between builtin nn modules (e.g. torch.nn.Linear) and user defined nn modules.
+ """
+
+ def _wrap_source(self, attr_source):
+ if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource):
+ return UnspecializedBuiltinNNModuleSource(attr_source)
+ return attr_source
+
+
class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
"""
Tracing behavior: trace into submodules and treat them as Unspecialized, do not
@@ -1152,19 +1169,12 @@
super().__init__(value=value, **kwargs)
self.source = source
- @staticmethod
- def _wrap_source(source):
- if not isinstance(source, (FSDPNNModuleSource, UnspecializedNNModuleSource)):
+ def _wrap_source(self, attr_source):
+ if not isinstance(
+ attr_source, (FSDPNNModuleSource, UnspecializedNNModuleSource)
+ ):
if torch._dynamo.config.skip_fsdp_guards:
- return FSDPNNModuleSource(source)
+ return FSDPNNModuleSource(attr_source)
else:
- # this makes us behave like a usual UnspecializedNNModuleVariable for guarding purposes
- return UnspecializedNNModuleSource(source)
- else:
- return source
-
- def __setattr__(self, name: str, value: Any) -> None:
- if name == "source":
- value = FSDPManagedNNModuleVariable._wrap_source(value)
-
- return super().__setattr__(name, value)
+ return UnspecializedNNModuleSource(attr_source)
+ return attr_source
diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py
index 2432aa5..42ba358 100644
--- a/torch/_dynamo/variables/user_defined.py
+++ b/torch/_dynamo/variables/user_defined.py
@@ -27,6 +27,7 @@
GetItemSource,
ODictGetItemSource,
RandomValueSource,
+ UnspecializedParamBufferSource,
WeakRefCallSource,
)
from ..utils import (
@@ -1051,6 +1052,19 @@
else:
return trace_rules.lookup(func)(func)
+ if (
+ torch._dynamo.config.inline_inbuilt_nn_modules
+ and source
+ and isinstance(self, variables.UnspecializedNNModuleVariable)
+ # export has some awkwardness around specialized and unspecialized modules. Skip wrapping source for export
+ # usecase for now.
+ and not tx.output.export
+ ):
+ # Recalculate source for params/buffers
+ if name in ("_buffers", "_parameters"):
+ source = UnspecializedParamBufferSource(self.source, name)
+ source = self._wrap_source(source)
+
if subobj is not NO_SUCH_SUBOBJ:
if is_wrapper_or_member_descriptor(subobj):
options = {"source": source}
diff --git a/torch/_guards.py b/torch/_guards.py
index d4d1d6b..ccc0773 100644
--- a/torch/_guards.py
+++ b/torch/_guards.py
@@ -129,6 +129,7 @@
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
GuardSource.LOCAL_FSDP_MODULE,
GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
+ GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
)