[AOT Refactor] alias runtime wrappers (#114562)
---
Part _ of https://github.com/pytorch/pytorch/issues/114548
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114562
Approved by: https://github.com/bdhirsh
ghstack dependencies: #114550, #114551, #114552, #114553, #114554, #114555, #114556, #114557, #114558, #114559, #114561
diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py
index 0143603..4474666 100644
--- a/torch/_functorch/_aot_autograd/input_output_analysis.py
+++ b/torch/_functorch/_aot_autograd/input_output_analysis.py
@@ -6,12 +6,9 @@
In particular, the following analyses are provided:
1. Refine the view and mutation metadata collected previously - removing duplicate
inputs or mapping views to their bases.
-2. Based on view base analysis, it may merge inputs of different views into a single
- input corresponding to a common base.
-3. We also analyze the function signature for export graphs.
+2. We also analyze the function signature for export graphs.
"""
-import collections
import itertools
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -21,7 +18,6 @@
from torch._logging import getArtifactLogger
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.fx.experimental.symbolic_shapes import is_concrete_int
-from torch.multiprocessing.reductions import StorageWeakRef
from .functional_utils import _get_mutation_type
from .schemas import (
BackwardSignature,
@@ -323,7 +319,7 @@
return False
-def _compute_overlapping_inputs(fwd_inputs, aliased_input_indices):
+def compute_overlapping_inputs(fwd_inputs, aliased_input_indices):
actual_aliased_indices = set()
for j in range(len(aliased_input_indices)):
for i in range(j):
@@ -335,248 +331,6 @@
return actual_aliased_indices
-# Note [Handling mutations on an input that aliases other inputs]
-# The easiest example to show-case this edge case is here:
-#
-# def f(a, b):
-# a.mul_(2)
-# out = a + b
-# return out
-# b = torch.ones(...)
-# a = b.view(-1)
-# f(a, b)
-#
-# In this situation, if a and b happened to be aliased, we need to trace something different!
-# Suppose we had b = a.view(-1)
-# (In this case, that means that `a._base is b`)
-#
-# We need to ensure that the aliasing relationship between a and b is preserved.
-# We do that detecting the specific situation above (mutate an input that aliases another input),
-# and when we do that, we create a synthetic base argument. Then inside of the traced forward,
-# we regenerate a and b off of that base.
-# The complete example of the transformed function looks like this:
-#
-# // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views
-# // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph
-# def traced_forward(base):
-# a = base.as_strided(...)
-# b = base.as_strided(...)
-# a_updated = a.mul(2)
-# base_updated = torch.as_strided_scatter(base, a_updated, ...)
-# b_updated = base_updated.as_strided(...)
-# out = a_updated + b_updated
-# return a_updated, out
-#
-# def compiled_fn(a, b):
-# // we detect that a is the "differentiable base" here
-# base = a
-# // In other situations, we might do either:
-# // (1) a and b are both views off of some larger differentiable base
-# // assert a._base is b._base and a._base is not None
-# // base = a._base
-# // (2) a and b both don't require gradients. Create a base from the storage
-# // assert a._base is None and b._base is None
-# // base = torch.Tensor(a.storage())
-# a_updated, out = traced_forward(base)
-# a.copy_(a_updated)
-# return out
-#
-# This function:
-# (1) Merges input views into a synthetic base argument, when any of those input views are mutated
-# (2) Returns metadata telling the autograd.Function how to modify their arguments properly,
-# to respect the new calling convention.
-#
-# The calling convention is as follows.
-# Any inputs that were originally views of one another get yanked, and replaced with a synthetic base.
-# The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN],
-# Where the ordering of the bases is determined from the ordering of the original view args.
-# baseA will come before baseB if the earliest original argument coming from baseA
-# showed up earlier in the argument list than the earliest original argument coming from baseB.
-#
-# Example, given some tensors a, b, c, d
-# call site:
-# f(a, c.view(-1), b.view(-1), b, c, d)
-# Modified argument list:
-# c_base comes first because the first c view came earlier in arg list than the first b view
-# a and d still show up in the modified arg list, but b and c don't- they're regenerated from their bases
-# b_base = torch.Tensor(b.storage())
-# c_base = torch.Tensor(c.storage())
-# f(c_base, b_base, a, d)
-def merge_view_inputs(
- fwd_inputs: List[Any],
- mutated_input_info: List[InputAliasInfo],
- *,
- # The autograd case currently has more restrictions than the inference case.
- is_inference: bool,
-) -> Tuple[List[Any], Optional[List[Union[int, Tuple[int, torch.Tensor]]]]]:
- def _are_differentiable_views(view1, view2):
- if view1 is view2:
- return True
- if view1._base is None and view2._base is None:
- return False
- if view1._base is view2._base or view1._base is view2 or view1 is view2._base:
- return True
- return False
-
- def _same_dtype_views(view1, view2):
- if view1.dtype != view2.dtype:
- return False
- if view1._base is not None and view1.dtype != view1._base.dtype:
- return False
- if view2._base is not None and view2.dtype != view2._base.dtype:
- return False
- return True
-
- assert len(fwd_inputs) == len(mutated_input_info)
- storage_ref_to_idx: Dict[StorageWeakRef, List[int]] = collections.defaultdict(list)
- base_args = []
- other_args = []
- for i, inpt in enumerate(fwd_inputs):
- if isinstance(inpt, Tensor):
- storage_ref = StorageWeakRef(inpt.untyped_storage())
- storage_ref_to_idx[storage_ref].append(i)
- else:
- other_args.append(inpt)
- # Note [Synthetic Base Info Metadata]
- # This list contains metadata that tells you what the i'th argument in the inner calling convention should be.
- # It's either:
- # - another int (corresponding to the index in the argument list of the element from the outer calling convention)
- # - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx])
- # idx corresponds to which synthetic base from the outer calling context to view
- inner_calling_convention_meta: Dict[int, Union[int, Tuple[int, torch.Tensor]]] = {}
- for aliased_input_indices in storage_ref_to_idx.values():
- if len(aliased_input_indices) <= 1 or not any(
- # We only care about mutations that affect all aliases,
- # so metadata mutations on an input doesn't require us to do synthetic base handling.
- mutated_input_info[inpt_idx].mutates_data
- for inpt_idx in aliased_input_indices
- ):
- for curr_idx in aliased_input_indices:
- other_args.append(fwd_inputs[curr_idx])
- continue
-
- # Here, we attempt to do a more complicated check to detect false aliasing
- # (e.g. if all the tensors have the same storage, but don't actually overlap)
- # In theory, we could have a large group of tensors that all share storages, where only *some* of them
- # have overlapping memory.
- # I don't bother with that case for now: here, we only bail out earlier if we detect that **every** pair
- # of tensors in the current group that shares a storage is non-overlapping.
- aliased_input_indices_no_false_sharing = _compute_overlapping_inputs(
- fwd_inputs, aliased_input_indices
- )
- if len(aliased_input_indices_no_false_sharing) <= 1:
- for curr_idx in aliased_input_indices:
- other_args.append(fwd_inputs[curr_idx])
- continue
-
- # We detected an input that was mutated, AND aliases with another input.
- # we need to replace this set of aliased inputs with a single synthetic base.
- # For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases
- # and error out. We can fix them later.
- # These checks are transitive, so we don't need to check every pair.
- for idx1, idx2 in zip(
- aliased_input_indices, aliased_input_indices[1:], strict=False
- ):
- view1 = fwd_inputs[idx1]
- view2 = fwd_inputs[idx2]
- # The "inputs that are aliased but have different differentiable bases" case
- # is more complicated and hopefully pretty rare. Not currently handled.
- if not is_inference:
- assert _are_differentiable_views(
- view1, view2
- ), "aot_autograd() does not yet handle non-differentiable view input mutations."
- # Regenerating views when reinterpreting complex / real tensors seems non-trivial,
- # not handling for now
- assert _same_dtype_views(
- view1, view2
- ), "aot_autograd() does not yet handle input mutations on views with different dtypes."
- non_none_bases = [
- fwd_inputs[i]._base
- for i in aliased_input_indices
- if fwd_inputs[i]._base is not None
- ]
- aliases_with_none_bases = [
- fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None
- ]
- if len(non_none_bases) == 0:
- # Case where none of the aliases have a ._base
- # we generate a synthetic base without gradients, and generate views off of it
- # We hit this case when we have input tensors to the graph that share a storage,
- # but do not have a ._base field.
- # Wondering when we hit this case?
- # The _base field simply says that autograd knows about the aliasing relationship,
- # but sometimes we create tensors which are aliased out of the same storage but guaranteed
- # to be disjoint. In these cases, we will skip setting up the _base relationship
- # for performance reasons (because the fact that the tensors share the same storage
- # is unobservable unless you (1) do naughty things with resize_/as_strided
- # or (2) look at the storage--as we are doing here.)
- # One particular example of this is optimizer steps on the LSTM module:
- # LSTM parameters are packed into a contiguous storage for efficiency reasons when
- # calling cuDNN kernels, so when these parameters get passed to the optimizer we will
- # find they share the same storage, but do not have _base set since they are all disjoint.
- #
- # NOTE: There is one case where this is unsafe:
- # torch.Tensor(storage) will ALWAYS create a 1D tensor, which is not necessarily
- # the same shape as the "actual" base that the tensor came from.
- # For the most part this is fine, because we always use as_strided()
- # to generate the original aliased inputs again.
- # If we were to use view-replay though, this could cause the aliased views
- # to have incorrect sizes.
- example_idx = aliased_input_indices[0]
- example_alias = fwd_inputs[example_idx]
- # Note that this function is re-used at both trace time and runtime.
- # At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor.
- synthetic_base = torch.empty(
- (0,), dtype=example_alias.dtype, device=example_alias.device
- )
- # We don't actually have a convenient way of going from storage -> tensor,
- # So using set_() here (we suffer some minor overhead, but this case is rare).
- synthetic_base.set_(example_alias.untyped_storage())
- else:
- # Case where all of the aliases require gradients, and have the same _base.
- synthetic_base = non_none_bases[0]
- for other_base in non_none_bases[1:]:
- assert (
- other_base is synthetic_base
- ), "aot_autograd() does not yet handle non-differentiable view input mutations."
- for alias in aliases_with_none_bases:
- assert (
- alias is synthetic_base
- ), "aot_autograd() does not yet handle non-differentiable view input mutations."
- base_args.append(synthetic_base)
- for curr_view_idx in aliased_input_indices:
- curr_view = fwd_inputs[curr_view_idx]
- base_idx = len(base_args) - 1
- # We store just enough info here so that we can regenerate the view later.
- # Regeneration: curr_view._view_func(args[base_idx])
- inner_calling_convention_meta[curr_view_idx] = (base_idx, curr_view)
- if len(base_args) == 0:
- assert len(other_args) == len(fwd_inputs)
- # If no synthetic bases are necessary, just return the original inputs.
- return fwd_inputs, None
- else:
- # Otherwise, return:
- # (1) The new args according to the updated calling convention: (synthetic_bases, other_args)
- # (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention.
- # We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention.
- args_to_functionalization = base_args + other_args
- arg_to_old_idx_map = {arg: i for (i, arg) in enumerate(fwd_inputs)}
- for i, other_arg in enumerate(other_args):
- new_idx = len(base_args) + i
- old_idx = arg_to_old_idx_map[other_arg]
- inner_calling_convention_meta[old_idx] = new_idx
- # post process into a list
- post_processed_calling_convention_meta: List[
- Union[int, Tuple[int, torch.Tensor]]
- ] = [-1 for _ in range(len(inner_calling_convention_meta))]
- for k, v in inner_calling_convention_meta.items():
- post_processed_calling_convention_meta[k] = v
- # Quick assert: every argument in the inner calling convention should be accounted for.
- for x in post_processed_calling_convention_meta:
- assert x != -1
- return args_to_functionalization, post_processed_calling_convention_meta
-
-
def _graph_input_names(gm):
return [node.name for node in gm.graph.nodes if node.op == "placeholder"]
diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py
index 91cdd19..7adc8a9 100644
--- a/torch/_functorch/_aot_autograd/runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py
@@ -1,18 +1,55 @@
"""
-This module holds defines runtime wrappers, which, based on previous analysis
-attempts to process the inputs and outputs, apply mutations and handle
-functionalized randomness at runtime.
+This module defines runtime wrappers, which, based on previous analysis attempts to:
+1. process the inputs and outputs
+2. apply mutations
+3. handle functionalized randomness
+4. deduplicate inputs and consolidate views into their bases (see input_output_analysis)
"""
-from typing import Callable, List, Optional, Union
+import collections
+import pprint
+from functools import wraps
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
+import torch.utils.dlpack
+from torch import Tensor
+from torch._guards import DuplicateInputs, TracingContext
from torch._prims_common import CUDARngStateHelper
+from torch.multiprocessing.reductions import StorageWeakRef
+from .. import config
+from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
from .functional_utils import gen_alias_from_base
-from .schemas import OutputType, SubclassCreationMeta, TensorAlias, ViewAndMutationMeta
-from .subclass_utils import unwrap_tensor_subclasses, wrap_tensor_subclasses
-from .utils import call_func_at_runtime_with_args, make_boxed_func
+from .input_output_analysis import (
+ compute_overlapping_inputs,
+ create_synthetic_base_metadata,
+ remove_dupe_metadata,
+)
+from .logging_utils import describe_input, format_guard_bug_msg
+from .schemas import (
+ AOTConfig,
+ InputAliasInfo,
+ OutputType,
+ SubclassCreationMeta,
+ TensorAlias,
+ ViewAndMutationMeta,
+)
+from .subclass_utils import (
+ requires_subclass_dispatch,
+ unwrap_tensor_subclasses,
+ wrap_tensor_subclasses,
+)
+
+from .utils import (
+ call_func_at_runtime_with_args,
+ make_boxed_func,
+ partial_flatten_asdict,
+ strict_zip,
+)
+
+
+zip = strict_zip
# The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic
@@ -268,3 +305,698 @@
# box it
inner_fn._boxed_call = True # type: ignore[attr-defined]
return inner_fn
+
+
+# MOTIVATION:
+#
+# When tracing functions for future execution, one must be careful not to pass
+# in the same input tensor multiple times (e.g., f(x, x), as this can result
+# in graphs that are ONLY valid if you later pass a new tensor in exactly the
+# same way (e.g., f(y, y)). (NB: we really mean duplicate; two distinct
+# tensors that alias each other is a different situation that is covered by
+# aot_dispatch_deduplicated_autograd). Here are two examples:
+#
+# (1) Suppose you have a function:
+#
+# def f(x, y):
+# return x + y
+#
+# If you make_fx(f)(x, x), you will trace out:
+#
+# def f(x, y):
+# return y + y
+#
+# Oops!
+#
+# (2) For most tensors x and y, you can compute f's gradient with respect to
+# these to inputs by saying torch.autograd.grad(f(x, y), (x, y)). However,
+# if x is y, you will trace out a program that gets incorrect gradients:
+#
+# >>> x = torch.randn(1, requires_grad=True)
+# >>> torch.autograd.grad(x + x, (x, x))
+# (tensor([2.]), tensor([2.]))
+#
+# In other words, the gradient is double-counted. Deduplicating the arguments
+# gives you an appropriate gradient:
+#
+# >>> y = torch.randn(1, requires_grad=True)
+# >>> torch.autograd.grad(x + y, (x, y))
+# (tensor([1.]), tensor([1.]))
+#
+# HOW TO DEDUPLICATE:
+#
+# There are a few strategies, in order of preference:
+#
+# 1. For every duplicate argument to the function, detach it into
+# a separate leaf tensor, so that it is no longer duplicated.
+#
+# PRO: The resulting compiled graph works for any configuration
+# of duplicated arguments.
+#
+# CON: It does not (naively) work if you mutate the metadata of inputs:
+#
+# def f(x, y):
+# x.transpose_(0, 1)
+# y.transpose_(0, 2)
+#
+# x = torch.randn(2, 3, 4)
+# f(x, x)
+#
+# The ordering of the transposes inside f dictates whether or not
+# you get [4, 2, 3] or [3, 4, 2]. This means that you cannot precompute
+# what metadata mutations should get applied to each input; you need to
+# assume they aren't duplicates (what we do today) or preserve
+# the original metadata mutations exactly in order, so that they work
+# for any duplicate configuration.
+#
+# CON: It does not (naively) work if you mutate the data of inputs.
+# In particular, leaf tensors that require grad cannot be mutated,
+# this makes it impossible to differentiate with respect to the original
+# base.
+#
+# 2. For every duplicate argument to the function, remove it, so it is
+# no longer part of the "true" signature:
+#
+# PRO: Implemented naively, it still works for metadata/data mutation.
+#
+# CON: The resulting compiled graph is duplicate-specialized: it only
+# works if future calls duplicate arguments in exactly the same way.
+# Horribly, Dynamo doesn't guard on this at the moment. But even if
+# it did, you could still end up recompiling a bunch of each duplicate.
+#
+# Our strategy is to do (1) if we can, and do (2) otherwise, erroring if
+# Dynamo's guards are not enough. In practice, this seems to cover
+# everything.
+#
+def aot_wrapper_dedupe(
+ flat_fn,
+ flat_args: List[Tensor],
+ aot_config: AOTConfig,
+ *,
+ compiler_fn,
+ fw_metadata,
+):
+ # Use information about whether or not flat_fn mutates its arguments
+ # or not to handle dupe args
+
+ # Strategy 1: For any input that is not mutated, we can leafify it if we
+ # need to remove a duplicate.
+ leaf_flat_args = []
+ args_set = set()
+ ok = True
+
+ for i, a in enumerate(flat_args):
+ if not isinstance(a, torch.Tensor):
+ leaf_flat_args.append(a)
+ elif a not in args_set:
+ args_set.add(a)
+ leaf_flat_args.append(a)
+ elif (
+ not fw_metadata.input_info[i].mutates_data
+ and not fw_metadata.input_info[i].mutates_metadata
+ ):
+ leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad))
+ else:
+ ok = False
+ break
+
+ if ok:
+ return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
+
+ if requires_subclass_dispatch(leaf_flat_args, fw_metadata):
+ raise RuntimeError(
+ """\
+Encountered duplicate inputs that are mutated in the graph, but at least one input/output
+to the graph is a tensor subclass. This is not supported today. You can try to
+remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
+ )
+
+ # export path: ban duplicate inputs for now, add later if requested.
+ if aot_config.is_export:
+ raise RuntimeError(
+ f"""\
+Encountered duplicated inputs that are mutated in the graph you are trying to export.
+This functionality is currently not supported. If needed, please file a github issue.
+
+fw_metadata={str(fw_metadata)}
+ """
+ )
+
+ # Strategy 2: Duplicate specialize.
+ #
+ # In Haskell types, suppose you have:
+ #
+ # add_dupe_args :: DedupedArgs -> Args
+ # remove_dupe_args :: Args -> DedupedArgs
+ #
+ # compiler_fn
+ # :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R)
+ # deped_compiler_fn
+ # :: (Args -> R) -> Args -> AOTConfig -> (Args -> R)
+ #
+ # Then the code below can be written in point-free style as:
+ #
+ # deduped_compiler_fn f a c =
+ # compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args
+ #
+ # Suppose you have:
+ #
+ # [a, b, a, c]
+ #
+ # We want:
+ #
+ # remove_dupe_args([a, b, a, c]) == [a, b, c]
+ # add_dupe_args([a, b, c]) == [a, b, a, c]
+ #
+ # This is done via (respectively):
+ #
+ # seen_args = {a: 0, b: 1, c: 2}
+ # enumerate(add_dupe_map) = [ # how to get args from the deduped list
+ # (0, 0),
+ # (1, 1),
+ # (2, 0),
+ # (3, 2),
+ # ]
+ # keep_arg_mask = [True, True, False, True]
+
+ seen_args: Dict[Tensor, int] = {}
+ keep_arg_mask = []
+ # Implicitly map duped arg position (list index) to de-duped arg position
+ add_dupe_map: List[int] = []
+ duped_arg_len = len(flat_args)
+
+ j = 0 # index into deduped_flat_args
+ for t in flat_args:
+ if isinstance(t, torch.Tensor):
+ if t in seen_args:
+ keep_arg_mask.append(False)
+ add_dupe_map.append(seen_args[t])
+ continue
+ seen_args[t] = j
+
+ keep_arg_mask.append(True)
+ add_dupe_map.append(j)
+ j += 1
+ assert (
+ len(add_dupe_map) == duped_arg_len
+ ), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
+
+ # NB: Hot path, avoid set lookups here
+ # TODO: Can avoid the zip here too, probably
+ def remove_dupe_args(args):
+ return [t for t, keep in zip(args, keep_arg_mask) if keep]
+
+ def add_dupe_args(args):
+ return [args[add_dupe_map[i]] for i in range(duped_arg_len)]
+
+ deduped_flat_args = remove_dupe_args(flat_args)
+
+ # Update our input metadata to remove duped input metadata.
+ updated_fw_metadata = remove_dupe_metadata(fw_metadata, keep_arg_mask, add_dupe_map)
+
+ if (
+ tracing_context := TracingContext.try_get()
+ and aot_config.aot_autograd_arg_pos_to_source
+ ):
+ # TODO(voz): This structure is 1:1, we could consider an alternate structure like
+ # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there,
+ # which feels like needless complexity for a tiny bit of efficiency at this point.
+ for dupe_arg_pos, (kept_pos, keep_arg) in enumerate(
+ zip(add_dupe_map, keep_arg_mask)
+ ):
+ if not keep_arg:
+ dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[
+ dupe_arg_pos
+ ]
+ kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[kept_pos]
+ tracing_context.guards_context.aotautograd_guards.append( # type: ignore[attr-defined]
+ DuplicateInputs(kept_arg_source, dupe_arg_source)
+ )
+
+ @wraps(flat_fn)
+ def wrapped_flat_fn(*args):
+ return flat_fn(*add_dupe_args(args))
+
+ if config.debug_assert:
+ ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
+ wrapped_flat_fn,
+ keep_input_mutations=fw_metadata.keep_input_mutations,
+ is_train=fw_metadata.is_train,
+ )(*deduped_flat_args)
+ assert (
+ ref_fw_metadata == updated_fw_metadata
+ ), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}"
+
+ compiled_fn = compiler_fn(
+ wrapped_flat_fn, deduped_flat_args, aot_config, fw_metadata=updated_fw_metadata
+ )
+
+ if not hasattr(compiled_fn, "_boxed_call"):
+ compiled_fn = make_boxed_func(compiled_fn)
+
+ @wraps(compiled_fn)
+ def wrapped_compiled_fn(args):
+ deduped_args = remove_dupe_args(args)
+ args.clear()
+ return compiled_fn(deduped_args)
+
+ wrapped_compiled_fn._boxed_call = True # type: ignore[attr-defined]
+
+ # This can be uncommented when we properly guard for duplicates,
+ # but right now we must not do it.
+ # if not config.debug_assert:
+ # return wrapped_compiled_fn
+
+ @wraps(wrapped_compiled_fn)
+ def debugged_compiled_fn(args):
+ # Test that the computed remove/add arg functions are an inverse
+ new_args = add_dupe_args(remove_dupe_args(args))
+ seen: Dict[Any, None] = {}
+ for i, (x, y) in enumerate(zip(new_args, args)):
+ seen[y] = None
+ assert x is y, format_guard_bug_msg(
+ aot_config,
+ f"{describe_input(i, aot_config)} would be a duplicate of "
+ f"{describe_input(add_dupe_map[i], aot_config)}",
+ )
+ # This is only an error if there is metadata mutation on both of
+ # the duped arguments; in this case, we need to know what order
+ # the metadata mutation applies in. You'll get the correct result
+ # otherwise, because a graph that assumes distinct inputs works if
+ # you dupe the inputs (the gradient contributions from each input
+ # will get summed up appropriately.)
+ #
+ # TODO: work out how to setup this assert correctly
+ """
+ assert len(seen) == unique_args, format_guard_bug_msg(aot_config,
+ f"there would be {unique_args} distinct arguments"
+ )
+ """
+ return wrapped_compiled_fn(args)
+
+ debugged_compiled_fn._boxed_call = True # type: ignore[attr-defined]
+
+ return debugged_compiled_fn
+
+
+# This layer handles the situation where you have two inputs that alias each other,
+# and one of the inputs is mutated.
+# We need to take special care to ensure that the mutation is applied to the other aliases in the graph.
+#
+# pre-condition: aot_wrapper_dedup has already run.
+# (This function will in theory work if there are duplicate args.
+# However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs
+# would cause us to hit that path more frequently).
+def aot_wrapper_synthetic_base(
+ flat_fn,
+ flat_args: List[Tensor],
+ aot_config: AOTConfig,
+ *,
+ fw_metadata: ViewAndMutationMeta,
+ # Currently, the only reason we need to plumb this bool is because
+ # the synthetic base code prohibits more cases in the autograd case than the inference case.
+ needs_autograd: bool,
+ compiler_fn,
+):
+ is_inference = not needs_autograd
+ flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
+ flat_args,
+ fw_metadata.input_info,
+ is_inference=is_inference,
+ )
+ # Happy path: we don't need synthetic bases
+ if synthetic_base_info is None:
+ return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
+
+ # export path: ban synthetic bases for now, add later if requested.
+ if requires_subclass_dispatch(flat_args, fw_metadata):
+ raise RuntimeError(
+ """\
+Encountered aliased inputs that are mutated in the graph, but at least one input/output
+to the graph is a tensor subclass. This is not supported today. You can try to
+remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
+ )
+
+ if aot_config.is_export:
+ raise RuntimeError(
+ f"""\
+Encountered aliased inputs that are mutated in the graph you are trying to export.
+This functionality is currently not supported. If needed, please file a github issue.
+
+synthetic_base_info={str(synthetic_base_info)}
+
+fw_metadata={str(fw_metadata)}
+ """
+ )
+
+ assert len(fw_metadata.input_info) == len(synthetic_base_info)
+
+ # Update our forward metadata to take synthetic bases into account
+ (
+ fw_metadata_updated,
+ aliased_arg_idx_with_metadata_mutations,
+ ) = create_synthetic_base_metadata(
+ fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases
+ )
+
+ num_aliased_args_with_metadata_mutations = len(
+ aliased_arg_idx_with_metadata_mutations
+ )
+
+ def _unpack_synthetic_bases(primals: Tuple[Any, ...]) -> List[Any]:
+ f_args_inner = []
+ for inner_idx_or_tuple in synthetic_base_info:
+ if isinstance(inner_idx_or_tuple, int):
+ f_args_inner.append(primals[inner_idx_or_tuple])
+ else:
+ inner_base_idx, view_tensor = inner_idx_or_tuple
+ base = primals[inner_base_idx]
+ view_arg = gen_alias_from_base(
+ base, view_tensor, view_tensor.requires_grad
+ )
+ f_args_inner.append(view_arg)
+ return f_args_inner
+
+ @wraps(flat_fn)
+ def wrapped_flat_fn(*args):
+ unpacked_args = _unpack_synthetic_bases(args)
+ # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases)
+ # is to relieve the downstream logic from having to reason about mutations on inputs that alias
+ # each other, by replacing aliased inputs with a synthetic base.
+ # One area where this breaks down a bit however is if one of those aliased inputs
+ # experienced a metadata mutation.
+ # We are now obligated to reapply the metadata mutation directly to the user's input;
+ # it isn't enough to apply mutations back to the synthetic base in the downstream logic.
+ #
+ # The way we handle this is by pretending that those aliased inputs that experience metadata mutations
+ # are additional outputs in the user's forward function.
+ # The downstream logic will just treat these as "user outputs that alias inputs".
+ # However, we will manually grab them at runtime here, use them to reapply the metadata mutation
+ # to the user inputs, and not return them to the user.
+ aliased_args_with_metadata_mutations = [
+ x
+ for i, x in enumerate(unpacked_args)
+ if i in aliased_arg_idx_with_metadata_mutations
+ ]
+ if len(aliased_args_with_metadata_mutations) > 0:
+ return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations
+ else:
+ return flat_fn(*unpacked_args)
+
+ if config.debug_assert:
+ ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
+ wrapped_flat_fn,
+ keep_input_mutations=fw_metadata.keep_input_mutations,
+ is_train=fw_metadata.is_train,
+ )(*flat_args_with_synthetic_bases)
+ assert ref_fw_metadata == fw_metadata_updated, (
+ f"ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, "
+ f"\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}"
+ )
+
+ compiled_fn = compiler_fn(
+ wrapped_flat_fn,
+ flat_args_with_synthetic_bases,
+ aot_config,
+ fw_metadata=fw_metadata_updated,
+ )
+
+ if not hasattr(compiled_fn, "_boxed_call"):
+ compiled_fn = make_boxed_func(compiled_fn)
+
+ @wraps(compiled_fn)
+ def wrapped_compiled_fn(args):
+ args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
+ args, fw_metadata.input_info, is_inference=is_inference
+ )
+ assert synthetic_base_info is not None
+ aliased_args_w_metadata_mutations = [
+ args[i] for i in aliased_arg_idx_with_metadata_mutations
+ ]
+ args.clear()
+ outs = compiled_fn(args_with_synthetic_bases)
+ if num_aliased_args_with_metadata_mutations > 0:
+ # This code does not handle **all** input metadata mutations.
+ # Instead, it only handles metadata mutations on inputs that were converted into synthetic bases
+ # (which only happens if at least one aliased input experienced a data mutation).
+ # e.g:
+ # def f(a, b):
+ # a.mul_(2)
+ # b.t_(1, 0)
+ # f(x.view(2, 2), x.view(2, 2))
+ mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:]
+ user_outs = outs[:-num_aliased_args_with_metadata_mutations]
+ for inp, mutated_inp in zip(
+ aliased_args_w_metadata_mutations, mutated_metadata_inps
+ ):
+ inp.as_strided_(
+ mutated_inp.size(),
+ mutated_inp.stride(),
+ mutated_inp.storage_offset(),
+ )
+ return user_outs
+ return outs
+
+ return wrapped_compiled_fn
+
+
+# Note [Handling mutations on an input that aliases other inputs]
+# The easiest example to show-case this edge case is here:
+#
+# def f(a, b):
+# a.mul_(2)
+# out = a + b
+# return out
+# b = torch.ones(...)
+# a = b.view(-1)
+# f(a, b)
+#
+# In this situation, if a and b happened to be aliased, we need to trace something different!
+# Suppose we had b = a.view(-1)
+# (In this case, that means that `a._base is b`)
+#
+# We need to ensure that the aliasing relationship between a and b is preserved.
+# We do that detecting the specific situation above (mutate an input that aliases another input),
+# and when we do that, we create a synthetic base argument. Then inside of the traced forward,
+# we regenerate a and b off of that base.
+# The complete example of the transformed function looks like this:
+#
+# // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views
+# // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph
+# def traced_forward(base):
+# a = base.as_strided(...)
+# b = base.as_strided(...)
+# a_updated = a.mul(2)
+# base_updated = torch.as_strided_scatter(base, a_updated, ...)
+# b_updated = base_updated.as_strided(...)
+# out = a_updated + b_updated
+# return a_updated, out
+#
+# def compiled_fn(a, b):
+# // we detect that a is the "differentiable base" here
+# base = a
+# // In other situations, we might do either:
+# // (1) a and b are both views off of some larger differentiable base
+# // assert a._base is b._base and a._base is not None
+# // base = a._base
+# // (2) a and b both don't require gradients. Create a base from the storage
+# // assert a._base is None and b._base is None
+# // base = torch.Tensor(a.storage())
+# a_updated, out = traced_forward(base)
+# a.copy_(a_updated)
+# return out
+#
+# This function:
+# (1) Merges input views into a synthetic base argument, when any of those input views are mutated
+# (2) Returns metadata telling the autograd.Function how to modify their arguments properly,
+# to respect the new calling convention.
+#
+# The calling convention is as follows.
+# Any inputs that were originally views of one another get yanked, and replaced with a synthetic base.
+# The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN],
+# Where the ordering of the bases is determined from the ordering of the original view args.
+# baseA will come before baseB if the earliest original argument coming from baseA
+# showed up earlier in the argument list than the earliest original argument coming from baseB.
+#
+# Example, given some tensors a, b, c, d
+# call site:
+# f(a, c.view(-1), b.view(-1), b, c, d)
+# Modified argument list:
+# c_base comes first because the first c view came earlier in arg list than the first b view
+# a and d still show up in the modified arg list, but b and c don't- they're regenerated from their bases
+# b_base = torch.Tensor(b.storage())
+# c_base = torch.Tensor(c.storage())
+# f(c_base, b_base, a, d)
+def merge_view_inputs(
+ fwd_inputs: List[Any],
+ mutated_input_info: List[InputAliasInfo],
+ *,
+ # The autograd case currently has more restrictions than the inference case.
+ is_inference: bool,
+) -> Tuple[List[Any], Optional[List[Union[int, Tuple[int, torch.Tensor]]]]]:
+ def _are_differentiable_views(view1, view2):
+ if view1 is view2:
+ return True
+ if view1._base is None and view2._base is None:
+ return False
+ if view1._base is view2._base or view1._base is view2 or view1 is view2._base:
+ return True
+ return False
+
+ def _same_dtype_views(view1, view2):
+ if view1.dtype != view2.dtype:
+ return False
+ if view1._base is not None and view1.dtype != view1._base.dtype:
+ return False
+ if view2._base is not None and view2.dtype != view2._base.dtype:
+ return False
+ return True
+
+ assert len(fwd_inputs) == len(mutated_input_info)
+ storage_ref_to_idx: Dict[StorageWeakRef, List[int]] = collections.defaultdict(list)
+ base_args = []
+ other_args = []
+ for i, inpt in enumerate(fwd_inputs):
+ if isinstance(inpt, Tensor):
+ storage_ref = StorageWeakRef(inpt.untyped_storage())
+ storage_ref_to_idx[storage_ref].append(i)
+ else:
+ other_args.append(inpt)
+ # Note [Synthetic Base Info Metadata]
+ # This list contains metadata that tells you what the i'th argument in the inner calling convention should be.
+ # It's either:
+ # - another int (corresponding to the index in the argument list of the element from the outer calling convention)
+ # - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx])
+ # idx corresponds to which synthetic base from the outer calling context to view
+ inner_calling_convention_meta: Dict[int, Union[int, Tuple[int, torch.Tensor]]] = {}
+ for aliased_input_indices in storage_ref_to_idx.values():
+ if len(aliased_input_indices) <= 1 or not any(
+ # We only care about mutations that affect all aliases,
+ # so metadata mutations on an input doesn't require us to do synthetic base handling.
+ mutated_input_info[inpt_idx].mutates_data
+ for inpt_idx in aliased_input_indices
+ ):
+ for curr_idx in aliased_input_indices:
+ other_args.append(fwd_inputs[curr_idx])
+ continue
+
+ # Here, we attempt to do a more complicated check to detect false aliasing
+ # (e.g. if all the tensors have the same storage, but don't actually overlap)
+ # In theory, we could have a large group of tensors that all share storages, where only *some* of them
+ # have overlapping memory.
+ # I don't bother with that case for now: here, we only bail out earlier if we detect that **every** pair
+ # of tensors in the current group that shares a storage is non-overlapping.
+ aliased_input_indices_no_false_sharing = compute_overlapping_inputs(
+ fwd_inputs, aliased_input_indices
+ )
+ if len(aliased_input_indices_no_false_sharing) <= 1:
+ for curr_idx in aliased_input_indices:
+ other_args.append(fwd_inputs[curr_idx])
+ continue
+
+ # We detected an input that was mutated, AND aliases with another input.
+ # we need to replace this set of aliased inputs with a single synthetic base.
+ # For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases
+ # and error out. We can fix them later.
+ # These checks are transitive, so we don't need to check every pair.
+ for idx1, idx2 in zip(
+ aliased_input_indices, aliased_input_indices[1:], strict=False
+ ):
+ view1 = fwd_inputs[idx1]
+ view2 = fwd_inputs[idx2]
+ # The "inputs that are aliased but have different differentiable bases" case
+ # is more complicated and hopefully pretty rare. Not currently handled.
+ if not is_inference:
+ assert _are_differentiable_views(
+ view1, view2
+ ), "aot_autograd() does not yet handle non-differentiable view input mutations."
+ # Regenerating views when reinterpreting complex / real tensors seems non-trivial,
+ # not handling for now
+ assert _same_dtype_views(
+ view1, view2
+ ), "aot_autograd() does not yet handle input mutations on views with different dtypes."
+ non_none_bases = [
+ fwd_inputs[i]._base
+ for i in aliased_input_indices
+ if fwd_inputs[i]._base is not None
+ ]
+ aliases_with_none_bases = [
+ fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None
+ ]
+ if len(non_none_bases) == 0:
+ # Case where none of the aliases have a ._base
+ # we generate a synthetic base without gradients, and generate views off of it
+ # We hit this case when we have input tensors to the graph that share a storage,
+ # but do not have a ._base field.
+ # Wondering when we hit this case?
+ # The _base field simply says that autograd knows about the aliasing relationship,
+ # but sometimes we create tensors which are aliased out of the same storage but guaranteed
+ # to be disjoint. In these cases, we will skip setting up the _base relationship
+ # for performance reasons (because the fact that the tensors share the same storage
+ # is unobservable unless you (1) do naughty things with resize_/as_strided
+ # or (2) look at the storage--as we are doing here.)
+ # One particular example of this is optimizer steps on the LSTM module:
+ # LSTM parameters are packed into a contiguous storage for efficiency reasons when
+ # calling cuDNN kernels, so when these parameters get passed to the optimizer we will
+ # find they share the same storage, but do not have _base set since they are all disjoint.
+ #
+ # NOTE: There is one case where this is unsafe:
+ # torch.Tensor(storage) will ALWAYS create a 1D tensor, which is not necessarily
+ # the same shape as the "actual" base that the tensor came from.
+ # For the most part this is fine, because we always use as_strided()
+ # to generate the original aliased inputs again.
+ # If we were to use view-replay though, this could cause the aliased views
+ # to have incorrect sizes.
+ example_idx = aliased_input_indices[0]
+ example_alias = fwd_inputs[example_idx]
+ # Note that this function is re-used at both trace time and runtime.
+ # At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor.
+ synthetic_base = torch.empty(
+ (0,), dtype=example_alias.dtype, device=example_alias.device
+ )
+ # We don't actually have a convenient way of going from storage -> tensor,
+ # So using set_() here (we suffer some minor overhead, but this case is rare).
+ synthetic_base.set_(example_alias.untyped_storage())
+ else:
+ # Case where all of the aliases require gradients, and have the same _base.
+ synthetic_base = non_none_bases[0]
+ for other_base in non_none_bases[1:]:
+ assert (
+ other_base is synthetic_base
+ ), "aot_autograd() does not yet handle non-differentiable view input mutations."
+ for alias in aliases_with_none_bases:
+ assert (
+ alias is synthetic_base
+ ), "aot_autograd() does not yet handle non-differentiable view input mutations."
+ base_args.append(synthetic_base)
+ for curr_view_idx in aliased_input_indices:
+ curr_view = fwd_inputs[curr_view_idx]
+ base_idx = len(base_args) - 1
+ # We store just enough info here so that we can regenerate the view later.
+ # Regeneration: curr_view._view_func(args[base_idx])
+ inner_calling_convention_meta[curr_view_idx] = (base_idx, curr_view)
+ if len(base_args) == 0:
+ assert len(other_args) == len(fwd_inputs)
+ # If no synthetic bases are necessary, just return the original inputs.
+ return fwd_inputs, None
+ else:
+ # Otherwise, return:
+ # (1) The new args according to the updated calling convention: (synthetic_bases, other_args)
+ # (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention.
+ # We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention.
+ args_to_functionalization = base_args + other_args
+ arg_to_old_idx_map = {arg: i for (i, arg) in enumerate(fwd_inputs)}
+ for i, other_arg in enumerate(other_args):
+ new_idx = len(base_args) + i
+ old_idx = arg_to_old_idx_map[other_arg]
+ inner_calling_convention_meta[old_idx] = new_idx
+ # post process into a list
+ post_processed_calling_convention_meta: List[
+ Union[int, Tuple[int, torch.Tensor]]
+ ] = [-1 for _ in range(len(inner_calling_convention_meta))]
+ for k, v in inner_calling_convention_meta.items():
+ post_processed_calling_convention_meta[k] = v
+ # Quick assert: every argument in the inner calling convention should be accounted for.
+ for x in post_processed_calling_convention_meta:
+ assert x != -1
+ return args_to_functionalization, post_processed_calling_convention_meta
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index d0aee35..a202f95 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -1,6 +1,5 @@
import itertools
import logging
-import pprint
from contextlib import nullcontext
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
@@ -28,7 +27,7 @@
from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions
from . import config
from .partitioners import default_partition
-from torch._guards import TracingContext, DuplicateInputs
+from torch._guards import TracingContext
from ._aot_autograd.utils import ( # noqa: F401
strict_zip,
@@ -101,15 +100,9 @@
remove_dupe_metadata,
create_synthetic_base_metadata,
_tensors_definitely_do_not_overlap,
- _compute_overlapping_inputs,
- merge_view_inputs,
+ compute_overlapping_inputs,
create_graph_signature,
)
-from ._aot_autograd.runtime_wrappers import ( # noqa: F401
- create_runtime_wrapper,
- functionalized_rng_runtime_epilogue,
- aot_dispatch_subclass_wrapper,
-)
from ._aot_autograd.traced_function_transforms import ( # noqa: F401
fn_input_mutations_to_outputs,
fn_prepped_for_autograd,
@@ -119,6 +112,14 @@
aot_dispatch_subclass,
create_functional_call,
)
+from ._aot_autograd.runtime_wrappers import ( # noqa: F401
+ create_runtime_wrapper,
+ functionalized_rng_runtime_epilogue,
+ aot_dispatch_subclass_wrapper,
+ aot_wrapper_dedupe,
+ aot_wrapper_synthetic_base,
+ merge_view_inputs,
+)
zip = strict_zip
@@ -508,412 +509,6 @@
return compiled_fn
-
-# MOTIVATION:
-#
-# When tracing functions for future execution, one must be careful not to pass
-# in the same input tensor multiple times (e.g., f(x, x), as this can result
-# in graphs that are ONLY valid if you later pass a new tensor in exactly the
-# same way (e.g., f(y, y)). (NB: we really mean duplicate; two distinct
-# tensors that alias each other is a different situation that is covered by
-# aot_dispatch_deduplicated_autograd). Here are two examples:
-#
-# (1) Suppose you have a function:
-#
-# def f(x, y):
-# return x + y
-#
-# If you make_fx(f)(x, x), you will trace out:
-#
-# def f(x, y):
-# return y + y
-#
-# Oops!
-#
-# (2) For most tensors x and y, you can compute f's gradient with respect to
-# these to inputs by saying torch.autograd.grad(f(x, y), (x, y)). However,
-# if x is y, you will trace out a program that gets incorrect gradients:
-#
-# >>> x = torch.randn(1, requires_grad=True)
-# >>> torch.autograd.grad(x + x, (x, x))
-# (tensor([2.]), tensor([2.]))
-#
-# In other words, the gradient is double-counted. Deduplicating the arguments
-# gives you an appropriate gradient:
-#
-# >>> y = torch.randn(1, requires_grad=True)
-# >>> torch.autograd.grad(x + y, (x, y))
-# (tensor([1.]), tensor([1.]))
-#
-# HOW TO DEDUPLICATE:
-#
-# There are a few strategies, in order of preference:
-#
-# 1. For every duplicate argument to the function, detach it into
-# a separate leaf tensor, so that it is no longer duplicated.
-#
-# PRO: The resulting compiled graph works for any configuration
-# of duplicated arguments.
-#
-# CON: It does not (naively) work if you mutate the metadata of inputs:
-#
-# def f(x, y):
-# x.transpose_(0, 1)
-# y.transpose_(0, 2)
-#
-# x = torch.randn(2, 3, 4)
-# f(x, x)
-#
-# The ordering of the transposes inside f dictates whether or not
-# you get [4, 2, 3] or [3, 4, 2]. This means that you cannot precompute
-# what metadata mutations should get applied to each input; you need to
-# assume they aren't duplicates (what we do today) or preserve
-# the original metadata mutations exactly in order, so that they work
-# for any duplicate configuration.
-#
-# CON: It does not (naively) work if you mutate the data of inputs.
-# In particular, leaf tensors that require grad cannot be mutated,
-# this makes it impossible to differentiate with respect to the original
-# base.
-#
-# 2. For every duplicate argument to the function, remove it, so it is
-# no longer part of the "true" signature:
-#
-# PRO: Implemented naively, it still works for metadata/data mutation.
-#
-# CON: The resulting compiled graph is duplicate-specialized: it only
-# works if future calls duplicate arguments in exactly the same way.
-# Horribly, Dynamo doesn't guard on this at the moment. But even if
-# it did, you could still end up recompiling a bunch of each duplicate.
-#
-# Our strategy is to do (1) if we can, and do (2) otherwise, erroring if
-# Dynamo's guards are not enough. In practice, this seems to cover
-# everything.
-#
-def aot_wrapper_dedupe(
- flat_fn,
- flat_args: List[Tensor],
- aot_config: AOTConfig,
- *,
- compiler_fn,
- fw_metadata,
-):
- # Use information about whether or not flat_fn mutates its arguments
- # or not to handle dupe args
-
- # Strategy 1: For any input that is not mutated, we can leafify it if we
- # need to remove a duplicate.
- leaf_flat_args = []
- args_set = set()
- ok = True
-
- for i, a in enumerate(flat_args):
- if not isinstance(a, torch.Tensor):
- leaf_flat_args.append(a)
- elif a not in args_set:
- args_set.add(a)
- leaf_flat_args.append(a)
- elif not fw_metadata.input_info[i].mutates_data and not fw_metadata.input_info[i].mutates_metadata:
- leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad))
- else:
- ok = False
- break
-
- if ok:
- return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
-
- if requires_subclass_dispatch(leaf_flat_args, fw_metadata):
- raise RuntimeError("""\
-Encountered duplicate inputs that are mutated in the graph, but at least one input/output
-to the graph is a tensor subclass. This is not supported today. You can try to
-remove the aliasing yourself as a workaround, or otherwise file an issue on github.""")
-
- # export path: ban duplicate inputs for now, add later if requested.
- if aot_config.is_export:
- raise RuntimeError(f"""\
-Encountered duplicated inputs that are mutated in the graph you are trying to export.
-This functionality is currently not supported. If needed, please file a github issue.
-
-fw_metadata={str(fw_metadata)}
- """)
-
- # Strategy 2: Duplicate specialize.
- #
- # In Haskell types, suppose you have:
- #
- # add_dupe_args :: DedupedArgs -> Args
- # remove_dupe_args :: Args -> DedupedArgs
- #
- # compiler_fn
- # :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R)
- # deped_compiler_fn
- # :: (Args -> R) -> Args -> AOTConfig -> (Args -> R)
- #
- # Then the code below can be written in point-free style as:
- #
- # deduped_compiler_fn f a c =
- # compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args
- #
- # Suppose you have:
- #
- # [a, b, a, c]
- #
- # We want:
- #
- # remove_dupe_args([a, b, a, c]) == [a, b, c]
- # add_dupe_args([a, b, c]) == [a, b, a, c]
- #
- # This is done via (respectively):
- #
- # seen_args = {a: 0, b: 1, c: 2}
- # enumerate(add_dupe_map) = [ # how to get args from the deduped list
- # (0, 0),
- # (1, 1),
- # (2, 0),
- # (3, 2),
- # ]
- # keep_arg_mask = [True, True, False, True]
-
- seen_args = {}
- keep_arg_mask = []
- # Implicitly map duped arg position (list index) to de-duped arg position
- add_dupe_map: List[int] = []
- duped_arg_len = len(flat_args)
-
- j = 0 # index into deduped_flat_args
- for t in flat_args:
- if isinstance(t, torch.Tensor):
- if t in seen_args:
- keep_arg_mask.append(False)
- add_dupe_map.append(seen_args[t])
- continue
- seen_args[t] = j
-
- keep_arg_mask.append(True)
- add_dupe_map.append(j)
- j += 1
- assert len(add_dupe_map) == duped_arg_len, (
- f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
- )
-
- # NB: Hot path, avoid set lookups here
- # TODO: Can avoid the zip here too, probably
- def remove_dupe_args(args):
- return [t for t, keep in zip(args, keep_arg_mask) if keep]
-
- def add_dupe_args(args):
- return [args[add_dupe_map[i]] for i in range(duped_arg_len)]
-
- deduped_flat_args = remove_dupe_args(flat_args)
-
- # Update our input metadata to remove duped input metadata.
- updated_fw_metadata = remove_dupe_metadata(fw_metadata, keep_arg_mask, add_dupe_map)
-
- if tracing_context := TracingContext.try_get() and aot_config.aot_autograd_arg_pos_to_source:
- # TODO(voz): This structure is 1:1, we could consider an alternate structure like
- # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there,
- # which feels like needless complexity for a tiny bit of efficiency at this point.
- for dupe_arg_pos, (kept_pos, keep_arg) in enumerate(zip(add_dupe_map, keep_arg_mask)):
- if not keep_arg:
- dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[dupe_arg_pos]
- kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[kept_pos]
- tracing_context.guards_context.aotautograd_guards.append(DuplicateInputs(kept_arg_source, dupe_arg_source))
-
- @wraps(flat_fn)
- def wrapped_flat_fn(*args):
- return flat_fn(*add_dupe_args(args))
-
- if config.debug_assert:
- ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
- wrapped_flat_fn,
- keep_input_mutations=fw_metadata.keep_input_mutations,
- is_train=fw_metadata.is_train,
- )(*deduped_flat_args)
- assert ref_fw_metadata == updated_fw_metadata, \
- f'ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}'
-
- compiled_fn = compiler_fn(wrapped_flat_fn, deduped_flat_args, aot_config, fw_metadata=updated_fw_metadata)
-
- if not hasattr(compiled_fn, "_boxed_call"):
- compiled_fn = make_boxed_func(compiled_fn)
-
- @wraps(compiled_fn)
- def wrapped_compiled_fn(args):
- deduped_args = remove_dupe_args(args)
- args.clear()
- return compiled_fn(deduped_args)
-
- wrapped_compiled_fn._boxed_call = True
-
- # This can be uncommented when we properly guard for duplicates,
- # but right now we must not do it.
- # if not config.debug_assert:
- # return wrapped_compiled_fn
-
- @wraps(wrapped_compiled_fn)
- def debugged_compiled_fn(args):
- # Test that the computed remove/add arg functions are an inverse
- new_args = add_dupe_args(remove_dupe_args(args))
- seen = {}
- for i, (x, y) in enumerate(zip(new_args, args)):
- seen[y] = None
- assert x is y, format_guard_bug_msg(
- aot_config,
- f"{describe_input(i, aot_config)} would be a duplicate of "
- f"{describe_input(add_dupe_map[i], aot_config)}",
- )
- # This is only an error if there is metadata mutation on both of
- # the duped arguments; in this case, we need to know what order
- # the metadata mutation applies in. You'll get the correct result
- # otherwise, because a graph that assumes distinct inputs works if
- # you dupe the inputs (the gradient contributions from each input
- # will get summed up appropriately.)
- #
- # TODO: work out how to setup this assert correctly
- """
- assert len(seen) == unique_args, format_guard_bug_msg(aot_config,
- f"there would be {unique_args} distinct arguments"
- )
- """
- return wrapped_compiled_fn(args)
-
- debugged_compiled_fn._boxed_call = True
-
- return debugged_compiled_fn
-
-# This layer handles the situation where you have two inputs that alias each other,
-# and one of the inputs is mutated.
-# We need to take special care to ensure that the mutation is applied to the other aliases in the graph.
-#
-# pre-condition: aot_wrapper_dedup has already run.
-# (This function will in theory work if there are duplicate args.
-# However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs
-# would cause us to hit that path more frequently).
-def aot_wrapper_synthetic_base(
- flat_fn,
- flat_args: List[Tensor],
- aot_config: AOTConfig,
- *,
- fw_metadata: ViewAndMutationMeta,
- # Currently, the only reason we need to plumb this bool is because
- # the synthetic base code prohibits more cases in the autograd case than the inference case.
- needs_autograd: bool,
- compiler_fn,
-):
- is_inference = not needs_autograd
- flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
- flat_args, fw_metadata.input_info, is_inference=is_inference,
- )
- # Happy path: we don't need synthetic bases
- if synthetic_base_info is None:
- return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
-
- # export path: ban synthetic bases for now, add later if requested.
- if requires_subclass_dispatch(flat_args, fw_metadata):
- raise RuntimeError("""\
-Encountered aliased inputs that are mutated in the graph, but at least one input/output
-to the graph is a tensor subclass. This is not supported today. You can try to
-remove the aliasing yourself as a workaround, or otherwise file an issue on github.""")
-
- if aot_config.is_export:
- raise RuntimeError(f"""\
-Encountered aliased inputs that are mutated in the graph you are trying to export.
-This functionality is currently not supported. If needed, please file a github issue.
-
-synthetic_base_info={str(synthetic_base_info)}
-
-fw_metadata={str(fw_metadata)}
- """)
-
- assert len(fw_metadata.input_info) == len(synthetic_base_info)
-
- # Update our forward metadata to take synthetic bases into account
- fw_metadata_updated, aliased_arg_idx_with_metadata_mutations = \
- create_synthetic_base_metadata(fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases)
-
- num_aliased_args_with_metadata_mutations = len(aliased_arg_idx_with_metadata_mutations)
-
- def unpack_synthetic_bases(primals: List[Any]) -> List[Any]:
- f_args_inner = []
- for inner_idx_or_tuple in synthetic_base_info:
- if isinstance(inner_idx_or_tuple, int):
- f_args_inner.append(primals[inner_idx_or_tuple])
- else:
- inner_base_idx, view_tensor = inner_idx_or_tuple
- base = primals[inner_base_idx]
- view_arg = gen_alias_from_base(
- base, view_tensor, view_tensor.requires_grad
- )
- f_args_inner.append(view_arg)
- return f_args_inner
-
- @wraps(flat_fn)
- def wrapped_flat_fn(*args):
- unpacked_args = unpack_synthetic_bases(args)
- # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases)
- # is to relieve the downstream logic from having to reason about mutations on inputs that alias
- # each other, by replacing aliased inputs with a synthetic base.
- # One area where this breaks down a bit however is if one of those aliased inputs
- # experienced a metadata mutation.
- # We are now obligated to reapply the metadata mutation directly to the user's input;
- # it isn't enough to apply mutations back to the synthetic base in the downstream logic.
- #
- # The way we handle this is by pretending that those aliased inputs that experience metadata mutations
- # are additional outputs in the user's forward function.
- # The downstream logic will just treat these as "user outputs that alias inputs".
- # However, we will manually grab them at runtime here, use them to reapply the metadata mutation
- # to the user inputs, and not return them to the user.
- aliased_args_with_metadata_mutations = [
- x for i, x in enumerate(unpacked_args) if i in aliased_arg_idx_with_metadata_mutations]
- if len(aliased_args_with_metadata_mutations) > 0:
- return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations
- else:
- return flat_fn(*unpacked_args)
-
- if config.debug_assert:
- ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
- wrapped_flat_fn,
- keep_input_mutations=fw_metadata.keep_input_mutations,
- is_train=fw_metadata.is_train,
- )(*flat_args_with_synthetic_bases)
- assert ref_fw_metadata == fw_metadata_updated, (
- f'ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, '
- f'\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}'
- )
-
- compiled_fn = compiler_fn(wrapped_flat_fn, flat_args_with_synthetic_bases, aot_config, fw_metadata=fw_metadata_updated)
-
- if not hasattr(compiled_fn, "_boxed_call"):
- compiled_fn = make_boxed_func(compiled_fn)
-
- @wraps(compiled_fn)
- def wrapped_compiled_fn(args):
- args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
- args, fw_metadata.input_info, is_inference=is_inference
- )
- assert synthetic_base_info is not None
- aliased_args_w_metadata_mutations = [args[i] for i in aliased_arg_idx_with_metadata_mutations]
- args.clear()
- outs = compiled_fn(args_with_synthetic_bases)
- if num_aliased_args_with_metadata_mutations > 0:
- # This code does not handle **all** input metadata mutations.
- # Instead, it only handles metadata mutations on inputs that were converted into synthetic bases
- # (which only happens if at least one aliased input experienced a data mutation).
- # e.g:
- # def f(a, b):
- # a.mul_(2)
- # b.t_(1, 0)
- # f(x.view(2, 2), x.view(2, 2))
- mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:]
- user_outs = outs[:-num_aliased_args_with_metadata_mutations]
- for inp, mutated_inp in zip(aliased_args_w_metadata_mutations, mutated_metadata_inps):
- inp.as_strided_(mutated_inp.size(), mutated_inp.stride(), mutated_inp.storage_offset())
- return user_outs
- return outs
-
- return wrapped_compiled_fn
-
-
# Has the precondition that there
# are no duplicate arguments in flat_args (e.g., the same Tensor
# object never shows up twice. However, two tensor inputs MAY alias