[easy][dynamo] Add tx as an arg in getitem_const (#132899)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132899
Approved by: https://github.com/yanboliang
ghstack dependencies: #132806
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 621e26a..8ce98b0 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -2122,7 +2122,7 @@
assert isinstance(tos1, ConstDictVariable)
if all(k in tos1 for k in tos): # type: ignore[attr-defined]
- self.push(TupleVariable([tos1.getitem_const(k) for k in tos])) # type: ignore[attr-defined]
+ self.push(TupleVariable([tos1.getitem_const(self, k) for k in tos])) # type: ignore[attr-defined,arg-type]
if sys.version_info < (3, 11):
self.push(ConstantVariable.create(True))
else:
diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py
index f1803e5..eb06996 100644
--- a/torch/_dynamo/variables/constant.py
+++ b/torch/_dynamo/variables/constant.py
@@ -101,7 +101,7 @@
"""
return self.unpack_var_sequence(tx=None)
- def getitem_const(self, arg: VariableTracker):
+ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
return ConstantVariable.create(
self.value[arg.as_python_constant()],
)
diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py
index 1013ab7..18e6c33 100644
--- a/torch/_dynamo/variables/dicts.py
+++ b/torch/_dynamo/variables/dicts.py
@@ -224,7 +224,7 @@
raise_observed_exception(KeyError, tx, self)
return self.items[key]
- def getitem_const(self, arg: VariableTracker):
+ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
unimplemented(f"dict KeyError: {arg.value}")
@@ -332,7 +332,7 @@
self.items.update(kwargs)
return ConstantVariable.create(None)
elif name in ("get", "__getattr__") and args[0] in self:
- return self.getitem_const(args[0])
+ return self.getitem_const(tx, args[0])
elif name == "__contains__" and len(args) == 1:
return ConstantVariable.create(args[0] in self)
else:
@@ -378,7 +378,7 @@
assert len(args) == 1
if args[0] in self:
- return self.getitem_const(args[0])
+ return self.getitem_const(tx, args[0])
else:
if self.default_factory is None:
raise KeyError(f"{args[0]}")
@@ -497,7 +497,7 @@
return super().call_method(tx, "update", (arg,), kwargs)
return super().call_method(tx, name, args, kwargs)
- def getitem_const(self, arg: VariableTracker):
+ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
raise RuntimeError("Illegal to getitem on a set")
diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py
index 3a4b96f..b29a30b 100644
--- a/torch/_dynamo/variables/lists.py
+++ b/torch/_dynamo/variables/lists.py
@@ -86,7 +86,7 @@
assert self.python_type() is not SizeVariable
return self.python_type()(self._as_proxy())
- def getitem_const(self, arg: VariableTracker):
+ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
from .tensor import SymNodeVariable
if isinstance(arg, SymNodeVariable):
@@ -132,7 +132,7 @@
unimplemented("__getitem__ with non-constant tensor")
else:
value = args[0]
- return self.getitem_const(value)
+ return self.getitem_const(tx, value)
elif name == "__contains__":
assert len(args) == 1
assert not kwargs
@@ -285,7 +285,7 @@
def as_python_constant(self):
return range(*[x.as_python_constant() for x in self.items])
- def getitem_const(self, arg: VariableTracker):
+ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
# implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
index = arg.as_python_constant()
diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py
index 04fcb8f..7139bfb 100644
--- a/torch/_dynamo/variables/optimizer.py
+++ b/torch/_dynamo/variables/optimizer.py
@@ -255,7 +255,7 @@
break
group_source = group_vt.source
- params_vt = group_vt.getitem_const(ConstantVariable.create("params"))
+ params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params"))
for p_ind, (p, p_vt) in enumerate(
zip(group["params"], params_vt.unpack_var_sequence(tx))
):