| import functools |
| import logging |
| import operator |
| from collections import defaultdict |
| from functools import partial |
| from importlib import import_module |
| from typing import Set |
| |
| from functorch.compile import ( |
| aot_module_simplified, |
| min_cut_rematerialization_partition, |
| nop, |
| ts_compile, |
| ) |
| |
| import torch |
| |
| from torch._functorch.compilers import debug_nop |
| from torch.fx import GraphModule |
| from torch.fx.passes.backends.cudagraphs import partition_cudagraphs |
| from torch.multiprocessing.reductions import StorageWeakRef |
| from torch.nn import Module |
| from torch.utils._pytree import tree_map |
| |
| from .. import config, eval_frame |
| from ..utils import clone_inputs, counters |
| from .backends import BACKENDS |
| from .normalize import normalize_ir |
| |
| log = logging.getLogger(__name__) |
| |
| |
| def aot_autograd(**kwargs): |
| def compiler_fn(gm: torch.fx.GraphModule, example_inputs): |
| import functorch.compile |
| |
| # Hack to get around circular import problems with aot_inductor_debug |
| if callable(kwargs.get("decompositions")): |
| kwargs["decompositions"] = kwargs["decompositions"]() |
| |
| # TODO: stop monkeypatching here (without even cleaning up, UGH!) |
| functorch.compile.config.use_functionalize = True |
| functorch.compile.config.use_fake_tensor = True |
| |
| counters["aot_autograd"]["total"] += 1 |
| use_fallback = False |
| |
| if not functorch.compile.config.use_functionalize and config.normalize_ir: |
| try: |
| gm = normalize_ir(gm, clone_inputs(example_inputs)) |
| except Exception: |
| log.debug("TorchDynamo unable to remove mutation") |
| use_fallback = True |
| |
| # NB: no clone here on example inputs |
| if not is_aot_autograd_safe_to_run(gm, example_inputs): |
| use_fallback = True |
| |
| if use_fallback: |
| log.debug("Unable to use AOT Autograd because graph has mutation") |
| counters["aot_autograd"]["not_ok"] += 1 |
| return gm |
| |
| # OK attempt to compile |
| |
| def _wrapped_bw_compiler(*args, **kwargs): |
| # stop TorchDynamo from trying to compile our generated backwards pass |
| return eval_frame.disable(eval_frame.disable(bw_compiler)(*args, **kwargs)) |
| |
| bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] |
| kwargs["bw_compiler"] = _wrapped_bw_compiler |
| |
| from torch._inductor.debug import enable_aot_logging |
| |
| try: |
| # NB: NOT cloned! |
| with enable_aot_logging(): |
| cg = aot_module_simplified(gm, example_inputs, **kwargs) |
| counters["aot_autograd"]["ok"] += 1 |
| return eval_frame.disable(cg) |
| except Exception: |
| counters["aot_autograd"]["not_ok"] += 1 |
| raise |
| |
| return compiler_fn |
| |
| |
| def is_aot_autograd_safe_to_run(gm, example_inputs): |
| """ |
| There are some known issues with Aot Autograd. This is a workaround to catch |
| such cases, and fallback to eager. We should fix these quickly. |
| |
| Issues |
| 1) LSTM - https://github.com/pytorch/torchdynamo/issues/1147 |
| 2) LSTM - https://github.com/pytorch/functorch/issues/586 |
| 3) Input mutation - https://github.com/pytorch/torchdynamo/issues/1301 |
| """ |
| |
| def raise_or_warn(reason): |
| msg = f"Unable to use Aot Autograd because of presence of {reason}" |
| if config.raise_on_unsafe_aot_autograd: |
| raise NotImplementedError(msg) |
| else: |
| log.warning(msg) |
| return False |
| |
| # 1) LSTM module (tts_angular) - https://github.com/pytorch/functorch/issues/586 |
| for submod in gm.modules(): |
| if submod.__class__.__name__ == "LSTM": |
| return raise_or_warn("LSTM") |
| |
| # 2) Mutation in the graphs are now always handled by AOT Autograd. |
| return True |
| |
| |
| DEBUG = False |
| |
| # Useful for debugging purpose |
| aot_eager = aot_autograd(fw_compiler=debug_nop if DEBUG else nop) |
| |
| # AOT Autograd with torchscript backend. Default partitioner. |
| aot_ts = aot_autograd(fw_compiler=ts_compile) |
| |
| # Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs |
| # inductor problems. |
| aot_inductor_debug = aot_autograd( |
| # these are taken from memory_efficient_fusion() |
| fw_compiler=nop, |
| bw_compiler=nop, |
| # NB: lambda here is to delay import of inductor |
| decompositions=lambda: import_module( |
| f"{config.inductor_import}.compile_fx" |
| ).select_decomp_table(), |
| partition_fn=functools.partial( |
| min_cut_rematerialization_partition, compiler="inductor" |
| ), |
| ) |
| |
| |
| def mem_efficient_fusion_kwargs(use_decomps): |
| from functorch.compile import ( |
| default_decompositions, |
| min_cut_rematerialization_partition, |
| ts_compile, |
| ) |
| |
| kwargs = { |
| # these are taken from memory_efficient_fusion() |
| "fw_compiler": ts_compile, |
| "bw_compiler": ts_compile, |
| "partition_fn": min_cut_rematerialization_partition, |
| } |
| |
| if use_decomps: |
| kwargs["decompositions"] = default_decompositions |
| |
| return kwargs |
| |
| |
| # Use min cut rematerialization and TorchScript+nvFuser with AOT Autograd |
| aot_mem_efficient_fusion = aot_autograd(**mem_efficient_fusion_kwargs(use_decomps=True)) |
| aot_mem_efficient_fusion_no_decomp = aot_autograd( |
| **mem_efficient_fusion_kwargs(use_decomps=False) |
| ) |
| |
| # Pass TorchScript+nvFuser context to TorchDynamo |
| aot_mem_efficient_fusion.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2") |
| aot_mem_efficient_fusion_no_decomp.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2") |
| |
| |
| def prims_executor(gm, inputs, *, executor): |
| from functorch.compile import make_boxed_func |
| |
| # This function is called once per forward/backward pass of a graph in AOT |
| # Autograd. We use it to set up the nvFuser-specific FX graph and return |
| # execute function. |
| from torch._prims.context import TorchRefsNvfuserCapabilityMode |
| from torch._prims.executor import execute |
| from torch.fx.experimental.proxy_tensor import make_fx |
| |
| # AOT Autograd might not use the partitioner, so we need to make sure that |
| # the graph is transformed to use nvFuser-compatible nodes. |
| if not getattr(gm, "_nvprim_transformed", False): |
| with TorchRefsNvfuserCapabilityMode(): |
| gm = make_fx(gm)(*inputs) |
| |
| # Then we return a callable that executes the "gm" graph |
| return make_boxed_func(partial(execute, gm, executor=executor)) |
| |
| |
| def nvprims_fw_bw_partition_fn(joint_module, joint_inputs, *, num_fwd_outputs): |
| # This function is called once per forward+backward pass of a graph in AOT |
| # Autograd. We use it to set up the nvFuser-specific FX graph that is later |
| # passed to the executor. |
| from functorch.compile import min_cut_rematerialization_partition |
| |
| from torch._prims.context import TorchRefsNvfuserCapabilityMode |
| from torch.fx.experimental.proxy_tensor import make_fx |
| |
| # AOT Autograd expects arguments of the traced function to be named exactly |
| # "primals, tangents" |
| def func(primals, tangents): |
| return joint_module(primals, tangents) |
| |
| # First we trace the graph conditionally decomposing nodes |
| # that can be sent to the nvfuser executor |
| with TorchRefsNvfuserCapabilityMode(): |
| prim_gm = make_fx(func)(*joint_inputs) |
| |
| # all nvprims for now |
| recomputable_ops = { |
| getattr(torch.ops.nvprims, prim) |
| for prim in dir(torch.ops.nvprims) |
| if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket) |
| and getattr(torch.ops.nvprims, prim).is_recomputable |
| } |
| |
| fw_gm, bw_gm = min_cut_rematerialization_partition( |
| prim_gm, |
| joint_inputs, |
| recomputable_ops=recomputable_ops, |
| num_fwd_outputs=num_fwd_outputs, |
| ) |
| # AOT Autograd might not use the partitioner, so we need to make sure that |
| # the graph is marked as already transformed to use nvFuser-compatible nodes |
| fw_gm._nvprim_transformed = True |
| bw_gm._nvprim_transformed = True |
| return fw_gm, bw_gm |
| |
| |
| def create_nvprims_backend(*, executor): |
| return aot_autograd( |
| fw_compiler=partial(prims_executor, executor=executor), |
| bw_compiler=partial(prims_executor, executor=executor), |
| partition_fn=nvprims_fw_bw_partition_fn, |
| ) |
| |
| |
| aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser") |
| aot_nvprims_aten = create_nvprims_backend(executor="aten") |
| |
| |
| def cloner(t): |
| if isinstance(t, torch.Tensor): |
| return t.clone() |
| else: |
| return t |
| |
| |
| class CudaGraphModule(Module): |
| gm: GraphModule |
| mutated_inputs: Set[int] |
| |
| def __init__(self, gm, mutated_inputs): |
| super().__init__() |
| self.gm = gm |
| self.mutated_inputs = mutated_inputs |
| |
| warmed_up = False |
| |
| # these are all None or all filled |
| graph = None |
| static_inputs = None |
| static_outputs = None |
| |
| # NB: we override __call__ as we don't need any nn.Module machinery |
| # and to reduce overhead |
| def __call__(self, *args): |
| # TODO: once we've recorded here, we'd like to replace the __call__ |
| # implementation with compiled bytecode that copies into static, replays |
| # the cuda graph, then copies out. First condition is the hotpath, |
| # needs optimizing |
| if self.graph is not None: |
| assert len(args) == len(self.static_inputs) |
| for dst, src in zip(self.static_inputs, args): |
| dst.copy_(src) |
| self.graph.replay() |
| for i in self.mutated_inputs: |
| args[i].copy_(self.static_inputs[i]) |
| return tree_map(cloner, self.static_outputs) |
| |
| elif self.warmed_up: |
| # record |
| self.static_inputs = [x.clone() for x in args] |
| self.graph = torch.cuda.CUDAGraph() |
| with torch.cuda.graph(self.graph): |
| self.static_outputs = self.gm(*self.static_inputs) |
| # NB: recording doesn't actually run the operations, so |
| # now we immediately replay the graph to serve up the result |
| self.graph.replay() |
| for i in self.mutated_inputs: |
| args[i].copy_(self.static_inputs[i]) |
| return tree_map(cloner, self.static_outputs) |
| |
| else: |
| # warmup |
| stream = torch.cuda.Stream() |
| stream.wait_stream(torch.cuda.current_stream()) |
| with torch.cuda.stream(stream): |
| r = self.gm(*args) |
| torch.cuda.current_stream().wait_stream(stream) |
| self.warmed_up = True |
| return r |
| |
| |
| # Interpreter versions of these passes can be found at |
| # https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23 |
| |
| |
| def find_input_mutations(g): |
| def meta_fk(meta): |
| return meta["val"] if "val" in meta else meta["fake_result"] |
| |
| inputs = defaultdict(set) |
| input_idx = 0 |
| mutated_inputs = set() |
| for n in g.nodes: |
| if n.op == "placeholder": |
| inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) |
| input_idx += 1 |
| elif n.op == "call_function": |
| if n.target is operator.getitem: |
| continue |
| schema = n.target._schema |
| for i, arg in enumerate(schema.arguments): |
| if i < len(n.args): |
| argument = n.args[i] |
| else: |
| if arg.name not in n.kwargs: |
| continue |
| argument = n.kwargs[arg.name] |
| mut_arg = False |
| if arg.alias_info: |
| if arg.alias_info.is_write: |
| mut_arg = True |
| if mut_arg: |
| # TODO: not correct for args that contain tensors in a struct |
| # like list |
| mutated_inputs |= inputs[ |
| StorageWeakRef(meta_fk(argument.meta)._typed_storage()) |
| ] |
| # TODO: error on unrecognized nodes |
| return mutated_inputs |
| |
| |
| # Mutates input graph |
| def apply_cuda_graphs(gm): |
| for n in gm.graph.nodes: |
| if n.op == "call_module": |
| assert not n.kwargs |
| submod = gm.get_submodule(n.target) |
| gm.delete_submodule(n.target) |
| mutated_inputs = find_input_mutations(submod.graph) |
| gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs)) |
| # NB: we didn't actually change the graph, no need for recompile |
| |
| |
| def cudagraphs(model, inputs): |
| model = partition_cudagraphs(model, inputs) |
| apply_cuda_graphs(model) |
| return model |
| |
| |
| aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs) |
| |
| aot_torchxla_trivial = aot_autograd( |
| fw_compiler=BACKENDS["torchxla_trivial"], |
| ) |
| |
| aot_torchxla_trace_once = aot_autograd( |
| fw_compiler=BACKENDS["torchxla_trace_once"], |
| ) |
| |
| |
| def create_aot_backends(): |
| """ |
| Register aliases for the AOT backends |
| """ |
| # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. |
| BACKENDS["aot_eager"] = aot_eager |
| |
| # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser |
| # by using the relevant fuser with torch.jit.fuser(...) |
| BACKENDS["aot_ts"] = aot_ts |
| |
| # "nvprims" is a subset of PrimTorch primitives that are guaranteed to be |
| # supported by nvFuser. This is the preferred backend for nvFuser+PrimTorch. |
| BACKENDS["nvprims_nvfuser"] = aot_nvprims_nvfuser |
| # This is useful for debugging. Can be removed later. |
| BACKENDS["nvprims_aten"] = aot_nvprims_aten |
| |
| # aot_ts_nvfuser uses the memory efficient fusion algorithm from AOT Autograd. |
| # It uses min cut rematerialization algorithm, uses nvFuser as the |
| # compiler backend, and TorchScript as the frontend. |
| BACKENDS["aot_ts_nvfuser"] = aot_mem_efficient_fusion |
| |
| # Similar to aot_ts_nvfuser, but disables the decompositions. Decompositions |
| # can cause accuracy deviations. This setting allows us to compare accuracy |
| # without worrying about the impact of decomposisitons. More details at |
| # https://github.com/pytorch/torchdynamo/issues/611 |
| BACKENDS["aot_ts_nvfuser_nodecomps"] = aot_mem_efficient_fusion_no_decomp |
| |
| # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful |
| # for debugging and can serve as a perf baseline. |
| BACKENDS["aot_cudagraphs"] = aot_cudagraphs |
| |
| # aot_inductor_debug just replaces the inductor compiler with nop to help |
| # isolate inductor vs aot_eager errors |
| BACKENDS["aot_inductor_debug"] = aot_inductor_debug |
| |
| BACKENDS["aot_torchxla_trivial"] = aot_torchxla_trivial |
| BACKENDS["aot_torchxla_trace_once"] = aot_torchxla_trace_once |