Implement wrappers for aot_dedup and aot_synthetic_base  (#125764)

it's kind of gross that aot_synthetic base requires storing the *old* fw_metadata's InputInfo, but it is what it is. After this change, aot_dispatch_base's runtime wrappers should all be implemented. After this, I'll start working on aot_dispatch_autograd's remaining runtime wrapping changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125764
Approved by: https://github.com/bdhirsh
ghstack dependencies: #125610
diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py
index 1278864..934b783 100644
--- a/torch/_functorch/_aot_autograd/runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py
@@ -691,370 +691,418 @@
 # Dynamo's guards are not enough.  In practice, this seems to cover
 # everything.
 #
-def aot_wrapper_dedupe(
-    flat_fn,
-    flat_args: List[Tensor],
-    aot_config: AOTConfig,
-    *,
-    compiler_fn,
-    fw_metadata,
-):
-    # Use information about whether or not flat_fn mutates its arguments
-    # or not to handle dupe args
-
-    # Strategy 1: For any input that is not mutated, we can leafify it if we
-    # need to remove a duplicate.
-    leaf_flat_args = []
-    args_set = set()
-    ok = True
-
-    for i, a in enumerate(flat_args):
-        if not isinstance(a, torch.Tensor):
-            leaf_flat_args.append(a)
-        elif a not in args_set:
-            args_set.add(a)
-            leaf_flat_args.append(a)
-        elif (
-            not fw_metadata.input_info[i].mutates_data
-            and not fw_metadata.input_info[i].mutates_metadata
-        ):
-            leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad))
-        else:
-            ok = False
-            break
-
-    if ok:
-        return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
-
-    if requires_subclass_dispatch(leaf_flat_args, fw_metadata):
-        raise RuntimeError(
-            """\
-Encountered duplicate inputs that are mutated in the graph, but at least one input/output
-to the graph is a tensor subclass. This is not supported today. You can try to
-remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
-        )
-
-    # export path: ban duplicate inputs for now, add later if requested.
-    if aot_config.is_export:
-        raise RuntimeError(
-            f"""\
-Encountered duplicated inputs that are mutated in the graph you are trying to export.
-This functionality is currently not supported. If needed, please file a github issue.
-
-fw_metadata={str(fw_metadata)}
-        """
-        )
-
-    # Strategy 2: Duplicate specialize.
-    #
-    # In Haskell types, suppose you have:
-    #
-    #   add_dupe_args :: DedupedArgs -> Args
-    #   remove_dupe_args :: Args -> DedupedArgs
-    #
-    #   compiler_fn
-    #       :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R)
-    #   deped_compiler_fn
-    #       :: (Args -> R) -> Args -> AOTConfig -> (Args -> R)
-    #
-    # Then the code below can be written in point-free style as:
-    #
-    #   deduped_compiler_fn f a c =
-    #       compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args
-    #
-    # Suppose you have:
-    #
-    #   [a, b, a, c]
-    #
-    # We want:
-    #
-    #   remove_dupe_args([a, b, a, c]) == [a, b, c]
-    #   add_dupe_args([a, b, c]) == [a, b, a, c]
-    #
-    # This is done via (respectively):
-    #
-    #   seen_args = {a: 0, b: 1, c: 2}
-    #   enumerate(add_dupe_map) = [  # how to get args from the deduped list
-    #       (0, 0),
-    #       (1, 1),
-    #       (2, 0),
-    #       (3, 2),
-    #   ]
-    #   keep_arg_mask = [True, True, False, True]
-
-    seen_args: Dict[Tensor, int] = {}
-    keep_arg_mask = []
-    # Implicitly map duped arg position (list index) to de-duped arg position
-    add_dupe_map: List[int] = []
-    duped_arg_len = len(flat_args)
-
-    j = 0  # index into deduped_flat_args
-    for t in flat_args:
-        if isinstance(t, torch.Tensor):
-            if t in seen_args:
-                keep_arg_mask.append(False)
-                add_dupe_map.append(seen_args[t])
-                continue
-            seen_args[t] = j
-
-        keep_arg_mask.append(True)
-        add_dupe_map.append(j)
-        j += 1
-    assert (
-        len(add_dupe_map) == duped_arg_len
-    ), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
+@dataclass
+class AOTDedupeWrapper(CompilerWrapper):
+    keep_arg_mask: List[bool] = field(default_factory=list)
+    add_dupe_map: List[int] = field(default_factory=list)
+    old_input_metadata: List[InputAliasInfo] = field(default_factory=list)
+    needs_post_compile: bool = True
 
     # NB: Hot path, avoid set lookups here
     # TODO: Can avoid the zip here too, probably
-    def remove_dupe_args(args):
-        return [t for t, keep in zip(args, keep_arg_mask) if keep]
+    def remove_dupe_args(self, args):
+        return [t for t, keep in zip(args, self.keep_arg_mask) if keep]
 
-    def add_dupe_args(args):
-        return [args[add_dupe_map[i]] for i in range(duped_arg_len)]
+    def add_dupe_args(self, args):
+        return [args[i] for i in self.add_dupe_map]
 
-    deduped_flat_args = remove_dupe_args(flat_args)
-
-    # Update our input metadata to remove duped input metadata.
-    updated_fw_metadata = remove_dupe_metadata(fw_metadata, keep_arg_mask, add_dupe_map)
-
-    if (
-        tracing_context := TracingContext.try_get()
-        and aot_config.aot_autograd_arg_pos_to_source
+    def pre_compile(
+        self,
+        flat_fn,
+        flat_args: List[Tensor],
+        aot_config: AOTConfig,
+        *,
+        fw_metadata: ViewAndMutationMeta,
     ):
-        # TODO(voz): This structure is 1:1, we could consider an alternate structure like
-        # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there,
-        # which feels like needless complexity for a tiny bit of efficiency at this point.
-        for dupe_arg_pos, (kept_pos, keep_arg) in enumerate(
-            zip(add_dupe_map, keep_arg_mask)
-        ):
-            if not keep_arg:
-                dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[
-                    dupe_arg_pos
-                ]
-                kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[kept_pos]
-                tracing_context.guards_context.aotautograd_guards.append(  # type: ignore[attr-defined]
-                    DuplicateInputs(kept_arg_source, dupe_arg_source)
-                )
+        # Use information about whether or not flat_fn mutates its arguments
+        # or not to handle dupe args
 
-    @wraps(flat_fn)
-    def wrapped_flat_fn(*args):
-        return flat_fn(*add_dupe_args(args))
+        # Strategy 1: For any input that is not mutated, we can leafify it if we
+        # need to remove a duplicate.
+        leaf_flat_args = []
+        args_set = set()
+        ok = True
 
-    if config.debug_assert:
-        ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
-            wrapped_flat_fn,
-            keep_input_mutations=fw_metadata.keep_input_mutations,
-            is_train=fw_metadata.is_train,
-        )(*deduped_flat_args)
-        assert (
-            ref_fw_metadata == updated_fw_metadata
-        ), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}"
+        for i, a in enumerate(flat_args):
+            if not isinstance(a, torch.Tensor):
+                leaf_flat_args.append(a)
+            elif a not in args_set:
+                args_set.add(a)
+                leaf_flat_args.append(a)
+            elif (
+                not fw_metadata.input_info[i].mutates_data
+                and not fw_metadata.input_info[i].mutates_metadata
+            ):
+                leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad))
+            else:
+                ok = False
+                break
 
-    compiled_fn = compiler_fn(
-        wrapped_flat_fn, deduped_flat_args, aot_config, fw_metadata=updated_fw_metadata
-    )
+        if ok:
+            self.needs_post_compile = False
+            return flat_fn, leaf_flat_args, aot_config, fw_metadata
 
-    @wraps(compiled_fn)
-    def wrapped_compiled_fn(args: List[Any]):
-        deduped_args = remove_dupe_args(args)
-        args.clear()
-        return compiled_fn(deduped_args)
-
-    wrapped_compiled_fn._boxed_call = True  # type: ignore[attr-defined]
-
-    # This can be uncommented when we properly guard for duplicates,
-    # but right now we must not do it.
-    # if not config.debug_assert:
-    #     return wrapped_compiled_fn
-
-    @wraps(wrapped_compiled_fn)
-    def debugged_compiled_fn(args):
-        # Test that the computed remove/add arg functions are an inverse
-        new_args = add_dupe_args(remove_dupe_args(args))
-        seen: Dict[Any, None] = {}
-        for i, (x, y) in enumerate(zip(new_args, args)):
-            seen[y] = None
-            assert x is y, format_guard_bug_msg(
-                aot_config,
-                f"{describe_input(i, aot_config)} would be a duplicate of "
-                f"{describe_input(add_dupe_map[i], aot_config)}",
+        if requires_subclass_dispatch(leaf_flat_args, fw_metadata):
+            raise RuntimeError(
+                """\
+        Encountered duplicate inputs that are mutated in the graph, but at least one input/output
+        to the graph is a tensor subclass. This is not supported today. You can try to
+        remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
             )
-        # This is only an error if there is metadata mutation on both of
-        # the duped arguments; in this case, we need to know what order
-        # the metadata mutation applies in.  You'll get the correct result
-        # otherwise, because a graph that assumes distinct inputs works if
-        # you dupe the inputs (the gradient contributions from each input
-        # will get summed up appropriately.)
+
+        # export path: ban duplicate inputs for now, add later if requested.
+        if aot_config.is_export:
+            raise RuntimeError(
+                f"""\
+        Encountered duplicated inputs that are mutated in the graph you are trying to export.
+        This functionality is currently not supported. If needed, please file a github issue.
+
+        fw_metadata={str(fw_metadata)}
+            """
+            )
+
+        # Strategy 2: Duplicate specialize.
         #
-        # TODO: work out how to setup this assert correctly
-        """
-        assert len(seen) == unique_args, format_guard_bug_msg(aot_config,
-            f"there would be {unique_args} distinct arguments"
+        # In Haskell types, suppose you have:
+        #
+        #   add_dupe_args :: DedupedArgs -> Args
+        #   remove_dupe_args :: Args -> DedupedArgs
+        #
+        #   compiler_fn
+        #       :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R)
+        #   deped_compiler_fn
+        #       :: (Args -> R) -> Args -> AOTConfig -> (Args -> R)
+        #
+        # Then the code below can be written in point-free style as:
+        #
+        #   deduped_compiler_fn f a c =
+        #       compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args
+        #
+        # Suppose you have:
+        #
+        #   [a, b, a, c]
+        #
+        # We want:
+        #
+        #   remove_dupe_args([a, b, a, c]) == [a, b, c]
+        #   add_dupe_args([a, b, c]) == [a, b, a, c]
+        #
+        # This is done via (respectively):
+        #
+        #   seen_args = {a: 0, b: 1, c: 2}
+        #   enumerate(add_dupe_map) = [  # how to get args from the deduped list
+        #       (0, 0),
+        #       (1, 1),
+        #       (2, 0),
+        #       (3, 2),
+        #   ]
+        #   keep_arg_mask = [True, True, False, True]
+
+        seen_args: Dict[Tensor, int] = {}
+        # Implicitly map duped arg position (list index) to de-duped arg position
+        keep_arg_mask: List[bool] = []
+        add_dupe_map: List[int] = []
+        duped_arg_len = len(flat_args)
+
+        j = 0  # index into deduped_flat_args
+        for t in flat_args:
+            if isinstance(t, torch.Tensor):
+                if t in seen_args:
+                    keep_arg_mask.append(False)
+                    add_dupe_map.append(seen_args[t])
+                    continue
+                seen_args[t] = j
+
+            keep_arg_mask.append(True)
+            add_dupe_map.append(j)
+            j += 1
+        assert (
+            len(add_dupe_map) == duped_arg_len
+        ), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
+
+        self.keep_arg_mask = keep_arg_mask
+        self.add_dupe_map = add_dupe_map
+
+        deduped_flat_args = self.remove_dupe_args(flat_args)
+
+        # Update our input metadata to remove duped input metadata.
+        updated_fw_metadata = remove_dupe_metadata(
+            fw_metadata, keep_arg_mask, add_dupe_map
         )
-        """
-        return wrapped_compiled_fn(args)
 
-    debugged_compiled_fn._boxed_call = True  # type: ignore[attr-defined]
+        if (
+            tracing_context := TracingContext.try_get()
+            and aot_config.aot_autograd_arg_pos_to_source
+        ):
+            # TODO(voz): This structure is 1:1, we could consider an alternate structure like
+            # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there,
+            # which feels like needless complexity for a tiny bit of efficiency at this point.
+            for dupe_arg_pos, (kept_pos, keep_arg) in enumerate(
+                zip(add_dupe_map, keep_arg_mask)
+            ):
+                if not keep_arg:
+                    dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[
+                        dupe_arg_pos
+                    ]
+                    kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[
+                        kept_pos
+                    ]
+                    tracing_context.guards_context.aotautograd_guards.append(  # type: ignore[attr-defined]
+                        DuplicateInputs(kept_arg_source, dupe_arg_source)
+                    )
 
-    return debugged_compiled_fn
+        @wraps(flat_fn)
+        def wrapped_flat_fn(*args):
+            return flat_fn(*self.add_dupe_args(args))
+
+        if config.debug_assert:
+            ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
+                wrapped_flat_fn,
+                keep_input_mutations=fw_metadata.keep_input_mutations,
+                is_train=fw_metadata.is_train,
+            )(*deduped_flat_args)
+            assert (
+                ref_fw_metadata == updated_fw_metadata
+            ), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}"
+
+        return wrapped_flat_fn, deduped_flat_args, aot_config, updated_fw_metadata
+
+    def post_compile(
+        self,
+        compiled_fn,
+        aot_config: AOTConfig,
+        *,
+        fw_metadata: ViewAndMutationMeta,
+    ):
+        if not self.needs_post_compile:
+            return compiled_fn
+
+        @wraps(compiled_fn)
+        def wrapped_compiled_fn(args: List[Any]):
+            deduped_args = self.remove_dupe_args(args)
+            args.clear()
+            return compiled_fn(deduped_args)
+
+        wrapped_compiled_fn._boxed_call = True  # type: ignore[attr-defined]
+
+        # This can be uncommented when we properly guard for duplicates,
+        # but right now we must not do it.
+        # if not config.debug_assert:
+        #     return wrapped_compiled_fn
+
+        @wraps(wrapped_compiled_fn)
+        def debugged_compiled_fn(args):
+            # Test that the computed remove/add arg functions are an inverse
+            new_args = self.add_dupe_args(self.remove_dupe_args(args))
+            seen: Dict[Any, None] = {}
+            for i, (x, y) in enumerate(zip(new_args, args)):
+                seen[y] = None
+                assert x is y, format_guard_bug_msg(
+                    aot_config,
+                    f"{describe_input(i, aot_config)} would be a duplicate of "
+                    f"{describe_input(self.add_dupe_map[i], aot_config)}",
+                )
+            # This is only an error if there is metadata mutation on both of
+            # the duped arguments; in this case, we need to know what order
+            # the metadata mutation applies in.  You'll get the correct result
+            # otherwise, because a graph that assumes distinct inputs works if
+            # you dupe the inputs (the gradient contributions from each input
+            # will get summed up appropriately.)
+            #
+            # TODO: work out how to setup this assert correctly
+            """
+            assert len(seen) == unique_args, format_guard_bug_msg(aot_config,
+                f"there would be {unique_args} distinct arguments"
+            )
+            """
+            return wrapped_compiled_fn(args)
+
+        debugged_compiled_fn._boxed_call = True  # type: ignore[attr-defined]
+
+        return debugged_compiled_fn
 
 
 # This layer handles the situation where you have two inputs that alias each other,
 # and one of the inputs is mutated.
 # We need to take special care to ensure that the mutation is applied to the other aliases in the graph.
 #
-# pre-condition: aot_wrapper_dedup has already run.
+# pre-condition: AOTDedupWrapper has already run.
 # (This function will in theory work if there are duplicate args.
 # However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs
 # would cause us to hit that path more frequently).
-def aot_wrapper_synthetic_base(
-    flat_fn,
-    flat_args: List[Tensor],
-    aot_config: AOTConfig,
-    *,
-    fw_metadata: ViewAndMutationMeta,
-    # Currently, the only reason we need to plumb this bool is because
-    # the synthetic base code prohibits more cases in the autograd case than the inference case.
-    needs_autograd: bool,
-    compiler_fn,
-):
-    is_inference = not needs_autograd
-    flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
-        flat_args,
-        fw_metadata.input_info,
-        is_inference=is_inference,
-    )
-    # Happy path: we don't need synthetic bases
-    if synthetic_base_info is None:
-        return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
+@dataclass
+class AOTSyntheticBaseWrapper(CompilerWrapper):
+    trace_joint: bool  # TODO: refactor trace_joint
+    needs_post_compile: bool = True
+    aliased_arg_idx_with_metadata_mutations: List[int] = field(default_factory=list)
 
-    # export path: ban synthetic bases for now, add later if requested.
-    if requires_subclass_dispatch(flat_args, fw_metadata):
-        raise RuntimeError(
-            """\
-Encountered aliased inputs that are mutated in the graph, but at least one input/output
-to the graph is a tensor subclass. This is not supported today. You can try to
-remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
+    def pre_compile(
+        self,
+        flat_fn,
+        flat_args: List[Any],
+        aot_config: AOTConfig,
+        *,
+        fw_metadata: ViewAndMutationMeta,
+    ):
+        is_inference = not self.trace_joint
+        flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
+            flat_args,
+            fw_metadata.input_info,
+            is_inference=is_inference,
         )
 
-    if aot_config.is_export:
-        raise RuntimeError(
-            f"""\
-Encountered aliased inputs that are mutated in the graph you are trying to export.
-This functionality is currently not supported. If needed, please file a github issue.
+        # Happy path: we don't need synthetic bases
+        if synthetic_base_info is None:
+            self.needs_post_compile = False
+            return flat_fn, flat_args, aot_config, fw_metadata
 
-synthetic_base_info={str(synthetic_base_info)}
+        # export path: ban synthetic bases for now, add later if requested.
+        if requires_subclass_dispatch(flat_args, fw_metadata):
+            raise RuntimeError(
+                """\
+        Encountered aliased inputs that are mutated in the graph, but at least one input/output
+        to the graph is a tensor subclass. This is not supported today. You can try to
+        remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
+            )
 
-fw_metadata={str(fw_metadata)}
-        """
+        if aot_config.is_export:
+            raise RuntimeError(
+                f"""\
+        Encountered aliased inputs that are mutated in the graph you are trying to export.
+        This functionality is currently not supported. If needed, please file a github issue.
+
+        synthetic_base_info={str(synthetic_base_info)}
+
+        fw_metadata={str(fw_metadata)}
+                """
+            )
+
+        assert len(fw_metadata.input_info) == len(synthetic_base_info)
+
+        # Update our forward metadata to take synthetic bases into account
+        (
+            fw_metadata_updated,
+            aliased_arg_idx_with_metadata_mutations,
+        ) = create_synthetic_base_metadata(
+            fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases
+        )
+        # Save old input args for post-compile
+        self.old_input_info = fw_metadata.input_info
+
+        self.aliased_arg_idx_with_metadata_mutations = (
+            aliased_arg_idx_with_metadata_mutations
         )
 
-    assert len(fw_metadata.input_info) == len(synthetic_base_info)
+        num_aliased_args_with_metadata_mutations = len(
+            aliased_arg_idx_with_metadata_mutations
+        )
 
-    # Update our forward metadata to take synthetic bases into account
-    (
-        fw_metadata_updated,
-        aliased_arg_idx_with_metadata_mutations,
-    ) = create_synthetic_base_metadata(
-        fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases
-    )
+        def _unpack_synthetic_bases(primals: Tuple[Any, ...]) -> List[Any]:
+            f_args_inner = []
+            for inner_idx_or_tuple in synthetic_base_info:
+                if isinstance(inner_idx_or_tuple, int):
+                    f_args_inner.append(primals[inner_idx_or_tuple])
+                else:
+                    inner_base_idx, view_tensor = inner_idx_or_tuple
+                    base = primals[inner_base_idx]
+                    view_arg = gen_alias_from_base(
+                        base, view_tensor, view_tensor.requires_grad
+                    )
+                    f_args_inner.append(view_arg)
+            return f_args_inner
 
-    num_aliased_args_with_metadata_mutations = len(
-        aliased_arg_idx_with_metadata_mutations
-    )
-
-    def _unpack_synthetic_bases(primals: Tuple[Any, ...]) -> List[Any]:
-        f_args_inner = []
-        for inner_idx_or_tuple in synthetic_base_info:
-            if isinstance(inner_idx_or_tuple, int):
-                f_args_inner.append(primals[inner_idx_or_tuple])
+        @wraps(flat_fn)
+        def wrapped_flat_fn(*args):
+            unpacked_args = _unpack_synthetic_bases(args)
+            # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases)
+            # is to relieve the downstream logic from having to reason about mutations on inputs that alias
+            # each other, by replacing aliased inputs with a synthetic base.
+            # One area where this breaks down a bit however is if one of those aliased inputs
+            # experienced a metadata mutation.
+            # We are now obligated to reapply the metadata mutation directly to the user's input;
+            # it isn't enough to apply mutations back to the synthetic base in the downstream logic.
+            #
+            # The way we handle this is by pretending that those aliased inputs that experience metadata mutations
+            # are additional outputs in the user's forward function.
+            # The downstream logic will just treat these as "user outputs that alias inputs".
+            # However, we will manually grab them at runtime here, use them to reapply the metadata mutation
+            # to the user inputs, and not return them to the user.
+            aliased_args_with_metadata_mutations = [
+                x
+                for i, x in enumerate(unpacked_args)
+                if i in self.aliased_arg_idx_with_metadata_mutations
+            ]
+            if len(aliased_args_with_metadata_mutations) > 0:
+                return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations
             else:
-                inner_base_idx, view_tensor = inner_idx_or_tuple
-                base = primals[inner_base_idx]
-                view_arg = gen_alias_from_base(
-                    base, view_tensor, view_tensor.requires_grad
-                )
-                f_args_inner.append(view_arg)
-        return f_args_inner
+                return flat_fn(*unpacked_args)
 
-    @wraps(flat_fn)
-    def wrapped_flat_fn(*args):
-        unpacked_args = _unpack_synthetic_bases(args)
-        # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases)
-        # is to relieve the downstream logic from having to reason about mutations on inputs that alias
-        # each other, by replacing aliased inputs with a synthetic base.
-        # One area where this breaks down a bit however is if one of those aliased inputs
-        # experienced a metadata mutation.
-        # We are now obligated to reapply the metadata mutation directly to the user's input;
-        # it isn't enough to apply mutations back to the synthetic base in the downstream logic.
-        #
-        # The way we handle this is by pretending that those aliased inputs that experience metadata mutations
-        # are additional outputs in the user's forward function.
-        # The downstream logic will just treat these as "user outputs that alias inputs".
-        # However, we will manually grab them at runtime here, use them to reapply the metadata mutation
-        # to the user inputs, and not return them to the user.
-        aliased_args_with_metadata_mutations = [
-            x
-            for i, x in enumerate(unpacked_args)
-            if i in aliased_arg_idx_with_metadata_mutations
-        ]
-        if len(aliased_args_with_metadata_mutations) > 0:
-            return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations
-        else:
-            return flat_fn(*unpacked_args)
-
-    if config.debug_assert:
-        ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
+        if config.debug_assert:
+            ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
+                wrapped_flat_fn,
+                keep_input_mutations=fw_metadata.keep_input_mutations,
+                is_train=fw_metadata.is_train,
+            )(*flat_args_with_synthetic_bases)
+            assert ref_fw_metadata == fw_metadata_updated, (
+                f"ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, "
+                f"\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}"
+            )
+        return (
             wrapped_flat_fn,
-            keep_input_mutations=fw_metadata.keep_input_mutations,
-            is_train=fw_metadata.is_train,
-        )(*flat_args_with_synthetic_bases)
-        assert ref_fw_metadata == fw_metadata_updated, (
-            f"ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, "
-            f"\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}"
+            flat_args_with_synthetic_bases,
+            aot_config,
+            fw_metadata_updated,
         )
 
-    compiled_fn = compiler_fn(
-        wrapped_flat_fn,
-        flat_args_with_synthetic_bases,
-        aot_config,
-        fw_metadata=fw_metadata_updated,
-    )
+    def post_compile(
+        self,
+        compiled_fn,
+        aot_config: AOTConfig,
+        *,
+        fw_metadata: ViewAndMutationMeta,
+    ):
+        if not self.needs_post_compile:
+            return compiled_fn
 
-    @wraps(compiled_fn)
-    def wrapped_compiled_fn(args):
-        args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
-            args, fw_metadata.input_info, is_inference=is_inference
-        )
-        assert synthetic_base_info is not None
-        aliased_args_w_metadata_mutations = [
-            args[i] for i in aliased_arg_idx_with_metadata_mutations
-        ]
-        args.clear()
-        outs = compiled_fn(args_with_synthetic_bases)
-        if num_aliased_args_with_metadata_mutations > 0:
-            # This code does not handle **all** input metadata mutations.
-            # Instead, it only handles metadata mutations on inputs that were converted into synthetic bases
-            # (which only happens if at least one aliased input experienced a data mutation).
-            # e.g:
-            # def f(a, b):
-            #     a.mul_(2)
-            #     b.t_(1, 0)
-            # f(x.view(2, 2), x.view(2, 2))
-            mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:]
-            user_outs = outs[:-num_aliased_args_with_metadata_mutations]
-            for inp, mutated_inp in zip(
-                aliased_args_w_metadata_mutations, mutated_metadata_inps
-            ):
-                inp.as_strided_(
-                    mutated_inp.size(),
-                    mutated_inp.stride(),
-                    mutated_inp.storage_offset(),
-                )
-            return user_outs
-        return outs
+        is_inference = not self.trace_joint
 
-    return wrapped_compiled_fn
+        @wraps(compiled_fn)
+        def wrapped_compiled_fn(args):
+            args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
+                args, self.old_input_info, is_inference=is_inference
+            )
+            assert synthetic_base_info is not None
+            aliased_args_w_metadata_mutations = [
+                args[i] for i in self.aliased_arg_idx_with_metadata_mutations
+            ]
+            num_aliased_args_with_metadata_mutations = len(
+                aliased_args_w_metadata_mutations
+            )
+            args.clear()
+            outs = compiled_fn(args_with_synthetic_bases)
+            if num_aliased_args_with_metadata_mutations > 0:
+                # This code does not handle **all** input metadata mutations.
+                # Instead, it only handles metadata mutations on inputs that were converted into synthetic bases
+                # (which only happens if at least one aliased input experienced a data mutation).
+                # e.g:
+                # def f(a, b):
+                #     a.mul_(2)
+                #     b.t_(1, 0)
+                # f(x.view(2, 2), x.view(2, 2))
+                mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:]
+                user_outs = outs[:-num_aliased_args_with_metadata_mutations]
+                for inp, mutated_inp in zip(
+                    aliased_args_w_metadata_mutations, mutated_metadata_inps
+                ):
+                    inp.as_strided_(
+                        mutated_inp.size(),
+                        mutated_inp.stride(),
+                        mutated_inp.storage_offset(),
+                    )
+                return user_outs
+            return outs
+
+        return wrapped_compiled_fn
 
 
 # Note [Handling mutations on an input that aliases other inputs]
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index 74d9bc9..f1ba677 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -2,7 +2,7 @@
 
 import itertools
 from contextlib import contextmanager, nullcontext
-from functools import partial, wraps
+from functools import wraps
 from typing import Any, Callable, Dict, List, Optional, Tuple
 from unittest.mock import patch
 
@@ -67,8 +67,8 @@
     track_graph_compiling,
 )
 from ._aot_autograd.runtime_wrappers import (  # noqa: F401
-    aot_wrapper_dedupe,
-    aot_wrapper_synthetic_base,
+    AOTDedupeWrapper,
+    AOTSyntheticBaseWrapper,
 )
 from ._aot_autograd.schemas import (  # noqa: F401
     AOTConfig,
@@ -670,17 +670,25 @@
                 aot_dispatch_base_graph if aot_config.is_export else aot_dispatch_base
             )
 
-        compiler_fn = partial(
-            aot_wrapper_synthetic_base,
-            compiler_fn=compiler_fn,
-            needs_autograd=needs_autograd,
-        )
-        compiler_fn = partial(aot_wrapper_dedupe, compiler_fn=compiler_fn)
-        # You can put more passes here
+        wrappers = [
+            AOTDedupeWrapper(),
+            AOTSyntheticBaseWrapper(trace_joint=needs_autograd),
+            # Add more passes here
+        ]
+        for wrapper in wrappers:
+            flat_fn, fake_flat_args, aot_config, fw_metadata = wrapper.pre_compile(
+                flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata
+            )
 
         compiled_fn = compiler_fn(
             flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata
         )
+
+        for wrapper in reversed(wrappers):
+            compiled_fn = wrapper.post_compile(
+                compiled_fn, aot_config, fw_metadata=fw_metadata
+            )
+
         if aot_config.is_export:
             # During export, we don't get back a callable - we get back the raw fx graph
             # (either a joint or an inference-only graph)