[dynamo][reland][inline-inbuilt-nn-modules] Mark attributes of nn mod… (#133714)

Relands https://github.com/pytorch/pytorch/pull/132539
Relands https://github.com/pytorch/pytorch/pull/132736

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133714
Approved by: https://github.com/jansel
diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py
index a4de786..6e63389 100644
--- a/test/dynamo/test_modules.py
+++ b/test/dynamo/test_modules.py
@@ -2772,6 +2772,49 @@
         self.assertEqual(num_compiles, 1)
 
     @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
+    def test_mark_static_nn_module_tensor(self):
+        # This test verifies that dynamo will mark
+        # the nn module tensor attributes as static
+        num_compiles = 0
+
+        def debug_compiler(gm, _):
+            nonlocal num_compiles
+            num_compiles += 1
+
+            input_nodes = [
+                n
+                for n in gm.graph.nodes
+                if n.op == "placeholder" and n.name == "l_mod_buf"
+            ]
+
+            self.assertGreater(len(input_nodes), 0)
+            for input_node in input_nodes:
+                self.assertEqual(
+                    input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
+                    "unguarded",
+                )
+
+            return gm
+
+        class TestModule(torch.nn.Module):
+            def __init__(self) -> None:
+                super().__init__()
+                self.buf = torch.ones(2, 2)
+
+            def forward(self, x):
+                return self.buf * x
+
+        mod = TestModule()
+
+        @torch._dynamo.optimize(backend=debug_compiler)
+        def fn(x):
+            return x * mod(x)
+
+        inp = torch.ones(2)
+        fn(inp)
+        self.assertEqual(num_compiles, 1)
+
+    @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
     @torch._inductor.config.patch("freezing", True)
     @torch.no_grad()
     def test_mark_static_with_freezing(self):
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index d53e049..a94a0aa 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -5549,6 +5549,64 @@
 
         self.assertTrue(cnt.frame_count <= 2)
 
+    @torch._dynamo.config.patch(guard_nn_modules=False)
+    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
+    def test_inlining_cornercase(self):
+        """
+        nn.Modules can be mapped to either NNModuleVariable or UnspecializedNNModuleVariable. For NNModuleVariable, the
+        tensor attributes become part of the Dynamo graph. For unspecialized, they are lifted as inputs.
+
+        But there is a cornercase. Suppose you have NNModuleVariable with a submodule that is
+        UnspecializedNNModuleVariable. Today, Dynamo will still consider the submodule as specialized (courtesy of
+        guard.source().is_nn_module()). In retrospect, this is a mistake but there are dependencies of export and also
+        cudagraphs which make it harder to fix the corner case right away. The long term solution is
+        inline_inbuilt_nn_modules anyways, so we might have to live with this cornercase in the short term.
+
+        We are starting to annotate the source of each nn module more precisely - NNModuleVariable attribute is marked
+        as NNModuleSource, UnspecilaizedNNModuleVariable attribute is marked as UnspecializedNNModuleSource. But this
+        changes the behavior for the cornercase. And fails some tests which have unfortunately relied on this behavior.
+
+
+        To solve this, we tag the source only when inline_inbuilt_nn_module flag is turned on.
+
+        In this test, we purposely turn the flag off, testing that the tagging is disabled.
+        """
+
+        class SubMod(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(1, 1)
+                self.a = torch.randn(1, 1)
+                self.counter = 0
+                self.multipliers = [2.2, 3.3]
+
+            def forward(self, x):
+                self.counter += 1
+                return (
+                    self.linear(x) * self.a * self.multipliers[0] * self.multipliers[1]
+                )
+
+        class Mod(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.submod = SubMod()
+
+            def forward(self, x):
+                return self.submod(x)
+
+        mod = Mod()
+        opt_mod = torch.compile(mod, backend="eager")
+
+        x = torch.randn(1, 1)
+        ref = mod(x)
+        res = opt_mod(x)
+
+        mod.submod.multipliers = [3.3, 4.4]
+        # Since guard_nn_modules is False, this will not recompile
+        with torch._dynamo.config.patch(error_on_recompile=True):
+            ref = mod(x)
+            res = opt_mod(x)
+
     def test_optimized_module_training(self):
         mod = torch.nn.Linear(3, 3)
         mod.eval()
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index 8f11f2e..397bb28 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -99,6 +99,7 @@
     TypeSource,
     UnspecializedBuiltinNNModuleSource,
     UnspecializedNNModuleSource,
+    UnspecializedParamBufferSource,
     WeakRefCallSource,
 )
 from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn  # noqa: F401
@@ -875,7 +876,7 @@
                 example_value=example_value,
                 guard_manager_enum=guard_manager_enum,
             )
-        elif istype(source, AttrSource):
+        elif istype(source, (AttrSource, UnspecializedParamBufferSource)):
             assert base_guard_manager  # to make mypy happy
 
             if (
@@ -1930,7 +1931,7 @@
             #
             assert guard.source is not None
             static, reason = tensor_always_has_static_shape(
-                value, is_tensor=True, guard_source=guard.source
+                value, is_tensor=True, tensor_source=guard.originating_source
             )
 
             if not static:
diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py
index 725aecf..a2850de 100644
--- a/torch/_dynamo/source.py
+++ b/torch/_dynamo/source.py
@@ -236,6 +236,12 @@
         return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
 
 
+# Special AttrSource to differentiate module._buffers or module._parameters
+@dataclasses.dataclass(frozen=True)
+class UnspecializedParamBufferSource(AttrSource):
+    pass
+
+
 # This source is intended to be used in places where a source is needed but it is expected
 # that the symbol will be simplified out later on. Symbols with ephemeral sources are
 # prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
@@ -679,6 +685,14 @@
     return True
 
 
+def is_from_unspecialized_param_buffer_source(source: Source):
+    if isinstance(source, UnspecializedParamBufferSource):
+        return True
+    if isinstance(source, ChainedSource):
+        return is_from_unspecialized_param_buffer_source(source.base)
+    return False
+
+
 def is_from_flatten_script_object_source(source: Source):
     if isinstance(source, FlattenScriptObjectSource):
         return True
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index daf1979..9bac71a 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -62,7 +62,7 @@
 import torch.utils._pytree as pytree
 from torch import fx
 from torch._dispatch.python import enable_python_dispatcher
-from torch._guards import TracingContext
+from torch._guards import Source, TracingContext
 from torch._subclasses.meta_utils import is_sparse_compressed
 from torch._utils_internal import log_compilation_event
 from torch.fx._utils import _format_graph_code, lazy_format_graph_code
@@ -2248,7 +2248,7 @@
 def tensor_always_has_static_shape(
     tensor: Union[torch.Tensor, Any],
     is_tensor: bool,
-    guard_source: torch._guards.GuardSource,
+    tensor_source: Source,
 ) -> Tuple[bool, Optional[TensorStaticReason]]:
     """
     Given a tensor, source, and is_tensor flag, determine if a shape should be static.
@@ -2261,12 +2261,20 @@
     Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape.
     The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed.
     """
+    from .source import is_from_unspecialized_param_buffer_source
+
     if (
-        guard_source.is_specialized_nn_module()
-        and config.force_nn_module_property_static_shapes
-    ):
+        tensor_source.guard_source().is_specialized_nn_module()
+        # Marking the tensor attributes of nn modules static to keep the behavior same as before
+        # inline_inbuilt_nn_module flag was introduced.
+        or tensor_source.guard_source().is_unspecialized_nn_module()
+    ) and config.force_nn_module_property_static_shapes:
         return True, TensorStaticReason.NN_MODULE_PROPERTY
-    if type(tensor) is torch.nn.Parameter and config.force_parameter_static_shapes:
+
+    if (
+        type(tensor) is torch.nn.Parameter
+        or is_from_unspecialized_param_buffer_source(tensor_source)
+    ) and config.force_parameter_static_shapes:
         return True, TensorStaticReason.PARAMETER
     if not is_tensor:
         return True, TensorStaticReason.NOT_TENSOR
diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py
index 08a0cfa..d8f9714 100644
--- a/torch/_dynamo/variables/__init__.py
+++ b/torch/_dynamo/variables/__init__.py
@@ -82,7 +82,11 @@
     TypingVariable,
     UnknownVariable,
 )
-from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
+from .nn_module import (
+    NNModuleVariable,
+    UnspecializedBuiltinNNModuleVariable,
+    UnspecializedNNModuleVariable,
+)
 from .optimizer import OptimizerVariable
 from .sdpa import SDPAParamsVariable
 from .tensor import (
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 72263d5..c18285c 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -170,7 +170,11 @@
     TorchVersionVariable,
     TypingVariable,
 )
-from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
+from .nn_module import (
+    FSDPManagedNNModuleVariable,
+    UnspecializedBuiltinNNModuleVariable,
+    UnspecializedNNModuleVariable,
+)
 from .optimizer import OptimizerVariable
 from .script_object import TorchScriptObjectVariable
 from .sdpa import SDPAParamsVariable
@@ -1312,7 +1316,11 @@
                     # this will get cleaned up once compile ends
                     self.tx.output.nn_modules[self.name] = value
 
-            result = UnspecializedNNModuleVariable(value, source=self.source)
+            if value.__module__.startswith(("torch.nn.", "torch.ao.")):
+                result = UnspecializedBuiltinNNModuleVariable(value, source=self.source)
+            else:
+                result = UnspecializedNNModuleVariable(value, source=self.source)
+
             if not SideEffects.cls_supports_mutation_side_effects(type(value)):
                 # don't allow STORE_ATTR mutation with custom __setattr__
                 return result
@@ -1341,6 +1349,7 @@
                 # specialized (as we don't expect users to be changing the
                 # NN modules on the fly)
                 or self.source.guard_source().is_specialized_nn_module()
+                or self.source.guard_source().is_unspecialized_builtin_nn_module()
                 or is_from_defaults(self.source)
                 or is_cell_contents(self.source)
                 # TODO: Delete this condition when rollout is done.  NB: this
@@ -1381,7 +1390,12 @@
         if (
             config.inline_inbuilt_nn_modules
             and not is_static_input
-            and isinstance(value, torch.nn.Parameter)
+            and (
+                isinstance(value, torch.nn.Parameter)
+                # mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior
+                # compatible with previous behavior.
+                or (source and source.guard_source().is_unspecialized_nn_module())
+            )
         ):
             self.mark_static_input(value, guard=False)
 
@@ -2574,7 +2588,9 @@
     ):
         assert source is not None
         static_shapes, reason = tensor_always_has_static_shape(
-            e, is_tensor, guard_source=source.guard_source()
+            e,
+            is_tensor,
+            tensor_source=source,
         )
 
         if not parent_context:
diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py
index 4a484b2..3bd845d 100644
--- a/torch/_dynamo/variables/nn_module.py
+++ b/torch/_dynamo/variables/nn_module.py
@@ -5,7 +5,7 @@
 import itertools
 import types
 from contextlib import contextmanager, nullcontext
-from typing import Any, Dict, List, TYPE_CHECKING
+from typing import Dict, List, TYPE_CHECKING
 
 import torch.nn
 
@@ -24,6 +24,7 @@
     FSDPNNModuleSource,
     GetItemSource,
     NNModuleSource,
+    UnspecializedBuiltinNNModuleSource,
     UnspecializedNNModuleSource,
 )
 from ..utils import (
@@ -800,6 +801,11 @@
         # nn_module_stack_source appropriately to resemble mod.linear.
         self.nn_module_stack_source = self.source
 
+    def _wrap_source(self, attr_source):
+        if not isinstance(attr_source, UnspecializedNNModuleSource):
+            return UnspecializedNNModuleSource(attr_source)
+        return attr_source
+
     def get_nn_module_stack_source(self):
         return self.nn_module_stack_source or self.source
 
@@ -1131,6 +1137,17 @@
         return out
 
 
+class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
+    """
+    Differentiates between builtin nn modules (e.g. torch.nn.Linear) and user defined nn modules.
+    """
+
+    def _wrap_source(self, attr_source):
+        if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource):
+            return UnspecializedBuiltinNNModuleSource(attr_source)
+        return attr_source
+
+
 class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
     """
     Tracing behavior: trace into submodules and treat them as Unspecialized, do not
@@ -1152,19 +1169,12 @@
         super().__init__(value=value, **kwargs)
         self.source = source
 
-    @staticmethod
-    def _wrap_source(source):
-        if not isinstance(source, (FSDPNNModuleSource, UnspecializedNNModuleSource)):
+    def _wrap_source(self, attr_source):
+        if not isinstance(
+            attr_source, (FSDPNNModuleSource, UnspecializedNNModuleSource)
+        ):
             if torch._dynamo.config.skip_fsdp_guards:
-                return FSDPNNModuleSource(source)
+                return FSDPNNModuleSource(attr_source)
             else:
-                # this makes us behave like a usual UnspecializedNNModuleVariable for guarding purposes
-                return UnspecializedNNModuleSource(source)
-        else:
-            return source
-
-    def __setattr__(self, name: str, value: Any) -> None:
-        if name == "source":
-            value = FSDPManagedNNModuleVariable._wrap_source(value)
-
-        return super().__setattr__(name, value)
+                return UnspecializedNNModuleSource(attr_source)
+        return attr_source
diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py
index 2432aa5..42ba358 100644
--- a/torch/_dynamo/variables/user_defined.py
+++ b/torch/_dynamo/variables/user_defined.py
@@ -27,6 +27,7 @@
     GetItemSource,
     ODictGetItemSource,
     RandomValueSource,
+    UnspecializedParamBufferSource,
     WeakRefCallSource,
 )
 from ..utils import (
@@ -1051,6 +1052,19 @@
                 else:
                     return trace_rules.lookup(func)(func)
 
+        if (
+            torch._dynamo.config.inline_inbuilt_nn_modules
+            and source
+            and isinstance(self, variables.UnspecializedNNModuleVariable)
+            # export has some awkwardness around specialized and unspecialized modules. Skip wrapping source for export
+            # usecase for now.
+            and not tx.output.export
+        ):
+            # Recalculate source for params/buffers
+            if name in ("_buffers", "_parameters"):
+                source = UnspecializedParamBufferSource(self.source, name)
+            source = self._wrap_source(source)
+
         if subobj is not NO_SUCH_SUBOBJ:
             if is_wrapper_or_member_descriptor(subobj):
                 options = {"source": source}
diff --git a/torch/_guards.py b/torch/_guards.py
index d4d1d6b..ccc0773 100644
--- a/torch/_guards.py
+++ b/torch/_guards.py
@@ -129,6 +129,7 @@
             GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
             GuardSource.LOCAL_FSDP_MODULE,
             GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
+            GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
         )