[inductor] Reland #95567 part 1 (#96023)
This is the non-problematic part of #95567. The errors were coming from
IR printing changes which will be next in the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96023
Approved by: https://github.com/ngimel, https://github.com/mlazos
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index 2089081..7c7c2cb 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -1358,8 +1358,8 @@
total = psutil.virtual_memory().total
percentage = psutil.Process(os.getpid()).memory_percent()
peak_mem = percentage * total / 10**9
- except Exception as e:
- log.exception(f"Failed for {mode} {e}")
+ except Exception:
+ log.exception(f"Backend {mode} failed in warmup()")
return sys.exit(-1)
dynamo_stats = get_dynamo_stats()
dynamo_stats.subtract(start_stats)
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 320c16e..1fda9f5 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -2202,6 +2202,7 @@
(v,),
)
+ @slow()
def test_conv_transpose2d_unary(self):
if self.device == "cuda":
raise unittest.SkipTest("only support cpu conv_transpose2d unary test")
diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py
index b393a89..4747597 100644
--- a/torch/_dynamo/config.py
+++ b/torch/_dynamo/config.py
@@ -24,7 +24,7 @@
log_file_name = None
# Verbose will print full stack traces on warnings and errors
-verbose = False
+verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
# If true, traced graph outputs will be outputted as Python GraphModule code.
# If false, traced graph outputs will be outputted in tabular form.
diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py
index 56c867e..ed9b39d 100644
--- a/torch/_dynamo/exc.py
+++ b/torch/_dynamo/exc.py
@@ -44,7 +44,7 @@
def __init__(self, backend_fn, inner_exception):
self.backend_name = getattr(backend_fn, "__name__", "?")
self.inner_exception = inner_exception
- msg = f"{self.backend_name} raised {type(inner_exception).__name__}: {inner_exception}"
+ msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}"
super().__init__(msg)
@@ -103,8 +103,8 @@
msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
torch._dynamo.replay('{exc.record_filename}').\n"
- if not config.verbose:
- msg += "\nSet torch._dynamo.config.verbose=True for more information\n"
+ if not config.verbose and hasattr(exc, "real_stack"):
+ msg += "\nSet torch._dynamo.config.verbose=True or TORCHDYNAMO_VERBOSE=1 for more information\n"
if hasattr(exc, "inner_exception") and hasattr(
exc.inner_exception, "minifier_path"
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index e2377e5..b59ecae 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -705,8 +705,9 @@
_step_logger()(logging.INFO, f"done compiler function {name}")
assert callable(compiled_fn), "compiler_fn did not return callable"
except Exception as e:
- compiled_fn = gm.forward
- raise BackendCompilerFailed(self.compiler_fn, e) from e
+ raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
+ e.__traceback__
+ ) from None
return compiled_fn
def fake_example_inputs(self) -> List[torch.Tensor]:
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index c231652..34206b9 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -596,7 +596,12 @@
key, path = write(source_code, "py", extra)
if key not in cls.cache:
with open(path) as f:
- code = compile(f.read(), path, "exec")
+ try:
+ code = compile(f.read(), path, "exec")
+ except Exception as e:
+ raise RuntimeError(
+ f"Failed to import {path}\n{type(e).__name__}: {e}"
+ )
mod = types.ModuleType(f"{__name__}.{key}")
mod.__file__ = path
mod.key = key
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 17bbd90..5a9b3ec 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -116,6 +116,7 @@
graph_id=None,
):
super().__init__(gm)
+ self.extra_traceback = False # we do our own error wrapping
if shape_env is None:
shape_env = ShapeEnv()
self.reuse_shape_env = False
@@ -353,8 +354,9 @@
out = lowerings[target](*args, **kwargs)
return out
except Exception as e:
- log.exception("Error from lowering")
- raise LoweringException(e, target, args, kwargs) from e
+ raise LoweringException(e, target, args, kwargs).with_traceback(
+ e.__traceback__
+ ) from None
def get_attr(self, target, args, kwargs):
# this is a constant
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index d1bc125..374c84b 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -1,7 +1,6 @@
import collections
import contextlib
import functools
-import glob
import itertools
import logging
import math
@@ -530,33 +529,32 @@
torch._dynamo.config.debug_dir_root = self.prev_debug_name
-def run_and_get_triton_code(fn, *args, **kwargs):
- from torch._inductor.debug import DebugContext
- from torch._inductor.virtualized import V
+def run_and_get_code(fn, *args, **kwargs):
+ from .graph import GraphLowering
- torch._dynamo.reset()
+ compile_to_module = GraphLowering.compile_to_module
+ source_codes = []
- context = DebugContext()
+ def patched_compile_to_module(self):
+ mod = compile_to_module(self)
+ with open(mod.__file__, "r") as f:
+ source_codes.append(f.read())
+ return mod
- with DebugDirManager(), mock.patch.object(
- config.trace, "enabled", True
- ), context, V.set_debug_handler(context):
-
- dir_name = "/".join(context._path.split("/")[:-1]) + "/"
- fil = dir_name + "*inference*"
- existing_dirs = glob.glob(fil)
-
+ with mock.patch.object(
+ GraphLowering, "compile_to_module", patched_compile_to_module
+ ):
+ torch._dynamo.reset()
fn(*args, **kwargs)
+ return source_codes
- assert context._path is not None
- dir_dbg = [x for x in glob.glob(fil) if x not in existing_dirs]
-
- assert len(dir_dbg) == 1, f"{dir_dbg}, {context._path}"
-
- full_name = os.path.join(dir_dbg[0], "output_code.py")
- with open(full_name, "r") as f:
- return f.read()
+def run_and_get_triton_code(fn, *args, **kwargs):
+ source_codes = run_and_get_code(fn, *args, **kwargs)
+ assert (
+ len(source_codes) == 1
+ ), f"expected exactly one code output got {len(source_codes)}"
+ return source_codes[0]
def developer_warning(msg):
diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py
index 586dd3b..6cffc4c 100644
--- a/torch/fx/interpreter.py
+++ b/torch/fx/interpreter.py
@@ -76,6 +76,7 @@
self.env : Dict[Node, Any] = {}
self.name = "Interpreter"
self.garbage_collect_values = garbage_collect_values
+ self.extra_traceback = True
if self.garbage_collect_values:
# Run through reverse nodes and record the first instance of a use
@@ -135,12 +136,13 @@
try:
self.env[node] = self.run_node(node)
except Exception as e:
- msg = f"While executing {node.format_node()}"
- msg = '{}\n\n{}'.format(e.args[0], msg) if e.args else str(msg)
- msg += f"\nOriginal traceback:\n{node.stack_trace}"
- e.args = (msg,) + e.args[1:]
- if isinstance(e, KeyError):
- raise RuntimeError(*e.args) from e
+ if self.extra_traceback:
+ msg = f"While executing {node.format_node()}"
+ msg = '{}\n\n{}'.format(e.args[0], msg) if e.args else str(msg)
+ msg += f"\nOriginal traceback:\n{node.stack_trace}"
+ e.args = (msg,) + e.args[1:]
+ if isinstance(e, KeyError):
+ raise RuntimeError(*e.args) from e
raise
if self.garbage_collect_values: