[dynamo] GetItemSource - restrict the supported index Source to be GlobalWeakRefSource (#117138)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117138
Approved by: https://github.com/jansel
diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py
index e4f461a..e9756d8 100644
--- a/torch/_dynamo/source.py
+++ b/torch/_dynamo/source.py
@@ -325,15 +325,23 @@
return slice_class(*slice_args)
def name(self):
+ # Index can be of following types
+ # 1) GlobalWeakRefSource - for parameters
+ # 2) enum.Enum
+ # 3) index is a slice - example 1:4
+ # 4) index is a constant - example string, integer
if isinstance(self.index, Source):
+ if not isinstance(self.index, GlobalWeakRefSource):
+ raise ValueError(
+ "GetItemSource index must be a constant,enum or a GlobalWeakRefSource"
+ )
return f"{self.base.name()}[{self.index.name()}]"
+ elif self.index_is_slice:
+ return f"{self.base.name()}[{self.unpack_slice()!r}]"
+ elif isinstance(self.index, enum.Enum):
+ return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]"
else:
- if self.index_is_slice:
- return f"{self.base.name()}[{self.unpack_slice()!r}]"
- elif isinstance(self.index, enum.Enum):
- return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]"
- else:
- return f"{self.base.name()}[{self.index!r}]"
+ return f"{self.base.name()}[{self.index!r}]"
@dataclasses.dataclass(frozen=True)
diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py
index 0374335..5cf720d 100644
--- a/torch/_dynamo/variables/optimizer.py
+++ b/torch/_dynamo/variables/optimizer.py
@@ -130,7 +130,9 @@
install_guard(state_source.make_guard(GuardBuilder.DICT_KEYS))
for p, value in self.value.state.items():
tx.store_global_weakref(global_key_name(p), p)
- p_state_source = GetItemSource(state_source, self.tensor_to_source[p])
+ p_state_source = GetItemSource(
+ state_source, GlobalWeakRefSource(global_key_name(p))
+ )
install_guard(p_state_source.make_guard(GuardBuilder.DICT_KEYS))
for k, v in value.items():
if (