| # mypy: ignore-errors |
| |
| import functools |
| from collections import defaultdict |
| from typing import Dict, List, Optional |
| |
| import torch |
| from torch._dynamo import config |
| from torch._dynamo.backends.common import aot_autograd |
| from torch._dynamo.backends.debugging import boxed_nop |
| from torch._inductor.cudagraph_utils import ( |
| BoxedDeviceIndex, |
| check_multiple_devices_or_any_cpu_nodes, |
| format_default_skip_message, |
| get_mutation_stack_trace, |
| get_placeholder_info, |
| log_cudagraph_skip_and_bump_counter, |
| ) |
| from torch._inductor.utils import ( |
| BoxedBool, |
| count_tangents, |
| get_first_incompatible_cudagraph_node, |
| num_fw_fixed_arguments, |
| output_node, |
| ) |
| |
| from torch.multiprocessing.reductions import StorageWeakRef |
| from .registry import register_backend |
| |
| |
| def find_input_mutations(g): |
| def meta_fk(meta): |
| return meta["val"] if "val" in meta else meta["fake_result"] |
| |
| inputs = defaultdict(set) |
| input_idx = 0 |
| mutated_inputs = set() |
| for n in g.nodes: |
| if n.op == "placeholder": |
| if isinstance(meta_fk(n.meta), torch.Tensor): |
| inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) |
| input_idx += 1 |
| elif n.op == "call_function": |
| if not hasattr(n.target, "_schema"): |
| continue |
| |
| schema = n.target._schema |
| for i, arg in enumerate(schema.arguments): |
| if i < len(n.args): |
| argument = n.args[i] |
| else: |
| if arg.name not in n.kwargs: |
| continue |
| argument = n.kwargs[arg.name] |
| mut_arg = False |
| if arg.alias_info: |
| if arg.alias_info.is_write: |
| mut_arg = True |
| if mut_arg: |
| # TODO: not correct for args that contain tensors in a struct |
| # like list |
| mutated_inputs |= inputs[ |
| StorageWeakRef(meta_fk(argument.meta)._typed_storage()) |
| ] |
| |
| # TODO: error on unrecognized nodes |
| return mutated_inputs |
| |
| |
| def get_device_node_mapping(gm: torch.fx.GraphModule): |
| device_node_mapping: Dict[torch.device, torch.fx.Node] = {} |
| for n in gm.graph.nodes: |
| t = n.meta.get("val", None) |
| if isinstance(t, torch.Tensor) and t.device not in device_node_mapping: |
| device_node_mapping[t.device] = n |
| return device_node_mapping |
| |
| |
| def check_for_mutation_ignore_cuda_graph_managed_tensor( |
| aot_model: torch.fx.GraphModule, num_fixed |
| ) -> Optional[str]: |
| mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed)) |
| if not mutation_indices: |
| return None |
| |
| placeholders = get_placeholder_info(aot_model.graph) |
| return get_mutation_stack_trace(placeholders, mutation_indices) |
| |
| |
| def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]: |
| if not config.cudagraph_backend_support_input_mutation: |
| if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor( |
| aot_model, num_fixed |
| ): |
| return mut_skip |
| |
| if skip := check_multiple_devices_or_any_cpu_nodes( |
| get_device_node_mapping(aot_model) |
| ): |
| return skip |
| |
| if node := get_first_incompatible_cudagraph_node(aot_model): |
| return format_default_skip_message(f"incompatible op ({node.name})") |
| |
| return None |
| |
| |
| def get_device_index(gm) -> int: |
| device = next(iter(get_device_node_mapping(gm))) |
| assert device.type == "cuda" |
| return device.index |
| |
| |
| def get_stack_traces(gm) -> List[Optional[str]]: |
| output = output_node(gm) |
| assert len(output.args) == 1 |
| return [ |
| (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) |
| for arg in output.args[0] |
| ] |
| |
| |
| def cudagraphs(dynamo_model, dynamo_inputs): |
| from torch._inductor.cudagraph_trees import cudagraphify_impl |
| |
| do_cudagraphs = BoxedBool(True) |
| boxed_device_index = BoxedDeviceIndex(None) |
| |
| def forward_cudagraphs(aot_model, aot_inputs, is_inference=False): |
| interp = boxed_nop(aot_model, aot_inputs) |
| fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs)) |
| if skip_msg := check_for_skip(aot_model, fixed): |
| BoxedBool.disable(do_cudagraphs) |
| log_cudagraph_skip_and_bump_counter( |
| f"skipping cudagraphs due to {skip_msg}" |
| ) |
| return interp |
| |
| boxed_device_index.set(get_device_index(aot_model)) |
| out = cudagraphify_impl( |
| interp, |
| aot_inputs, |
| range(fixed), |
| device_index=boxed_device_index.value, |
| is_backward=False, |
| is_inference=False, |
| stack_traces=get_stack_traces(aot_model), |
| placeholders=get_placeholder_info(aot_model.graph), |
| mutated_input_idxs=find_input_mutations(aot_model.graph), |
| ) |
| out._boxed_call = True |
| return out |
| |
| def backward_cudagraphs(aot_model, aot_inputs): |
| interp = boxed_nop(aot_model, aot_inputs) |
| if not do_cudagraphs: |
| return aot_model |
| |
| fixed = count_tangents(aot_model) |
| if skip_msg := check_for_skip(aot_model, fixed): |
| log_cudagraph_skip_and_bump_counter( |
| "skipping cudagraphs due to %s", skip_msg |
| ) |
| |
| # See [Backward Generation Handling] |
| manager = torch._inductor.cudagraph_trees.get_manager( |
| boxed_device_index.value, create_if_none_exists=False |
| ) |
| assert manager is not None |
| |
| def fn(inputs): |
| manager.set_to_running_backward() |
| return aot_model(inputs) |
| |
| fn._boxed_call = True |
| return fn |
| |
| out = cudagraphify_impl( |
| interp, |
| aot_inputs, |
| range(fixed), |
| device_index=get_device_index(aot_model), |
| is_backward=True, |
| is_inference=False, |
| stack_traces=get_stack_traces(aot_model), |
| placeholders=get_placeholder_info(aot_model.graph), |
| mutated_input_idxs=find_input_mutations(aot_model.graph), |
| ) |
| out._boxed_call = True |
| return out |
| |
| aot_cudagraphs = aot_autograd( |
| fw_compiler=forward_cudagraphs, |
| bw_compiler=backward_cudagraphs, |
| inference_compiler=functools.partial(forward_cudagraphs, is_inference=True), |
| keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation, |
| ) |
| return aot_cudagraphs(dynamo_model, dynamo_inputs) |
| |
| |
| class CudagraphsBackend: |
| compiler_name = "cudagraphs" |
| |
| @staticmethod |
| def reset(): |
| from torch._inductor.cudagraph_trees import reset_cudagraph_trees |
| |
| reset_cudagraph_trees() |
| |
| @staticmethod |
| def __call__(model, inputs): |
| return cudagraphs(model, inputs) |
| |
| |
| # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful |
| # for debugging and can serve as a perf baseline. |
| register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend()) |
| |
| |
| def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True): |
| """This isn't registered as a backend, but is used in some benchmarks""" |
| assert isinstance(inputs, (list, tuple)) |
| if copy_inputs: |
| static_inputs = [torch.zeros_like(x) for x in inputs] |
| else: |
| static_inputs = list(inputs) |
| |
| # warmup |
| torch.cuda.synchronize() |
| stream = torch.cuda.Stream() |
| stream.wait_stream(torch.cuda.current_stream()) |
| with torch.cuda.stream(stream): |
| model(*inputs) |
| stream.synchronize() |
| torch.cuda.current_stream().wait_stream(stream) |
| torch.cuda.synchronize() |
| |
| # record |
| graph = torch.cuda.CUDAGraph() |
| with torch.cuda.graph(graph, stream=stream): |
| static_outputs = model(*static_inputs) |
| if not isinstance(static_outputs, (list, tuple)): |
| static_outputs = (static_outputs,) |
| |
| def run(*new_inputs): |
| assert len(static_inputs) == len(new_inputs) |
| if copy_inputs: |
| for dst, src in zip(static_inputs, new_inputs): |
| dst.copy_(src) |
| graph.replay() |
| if copy_outputs: |
| return [x.clone() for x in static_outputs] |
| else: |
| return static_outputs |
| |
| return run |