Strengthen partially supported invariant of base for chained sources (#103445)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103445
Approved by: https://github.com/ezyang
diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py
index cb372cc..0a748e7 100644
--- a/torch/_dynamo/source.py
+++ b/torch/_dynamo/source.py
@@ -3,7 +3,7 @@
import enum
from typing import Any, Optional, Union
-from torch._guards import GuardSource, Source
+from torch._guards import ChainedSource, GuardSource, Source
from . import utils
from .bytecode_transformation import create_call_function, create_instruction
@@ -158,8 +158,7 @@
@dataclasses.dataclass(frozen=True)
-class AttrSource(Source):
- base: Source
+class AttrSource(ChainedSource):
member: str
def __post_init__(self):
@@ -204,8 +203,7 @@
@dataclasses.dataclass(frozen=True)
-class TensorPropertySource(Source):
- base: Source
+class TensorPropertySource(ChainedSource):
prop: TensorProperty
idx: Optional[int] = None # None for STORAGE_OFFSET
@@ -244,9 +242,7 @@
@dataclasses.dataclass(frozen=True)
-class NegateSource(Source):
- base: Source
-
+class NegateSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
@@ -262,8 +258,7 @@
@dataclasses.dataclass(frozen=True)
-class DefaultsSource(Source):
- base: Source
+class DefaultsSource(ChainedSource):
idx_key: Union[int, str]
is_kw: bool = False
field: str = dataclasses.field(init=False, repr=False, compare=False)
@@ -305,8 +300,7 @@
@dataclasses.dataclass(frozen=True)
-class GetItemSource(Source):
- base: Source
+class GetItemSource(ChainedSource):
index: Any
index_is_slice: bool = False
@@ -358,9 +352,7 @@
@dataclasses.dataclass(frozen=True)
-class TypeSource(Source):
- base: Source
-
+class TypeSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
@@ -375,33 +367,38 @@
return f"type({self.base.name()})"
+# NB - SuperSource is a weird one.
+# it is our only source with 2 bases, so we use the objec
+# as the base, rather than the type, since an invocation
+# like super(Foo, foo) is represented here, the source object base is more spiritually
+# aligned with the instance, rather than the type.
+# This whole construction is questionable tho, and we should probably find a way to
+# avoid this exception to our otherwise nice source parentage invariant.
@dataclasses.dataclass(frozen=True)
-class SuperSource(Source):
+class SuperSource(ChainedSource):
type: Source
- obj: Source
def __post_init__(self):
assert self.type is not None
- assert self.obj is not None
+ assert self.base is not None
def reconstruct(self, codegen):
codegen.load_import_from("builtins", "super")
return (
self.type.reconstruct(codegen)
- + self.obj.reconstruct(codegen)
+ + self.base.reconstruct(codegen)
+ create_call_function(2, True)
)
def guard_source(self):
- return self.obj.guard_source()
+ return self.base.guard_source()
def name(self):
- return f"super({self.type.name()}, {self.obj.name()})"
+ return f"super({self.type.name()}, {self.base.name()})"
@dataclasses.dataclass(frozen=True)
-class ODictGetItemSource(Source):
- base: Source
+class ODictGetItemSource(ChainedSource):
index: Any
def __post_init__(self):
@@ -428,29 +425,27 @@
@dataclasses.dataclass(frozen=True)
-class NNModuleSource(Source):
- inner: Source
-
+class NNModuleSource(ChainedSource):
def reconstruct(self, codegen):
- return self.inner.reconstruct(codegen)
+ return self.base.reconstruct(codegen)
def guard_source(self):
- return _GUARD_SOURCE_NN_MODULE[self.inner.guard_source()]
+ return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
def name(self):
- return self.inner.name()
+ return self.base.name()
@dataclasses.dataclass(frozen=True)
class NotNNModuleSource(NNModuleSource):
def guard_source(self):
- return _GUARD_SOURCE_NOT_NN_MODULE[self.inner.guard_source()]
+ return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True)
class FSDPNNModuleSource(NNModuleSource):
def guard_source(self):
- return _GUARD_SOURCE_FSDP_MODULE[self.inner.guard_source()]
+ return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True)
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index f70e52c..cda2ad7 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -916,7 +916,7 @@
source = (
None
if a.source is None or b.source is None
- else SuperSource(a.source, b.source)
+ else SuperSource(type=a.source, base=b.source)
)
return variables.SuperVariable(a, b, source=source)
diff --git a/torch/_guards.py b/torch/_guards.py
index dd85192..a5a4a3c 100644
--- a/torch/_guards.py
+++ b/torch/_guards.py
@@ -541,6 +541,7 @@
# Subclasses can be found in torch/_dynamo/source.py
+# TODO(voz): Consider a toplevel torch/_source.py
@dataclasses.dataclass(frozen=True)
class Source:
def reconstruct(self, codegen):
@@ -561,6 +562,14 @@
return self.guard_source().is_nn_module()
+# Subclasses can be found in torch/_dynamo/source.py
+# Note - there is an odd exception to this invariant of a single base,
+# see class SuperSource
+@dataclasses.dataclass(frozen=True)
+class ChainedSource(Source):
+ base: Source
+
+
def detect_fake_mode(inputs: Any = None):
"""
Attempts to "detect" what the current fake mode is. If there is one ambiently