Refactor cudagraphs to use serializable placeholder info (#130252)

This PR refactors placeholders in cudagraphs to be serializable. We define a new PlaceholderInfo object which only has the necessary parts of placeholders for logging/debugging, and use that instead of `torch.fx.Node` directly. This allows us to then save PlaceholderInfo into the FXGraphCache/AOTAutogradCache later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130252
Approved by: https://github.com/eellison, https://github.com/masnesral
ghstack dependencies: #129384
diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py
index 9239ab0..5ff49f0 100644
--- a/test/inductor/test_cudagraph_trees.py
+++ b/test/inductor/test_cudagraph_trees.py
@@ -1270,7 +1270,10 @@
                     foo(*inps)
                 except Exception as e:
                     thrown = True
-                    self.assertTrue("at::cuda::blas::gemm<float>" in str(e))
+                    self.assertTrue(
+                        "at::cuda::blas::gemm<float>" in str(e)
+                        or "at::cuda::blas::gemm_internal_cublas<float>" in str(e)
+                    )
                     self.assertTrue(
                         "getCurrentCUDABlasHandle" in str(e)
                         or "getNewWorkspace" in str(e)
diff --git a/torch/_dynamo/backends/cudagraphs.py b/torch/_dynamo/backends/cudagraphs.py
index f84ab65..719d792 100644
--- a/torch/_dynamo/backends/cudagraphs.py
+++ b/torch/_dynamo/backends/cudagraphs.py
@@ -13,7 +13,7 @@
     check_multiple_devices_or_any_cpu_nodes,
     format_default_skip_message,
     get_mutation_stack_trace,
-    get_placeholders,
+    get_placeholder_info,
     log_cudagraph_skip_and_bump_counter,
 )
 from torch._inductor.utils import (
@@ -83,7 +83,7 @@
     if not mutation_indices:
         return None
 
-    placeholders = [node for node in aot_model.graph.nodes if node.op == "placeholder"]
+    placeholders = get_placeholder_info(aot_model.graph)
     return get_mutation_stack_trace(placeholders, mutation_indices)
 
 
@@ -145,7 +145,7 @@
             is_backward=False,
             is_inference=False,
             stack_traces=get_stack_traces(aot_model),
-            placeholders=get_placeholders(aot_model.graph),
+            placeholders=get_placeholder_info(aot_model.graph),
             mutated_input_idxs=find_input_mutations(aot_model.graph),
         )
         out._boxed_call = True
@@ -183,7 +183,7 @@
             is_backward=True,
             is_inference=False,
             stack_traces=get_stack_traces(aot_model),
-            placeholders=get_placeholders(aot_model.graph),
+            placeholders=get_placeholder_info(aot_model.graph),
             mutated_input_idxs=find_input_mutations(aot_model.graph),
         )
         out._boxed_call = True
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index 922d06b..8eb8933 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -38,8 +38,9 @@
 )
 from torch._inductor.cudagraph_utils import (
     BoxedDeviceIndex,
-    get_placeholders,
+    get_placeholder_info,
     log_cudagraph_skip_and_bump_counter,
+    PlaceholderInfo,
 )
 from torch._inductor.debug import save_args_for_compile_fx_inner
 from torch._inductor.runtime.runtime_utils import cache_dir
@@ -52,7 +53,6 @@
 )
 from torch._logging import trace_structured
 from torch._ops import OpOverload
-from torch._subclasses.fake_tensor import FakeTensor
 from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter
 from torch.fx.passes.fake_tensor_prop import FakeTensorProp
 
@@ -443,6 +443,7 @@
 
 def cudagraph_post_compile(
     cudagraphs: BoxedBool,
+    example_inputs: List[Any],
     compiled_graph: CompiledFxGraph,
     cudagraph_fail_reasons: List[str],
     inputs_to_check: Sequence[int],
@@ -450,8 +451,7 @@
     is_inference: bool,
     is_backward: bool,
     stack_traces: List[Optional[str]],
-    placeholders: Tuple[torch.fx.Node, ...],
-    example_inputs: List[Any],
+    placeholders: Sequence[PlaceholderInfo],
     static_input_idxs: Sequence[int],
 ):
     """
@@ -476,7 +476,6 @@
 
         compiled_graph.current_callable = cudagraphify(
             compiled_graph.current_callable,
-            example_inputs,
             static_input_idxs=static_input_idxs,
             device_index=next(iter(compiled_graph.device_idxs)),
             stack_traces=stack_traces,
@@ -734,9 +733,10 @@
                 for arg in output.args[0]
             ]
             cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
-            placeholders = tuple(get_placeholders(gm.graph))
+            placeholders = tuple(get_placeholder_info(gm.graph))
             cudagraph_post_compile(
                 cudagraphs,
+                example_inputs,
                 compiled_graph,
                 cudagraph_fail_reasons,
                 inputs_to_check,
@@ -745,7 +745,6 @@
                 is_backward,
                 stack_traces,
                 placeholders,
-                example_inputs,
                 static_input_idxs,
             )
 
@@ -1048,7 +1047,6 @@
 @dynamo_utils.dynamo_timed
 def cudagraphify(
     model: Callable[..., Any],
-    inputs: List[torch.Tensor],
     static_input_idxs: Sequence[int] = (),
     *,
     device_index: int,
@@ -1056,7 +1054,7 @@
     is_backward: bool,
     is_inference: bool,
     constants: Tuple[torch.Tensor, ...] = (),
-    placeholders: Tuple[torch.fx.Node, ...] = (),
+    placeholders: Sequence[PlaceholderInfo] = (),
     mutated_input_idxs: Tuple[int, ...] = (),
 ) -> Callable[..., Any]:
     from torch._inductor.cudagraph_trees import (
@@ -1078,10 +1076,6 @@
     else:
         cudagraphify_fn = cudagraphify_impl
 
-    # if using fake tensors, defer cudagraphs until we get real inputs at runtime
-    if not any(isinstance(inp, FakeTensor) for inp in inputs):
-        return cudagraphify_fn(model, inputs, static_input_idxs)
-
     compiled_fn = None
 
     def run(new_inputs):
diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py
index 2034313..c959d6e 100644
--- a/torch/_inductor/cudagraph_trees.py
+++ b/torch/_inductor/cudagraph_trees.py
@@ -83,6 +83,7 @@
     FunctionID,
     log_cudagraph_skip_and_bump_counter,
     log_data_ptr_mismatch,
+    PlaceholderInfo,
     WrappedFunction,
 )
 from torch.multiprocessing.reductions import StorageWeakRef
@@ -390,7 +391,7 @@
     is_inference: bool,
     stack_traces: Optional[StackTraces] = None,
     constants: Tuple[torch.Tensor, ...] = (),
-    placeholders: Tuple[torch.fx.Node, ...] = (),
+    placeholders: Tuple[PlaceholderInfo, ...] = (),
     mutated_input_idxs: Tuple[int, ...] = (),
 ):
     manager = get_container(device_index).get_tree_manager()
diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py
index 988881c..184c2fc 100644
--- a/torch/_inductor/cudagraph_utils.py
+++ b/torch/_inductor/cudagraph_utils.py
@@ -1,4 +1,6 @@
 # mypy: allow-untyped-defs
+from __future__ import annotations
+
 import dataclasses
 from enum import Enum
 from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
@@ -17,6 +19,21 @@
 
 
 @dataclasses.dataclass(frozen=True)
+class PlaceholderInfo:
+    """
+    A serializable version of torch.fx.Node that contains information
+    pertinent to placeholder stack traces. We use these in logging and error messages
+    related to cudagraphs, and will cache these results.
+    """
+
+    name: str
+    stack_trace: Optional[str]
+    # This field is recursive, but never cyclic (since a node never uses itself)
+    users: List[PlaceholderInfo]
+    mutating_use_stack_trace: Optional[str]
+
+
+@dataclasses.dataclass(frozen=True)
 class WrappedFunction:
     """
     Represents a function that you want to record for CUDA graph replay,
@@ -28,15 +45,13 @@
     static_input_idxs: List[int]
     id: FunctionID
     constants: Tuple[torch.Tensor, ...]
-    placeholders: List[torch.fx.Node]
+    placeholders: List[PlaceholderInfo]
     mutated_input_idxs: List[int]
 
 
-def get_placeholders(graph: torch.fx.Graph) -> List[torch.fx.Node]:
-    return [node for node in graph.nodes if node.op == "placeholder"]
-
-
-def get_mutating_use_stack_trace(placeholder_node: torch.fx.Node) -> Optional[str]:
+def get_mutating_use_stack_trace_from_node(
+    placeholder_node: torch.fx.Node,
+) -> Optional[str]:
     # reinplaced uses might have a single, non-copy_ use
     if len(placeholder_node.users) == 1:
         return next(iter(placeholder_node.users)).meta.get("stack_trace", None)
@@ -49,12 +64,37 @@
     return None
 
 
+def get_mutating_use_stack_trace(placeholder_info: PlaceholderInfo) -> Optional[str]:
+    return placeholder_info.mutating_use_stack_trace
+
+
+def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo:
+    name = placeholder_node.name
+    stack_trace = placeholder_node.meta.get("stack_trace", None)
+    users = []
+    mutating_use_stack_trace = None
+    # Only recurse to users once, since we only care about user's stack traces
+    if placeholder_node.op == "placeholder":
+        users = [to_placeholder_info(i) for i in placeholder_node.users]
+        mutating_use_stack_trace = get_mutating_use_stack_trace_from_node(
+            placeholder_node
+        )
+
+    return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace)
+
+
+def get_placeholder_info(graph: torch.fx.Graph) -> List[PlaceholderInfo]:
+    return [
+        to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder"
+    ]
+
+
 def format_default_skip_message(reason: str) -> str:
     return f"skipping cudagraphs due to {reason}"
 
 
 def get_mutation_stack_trace(
-    placeholders: List[torch.fx.Node], mutation_indices: List[int]
+    placeholders: List[PlaceholderInfo], mutation_indices: List[int]
 ) -> str:
     stack_trace: Optional[str] = ""
 
@@ -98,7 +138,7 @@
     )
 
 
-def get_use_stack_trace(node) -> Optional[str]:
+def _get_use_stack_trace(node) -> Optional[str]:
     for use in node.users:
         if stack_trace := use.meta.get("stack_trace", None):
             return stack_trace
@@ -110,7 +150,7 @@
 ) -> Optional[str]:
     if cpu_node := device_node_mapping.get(torch.device("cpu")):
         msg = f"cpu device ({cpu_node.name})"
-        if stack_trace := get_use_stack_trace(cpu_node):
+        if stack_trace := _get_use_stack_trace(cpu_node):
             return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}")
 
         return format_default_skip_message(msg)
@@ -160,7 +200,7 @@
         has_mutation = len(mutation_indices) != 0
         if not has_mutation:
             return None
-        placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
+        placeholders = get_placeholder_info(gm.graph)
         return get_mutation_stack_trace(placeholders, mutation_indices)
 
     else:
@@ -168,7 +208,7 @@
         return None if not has_mutation else default_msg
 
 
-def get_placeholder_stack_trace(placeholder: torch.fx.Node) -> Optional[str]:
+def get_placeholder_stack_trace(placeholder: PlaceholderInfo) -> Optional[str]:
     """
     Gets the first non-empty stack trace of a placeholder or its users.
     """
@@ -207,7 +247,7 @@
 
 
 def log_data_ptr_mismatch(
-    placeholders: List[torch.fx.Node],
+    placeholders: List[PlaceholderInfo],
     inputs: List[torch.Tensor],
     recorded_data_ptr: List[Optional[int]],
     target_idxs: List[int],