Strengthen partially supported invariant of base for chained sources (#103445)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103445
Approved by: https://github.com/ezyang
diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py
index cb372cc..0a748e7 100644
--- a/torch/_dynamo/source.py
+++ b/torch/_dynamo/source.py
@@ -3,7 +3,7 @@
 import enum
 from typing import Any, Optional, Union
 
-from torch._guards import GuardSource, Source
+from torch._guards import ChainedSource, GuardSource, Source
 
 from . import utils
 from .bytecode_transformation import create_call_function, create_instruction
@@ -158,8 +158,7 @@
 
 
 @dataclasses.dataclass(frozen=True)
-class AttrSource(Source):
-    base: Source
+class AttrSource(ChainedSource):
     member: str
 
     def __post_init__(self):
@@ -204,8 +203,7 @@
 
 
 @dataclasses.dataclass(frozen=True)
-class TensorPropertySource(Source):
-    base: Source
+class TensorPropertySource(ChainedSource):
     prop: TensorProperty
     idx: Optional[int] = None  # None for STORAGE_OFFSET
 
@@ -244,9 +242,7 @@
 
 
 @dataclasses.dataclass(frozen=True)
-class NegateSource(Source):
-    base: Source
-
+class NegateSource(ChainedSource):
     def __post_init__(self):
         assert self.base is not None
 
@@ -262,8 +258,7 @@
 
 
 @dataclasses.dataclass(frozen=True)
-class DefaultsSource(Source):
-    base: Source
+class DefaultsSource(ChainedSource):
     idx_key: Union[int, str]
     is_kw: bool = False
     field: str = dataclasses.field(init=False, repr=False, compare=False)
@@ -305,8 +300,7 @@
 
 
 @dataclasses.dataclass(frozen=True)
-class GetItemSource(Source):
-    base: Source
+class GetItemSource(ChainedSource):
     index: Any
     index_is_slice: bool = False
 
@@ -358,9 +352,7 @@
 
 
 @dataclasses.dataclass(frozen=True)
-class TypeSource(Source):
-    base: Source
-
+class TypeSource(ChainedSource):
     def __post_init__(self):
         assert self.base is not None
 
@@ -375,33 +367,38 @@
         return f"type({self.base.name()})"
 
 
+# NB - SuperSource is a weird one.
+# it is our only source with 2 bases, so we use the objec
+# as the base, rather than the type, since an invocation
+# like super(Foo, foo) is represented here, the source object base is more spiritually
+# aligned with the instance, rather than the type.
+# This whole construction is questionable tho, and we should probably find a way to
+# avoid this exception to our otherwise nice source parentage invariant.
 @dataclasses.dataclass(frozen=True)
-class SuperSource(Source):
+class SuperSource(ChainedSource):
     type: Source
-    obj: Source
 
     def __post_init__(self):
         assert self.type is not None
-        assert self.obj is not None
+        assert self.base is not None
 
     def reconstruct(self, codegen):
         codegen.load_import_from("builtins", "super")
         return (
             self.type.reconstruct(codegen)
-            + self.obj.reconstruct(codegen)
+            + self.base.reconstruct(codegen)
             + create_call_function(2, True)
         )
 
     def guard_source(self):
-        return self.obj.guard_source()
+        return self.base.guard_source()
 
     def name(self):
-        return f"super({self.type.name()}, {self.obj.name()})"
+        return f"super({self.type.name()}, {self.base.name()})"
 
 
 @dataclasses.dataclass(frozen=True)
-class ODictGetItemSource(Source):
-    base: Source
+class ODictGetItemSource(ChainedSource):
     index: Any
 
     def __post_init__(self):
@@ -428,29 +425,27 @@
 
 
 @dataclasses.dataclass(frozen=True)
-class NNModuleSource(Source):
-    inner: Source
-
+class NNModuleSource(ChainedSource):
     def reconstruct(self, codegen):
-        return self.inner.reconstruct(codegen)
+        return self.base.reconstruct(codegen)
 
     def guard_source(self):
-        return _GUARD_SOURCE_NN_MODULE[self.inner.guard_source()]
+        return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
 
     def name(self):
-        return self.inner.name()
+        return self.base.name()
 
 
 @dataclasses.dataclass(frozen=True)
 class NotNNModuleSource(NNModuleSource):
     def guard_source(self):
-        return _GUARD_SOURCE_NOT_NN_MODULE[self.inner.guard_source()]
+        return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()]
 
 
 @dataclasses.dataclass(frozen=True)
 class FSDPNNModuleSource(NNModuleSource):
     def guard_source(self):
-        return _GUARD_SOURCE_FSDP_MODULE[self.inner.guard_source()]
+        return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
 
 
 @dataclasses.dataclass(frozen=True)
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index f70e52c..cda2ad7 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -916,7 +916,7 @@
         source = (
             None
             if a.source is None or b.source is None
-            else SuperSource(a.source, b.source)
+            else SuperSource(type=a.source, base=b.source)
         )
         return variables.SuperVariable(a, b, source=source)
 
diff --git a/torch/_guards.py b/torch/_guards.py
index dd85192..a5a4a3c 100644
--- a/torch/_guards.py
+++ b/torch/_guards.py
@@ -541,6 +541,7 @@
 
 
 # Subclasses can be found in torch/_dynamo/source.py
+# TODO(voz): Consider a toplevel torch/_source.py
 @dataclasses.dataclass(frozen=True)
 class Source:
     def reconstruct(self, codegen):
@@ -561,6 +562,14 @@
         return self.guard_source().is_nn_module()
 
 
+# Subclasses can be found in torch/_dynamo/source.py
+# Note - there is an odd exception to this invariant of a single base,
+# see class SuperSource
+@dataclasses.dataclass(frozen=True)
+class ChainedSource(Source):
+    base: Source
+
+
 def detect_fake_mode(inputs: Any = None):
     """
     Attempts to "detect" what the current fake mode is.  If there is one ambiently