Refactor cudagraphs to use serializable placeholder info (#130252)
This PR refactors placeholders in cudagraphs to be serializable. We define a new PlaceholderInfo object which only has the necessary parts of placeholders for logging/debugging, and use that instead of `torch.fx.Node` directly. This allows us to then save PlaceholderInfo into the FXGraphCache/AOTAutogradCache later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130252
Approved by: https://github.com/eellison, https://github.com/masnesral
ghstack dependencies: #129384
diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py
index 9239ab0..5ff49f0 100644
--- a/test/inductor/test_cudagraph_trees.py
+++ b/test/inductor/test_cudagraph_trees.py
@@ -1270,7 +1270,10 @@
foo(*inps)
except Exception as e:
thrown = True
- self.assertTrue("at::cuda::blas::gemm<float>" in str(e))
+ self.assertTrue(
+ "at::cuda::blas::gemm<float>" in str(e)
+ or "at::cuda::blas::gemm_internal_cublas<float>" in str(e)
+ )
self.assertTrue(
"getCurrentCUDABlasHandle" in str(e)
or "getNewWorkspace" in str(e)
diff --git a/torch/_dynamo/backends/cudagraphs.py b/torch/_dynamo/backends/cudagraphs.py
index f84ab65..719d792 100644
--- a/torch/_dynamo/backends/cudagraphs.py
+++ b/torch/_dynamo/backends/cudagraphs.py
@@ -13,7 +13,7 @@
check_multiple_devices_or_any_cpu_nodes,
format_default_skip_message,
get_mutation_stack_trace,
- get_placeholders,
+ get_placeholder_info,
log_cudagraph_skip_and_bump_counter,
)
from torch._inductor.utils import (
@@ -83,7 +83,7 @@
if not mutation_indices:
return None
- placeholders = [node for node in aot_model.graph.nodes if node.op == "placeholder"]
+ placeholders = get_placeholder_info(aot_model.graph)
return get_mutation_stack_trace(placeholders, mutation_indices)
@@ -145,7 +145,7 @@
is_backward=False,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
- placeholders=get_placeholders(aot_model.graph),
+ placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
@@ -183,7 +183,7 @@
is_backward=True,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
- placeholders=get_placeholders(aot_model.graph),
+ placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index 922d06b..8eb8933 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -38,8 +38,9 @@
)
from torch._inductor.cudagraph_utils import (
BoxedDeviceIndex,
- get_placeholders,
+ get_placeholder_info,
log_cudagraph_skip_and_bump_counter,
+ PlaceholderInfo,
)
from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._inductor.runtime.runtime_utils import cache_dir
@@ -52,7 +53,6 @@
)
from torch._logging import trace_structured
from torch._ops import OpOverload
-from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
@@ -443,6 +443,7 @@
def cudagraph_post_compile(
cudagraphs: BoxedBool,
+ example_inputs: List[Any],
compiled_graph: CompiledFxGraph,
cudagraph_fail_reasons: List[str],
inputs_to_check: Sequence[int],
@@ -450,8 +451,7 @@
is_inference: bool,
is_backward: bool,
stack_traces: List[Optional[str]],
- placeholders: Tuple[torch.fx.Node, ...],
- example_inputs: List[Any],
+ placeholders: Sequence[PlaceholderInfo],
static_input_idxs: Sequence[int],
):
"""
@@ -476,7 +476,6 @@
compiled_graph.current_callable = cudagraphify(
compiled_graph.current_callable,
- example_inputs,
static_input_idxs=static_input_idxs,
device_index=next(iter(compiled_graph.device_idxs)),
stack_traces=stack_traces,
@@ -734,9 +733,10 @@
for arg in output.args[0]
]
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
- placeholders = tuple(get_placeholders(gm.graph))
+ placeholders = tuple(get_placeholder_info(gm.graph))
cudagraph_post_compile(
cudagraphs,
+ example_inputs,
compiled_graph,
cudagraph_fail_reasons,
inputs_to_check,
@@ -745,7 +745,6 @@
is_backward,
stack_traces,
placeholders,
- example_inputs,
static_input_idxs,
)
@@ -1048,7 +1047,6 @@
@dynamo_utils.dynamo_timed
def cudagraphify(
model: Callable[..., Any],
- inputs: List[torch.Tensor],
static_input_idxs: Sequence[int] = (),
*,
device_index: int,
@@ -1056,7 +1054,7 @@
is_backward: bool,
is_inference: bool,
constants: Tuple[torch.Tensor, ...] = (),
- placeholders: Tuple[torch.fx.Node, ...] = (),
+ placeholders: Sequence[PlaceholderInfo] = (),
mutated_input_idxs: Tuple[int, ...] = (),
) -> Callable[..., Any]:
from torch._inductor.cudagraph_trees import (
@@ -1078,10 +1076,6 @@
else:
cudagraphify_fn = cudagraphify_impl
- # if using fake tensors, defer cudagraphs until we get real inputs at runtime
- if not any(isinstance(inp, FakeTensor) for inp in inputs):
- return cudagraphify_fn(model, inputs, static_input_idxs)
-
compiled_fn = None
def run(new_inputs):
diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py
index 2034313..c959d6e 100644
--- a/torch/_inductor/cudagraph_trees.py
+++ b/torch/_inductor/cudagraph_trees.py
@@ -83,6 +83,7 @@
FunctionID,
log_cudagraph_skip_and_bump_counter,
log_data_ptr_mismatch,
+ PlaceholderInfo,
WrappedFunction,
)
from torch.multiprocessing.reductions import StorageWeakRef
@@ -390,7 +391,7 @@
is_inference: bool,
stack_traces: Optional[StackTraces] = None,
constants: Tuple[torch.Tensor, ...] = (),
- placeholders: Tuple[torch.fx.Node, ...] = (),
+ placeholders: Tuple[PlaceholderInfo, ...] = (),
mutated_input_idxs: Tuple[int, ...] = (),
):
manager = get_container(device_index).get_tree_manager()
diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py
index 988881c..184c2fc 100644
--- a/torch/_inductor/cudagraph_utils.py
+++ b/torch/_inductor/cudagraph_utils.py
@@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
+from __future__ import annotations
+
import dataclasses
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
@@ -17,6 +19,21 @@
@dataclasses.dataclass(frozen=True)
+class PlaceholderInfo:
+ """
+ A serializable version of torch.fx.Node that contains information
+ pertinent to placeholder stack traces. We use these in logging and error messages
+ related to cudagraphs, and will cache these results.
+ """
+
+ name: str
+ stack_trace: Optional[str]
+ # This field is recursive, but never cyclic (since a node never uses itself)
+ users: List[PlaceholderInfo]
+ mutating_use_stack_trace: Optional[str]
+
+
+@dataclasses.dataclass(frozen=True)
class WrappedFunction:
"""
Represents a function that you want to record for CUDA graph replay,
@@ -28,15 +45,13 @@
static_input_idxs: List[int]
id: FunctionID
constants: Tuple[torch.Tensor, ...]
- placeholders: List[torch.fx.Node]
+ placeholders: List[PlaceholderInfo]
mutated_input_idxs: List[int]
-def get_placeholders(graph: torch.fx.Graph) -> List[torch.fx.Node]:
- return [node for node in graph.nodes if node.op == "placeholder"]
-
-
-def get_mutating_use_stack_trace(placeholder_node: torch.fx.Node) -> Optional[str]:
+def get_mutating_use_stack_trace_from_node(
+ placeholder_node: torch.fx.Node,
+) -> Optional[str]:
# reinplaced uses might have a single, non-copy_ use
if len(placeholder_node.users) == 1:
return next(iter(placeholder_node.users)).meta.get("stack_trace", None)
@@ -49,12 +64,37 @@
return None
+def get_mutating_use_stack_trace(placeholder_info: PlaceholderInfo) -> Optional[str]:
+ return placeholder_info.mutating_use_stack_trace
+
+
+def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo:
+ name = placeholder_node.name
+ stack_trace = placeholder_node.meta.get("stack_trace", None)
+ users = []
+ mutating_use_stack_trace = None
+ # Only recurse to users once, since we only care about user's stack traces
+ if placeholder_node.op == "placeholder":
+ users = [to_placeholder_info(i) for i in placeholder_node.users]
+ mutating_use_stack_trace = get_mutating_use_stack_trace_from_node(
+ placeholder_node
+ )
+
+ return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace)
+
+
+def get_placeholder_info(graph: torch.fx.Graph) -> List[PlaceholderInfo]:
+ return [
+ to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder"
+ ]
+
+
def format_default_skip_message(reason: str) -> str:
return f"skipping cudagraphs due to {reason}"
def get_mutation_stack_trace(
- placeholders: List[torch.fx.Node], mutation_indices: List[int]
+ placeholders: List[PlaceholderInfo], mutation_indices: List[int]
) -> str:
stack_trace: Optional[str] = ""
@@ -98,7 +138,7 @@
)
-def get_use_stack_trace(node) -> Optional[str]:
+def _get_use_stack_trace(node) -> Optional[str]:
for use in node.users:
if stack_trace := use.meta.get("stack_trace", None):
return stack_trace
@@ -110,7 +150,7 @@
) -> Optional[str]:
if cpu_node := device_node_mapping.get(torch.device("cpu")):
msg = f"cpu device ({cpu_node.name})"
- if stack_trace := get_use_stack_trace(cpu_node):
+ if stack_trace := _get_use_stack_trace(cpu_node):
return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}")
return format_default_skip_message(msg)
@@ -160,7 +200,7 @@
has_mutation = len(mutation_indices) != 0
if not has_mutation:
return None
- placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
+ placeholders = get_placeholder_info(gm.graph)
return get_mutation_stack_trace(placeholders, mutation_indices)
else:
@@ -168,7 +208,7 @@
return None if not has_mutation else default_msg
-def get_placeholder_stack_trace(placeholder: torch.fx.Node) -> Optional[str]:
+def get_placeholder_stack_trace(placeholder: PlaceholderInfo) -> Optional[str]:
"""
Gets the first non-empty stack trace of a placeholder or its users.
"""
@@ -207,7 +247,7 @@
def log_data_ptr_mismatch(
- placeholders: List[torch.fx.Node],
+ placeholders: List[PlaceholderInfo],
inputs: List[torch.Tensor],
recorded_data_ptr: List[Optional[int]],
target_idxs: List[int],