Improve print stack/locals printing in comptime (#133651)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133651
Approved by: https://github.com/anijain2305
diff --git a/test/dynamo/test_comptime.py b/test/dynamo/test_comptime.py
index 61a3a08..28e8f15 100644
--- a/test/dynamo/test_comptime.py
+++ b/test/dynamo/test_comptime.py
@@ -160,7 +160,7 @@
self.assertExpectedInline(
FILE.getvalue(),
"""\
-- TensorVariable()
+- FakeTensor(..., size=(2,))
""",
)
@@ -186,8 +186,8 @@
self.assertExpectedInline(
FILE.getvalue(),
"""\
-x = TensorVariable()
-y = TensorVariable()
+x = FakeTensor(..., size=(2,))
+y = FakeTensor(..., size=(2,))
""",
)
diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py
index f1a91e3..972d79d 100644
--- a/torch/_dynamo/comptime.py
+++ b/torch/_dynamo/comptime.py
@@ -233,10 +233,9 @@
NB: Stack grows downwards in our print
"""
- # TODO: improve printing
tx = self.__get_tx(stacklevel)
for s in tx.stack:
- print(f"- {s}", file=file)
+ print(f"- {s.debug_repr()}", file=file)
def print_locals(self, *, file=None, stacklevel=0):
"""
@@ -244,10 +243,9 @@
By default this view is very limited; you can get more information
about any individual local using get_local().
"""
- # TODO: improve by improving the VariableTracker printing
tx = self.__get_tx(stacklevel)
for k, v in tx.symbolic_locals.items():
- print(f"{k} = {v}", file=file)
+ print(f"{k} = {v.debug_repr()}", file=file)
def print_bt(self, *, file=None, stacklevel=0):
"""