Add Cudagraphs disable checking (#121018)
Adds the same cudagraphs disable checking from inductor - cudagraph trees to cudagraphs backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121018
Approved by: https://github.com/ezyang
ghstack dependencies: #121017
diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py
index d001b94..1ac37eb 100644
--- a/test/inductor/test_cudagraph_trees.py
+++ b/test/inductor/test_cudagraph_trees.py
@@ -19,9 +19,11 @@
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import (
+ instantiate_parametrized_tests,
IS_CI,
IS_LINUX,
IS_WINDOWS,
+ parametrize,
skipIfRocm,
TEST_CUDA_GRAPH,
TEST_WITH_ASAN,
@@ -50,6 +52,13 @@
from io import StringIO
+def get_compile_fn(backend):
+ if backend == "cudagraphs":
+ return functools.partial(torch.compile, backend="cudagraphs")
+ else:
+ return functools.partial(torch.compile, mode="reduce-overhead")
+
+
class capture_stderr(list):
"""
Replace sys.stderr with a temporary StringIO
@@ -244,15 +253,16 @@
opt = torch.compile(model.forward, mode="reduce-overhead")(x, y, z)
FileCheck().check(
- "skipping cudagraphs due to mutaton on input. Found from"
+ "skipping cudagraphs due to mutation on input. Found from"
).check("torch.logical_xor").run(captured_output[0])
@requires_multigpu()
- def test_multiple_devices_msg(self):
- @torch.compile()
+ @parametrize("backend", ("inductor", "cudagraphs"))
+ def test_multiple_devices_msg(self, backend):
def foo(x, y):
return (x + 1, y + 2)
+ foo = get_compile_fn(backend)(foo)
with capture_stderr() as captured_output:
foo(torch.ones([10], device="cuda"), torch.ones([20]))
@@ -269,19 +279,22 @@
captured_output[0]
)
- def test_mutation(self):
- @torch.compile()
+ @parametrize("backend", ("inductor", "cudagraphs"))
+ @torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
+ def test_mutation_on_inp(self, backend):
def foo(x):
x.add_(2)
return x
+ foo = get_compile_fn(backend)(foo)
+
def inp():
return torch.ones([10], device="cuda")
with capture_stderr() as captured_output:
foo(inp())
- FileCheck().check("skipping cudagraphs due to mutaton on input.").check(
+ FileCheck().check("skipping cudagraphs due to mutation on input.").check(
".add_(2)"
).run(captured_output[0])
@@ -1454,6 +1467,7 @@
with self.assertRaisesRegex(Exception, "custom error msg"):
device = x.untyped_storage()
+ instantiate_parametrized_tests(CudaGraphTreeTests)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
diff --git a/torch/_dynamo/backends/cudagraphs.py b/torch/_dynamo/backends/cudagraphs.py
index 5b894a7..999917c 100644
--- a/torch/_dynamo/backends/cudagraphs.py
+++ b/torch/_dynamo/backends/cudagraphs.py
@@ -1,13 +1,20 @@
# mypy: ignore-errors
-import logging
import operator
from collections import defaultdict
-from typing import Set
+from typing import Dict, Optional, Set
import torch
-from torch._inductor.utils import BoxedBool
-
+from torch._inductor.cudagraph_utils import (
+ check_multiple_devices_or_any_cpu_nodes,
+ get_mutation_stack_trace,
+)
+from torch._inductor.utils import (
+ BoxedBool,
+ count_tangents,
+ has_incompatible_cudagraph_ops,
+ num_fw_fixed_arguments,
+)
from torch.fx import GraphModule
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
from torch.multiprocessing.reductions import StorageWeakRef
@@ -16,7 +23,7 @@
from .common import aot_autograd
from .registry import register_backend
-log = logging.getLogger(__name__)
+perf_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
def cloner(t):
@@ -95,7 +102,8 @@
mutated_inputs = set()
for n in g.nodes:
if n.op == "placeholder":
- inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
+ if isinstance(meta_fk(n.meta), torch.Tensor):
+ inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
input_idx += 1
elif n.op == "call_function":
if n.target is operator.getitem:
@@ -118,6 +126,7 @@
mutated_inputs |= inputs[
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
]
+
# TODO: error on unrecognized nodes
return mutated_inputs
@@ -134,25 +143,69 @@
# NB: we didn't actually change the graph, no need for recompile
+def get_device_node_mapping(gm: torch.fx.GraphModule):
+ device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
+ for n in gm.graph.nodes:
+ t = n.meta.get("val", None)
+ if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
+ device_node_mapping[t.device] = n
+ return device_node_mapping
+
+
+def check_for_mutation(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
+ mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
+ if not mutation_indices:
+ return None
+
+ return get_mutation_stack_trace(aot_model, mutation_indices)
+
+
+def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
+ if mut_skip := check_for_mutation(aot_model, num_fixed):
+ return mut_skip
+
+ if skip := check_multiple_devices_or_any_cpu_nodes(
+ get_device_node_mapping(aot_model)
+ ):
+ return skip
+
+ if has_incompatible_cudagraph_ops(aot_model):
+ return "skipping cudagraphs due to incompatible op"
+
+ return None
+
+
def cudagraphs(dynamo_model, dynamo_inputs):
do_cudagraphs = BoxedBool(True)
def forward_cudagraphs(aot_model, aot_inputs):
- fixed = torch._inductor.utils.num_fw_fixed_arguments(
- len(dynamo_inputs), len(aot_inputs)
- )
+ fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
+ if skip_msg := check_for_skip(aot_model, fixed):
+ BoxedBool.disable(cudagraphs)
+ perf_log.warning("skipping cudagraphs due to %s", skip_msg)
+ return aot_model
+
model = partition_cudagraphs(aot_model, aot_inputs)
apply_cuda_graphs(model)
return model
def backward_cudagraphs(aot_model, aot_inputs):
- fixed = torch._inductor.utils.count_tangents(aot_model)
+ if not do_cudagraphs:
+ return aot_model
+
+ fixed = count_tangents(aot_model)
+ if skip_msg := check_for_skip(aot_model, fixed):
+ perf_log.warning("skipping cudagraphs due to %s", skip_msg)
+ return aot_model
+
model = partition_cudagraphs(aot_model, aot_inputs)
apply_cuda_graphs(model)
return model
aot_cudagraphs = aot_autograd(
- fw_compiler=forward_cudagraphs, bw_compiler=backward_cudagraphs
+ fw_compiler=forward_cudagraphs,
+ bw_compiler=backward_cudagraphs,
+ keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
)
return aot_cudagraphs(dynamo_model, dynamo_inputs)
diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py
index 3179381..a41094f 100644
--- a/torch/_dynamo/config.py
+++ b/torch/_dynamo/config.py
@@ -353,6 +353,11 @@
"skipfiles_inline_module_allowlist",
}
+# for backend="cudagraphs", mutations on input be sent to the cudagraph backend
+# or replayed in aot_autograd epilogue. default is False because mutation on inputs
+# can prevent cudagraphing.
+cudagraph_backend_keep_input_mutation = False
+
# When True, only ops that have the torch.Tag.pt2_compliant tag
# will be allowed into the graph; all other ops will be disallowed
# and will fall back to eager-mode PyTorch. Useful to ensure
diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py
index cceaabc..3bab6f5 100644
--- a/torch/_inductor/cudagraph_utils.py
+++ b/torch/_inductor/cudagraph_utils.py
@@ -1,4 +1,4 @@
-from typing import Dict, Optional
+from typing import Dict, Iterable, Optional
import torch
from torch._inductor.codecache import CompiledFxGraph
@@ -21,6 +21,24 @@
return f"skipping cudagraphs due to {reason}"
+def get_mutation_stack_trace(
+ gm: torch.fx.GraphModule, mutation_indices: Iterable[int]
+) -> str:
+ stack_trace: Optional[str] = ""
+ placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
+
+ for idx in mutation_indices:
+ placeholder = placeholders[idx]
+ if stack_trace := get_mutating_use_stack_trace(placeholder):
+ break
+
+ if stack_trace:
+ msg = f"skipping cudagraphs due to mutation on input. Found from : \n {stack_trace}"
+ return msg
+
+ return format_default_skip_message("mutated inputs")
+
+
def check_for_mutation(
gm: torch.fx.GraphModule, compiled_graph: CompiledFxGraph, num_fixed: int
) -> Optional[str]:
@@ -33,23 +51,10 @@
idx for idx in compiled_graph.mutated_input_idxs if idx >= num_fixed
]
has_mutation = len(mutation_indices) != 0
-
if not has_mutation:
return None
- stack_trace: Optional[str] = ""
- placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
-
- for idx in mutation_indices:
- placeholder = placeholders[idx]
- if stack_trace := get_mutating_use_stack_trace(placeholder):
- break
-
- if stack_trace:
- msg = f"skipping cudagraphs due to mutaton on input. Found from : \n {stack_trace}"
- return msg
-
- return default_msg
+ return get_mutation_stack_trace(gm, mutation_indices)
else:
has_mutation = len(compiled_graph.mutated_inputs) != 0