[dynamo] Minifier fixes for reproducing segfault (#89712)
Helped with minifying the segfault in https://github.com/pytorch/torchdynamo/issues/1928
Tests not really needed. It improves quality of life as segfault can fail anywhere (when CUDA_LAUNCH_BLOCKING is off)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89712
Approved by: https://github.com/mlazos, https://github.com/ngimel
diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py
index 29d8301..36dd15e 100644
--- a/torch/_dynamo/debug_utils.py
+++ b/torch/_dynamo/debug_utils.py
@@ -222,6 +222,12 @@
def save_graph_repro(fd, gm, args, compiler_name):
+ sync_line = ""
+ for arg in args:
+ if arg.is_cuda:
+ sync_line = "torch.cuda.synchronize() # Ensures that segfaults are surfaced"
+ break
+
if "inductor" in compiler_name:
fd.write(f"import {config.inductor_import}.overrides\n")
fd.write(generate_compiler_repro_string(gm, args))
@@ -243,7 +249,8 @@
textwrap.dedent(
f"""
compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args)
- compiled(args)
+ ref = compiled(args)
+ {sync_line}
"""
)
)
@@ -296,27 +303,41 @@
stderr.seek(0)
print(textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "))
print(textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "))
+ # print(f"Isolated test failed - {file_name}")
return True
return False
def inductor_fails(fx_g, args, check_str=None):
+ has_cuda = False
+ for arg in args:
+ if arg.is_cuda:
+ has_cuda = True
+ break
+
+ def sync():
+ if has_cuda:
+ # Ensures that segfaults are surfaced
+ torch.cuda.synchronize()
+
compile_fx_inner = import_module(
f"{config.inductor_import}.compile_fx"
).compile_fx_inner
- import_module(f"{config.inductor_import}.config").triton.autotune = False
-
try:
result = fx_g(*args)
assert isinstance(result, (tuple, list))
assert not any([isinstance(x, (tuple, list)) for x in result])
except Exception:
return False
+ result = None
+
+ sync()
try:
compile_mod = compile_fx_inner(fx_g, args)
compile_mod(args)
+ sync()
except Exception as e:
if check_str is not None and check_str not in repr(e):
return False