Add Cudagraphs disable checking (#121018)

Adds the same cudagraphs disable checking from inductor - cudagraph trees to cudagraphs backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121018
Approved by: https://github.com/ezyang
ghstack dependencies: #121017
diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py
index d001b94..1ac37eb 100644
--- a/test/inductor/test_cudagraph_trees.py
+++ b/test/inductor/test_cudagraph_trees.py
@@ -19,9 +19,11 @@
 
 from torch.testing._internal.common_cuda import TEST_MULTIGPU
 from torch.testing._internal.common_utils import (
+    instantiate_parametrized_tests,
     IS_CI,
     IS_LINUX,
     IS_WINDOWS,
+    parametrize,
     skipIfRocm,
     TEST_CUDA_GRAPH,
     TEST_WITH_ASAN,
@@ -50,6 +52,13 @@
 from io import StringIO
 
 
+def get_compile_fn(backend):
+    if backend == "cudagraphs":
+        return functools.partial(torch.compile, backend="cudagraphs")
+    else:
+        return functools.partial(torch.compile, mode="reduce-overhead")
+
+
 class capture_stderr(list):
     """
     Replace sys.stderr with a temporary StringIO
@@ -244,15 +253,16 @@
                 opt = torch.compile(model.forward, mode="reduce-overhead")(x, y, z)
 
             FileCheck().check(
-                "skipping cudagraphs due to mutaton on input. Found from"
+                "skipping cudagraphs due to mutation on input. Found from"
             ).check("torch.logical_xor").run(captured_output[0])
 
         @requires_multigpu()
-        def test_multiple_devices_msg(self):
-            @torch.compile()
+        @parametrize("backend", ("inductor", "cudagraphs"))
+        def test_multiple_devices_msg(self, backend):
             def foo(x, y):
                 return (x + 1, y + 2)
 
+            foo = get_compile_fn(backend)(foo)
             with capture_stderr() as captured_output:
                 foo(torch.ones([10], device="cuda"), torch.ones([20]))
 
@@ -269,19 +279,22 @@
                 captured_output[0]
             )
 
-        def test_mutation(self):
-            @torch.compile()
+        @parametrize("backend", ("inductor", "cudagraphs"))
+        @torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
+        def test_mutation_on_inp(self, backend):
             def foo(x):
                 x.add_(2)
                 return x
 
+            foo = get_compile_fn(backend)(foo)
+
             def inp():
                 return torch.ones([10], device="cuda")
 
             with capture_stderr() as captured_output:
                 foo(inp())
 
-            FileCheck().check("skipping cudagraphs due to mutaton on input.").check(
+            FileCheck().check("skipping cudagraphs due to mutation on input.").check(
                 ".add_(2)"
             ).run(captured_output[0])
 
@@ -1454,6 +1467,7 @@
             with self.assertRaisesRegex(Exception, "custom error msg"):
                 device = x.untyped_storage()
 
+    instantiate_parametrized_tests(CudaGraphTreeTests)
 
 if __name__ == "__main__":
     from torch._dynamo.test_case import run_tests
diff --git a/torch/_dynamo/backends/cudagraphs.py b/torch/_dynamo/backends/cudagraphs.py
index 5b894a7..999917c 100644
--- a/torch/_dynamo/backends/cudagraphs.py
+++ b/torch/_dynamo/backends/cudagraphs.py
@@ -1,13 +1,20 @@
 # mypy: ignore-errors
 
-import logging
 import operator
 from collections import defaultdict
-from typing import Set
+from typing import Dict, Optional, Set
 
 import torch
-from torch._inductor.utils import BoxedBool
-
+from torch._inductor.cudagraph_utils import (
+    check_multiple_devices_or_any_cpu_nodes,
+    get_mutation_stack_trace,
+)
+from torch._inductor.utils import (
+    BoxedBool,
+    count_tangents,
+    has_incompatible_cudagraph_ops,
+    num_fw_fixed_arguments,
+)
 from torch.fx import GraphModule
 from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
 from torch.multiprocessing.reductions import StorageWeakRef
@@ -16,7 +23,7 @@
 from .common import aot_autograd
 from .registry import register_backend
 
-log = logging.getLogger(__name__)
+perf_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
 
 
 def cloner(t):
@@ -95,7 +102,8 @@
     mutated_inputs = set()
     for n in g.nodes:
         if n.op == "placeholder":
-            inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
+            if isinstance(meta_fk(n.meta), torch.Tensor):
+                inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
             input_idx += 1
         elif n.op == "call_function":
             if n.target is operator.getitem:
@@ -118,6 +126,7 @@
                     mutated_inputs |= inputs[
                         StorageWeakRef(meta_fk(argument.meta)._typed_storage())
                     ]
+
         # TODO: error on unrecognized nodes
     return mutated_inputs
 
@@ -134,25 +143,69 @@
     # NB: we didn't actually change the graph, no need for recompile
 
 
+def get_device_node_mapping(gm: torch.fx.GraphModule):
+    device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
+    for n in gm.graph.nodes:
+        t = n.meta.get("val", None)
+        if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
+            device_node_mapping[t.device] = n
+    return device_node_mapping
+
+
+def check_for_mutation(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
+    mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
+    if not mutation_indices:
+        return None
+
+    return get_mutation_stack_trace(aot_model, mutation_indices)
+
+
+def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
+    if mut_skip := check_for_mutation(aot_model, num_fixed):
+        return mut_skip
+
+    if skip := check_multiple_devices_or_any_cpu_nodes(
+        get_device_node_mapping(aot_model)
+    ):
+        return skip
+
+    if has_incompatible_cudagraph_ops(aot_model):
+        return "skipping cudagraphs due to incompatible op"
+
+    return None
+
+
 def cudagraphs(dynamo_model, dynamo_inputs):
     do_cudagraphs = BoxedBool(True)
 
     def forward_cudagraphs(aot_model, aot_inputs):
-        fixed = torch._inductor.utils.num_fw_fixed_arguments(
-            len(dynamo_inputs), len(aot_inputs)
-        )
+        fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
+        if skip_msg := check_for_skip(aot_model, fixed):
+            BoxedBool.disable(cudagraphs)
+            perf_log.warning("skipping cudagraphs due to %s", skip_msg)
+            return aot_model
+
         model = partition_cudagraphs(aot_model, aot_inputs)
         apply_cuda_graphs(model)
         return model
 
     def backward_cudagraphs(aot_model, aot_inputs):
-        fixed = torch._inductor.utils.count_tangents(aot_model)
+        if not do_cudagraphs:
+            return aot_model
+
+        fixed = count_tangents(aot_model)
+        if skip_msg := check_for_skip(aot_model, fixed):
+            perf_log.warning("skipping cudagraphs due to %s", skip_msg)
+            return aot_model
+
         model = partition_cudagraphs(aot_model, aot_inputs)
         apply_cuda_graphs(model)
         return model
 
     aot_cudagraphs = aot_autograd(
-        fw_compiler=forward_cudagraphs, bw_compiler=backward_cudagraphs
+        fw_compiler=forward_cudagraphs,
+        bw_compiler=backward_cudagraphs,
+        keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
     )
     return aot_cudagraphs(dynamo_model, dynamo_inputs)
 
diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py
index 3179381..a41094f 100644
--- a/torch/_dynamo/config.py
+++ b/torch/_dynamo/config.py
@@ -353,6 +353,11 @@
     "skipfiles_inline_module_allowlist",
 }
 
+# for backend="cudagraphs", mutations on input be sent to the cudagraph backend
+# or replayed in aot_autograd epilogue. default is False because mutation on inputs
+# can prevent cudagraphing.
+cudagraph_backend_keep_input_mutation = False
+
 # When True, only ops that have the torch.Tag.pt2_compliant tag
 # will be allowed into the graph; all other ops will be disallowed
 # and will fall back to eager-mode PyTorch. Useful to ensure
diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py
index cceaabc..3bab6f5 100644
--- a/torch/_inductor/cudagraph_utils.py
+++ b/torch/_inductor/cudagraph_utils.py
@@ -1,4 +1,4 @@
-from typing import Dict, Optional
+from typing import Dict, Iterable, Optional
 
 import torch
 from torch._inductor.codecache import CompiledFxGraph
@@ -21,6 +21,24 @@
     return f"skipping cudagraphs due to {reason}"
 
 
+def get_mutation_stack_trace(
+    gm: torch.fx.GraphModule, mutation_indices: Iterable[int]
+) -> str:
+    stack_trace: Optional[str] = ""
+    placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
+
+    for idx in mutation_indices:
+        placeholder = placeholders[idx]
+        if stack_trace := get_mutating_use_stack_trace(placeholder):
+            break
+
+    if stack_trace:
+        msg = f"skipping cudagraphs due to mutation on input. Found from : \n {stack_trace}"
+        return msg
+
+    return format_default_skip_message("mutated inputs")
+
+
 def check_for_mutation(
     gm: torch.fx.GraphModule, compiled_graph: CompiledFxGraph, num_fixed: int
 ) -> Optional[str]:
@@ -33,23 +51,10 @@
             idx for idx in compiled_graph.mutated_input_idxs if idx >= num_fixed
         ]
         has_mutation = len(mutation_indices) != 0
-
         if not has_mutation:
             return None
 
-        stack_trace: Optional[str] = ""
-        placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
-
-        for idx in mutation_indices:
-            placeholder = placeholders[idx]
-            if stack_trace := get_mutating_use_stack_trace(placeholder):
-                break
-
-        if stack_trace:
-            msg = f"skipping cudagraphs due to mutaton on input. Found from : \n {stack_trace}"
-            return msg
-
-        return default_msg
+        return get_mutation_stack_trace(gm, mutation_indices)
 
     else:
         has_mutation = len(compiled_graph.mutated_inputs) != 0