[dynamo] GetItemSource - restrict the supported index Source to be GlobalWeakRefSource (#117138)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117138
Approved by: https://github.com/jansel
diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py
index e4f461a..e9756d8 100644
--- a/torch/_dynamo/source.py
+++ b/torch/_dynamo/source.py
@@ -325,15 +325,23 @@
         return slice_class(*slice_args)
 
     def name(self):
+        # Index can be of following types
+        # 1) GlobalWeakRefSource - for parameters
+        # 2) enum.Enum
+        # 3) index is a slice - example 1:4
+        # 4) index is a constant - example string, integer
         if isinstance(self.index, Source):
+            if not isinstance(self.index, GlobalWeakRefSource):
+                raise ValueError(
+                    "GetItemSource index must be a constant,enum or a GlobalWeakRefSource"
+                )
             return f"{self.base.name()}[{self.index.name()}]"
+        elif self.index_is_slice:
+            return f"{self.base.name()}[{self.unpack_slice()!r}]"
+        elif isinstance(self.index, enum.Enum):
+            return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]"
         else:
-            if self.index_is_slice:
-                return f"{self.base.name()}[{self.unpack_slice()!r}]"
-            elif isinstance(self.index, enum.Enum):
-                return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]"
-            else:
-                return f"{self.base.name()}[{self.index!r}]"
+            return f"{self.base.name()}[{self.index!r}]"
 
 
 @dataclasses.dataclass(frozen=True)
diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py
index 0374335..5cf720d 100644
--- a/torch/_dynamo/variables/optimizer.py
+++ b/torch/_dynamo/variables/optimizer.py
@@ -130,7 +130,9 @@
         install_guard(state_source.make_guard(GuardBuilder.DICT_KEYS))
         for p, value in self.value.state.items():
             tx.store_global_weakref(global_key_name(p), p)
-            p_state_source = GetItemSource(state_source, self.tensor_to_source[p])
+            p_state_source = GetItemSource(
+                state_source, GlobalWeakRefSource(global_key_name(p))
+            )
             install_guard(p_state_source.make_guard(GuardBuilder.DICT_KEYS))
             for k, v in value.items():
                 if (