blob: 999917cdec27c4828f9cd929aecef106ebd271a8 [file] [log] [blame]
# mypy: ignore-errors
import operator
from collections import defaultdict
from typing import Dict, Optional, Set
import torch
from torch._inductor.cudagraph_utils import (
check_multiple_devices_or_any_cpu_nodes,
get_mutation_stack_trace,
)
from torch._inductor.utils import (
BoxedBool,
count_tangents,
has_incompatible_cudagraph_ops,
num_fw_fixed_arguments,
)
from torch.fx import GraphModule
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
from torch.multiprocessing.reductions import StorageWeakRef
from torch.nn import Module
from torch.utils._pytree import tree_map
from .common import aot_autograd
from .registry import register_backend
perf_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
def cloner(t):
if isinstance(t, torch.Tensor):
return t.clone()
else:
return t
class CudaGraphModule(Module):
gm: GraphModule
mutated_inputs: Set[int]
def __init__(self, gm, mutated_inputs):
super().__init__()
self.gm = gm
self.mutated_inputs = mutated_inputs
warmed_up = False
# these are all None or all filled
graph = None
static_inputs = None
static_outputs = None
# NB: we override __call__ as we don't need any nn.Module machinery
# and to reduce overhead
def __call__(self, *args):
# TODO: once we've recorded here, we'd like to replace the __call__
# implementation with compiled bytecode that copies into static, replays
# the cuda graph, then copies out. First condition is the hotpath,
# needs optimizing
if self.graph is not None:
assert len(args) == len(self.static_inputs)
for dst, src in zip(self.static_inputs, args):
dst.copy_(src)
self.graph.replay()
for i in self.mutated_inputs:
args[i].copy_(self.static_inputs[i])
return tree_map(cloner, self.static_outputs)
elif self.warmed_up:
# record
self.static_inputs = [x.clone() for x in args]
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_outputs = self.gm(*self.static_inputs)
# NB: recording doesn't actually run the operations, so
# now we immediately replay the graph to serve up the result
self.graph.replay()
for i in self.mutated_inputs:
args[i].copy_(self.static_inputs[i])
return tree_map(cloner, self.static_outputs)
else:
# warmup
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
r = self.gm(*args)
torch.cuda.current_stream().wait_stream(stream)
self.warmed_up = True
return r
# Interpreter versions of these passes can be found at
# https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23
def find_input_mutations(g):
def meta_fk(meta):
return meta["val"] if "val" in meta else meta["fake_result"]
inputs = defaultdict(set)
input_idx = 0
mutated_inputs = set()
for n in g.nodes:
if n.op == "placeholder":
if isinstance(meta_fk(n.meta), torch.Tensor):
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
input_idx += 1
elif n.op == "call_function":
if n.target is operator.getitem:
continue
schema = n.target._schema
for i, arg in enumerate(schema.arguments):
if i < len(n.args):
argument = n.args[i]
else:
if arg.name not in n.kwargs:
continue
argument = n.kwargs[arg.name]
mut_arg = False
if arg.alias_info:
if arg.alias_info.is_write:
mut_arg = True
if mut_arg:
# TODO: not correct for args that contain tensors in a struct
# like list
mutated_inputs |= inputs[
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
]
# TODO: error on unrecognized nodes
return mutated_inputs
# Mutates input graph
def apply_cuda_graphs(gm):
for n in gm.graph.nodes:
if n.op == "call_module":
assert not n.kwargs
submod = gm.get_submodule(n.target)
gm.delete_submodule(n.target)
mutated_inputs = find_input_mutations(submod.graph)
gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs))
# NB: we didn't actually change the graph, no need for recompile
def get_device_node_mapping(gm: torch.fx.GraphModule):
device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
for n in gm.graph.nodes:
t = n.meta.get("val", None)
if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
device_node_mapping[t.device] = n
return device_node_mapping
def check_for_mutation(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
if not mutation_indices:
return None
return get_mutation_stack_trace(aot_model, mutation_indices)
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
if mut_skip := check_for_mutation(aot_model, num_fixed):
return mut_skip
if skip := check_multiple_devices_or_any_cpu_nodes(
get_device_node_mapping(aot_model)
):
return skip
if has_incompatible_cudagraph_ops(aot_model):
return "skipping cudagraphs due to incompatible op"
return None
def cudagraphs(dynamo_model, dynamo_inputs):
do_cudagraphs = BoxedBool(True)
def forward_cudagraphs(aot_model, aot_inputs):
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
if skip_msg := check_for_skip(aot_model, fixed):
BoxedBool.disable(cudagraphs)
perf_log.warning("skipping cudagraphs due to %s", skip_msg)
return aot_model
model = partition_cudagraphs(aot_model, aot_inputs)
apply_cuda_graphs(model)
return model
def backward_cudagraphs(aot_model, aot_inputs):
if not do_cudagraphs:
return aot_model
fixed = count_tangents(aot_model)
if skip_msg := check_for_skip(aot_model, fixed):
perf_log.warning("skipping cudagraphs due to %s", skip_msg)
return aot_model
model = partition_cudagraphs(aot_model, aot_inputs)
apply_cuda_graphs(model)
return model
aot_cudagraphs = aot_autograd(
fw_compiler=forward_cudagraphs,
bw_compiler=backward_cudagraphs,
keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
)
return aot_cudagraphs(dynamo_model, dynamo_inputs)
aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs)
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
# for debugging and can serve as a perf baseline.
# TODO(jansel): rename to just "cudagraphs"?
register_backend(name="cudagraphs", compiler_fn=cudagraphs)
def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
"""This isn't registered as a backend, but is used in some benchmarks"""
assert isinstance(inputs, (list, tuple))
if copy_inputs:
static_inputs = [torch.zeros_like(x) for x in inputs]
else:
static_inputs = list(inputs)
# warmup
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
model(*inputs)
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = model(*static_inputs)
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
def run(*new_inputs):
assert len(static_inputs) == len(new_inputs)
if copy_inputs:
for dst, src in zip(static_inputs, new_inputs):
dst.copy_(src)
graph.replay()
if copy_outputs:
return [x.clone() for x in static_outputs]
else:
return static_outputs
return run