Implement wrappers for aot_dedup and aot_synthetic_base (#125764)
it's kind of gross that aot_synthetic base requires storing the *old* fw_metadata's InputInfo, but it is what it is. After this change, aot_dispatch_base's runtime wrappers should all be implemented. After this, I'll start working on aot_dispatch_autograd's remaining runtime wrapping changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125764
Approved by: https://github.com/bdhirsh
ghstack dependencies: #125610
diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py
index 1278864..934b783 100644
--- a/torch/_functorch/_aot_autograd/runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py
@@ -691,370 +691,418 @@
# Dynamo's guards are not enough. In practice, this seems to cover
# everything.
#
-def aot_wrapper_dedupe(
- flat_fn,
- flat_args: List[Tensor],
- aot_config: AOTConfig,
- *,
- compiler_fn,
- fw_metadata,
-):
- # Use information about whether or not flat_fn mutates its arguments
- # or not to handle dupe args
-
- # Strategy 1: For any input that is not mutated, we can leafify it if we
- # need to remove a duplicate.
- leaf_flat_args = []
- args_set = set()
- ok = True
-
- for i, a in enumerate(flat_args):
- if not isinstance(a, torch.Tensor):
- leaf_flat_args.append(a)
- elif a not in args_set:
- args_set.add(a)
- leaf_flat_args.append(a)
- elif (
- not fw_metadata.input_info[i].mutates_data
- and not fw_metadata.input_info[i].mutates_metadata
- ):
- leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad))
- else:
- ok = False
- break
-
- if ok:
- return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
-
- if requires_subclass_dispatch(leaf_flat_args, fw_metadata):
- raise RuntimeError(
- """\
-Encountered duplicate inputs that are mutated in the graph, but at least one input/output
-to the graph is a tensor subclass. This is not supported today. You can try to
-remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
- )
-
- # export path: ban duplicate inputs for now, add later if requested.
- if aot_config.is_export:
- raise RuntimeError(
- f"""\
-Encountered duplicated inputs that are mutated in the graph you are trying to export.
-This functionality is currently not supported. If needed, please file a github issue.
-
-fw_metadata={str(fw_metadata)}
- """
- )
-
- # Strategy 2: Duplicate specialize.
- #
- # In Haskell types, suppose you have:
- #
- # add_dupe_args :: DedupedArgs -> Args
- # remove_dupe_args :: Args -> DedupedArgs
- #
- # compiler_fn
- # :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R)
- # deped_compiler_fn
- # :: (Args -> R) -> Args -> AOTConfig -> (Args -> R)
- #
- # Then the code below can be written in point-free style as:
- #
- # deduped_compiler_fn f a c =
- # compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args
- #
- # Suppose you have:
- #
- # [a, b, a, c]
- #
- # We want:
- #
- # remove_dupe_args([a, b, a, c]) == [a, b, c]
- # add_dupe_args([a, b, c]) == [a, b, a, c]
- #
- # This is done via (respectively):
- #
- # seen_args = {a: 0, b: 1, c: 2}
- # enumerate(add_dupe_map) = [ # how to get args from the deduped list
- # (0, 0),
- # (1, 1),
- # (2, 0),
- # (3, 2),
- # ]
- # keep_arg_mask = [True, True, False, True]
-
- seen_args: Dict[Tensor, int] = {}
- keep_arg_mask = []
- # Implicitly map duped arg position (list index) to de-duped arg position
- add_dupe_map: List[int] = []
- duped_arg_len = len(flat_args)
-
- j = 0 # index into deduped_flat_args
- for t in flat_args:
- if isinstance(t, torch.Tensor):
- if t in seen_args:
- keep_arg_mask.append(False)
- add_dupe_map.append(seen_args[t])
- continue
- seen_args[t] = j
-
- keep_arg_mask.append(True)
- add_dupe_map.append(j)
- j += 1
- assert (
- len(add_dupe_map) == duped_arg_len
- ), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
+@dataclass
+class AOTDedupeWrapper(CompilerWrapper):
+ keep_arg_mask: List[bool] = field(default_factory=list)
+ add_dupe_map: List[int] = field(default_factory=list)
+ old_input_metadata: List[InputAliasInfo] = field(default_factory=list)
+ needs_post_compile: bool = True
# NB: Hot path, avoid set lookups here
# TODO: Can avoid the zip here too, probably
- def remove_dupe_args(args):
- return [t for t, keep in zip(args, keep_arg_mask) if keep]
+ def remove_dupe_args(self, args):
+ return [t for t, keep in zip(args, self.keep_arg_mask) if keep]
- def add_dupe_args(args):
- return [args[add_dupe_map[i]] for i in range(duped_arg_len)]
+ def add_dupe_args(self, args):
+ return [args[i] for i in self.add_dupe_map]
- deduped_flat_args = remove_dupe_args(flat_args)
-
- # Update our input metadata to remove duped input metadata.
- updated_fw_metadata = remove_dupe_metadata(fw_metadata, keep_arg_mask, add_dupe_map)
-
- if (
- tracing_context := TracingContext.try_get()
- and aot_config.aot_autograd_arg_pos_to_source
+ def pre_compile(
+ self,
+ flat_fn,
+ flat_args: List[Tensor],
+ aot_config: AOTConfig,
+ *,
+ fw_metadata: ViewAndMutationMeta,
):
- # TODO(voz): This structure is 1:1, we could consider an alternate structure like
- # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there,
- # which feels like needless complexity for a tiny bit of efficiency at this point.
- for dupe_arg_pos, (kept_pos, keep_arg) in enumerate(
- zip(add_dupe_map, keep_arg_mask)
- ):
- if not keep_arg:
- dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[
- dupe_arg_pos
- ]
- kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[kept_pos]
- tracing_context.guards_context.aotautograd_guards.append( # type: ignore[attr-defined]
- DuplicateInputs(kept_arg_source, dupe_arg_source)
- )
+ # Use information about whether or not flat_fn mutates its arguments
+ # or not to handle dupe args
- @wraps(flat_fn)
- def wrapped_flat_fn(*args):
- return flat_fn(*add_dupe_args(args))
+ # Strategy 1: For any input that is not mutated, we can leafify it if we
+ # need to remove a duplicate.
+ leaf_flat_args = []
+ args_set = set()
+ ok = True
- if config.debug_assert:
- ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
- wrapped_flat_fn,
- keep_input_mutations=fw_metadata.keep_input_mutations,
- is_train=fw_metadata.is_train,
- )(*deduped_flat_args)
- assert (
- ref_fw_metadata == updated_fw_metadata
- ), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}"
+ for i, a in enumerate(flat_args):
+ if not isinstance(a, torch.Tensor):
+ leaf_flat_args.append(a)
+ elif a not in args_set:
+ args_set.add(a)
+ leaf_flat_args.append(a)
+ elif (
+ not fw_metadata.input_info[i].mutates_data
+ and not fw_metadata.input_info[i].mutates_metadata
+ ):
+ leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad))
+ else:
+ ok = False
+ break
- compiled_fn = compiler_fn(
- wrapped_flat_fn, deduped_flat_args, aot_config, fw_metadata=updated_fw_metadata
- )
+ if ok:
+ self.needs_post_compile = False
+ return flat_fn, leaf_flat_args, aot_config, fw_metadata
- @wraps(compiled_fn)
- def wrapped_compiled_fn(args: List[Any]):
- deduped_args = remove_dupe_args(args)
- args.clear()
- return compiled_fn(deduped_args)
-
- wrapped_compiled_fn._boxed_call = True # type: ignore[attr-defined]
-
- # This can be uncommented when we properly guard for duplicates,
- # but right now we must not do it.
- # if not config.debug_assert:
- # return wrapped_compiled_fn
-
- @wraps(wrapped_compiled_fn)
- def debugged_compiled_fn(args):
- # Test that the computed remove/add arg functions are an inverse
- new_args = add_dupe_args(remove_dupe_args(args))
- seen: Dict[Any, None] = {}
- for i, (x, y) in enumerate(zip(new_args, args)):
- seen[y] = None
- assert x is y, format_guard_bug_msg(
- aot_config,
- f"{describe_input(i, aot_config)} would be a duplicate of "
- f"{describe_input(add_dupe_map[i], aot_config)}",
+ if requires_subclass_dispatch(leaf_flat_args, fw_metadata):
+ raise RuntimeError(
+ """\
+ Encountered duplicate inputs that are mutated in the graph, but at least one input/output
+ to the graph is a tensor subclass. This is not supported today. You can try to
+ remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
)
- # This is only an error if there is metadata mutation on both of
- # the duped arguments; in this case, we need to know what order
- # the metadata mutation applies in. You'll get the correct result
- # otherwise, because a graph that assumes distinct inputs works if
- # you dupe the inputs (the gradient contributions from each input
- # will get summed up appropriately.)
+
+ # export path: ban duplicate inputs for now, add later if requested.
+ if aot_config.is_export:
+ raise RuntimeError(
+ f"""\
+ Encountered duplicated inputs that are mutated in the graph you are trying to export.
+ This functionality is currently not supported. If needed, please file a github issue.
+
+ fw_metadata={str(fw_metadata)}
+ """
+ )
+
+ # Strategy 2: Duplicate specialize.
#
- # TODO: work out how to setup this assert correctly
- """
- assert len(seen) == unique_args, format_guard_bug_msg(aot_config,
- f"there would be {unique_args} distinct arguments"
+ # In Haskell types, suppose you have:
+ #
+ # add_dupe_args :: DedupedArgs -> Args
+ # remove_dupe_args :: Args -> DedupedArgs
+ #
+ # compiler_fn
+ # :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R)
+ # deped_compiler_fn
+ # :: (Args -> R) -> Args -> AOTConfig -> (Args -> R)
+ #
+ # Then the code below can be written in point-free style as:
+ #
+ # deduped_compiler_fn f a c =
+ # compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args
+ #
+ # Suppose you have:
+ #
+ # [a, b, a, c]
+ #
+ # We want:
+ #
+ # remove_dupe_args([a, b, a, c]) == [a, b, c]
+ # add_dupe_args([a, b, c]) == [a, b, a, c]
+ #
+ # This is done via (respectively):
+ #
+ # seen_args = {a: 0, b: 1, c: 2}
+ # enumerate(add_dupe_map) = [ # how to get args from the deduped list
+ # (0, 0),
+ # (1, 1),
+ # (2, 0),
+ # (3, 2),
+ # ]
+ # keep_arg_mask = [True, True, False, True]
+
+ seen_args: Dict[Tensor, int] = {}
+ # Implicitly map duped arg position (list index) to de-duped arg position
+ keep_arg_mask: List[bool] = []
+ add_dupe_map: List[int] = []
+ duped_arg_len = len(flat_args)
+
+ j = 0 # index into deduped_flat_args
+ for t in flat_args:
+ if isinstance(t, torch.Tensor):
+ if t in seen_args:
+ keep_arg_mask.append(False)
+ add_dupe_map.append(seen_args[t])
+ continue
+ seen_args[t] = j
+
+ keep_arg_mask.append(True)
+ add_dupe_map.append(j)
+ j += 1
+ assert (
+ len(add_dupe_map) == duped_arg_len
+ ), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
+
+ self.keep_arg_mask = keep_arg_mask
+ self.add_dupe_map = add_dupe_map
+
+ deduped_flat_args = self.remove_dupe_args(flat_args)
+
+ # Update our input metadata to remove duped input metadata.
+ updated_fw_metadata = remove_dupe_metadata(
+ fw_metadata, keep_arg_mask, add_dupe_map
)
- """
- return wrapped_compiled_fn(args)
- debugged_compiled_fn._boxed_call = True # type: ignore[attr-defined]
+ if (
+ tracing_context := TracingContext.try_get()
+ and aot_config.aot_autograd_arg_pos_to_source
+ ):
+ # TODO(voz): This structure is 1:1, we could consider an alternate structure like
+ # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there,
+ # which feels like needless complexity for a tiny bit of efficiency at this point.
+ for dupe_arg_pos, (kept_pos, keep_arg) in enumerate(
+ zip(add_dupe_map, keep_arg_mask)
+ ):
+ if not keep_arg:
+ dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[
+ dupe_arg_pos
+ ]
+ kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[
+ kept_pos
+ ]
+ tracing_context.guards_context.aotautograd_guards.append( # type: ignore[attr-defined]
+ DuplicateInputs(kept_arg_source, dupe_arg_source)
+ )
- return debugged_compiled_fn
+ @wraps(flat_fn)
+ def wrapped_flat_fn(*args):
+ return flat_fn(*self.add_dupe_args(args))
+
+ if config.debug_assert:
+ ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
+ wrapped_flat_fn,
+ keep_input_mutations=fw_metadata.keep_input_mutations,
+ is_train=fw_metadata.is_train,
+ )(*deduped_flat_args)
+ assert (
+ ref_fw_metadata == updated_fw_metadata
+ ), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}"
+
+ return wrapped_flat_fn, deduped_flat_args, aot_config, updated_fw_metadata
+
+ def post_compile(
+ self,
+ compiled_fn,
+ aot_config: AOTConfig,
+ *,
+ fw_metadata: ViewAndMutationMeta,
+ ):
+ if not self.needs_post_compile:
+ return compiled_fn
+
+ @wraps(compiled_fn)
+ def wrapped_compiled_fn(args: List[Any]):
+ deduped_args = self.remove_dupe_args(args)
+ args.clear()
+ return compiled_fn(deduped_args)
+
+ wrapped_compiled_fn._boxed_call = True # type: ignore[attr-defined]
+
+ # This can be uncommented when we properly guard for duplicates,
+ # but right now we must not do it.
+ # if not config.debug_assert:
+ # return wrapped_compiled_fn
+
+ @wraps(wrapped_compiled_fn)
+ def debugged_compiled_fn(args):
+ # Test that the computed remove/add arg functions are an inverse
+ new_args = self.add_dupe_args(self.remove_dupe_args(args))
+ seen: Dict[Any, None] = {}
+ for i, (x, y) in enumerate(zip(new_args, args)):
+ seen[y] = None
+ assert x is y, format_guard_bug_msg(
+ aot_config,
+ f"{describe_input(i, aot_config)} would be a duplicate of "
+ f"{describe_input(self.add_dupe_map[i], aot_config)}",
+ )
+ # This is only an error if there is metadata mutation on both of
+ # the duped arguments; in this case, we need to know what order
+ # the metadata mutation applies in. You'll get the correct result
+ # otherwise, because a graph that assumes distinct inputs works if
+ # you dupe the inputs (the gradient contributions from each input
+ # will get summed up appropriately.)
+ #
+ # TODO: work out how to setup this assert correctly
+ """
+ assert len(seen) == unique_args, format_guard_bug_msg(aot_config,
+ f"there would be {unique_args} distinct arguments"
+ )
+ """
+ return wrapped_compiled_fn(args)
+
+ debugged_compiled_fn._boxed_call = True # type: ignore[attr-defined]
+
+ return debugged_compiled_fn
# This layer handles the situation where you have two inputs that alias each other,
# and one of the inputs is mutated.
# We need to take special care to ensure that the mutation is applied to the other aliases in the graph.
#
-# pre-condition: aot_wrapper_dedup has already run.
+# pre-condition: AOTDedupWrapper has already run.
# (This function will in theory work if there are duplicate args.
# However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs
# would cause us to hit that path more frequently).
-def aot_wrapper_synthetic_base(
- flat_fn,
- flat_args: List[Tensor],
- aot_config: AOTConfig,
- *,
- fw_metadata: ViewAndMutationMeta,
- # Currently, the only reason we need to plumb this bool is because
- # the synthetic base code prohibits more cases in the autograd case than the inference case.
- needs_autograd: bool,
- compiler_fn,
-):
- is_inference = not needs_autograd
- flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
- flat_args,
- fw_metadata.input_info,
- is_inference=is_inference,
- )
- # Happy path: we don't need synthetic bases
- if synthetic_base_info is None:
- return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
+@dataclass
+class AOTSyntheticBaseWrapper(CompilerWrapper):
+ trace_joint: bool # TODO: refactor trace_joint
+ needs_post_compile: bool = True
+ aliased_arg_idx_with_metadata_mutations: List[int] = field(default_factory=list)
- # export path: ban synthetic bases for now, add later if requested.
- if requires_subclass_dispatch(flat_args, fw_metadata):
- raise RuntimeError(
- """\
-Encountered aliased inputs that are mutated in the graph, but at least one input/output
-to the graph is a tensor subclass. This is not supported today. You can try to
-remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
+ def pre_compile(
+ self,
+ flat_fn,
+ flat_args: List[Any],
+ aot_config: AOTConfig,
+ *,
+ fw_metadata: ViewAndMutationMeta,
+ ):
+ is_inference = not self.trace_joint
+ flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
+ flat_args,
+ fw_metadata.input_info,
+ is_inference=is_inference,
)
- if aot_config.is_export:
- raise RuntimeError(
- f"""\
-Encountered aliased inputs that are mutated in the graph you are trying to export.
-This functionality is currently not supported. If needed, please file a github issue.
+ # Happy path: we don't need synthetic bases
+ if synthetic_base_info is None:
+ self.needs_post_compile = False
+ return flat_fn, flat_args, aot_config, fw_metadata
-synthetic_base_info={str(synthetic_base_info)}
+ # export path: ban synthetic bases for now, add later if requested.
+ if requires_subclass_dispatch(flat_args, fw_metadata):
+ raise RuntimeError(
+ """\
+ Encountered aliased inputs that are mutated in the graph, but at least one input/output
+ to the graph is a tensor subclass. This is not supported today. You can try to
+ remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
+ )
-fw_metadata={str(fw_metadata)}
- """
+ if aot_config.is_export:
+ raise RuntimeError(
+ f"""\
+ Encountered aliased inputs that are mutated in the graph you are trying to export.
+ This functionality is currently not supported. If needed, please file a github issue.
+
+ synthetic_base_info={str(synthetic_base_info)}
+
+ fw_metadata={str(fw_metadata)}
+ """
+ )
+
+ assert len(fw_metadata.input_info) == len(synthetic_base_info)
+
+ # Update our forward metadata to take synthetic bases into account
+ (
+ fw_metadata_updated,
+ aliased_arg_idx_with_metadata_mutations,
+ ) = create_synthetic_base_metadata(
+ fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases
+ )
+ # Save old input args for post-compile
+ self.old_input_info = fw_metadata.input_info
+
+ self.aliased_arg_idx_with_metadata_mutations = (
+ aliased_arg_idx_with_metadata_mutations
)
- assert len(fw_metadata.input_info) == len(synthetic_base_info)
+ num_aliased_args_with_metadata_mutations = len(
+ aliased_arg_idx_with_metadata_mutations
+ )
- # Update our forward metadata to take synthetic bases into account
- (
- fw_metadata_updated,
- aliased_arg_idx_with_metadata_mutations,
- ) = create_synthetic_base_metadata(
- fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases
- )
+ def _unpack_synthetic_bases(primals: Tuple[Any, ...]) -> List[Any]:
+ f_args_inner = []
+ for inner_idx_or_tuple in synthetic_base_info:
+ if isinstance(inner_idx_or_tuple, int):
+ f_args_inner.append(primals[inner_idx_or_tuple])
+ else:
+ inner_base_idx, view_tensor = inner_idx_or_tuple
+ base = primals[inner_base_idx]
+ view_arg = gen_alias_from_base(
+ base, view_tensor, view_tensor.requires_grad
+ )
+ f_args_inner.append(view_arg)
+ return f_args_inner
- num_aliased_args_with_metadata_mutations = len(
- aliased_arg_idx_with_metadata_mutations
- )
-
- def _unpack_synthetic_bases(primals: Tuple[Any, ...]) -> List[Any]:
- f_args_inner = []
- for inner_idx_or_tuple in synthetic_base_info:
- if isinstance(inner_idx_or_tuple, int):
- f_args_inner.append(primals[inner_idx_or_tuple])
+ @wraps(flat_fn)
+ def wrapped_flat_fn(*args):
+ unpacked_args = _unpack_synthetic_bases(args)
+ # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases)
+ # is to relieve the downstream logic from having to reason about mutations on inputs that alias
+ # each other, by replacing aliased inputs with a synthetic base.
+ # One area where this breaks down a bit however is if one of those aliased inputs
+ # experienced a metadata mutation.
+ # We are now obligated to reapply the metadata mutation directly to the user's input;
+ # it isn't enough to apply mutations back to the synthetic base in the downstream logic.
+ #
+ # The way we handle this is by pretending that those aliased inputs that experience metadata mutations
+ # are additional outputs in the user's forward function.
+ # The downstream logic will just treat these as "user outputs that alias inputs".
+ # However, we will manually grab them at runtime here, use them to reapply the metadata mutation
+ # to the user inputs, and not return them to the user.
+ aliased_args_with_metadata_mutations = [
+ x
+ for i, x in enumerate(unpacked_args)
+ if i in self.aliased_arg_idx_with_metadata_mutations
+ ]
+ if len(aliased_args_with_metadata_mutations) > 0:
+ return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations
else:
- inner_base_idx, view_tensor = inner_idx_or_tuple
- base = primals[inner_base_idx]
- view_arg = gen_alias_from_base(
- base, view_tensor, view_tensor.requires_grad
- )
- f_args_inner.append(view_arg)
- return f_args_inner
+ return flat_fn(*unpacked_args)
- @wraps(flat_fn)
- def wrapped_flat_fn(*args):
- unpacked_args = _unpack_synthetic_bases(args)
- # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases)
- # is to relieve the downstream logic from having to reason about mutations on inputs that alias
- # each other, by replacing aliased inputs with a synthetic base.
- # One area where this breaks down a bit however is if one of those aliased inputs
- # experienced a metadata mutation.
- # We are now obligated to reapply the metadata mutation directly to the user's input;
- # it isn't enough to apply mutations back to the synthetic base in the downstream logic.
- #
- # The way we handle this is by pretending that those aliased inputs that experience metadata mutations
- # are additional outputs in the user's forward function.
- # The downstream logic will just treat these as "user outputs that alias inputs".
- # However, we will manually grab them at runtime here, use them to reapply the metadata mutation
- # to the user inputs, and not return them to the user.
- aliased_args_with_metadata_mutations = [
- x
- for i, x in enumerate(unpacked_args)
- if i in aliased_arg_idx_with_metadata_mutations
- ]
- if len(aliased_args_with_metadata_mutations) > 0:
- return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations
- else:
- return flat_fn(*unpacked_args)
-
- if config.debug_assert:
- ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
+ if config.debug_assert:
+ ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
+ wrapped_flat_fn,
+ keep_input_mutations=fw_metadata.keep_input_mutations,
+ is_train=fw_metadata.is_train,
+ )(*flat_args_with_synthetic_bases)
+ assert ref_fw_metadata == fw_metadata_updated, (
+ f"ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, "
+ f"\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}"
+ )
+ return (
wrapped_flat_fn,
- keep_input_mutations=fw_metadata.keep_input_mutations,
- is_train=fw_metadata.is_train,
- )(*flat_args_with_synthetic_bases)
- assert ref_fw_metadata == fw_metadata_updated, (
- f"ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, "
- f"\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}"
+ flat_args_with_synthetic_bases,
+ aot_config,
+ fw_metadata_updated,
)
- compiled_fn = compiler_fn(
- wrapped_flat_fn,
- flat_args_with_synthetic_bases,
- aot_config,
- fw_metadata=fw_metadata_updated,
- )
+ def post_compile(
+ self,
+ compiled_fn,
+ aot_config: AOTConfig,
+ *,
+ fw_metadata: ViewAndMutationMeta,
+ ):
+ if not self.needs_post_compile:
+ return compiled_fn
- @wraps(compiled_fn)
- def wrapped_compiled_fn(args):
- args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
- args, fw_metadata.input_info, is_inference=is_inference
- )
- assert synthetic_base_info is not None
- aliased_args_w_metadata_mutations = [
- args[i] for i in aliased_arg_idx_with_metadata_mutations
- ]
- args.clear()
- outs = compiled_fn(args_with_synthetic_bases)
- if num_aliased_args_with_metadata_mutations > 0:
- # This code does not handle **all** input metadata mutations.
- # Instead, it only handles metadata mutations on inputs that were converted into synthetic bases
- # (which only happens if at least one aliased input experienced a data mutation).
- # e.g:
- # def f(a, b):
- # a.mul_(2)
- # b.t_(1, 0)
- # f(x.view(2, 2), x.view(2, 2))
- mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:]
- user_outs = outs[:-num_aliased_args_with_metadata_mutations]
- for inp, mutated_inp in zip(
- aliased_args_w_metadata_mutations, mutated_metadata_inps
- ):
- inp.as_strided_(
- mutated_inp.size(),
- mutated_inp.stride(),
- mutated_inp.storage_offset(),
- )
- return user_outs
- return outs
+ is_inference = not self.trace_joint
- return wrapped_compiled_fn
+ @wraps(compiled_fn)
+ def wrapped_compiled_fn(args):
+ args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
+ args, self.old_input_info, is_inference=is_inference
+ )
+ assert synthetic_base_info is not None
+ aliased_args_w_metadata_mutations = [
+ args[i] for i in self.aliased_arg_idx_with_metadata_mutations
+ ]
+ num_aliased_args_with_metadata_mutations = len(
+ aliased_args_w_metadata_mutations
+ )
+ args.clear()
+ outs = compiled_fn(args_with_synthetic_bases)
+ if num_aliased_args_with_metadata_mutations > 0:
+ # This code does not handle **all** input metadata mutations.
+ # Instead, it only handles metadata mutations on inputs that were converted into synthetic bases
+ # (which only happens if at least one aliased input experienced a data mutation).
+ # e.g:
+ # def f(a, b):
+ # a.mul_(2)
+ # b.t_(1, 0)
+ # f(x.view(2, 2), x.view(2, 2))
+ mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:]
+ user_outs = outs[:-num_aliased_args_with_metadata_mutations]
+ for inp, mutated_inp in zip(
+ aliased_args_w_metadata_mutations, mutated_metadata_inps
+ ):
+ inp.as_strided_(
+ mutated_inp.size(),
+ mutated_inp.stride(),
+ mutated_inp.storage_offset(),
+ )
+ return user_outs
+ return outs
+
+ return wrapped_compiled_fn
# Note [Handling mutations on an input that aliases other inputs]
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index 74d9bc9..f1ba677 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -2,7 +2,7 @@
import itertools
from contextlib import contextmanager, nullcontext
-from functools import partial, wraps
+from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
@@ -67,8 +67,8 @@
track_graph_compiling,
)
from ._aot_autograd.runtime_wrappers import ( # noqa: F401
- aot_wrapper_dedupe,
- aot_wrapper_synthetic_base,
+ AOTDedupeWrapper,
+ AOTSyntheticBaseWrapper,
)
from ._aot_autograd.schemas import ( # noqa: F401
AOTConfig,
@@ -670,17 +670,25 @@
aot_dispatch_base_graph if aot_config.is_export else aot_dispatch_base
)
- compiler_fn = partial(
- aot_wrapper_synthetic_base,
- compiler_fn=compiler_fn,
- needs_autograd=needs_autograd,
- )
- compiler_fn = partial(aot_wrapper_dedupe, compiler_fn=compiler_fn)
- # You can put more passes here
+ wrappers = [
+ AOTDedupeWrapper(),
+ AOTSyntheticBaseWrapper(trace_joint=needs_autograd),
+ # Add more passes here
+ ]
+ for wrapper in wrappers:
+ flat_fn, fake_flat_args, aot_config, fw_metadata = wrapper.pre_compile(
+ flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata
+ )
compiled_fn = compiler_fn(
flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata
)
+
+ for wrapper in reversed(wrappers):
+ compiled_fn = wrapper.post_compile(
+ compiled_fn, aot_config, fw_metadata=fw_metadata
+ )
+
if aot_config.is_export:
# During export, we don't get back a callable - we get back the raw fx graph
# (either a joint or an inference-only graph)