blob: 20cfb3759ad92c426ead902542de7b15b797c6de [file] [log] [blame]
import functools
import logging
import operator
from collections import defaultdict
from functools import partial
from importlib import import_module
from typing import Set
from functorch.compile import (
aot_module_simplified,
min_cut_rematerialization_partition,
nop,
ts_compile,
)
import torch
from torch._functorch.compilers import debug_nop
from torch.fx import GraphModule
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
from torch.multiprocessing.reductions import StorageWeakRef
from torch.nn import Module
from torch.utils._pytree import tree_map
from .. import config, eval_frame
from ..utils import clone_inputs, counters
from .backends import BACKENDS
from .normalize import normalize_ir
log = logging.getLogger(__name__)
def aot_autograd(**kwargs):
def compiler_fn(gm: torch.fx.GraphModule, example_inputs):
import functorch.compile
# Hack to get around circular import problems with aot_inductor_debug
if callable(kwargs.get("decompositions")):
kwargs["decompositions"] = kwargs["decompositions"]()
# TODO: stop monkeypatching here (without even cleaning up, UGH!)
functorch.compile.config.use_functionalize = True
functorch.compile.config.use_fake_tensor = True
counters["aot_autograd"]["total"] += 1
use_fallback = False
if not functorch.compile.config.use_functionalize and config.normalize_ir:
try:
gm = normalize_ir(gm, clone_inputs(example_inputs))
except Exception:
log.debug("TorchDynamo unable to remove mutation")
use_fallback = True
# NB: no clone here on example inputs
if not is_aot_autograd_safe_to_run(gm, example_inputs):
use_fallback = True
if use_fallback:
log.debug("Unable to use AOT Autograd because graph has mutation")
counters["aot_autograd"]["not_ok"] += 1
return gm
# OK attempt to compile
def _wrapped_bw_compiler(*args, **kwargs):
# stop TorchDynamo from trying to compile our generated backwards pass
return eval_frame.disable(eval_frame.disable(bw_compiler)(*args, **kwargs))
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
kwargs["bw_compiler"] = _wrapped_bw_compiler
from torch._inductor.debug import enable_aot_logging
try:
# NB: NOT cloned!
with enable_aot_logging():
cg = aot_module_simplified(gm, example_inputs, **kwargs)
counters["aot_autograd"]["ok"] += 1
return eval_frame.disable(cg)
except Exception:
counters["aot_autograd"]["not_ok"] += 1
raise
return compiler_fn
def is_aot_autograd_safe_to_run(gm, example_inputs):
"""
There are some known issues with Aot Autograd. This is a workaround to catch
such cases, and fallback to eager. We should fix these quickly.
Issues
1) LSTM - https://github.com/pytorch/torchdynamo/issues/1147
2) LSTM - https://github.com/pytorch/functorch/issues/586
3) Input mutation - https://github.com/pytorch/torchdynamo/issues/1301
"""
def raise_or_warn(reason):
msg = f"Unable to use Aot Autograd because of presence of {reason}"
if config.raise_on_unsafe_aot_autograd:
raise NotImplementedError(msg)
else:
log.warning(msg)
return False
# 1) LSTM module (tts_angular) - https://github.com/pytorch/functorch/issues/586
for submod in gm.modules():
if submod.__class__.__name__ == "LSTM":
return raise_or_warn("LSTM")
# 2) Mutation in the graphs are now always handled by AOT Autograd.
return True
DEBUG = False
# Useful for debugging purpose
aot_eager = aot_autograd(fw_compiler=debug_nop if DEBUG else nop)
# AOT Autograd with torchscript backend. Default partitioner.
aot_ts = aot_autograd(fw_compiler=ts_compile)
# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
# inductor problems.
aot_inductor_debug = aot_autograd(
# these are taken from memory_efficient_fusion()
fw_compiler=nop,
bw_compiler=nop,
# NB: lambda here is to delay import of inductor
decompositions=lambda: import_module(
f"{config.inductor_import}.compile_fx"
).select_decomp_table(),
partition_fn=functools.partial(
min_cut_rematerialization_partition, compiler="inductor"
),
)
def mem_efficient_fusion_kwargs(use_decomps):
from functorch.compile import (
default_decompositions,
min_cut_rematerialization_partition,
ts_compile,
)
kwargs = {
# these are taken from memory_efficient_fusion()
"fw_compiler": ts_compile,
"bw_compiler": ts_compile,
"partition_fn": min_cut_rematerialization_partition,
}
if use_decomps:
kwargs["decompositions"] = default_decompositions
return kwargs
# Use min cut rematerialization and TorchScript+nvFuser with AOT Autograd
aot_mem_efficient_fusion = aot_autograd(**mem_efficient_fusion_kwargs(use_decomps=True))
aot_mem_efficient_fusion_no_decomp = aot_autograd(
**mem_efficient_fusion_kwargs(use_decomps=False)
)
# Pass TorchScript+nvFuser context to TorchDynamo
aot_mem_efficient_fusion.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
aot_mem_efficient_fusion_no_decomp.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
def prims_executor(gm, inputs, *, executor):
from functorch.compile import make_boxed_func
# This function is called once per forward/backward pass of a graph in AOT
# Autograd. We use it to set up the nvFuser-specific FX graph and return
# execute function.
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
from torch.fx.experimental.proxy_tensor import make_fx
# AOT Autograd might not use the partitioner, so we need to make sure that
# the graph is transformed to use nvFuser-compatible nodes.
if not getattr(gm, "_nvprim_transformed", False):
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(gm)(*inputs)
# Then we return a callable that executes the "gm" graph
return make_boxed_func(partial(execute, gm, executor=executor))
def nvprims_fw_bw_partition_fn(joint_module, joint_inputs, *, num_fwd_outputs):
# This function is called once per forward+backward pass of a graph in AOT
# Autograd. We use it to set up the nvFuser-specific FX graph that is later
# passed to the executor.
from functorch.compile import min_cut_rematerialization_partition
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
# AOT Autograd expects arguments of the traced function to be named exactly
# "primals, tangents"
def func(primals, tangents):
return joint_module(primals, tangents)
# First we trace the graph conditionally decomposing nodes
# that can be sent to the nvfuser executor
with TorchRefsNvfuserCapabilityMode():
prim_gm = make_fx(func)(*joint_inputs)
# all nvprims for now
recomputable_ops = {
getattr(torch.ops.nvprims, prim)
for prim in dir(torch.ops.nvprims)
if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket)
and getattr(torch.ops.nvprims, prim).is_recomputable
}
fw_gm, bw_gm = min_cut_rematerialization_partition(
prim_gm,
joint_inputs,
recomputable_ops=recomputable_ops,
num_fwd_outputs=num_fwd_outputs,
)
# AOT Autograd might not use the partitioner, so we need to make sure that
# the graph is marked as already transformed to use nvFuser-compatible nodes
fw_gm._nvprim_transformed = True
bw_gm._nvprim_transformed = True
return fw_gm, bw_gm
def create_nvprims_backend(*, executor):
return aot_autograd(
fw_compiler=partial(prims_executor, executor=executor),
bw_compiler=partial(prims_executor, executor=executor),
partition_fn=nvprims_fw_bw_partition_fn,
)
aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser")
aot_nvprims_aten = create_nvprims_backend(executor="aten")
def cloner(t):
if isinstance(t, torch.Tensor):
return t.clone()
else:
return t
class CudaGraphModule(Module):
gm: GraphModule
mutated_inputs: Set[int]
def __init__(self, gm, mutated_inputs):
super().__init__()
self.gm = gm
self.mutated_inputs = mutated_inputs
warmed_up = False
# these are all None or all filled
graph = None
static_inputs = None
static_outputs = None
# NB: we override __call__ as we don't need any nn.Module machinery
# and to reduce overhead
def __call__(self, *args):
# TODO: once we've recorded here, we'd like to replace the __call__
# implementation with compiled bytecode that copies into static, replays
# the cuda graph, then copies out. First condition is the hotpath,
# needs optimizing
if self.graph is not None:
assert len(args) == len(self.static_inputs)
for dst, src in zip(self.static_inputs, args):
dst.copy_(src)
self.graph.replay()
for i in self.mutated_inputs:
args[i].copy_(self.static_inputs[i])
return tree_map(cloner, self.static_outputs)
elif self.warmed_up:
# record
self.static_inputs = [x.clone() for x in args]
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_outputs = self.gm(*self.static_inputs)
# NB: recording doesn't actually run the operations, so
# now we immediately replay the graph to serve up the result
self.graph.replay()
for i in self.mutated_inputs:
args[i].copy_(self.static_inputs[i])
return tree_map(cloner, self.static_outputs)
else:
# warmup
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
r = self.gm(*args)
torch.cuda.current_stream().wait_stream(stream)
self.warmed_up = True
return r
# Interpreter versions of these passes can be found at
# https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23
def find_input_mutations(g):
def meta_fk(meta):
return meta["val"] if "val" in meta else meta["fake_result"]
inputs = defaultdict(set)
input_idx = 0
mutated_inputs = set()
for n in g.nodes:
if n.op == "placeholder":
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
input_idx += 1
elif n.op == "call_function":
if n.target is operator.getitem:
continue
schema = n.target._schema
for i, arg in enumerate(schema.arguments):
if i < len(n.args):
argument = n.args[i]
else:
if arg.name not in n.kwargs:
continue
argument = n.kwargs[arg.name]
mut_arg = False
if arg.alias_info:
if arg.alias_info.is_write:
mut_arg = True
if mut_arg:
# TODO: not correct for args that contain tensors in a struct
# like list
mutated_inputs |= inputs[
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
]
# TODO: error on unrecognized nodes
return mutated_inputs
# Mutates input graph
def apply_cuda_graphs(gm):
for n in gm.graph.nodes:
if n.op == "call_module":
assert not n.kwargs
submod = gm.get_submodule(n.target)
gm.delete_submodule(n.target)
mutated_inputs = find_input_mutations(submod.graph)
gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs))
# NB: we didn't actually change the graph, no need for recompile
def cudagraphs(model, inputs):
model = partition_cudagraphs(model, inputs)
apply_cuda_graphs(model)
return model
aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs)
aot_torchxla_trivial = aot_autograd(
fw_compiler=BACKENDS["torchxla_trivial"],
)
aot_torchxla_trace_once = aot_autograd(
fw_compiler=BACKENDS["torchxla_trace_once"],
)
def create_aot_backends():
"""
Register aliases for the AOT backends
"""
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
BACKENDS["aot_eager"] = aot_eager
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
# by using the relevant fuser with torch.jit.fuser(...)
BACKENDS["aot_ts"] = aot_ts
# "nvprims" is a subset of PrimTorch primitives that are guaranteed to be
# supported by nvFuser. This is the preferred backend for nvFuser+PrimTorch.
BACKENDS["nvprims_nvfuser"] = aot_nvprims_nvfuser
# This is useful for debugging. Can be removed later.
BACKENDS["nvprims_aten"] = aot_nvprims_aten
# aot_ts_nvfuser uses the memory efficient fusion algorithm from AOT Autograd.
# It uses min cut rematerialization algorithm, uses nvFuser as the
# compiler backend, and TorchScript as the frontend.
BACKENDS["aot_ts_nvfuser"] = aot_mem_efficient_fusion
# Similar to aot_ts_nvfuser, but disables the decompositions. Decompositions
# can cause accuracy deviations. This setting allows us to compare accuracy
# without worrying about the impact of decomposisitons. More details at
# https://github.com/pytorch/torchdynamo/issues/611
BACKENDS["aot_ts_nvfuser_nodecomps"] = aot_mem_efficient_fusion_no_decomp
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
# for debugging and can serve as a perf baseline.
BACKENDS["aot_cudagraphs"] = aot_cudagraphs
# aot_inductor_debug just replaces the inductor compiler with nop to help
# isolate inductor vs aot_eager errors
BACKENDS["aot_inductor_debug"] = aot_inductor_debug
BACKENDS["aot_torchxla_trivial"] = aot_torchxla_trivial
BACKENDS["aot_torchxla_trace_once"] = aot_torchxla_trace_once