Fix missing extra_traceback in InterpreterShim (#97615)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97615
Approved by: https://github.com/Chillee, https://github.com/desertfire
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index 94cbc76..558c1788 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -3679,18 +3679,20 @@
class InterpreterShim(torch.fx.Interpreter):
+ @staticmethod
+ @functools.lru_cache(None)
+ def _dummy_gm():
+ return torch.fx.symbolic_trace(identity)
+
def __init__(self, graph, submodules):
- """
- We don't call super() here to avoid constructing a
- GraphModule which is very expensive (it does codegen).
- """
+ # call super() with a placeholder to avoid constructing a
+ # GraphModule which is very expensive (it does codegen).
+ super().__init__(self._dummy_gm(), garbage_collect_values=False)
self.module = self
self.graph = graph
self.submodules = submodules
- self.garbage_collect_values = False
- self.env = {}
+ self.extra_traceback = False
self.fetch_attr = submodules.__getitem__
- self.name = "InterpreterShim"
self.current_node = None
def run_node(self, n: torch.fx.Node) -> Any: