| import contextlib |
| import warnings |
| import weakref |
| from typing import ContextManager, Dict, List, Optional, Tuple, TYPE_CHECKING |
| |
| import torch |
| from torch._C._functorch import ( |
| _add_batch_dim, |
| _unwrap_functional_tensor, |
| _wrap_functional_tensor, |
| current_level, |
| get_unwrapped, |
| is_batchedtensor, |
| is_functorch_wrapped_tensor, |
| is_gradtrackingtensor, |
| maybe_get_bdim, |
| maybe_get_level, |
| peek_interpreter_stack, |
| TransformType, |
| ) |
| from torch._guards import Source |
| |
| from torch.multiprocessing.reductions import StorageWeakRef |
| from torch.utils._python_dispatch import ( |
| is_traceable_wrapper_subclass, |
| transform_subclass, |
| ) |
| from torch.utils.weak import WeakIdRef |
| |
| if TYPE_CHECKING: |
| # Import the following modules during type checking to enable code intelligence features, |
| # Do not import unconditionally, as they import sympy and importing sympy is very slow |
| from torch.fx.experimental.symbolic_shapes import SymbolicContext |
| |
| DimList = List |
| |
| |
| def safe_is_leaf(t): |
| try: |
| return t.is_leaf |
| except RuntimeError: |
| # inference mode can trigger this |
| return False |
| |
| |
| def safe_grad(t): |
| with warnings.catch_warnings(): |
| warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") |
| return t.grad |
| |
| |
| def assert_eq(a, b): |
| assert a == b, f"{a} != {b}" |
| |
| |
| def assert_metadata_eq(assert_eq, m1, m2, *, skip_symbolic=False): |
| def go(m1, m2): |
| assert_eq(m1.dtype, m2.dtype) |
| if not skip_symbolic: |
| assert_eq(m1.shape, m2.shape) |
| assert_eq(m1.requires_grad, m2.requires_grad) |
| assert_eq(m1.is_leaf, m2.is_leaf) |
| assert_eq(m1.grad_fn is None, m2.grad_fn is None) |
| assert_eq(m1.is_sparse, m2.is_sparse) |
| assert_eq(m1.is_inference(), m2.is_inference()) |
| assert_eq(m1.is_conj(), m2.is_conj()) |
| assert_eq(m1.is_neg(), m2.is_neg()) |
| assert_eq(safe_grad(m1) is not None, safe_grad(m2) is not None) |
| if safe_grad(m1) is not None: |
| go(safe_grad(m1), safe_grad(m2)) |
| if m1.is_sparse: |
| assert_eq(m1.dense_dim(), m2.dense_dim()) |
| assert_eq(m1.sparse_dim(), m2.sparse_dim()) |
| assert_eq(m1.is_coalesced(), m2.is_coalesced()) |
| else: |
| if not skip_symbolic: |
| assert_eq(m1.stride(), m2.stride()) |
| assert_eq(m1.storage_offset(), m2.storage_offset()) |
| assert_eq(m1._is_view(), m2._is_view()) |
| if m1._is_view(): |
| go(m1._base, m2._base) |
| # TODO: test if is resizable (no direct query for this atm) |
| # TODO: audit AutogradMeta to see if it matches |
| # TODO: test forward AD |
| |
| return go(m1, m2) |
| |
| |
| def is_sparse_coo(t): |
| return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo |
| |
| |
| def is_sparse_compressed(t): |
| return isinstance(t, torch.Tensor) and t.layout in { |
| torch.sparse_csr, |
| torch.sparse_csc, |
| torch.sparse_bsr, |
| torch.sparse_bsc, |
| } |
| |
| |
| def is_sparse_any(t): |
| return is_sparse_coo(t) or is_sparse_compressed(t) |
| |
| |
| # This is a class for converting multiple tensors into meta tensors which |
| # share the same view/storage structure. The operation model is you allocate |
| # one of these, and then call it repeatedly on all the tensors you want to |
| # convert. It's important to use the same object for tensors you want to |
| # share storage because this is how we correlate shared storages to the same |
| # meta storages. This class will hold weak references to cached tenosrs |
| # and tensor storages. |
| class MetaConverter: |
| def __init__(self): |
| self.storage_memo = {} |
| self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary() |
| self.maybe_storages_to_delete = [] |
| self.check_expired_frequency = 128 |
| self.check_expired_count = 0 |
| self.hit = 0 |
| self.miss = 0 |
| self.del_hook = None |
| self.arg_cnt = 0 |
| |
| def successful(self): |
| return self.hit > 0 and self.miss == 0 |
| |
| def check_for_expired_weak_storages(self): |
| new_li = [] |
| stor_to_delete = [] |
| for obj in self.maybe_storages_to_delete: |
| if not obj.expired(): |
| new_li.append(obj) |
| else: |
| stor_to_delete.append(obj) |
| for obj in stor_to_delete: |
| self.storage_memo.pop(obj, None) |
| self.maybe_storages_to_delete = new_li |
| |
| # if for some reason we have aquired many storages which have not expired |
| # even though a tensor with their storage has expired (aliasing or otherwise) |
| # check for expired storages less often so as to bound the amount of work we |
| # do checking for expired storages |
| self.check_expired_frequency = max( |
| self.check_expired_frequency, len(self.maybe_storages_to_delete) |
| ) |
| |
| def get_tensor_memo(self, t): |
| return self.tensor_memo.get(WeakIdRef(t), None) |
| |
| def set_tensor_memo(self, t, v): |
| # hold a weak ref to self, otherwise it will be kept alive |
| # by the del_ten closure |
| self_weak_ref = weakref.ref(self) |
| if is_sparse_any(t) or t.is_mkldnn or is_functorch_wrapped_tensor(t): |
| weak_st = None |
| else: |
| weak_st = StorageWeakRef(t._typed_storage()) |
| tensor_ref_key = WeakIdRef(t) |
| |
| def del_ten(): |
| # tensor outlives the converter |
| self_ref = self_weak_ref() |
| if self_ref is None: |
| return |
| # on shutdown, tensor_ref_key may not be in memo |
| self_ref.tensor_memo.pop(tensor_ref_key, None) |
| if weak_st and weak_st.expired(): |
| self_ref.storage_memo.pop(weak_st, None) |
| elif weak_st is not None: |
| # [expired-storages] |
| # NB: even though the tensor has died, |
| # the deallocation of its storage can take longer, |
| # even when the storage has no other uses/views. |
| # In this case, the StorageWeakRef object will be kept alive |
| # longer than it needs to be, however the storage itself |
| # will be deallocated. We retain the possibly dead storages |
| # and periodically check if any of them are expired and |
| # can be freed. |
| self_ref.maybe_storages_to_delete.append(weak_st) |
| |
| weakref.finalize(t, del_ten) |
| self.tensor_memo[tensor_ref_key] = v |
| |
| # NB: doesn't actually return a storage, because meta storage is |
| # not supported |
| def meta_storage(self, s, callback): |
| # NB: TypedStorage is freshly allocated and cannot be used as hash |
| # key index. |
| |
| # Use a Weak Ref to s in order to not leak memory |
| swr = StorageWeakRef(s) |
| if swr not in self.storage_memo: |
| self.storage_memo[swr] = callback( |
| lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta") |
| ).untyped_storage() |
| return self.storage_memo[swr] |
| |
| # This function assumes that it's possible to do the conversion |
| # NB: name here is used in a conventional way by Dynamo; it corresponds |
| # precisely to the Source.name() of the tensor we're fakeifying and |
| # corresponds to a valid Python expression. When we construct sub-names |
| # as part of this process, we will maintain this invariant! (Even though |
| # other users of this may not need it this property to be upheld.) |
| def meta_tensor( |
| self, |
| t, |
| shape_env=None, |
| callback=lambda t: t(), |
| source: Optional[Source] = None, |
| symbolic_context: Optional["SymbolicContext"] = None, |
| ): |
| if source is None: |
| from torch._dynamo.source import ConstantSource |
| |
| # TODO: make a dedicated UnknownSource for this? |
| source = ConstantSource( |
| f"__meta_utils_unknown_tensor{len(self.tensor_memo)}" |
| ) |
| |
| # This indicates you set no_dispatch() before calling into this |
| # function. This is an error: we may be creating fake tensors and |
| # will perform operations on them which need fake tensor mode to |
| # be active. You will segfault if you are in a no_dispatch() block. |
| assert not torch._C._dispatch_tls_local_exclude_set().has( |
| torch._C.DispatchKey.Python |
| ) |
| arg_cnt = self.arg_cnt |
| self.arg_cnt += 1 |
| |
| # When we make as_strided calls, we end up generating a guard |
| # that the new as_strided tensor is in bounds for the old storage |
| # for the base (since as_strided calls can "bust" out of their |
| # bounding box.) This guard is unnecessary: if a user is able |
| # to provide us a tensor with the view base setup this way, we |
| # don't need to produce a guard, because the fact that they |
| # were able to produce the view base means its in bounds. |
| # |
| # Now, ordinarily, this guard would be harmless. However, the |
| # generated guard refers to variables bound on the base variable. |
| # At the moment, Dynamo doesn't actually guard on x._base, because |
| # according to Voz this results in a lot of spurious invalidations, |
| # and also if the user doesn't directly make use of _base, its |
| # pointless anyway (because programs should be parametric over |
| # whether or not the input tensor is a view or not--unless you're |
| # mutating the input, but that's a whole 'nother ballgame). So |
| # for expediency, we suppress these guards so we don't have to |
| # deal with this (yet, anyway.) |
| # |
| # NB: An old version of this code suppressed guards for ALL operations |
| # happening during meta conversion, not just as_strided calls. |
| # This is too aggressive: we do duck sizing and 0/1 simplification |
| # as we allocate variables, and we do need to register guards for |
| # these cases. |
| maybe_suppress = contextlib.nullcontext |
| if shape_env is not None: |
| maybe_suppress = shape_env.suppress_guards |
| |
| def sym_sizes_strides_storage_offset( |
| t, src, symbolic_context=symbolic_context |
| ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: |
| if shape_env is not None: |
| fake_mode = torch._subclasses.fake_tensor.maybe_get_fake_mode(t) |
| if fake_mode is not None and fake_mode.shape_env is shape_env: |
| # Don't reallocate the sizes; the shape envs are the same, |
| # so reuse the old sizes/strides/etc |
| return (t.size(), t.stride(), t.storage_offset()) |
| else: |
| return shape_env.create_symbolic_sizes_strides_storage_offset( |
| t, |
| src, |
| symbolic_context=symbolic_context, |
| ) |
| else: |
| assert symbolic_context is None |
| return (t.size(), t.stride(), t.storage_offset()) |
| |
| def empty_create(inner_t, inner_src, symbolic_context=symbolic_context): |
| ( |
| inner_sizes, |
| inner_strides, |
| inner_storage_offset, |
| ) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context) |
| return torch.empty_strided( |
| inner_sizes, |
| inner_strides, |
| dtype=inner_t.dtype, |
| device="meta", |
| ) |
| |
| # Creates a subclass instance with empty inner tensors according to the specified |
| # symbolic context. |
| def empty_create_subclass( |
| t, |
| outer_size, |
| outer_stride, |
| symbolic_context=symbolic_context, |
| callback=callback, |
| source=source, |
| ): |
| from torch._dynamo.source import AttrSource |
| from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext |
| |
| assert symbolic_context is None or isinstance( |
| symbolic_context, SubclassSymbolicContext |
| ) |
| |
| # Note: transform_subclass will use __tensor_unflatten__ to generate |
| # a fresh subclass wrapper with outer sizes / strides according to the |
| # outer symbolic context (passed in to this function). Inner size / stride |
| # / storage offset symbols are allocated according to the appropriate inner |
| # symbolic contexts, after which the checks in transform_subclass() will |
| # relate them to the outer metadata as possible. |
| return transform_subclass( |
| t, |
| lambda attr, inner_t: callback( |
| lambda: empty_create( |
| inner_t, |
| AttrSource(source, attr), |
| symbolic_context=( |
| None |
| if symbolic_context is None |
| else symbolic_context.inner_contexts[attr] |
| ), |
| ) |
| ), |
| outer_size=outer_size, |
| outer_stride=outer_stride, |
| ) |
| |
| # Returns an all-dynamic symbolic context used for metafying the given tensor with |
| # fully dynamic dims. This is useful when fake-ifying intermediate tensors in |
| # closed-over ViewFunc state, as we don't have symbolic contexts for them, but we |
| # don't want to over-specialize during view replay. |
| def all_dynamic_symbolic_context(t, source, shape_env, callback): |
| from torch._dynamo.source import AttrSource |
| from torch.fx.experimental.symbolic_shapes import ( |
| DimDynamic, |
| StatelessSymbolicContext, |
| SubclassSymbolicContext, |
| SymbolicContext, |
| ) |
| |
| view_base_context: Optional[SymbolicContext] = None |
| if t._is_view(): |
| view_base_context = all_dynamic_symbolic_context( |
| t._base, AttrSource(source, "_base"), shape_env, callback |
| ) |
| |
| t_symbolic_context: SymbolicContext |
| t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.dim() |
| if is_traceable_wrapper_subclass(t): |
| inner_contexts: Dict[str, SymbolicContext] = {} |
| attrs, _ = t.__tensor_flatten__() |
| for attr in attrs: |
| assert isinstance(attr, str) |
| inner = getattr(t, attr) |
| inner_contexts[attr] = all_dynamic_symbolic_context( |
| inner, AttrSource(source, attr), shape_env, callback |
| ) |
| t_symbolic_context = SubclassSymbolicContext( |
| dynamic_sizes=t_dynamic_sizes, |
| constraint_sizes=[None] * t.dim(), |
| inner_contexts=inner_contexts, |
| tensor_source=source, |
| view_base_context=view_base_context, |
| ) |
| else: |
| t_symbolic_context = StatelessSymbolicContext( |
| dynamic_sizes=t_dynamic_sizes, |
| constraint_sizes=[None] * t.dim(), |
| view_base_context=view_base_context, |
| ) |
| |
| return t_symbolic_context |
| |
| # Returns a fake-ified version of an input view tensor t, given an already fake-ified |
| # base. At a high level, we want two things: |
| # 1. fake_t should have the same view relationship to the given fake base as the |
| # input t has to its _base. |
| # 2. fake_t should have symbolic sizes / strides / storage offset according to the |
| # appropriate symbolic context (i.e. from the automatic dynamic algorithm). |
| # |
| # We currently take different strategies across view types: |
| # * For dense -> dense views, accomplish both (1) and (2) simultaneously via an |
| # as_strided() call on the fake-ified base, passing symbolic metadata. |
| # * For views involving subclasses, perform view replay using view funcs to |
| # achieve (1). It's necessary for (2) to swap out any closed-over state in |
| # the view funcs with symbolicized SymInts and fake-ified tensors. Doing this |
| # avoids specialization (and thus over-eager simplification of symbols) that |
| # could occur during view replay on the fake-ified base. |
| # |
| # Examples: |
| # * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled |
| # with an as_strided() call on the fake base passing symbolic metadata. |
| # * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg |
| # is made symbolic to avoid invalid specialization and view replay is then |
| # done to reconstruct the view. |
| # * _nested_from_jagged(values, offsets) is a dense -> subclass view |
| # that returns a subclass instance from a dense values tensor. The offsets |
| # tensor is closed over in the view func, as it can be considered view metadata. |
| # First, the offsets tensor is fake-ified according to the inner symbolic |
| # context and with the correct relationship to the outer size / stride metadata. |
| # Then view replay is done, swapping in the fake offsets so the view replay output |
| # is fully fake with no invalid specialization. |
| def view_from_base(base, t, source=source, shape_env=shape_env): |
| # fake-ify t's metadata according to the outer symbolic context |
| (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset( |
| t, source |
| ) |
| if not is_traceable_wrapper_subclass( |
| t |
| ) and not is_traceable_wrapper_subclass(base): |
| # Dense -> Dense view case uses as_strided() to construct view relationship. |
| # TODO: Change this logic to use view replay for consistency? |
| # It's likely there is no view func available. |
| return base.as_strided(sizes, strides, storage_offset) |
| |
| from torch._dynamo.source import EphemeralSource |
| from torch.fx.experimental.symbolic_shapes import sym_eq |
| |
| def symint_visitor_fn(s): |
| if shape_env is None: |
| return s |
| |
| # NB: The symbol here is expected to be simplified out because we a priori |
| # allocate inner and outer symbols according to the appropriate symbolic |
| # contexts and prefer those over this symbol during symbol simplification |
| # (via usage of EphemeralSource below). This -shouldn't- happen, but if |
| # this symbol somehow leaks out beyond the view tensor's shape metadata, our |
| # assumption of it being simplified out will fail and it may be guarded on, |
| # which will hard error. |
| sym_source = EphemeralSource("symint_visitor_fn") |
| symbol = shape_env.create_symbol(s, sym_source) |
| return shape_env.create_symintnode(symbol, hint=s, source=sym_source) |
| |
| real_to_fake_mapping = {} |
| if is_traceable_wrapper_subclass(t): |
| # Fake-ify t naively here; this is only done so we can get fake-ified inner |
| # tensors with the correct relationships to the outer sizes / strides for use |
| # in view replay. It's done beforehand here because it's not easy to do when |
| # visiting tensors one-by-one during view replay. |
| # |
| # Example: |
| # Consider a Dense -> NJT view. NJT has (values, offsets) components and we |
| # want a view of values with the offsets closed over. As the offsets component |
| # is needed to describe the output view, it's important that it's fakeified |
| # correctly. |
| fake_t = empty_create_subclass( |
| t, outer_size=sizes, outer_stride=strides |
| ) |
| attrs, _ = fake_t.__tensor_flatten__() |
| for attr in attrs: |
| real_to_fake_mapping[getattr(t, attr)] = getattr(fake_t, attr) |
| |
| def tensor_visitor_fn( |
| visited_t, shape_env=shape_env, callback=callback, source=source |
| ): |
| # It's possible to close over an undefined tensor (e.g. NJT's lengths). |
| if visited_t is None: |
| return None |
| |
| # Fake inner tensors of view subclasses will come from the mapping built above. |
| fake_visited_t = real_to_fake_mapping.get(visited_t, None) |
| if fake_visited_t is not None: |
| return fake_visited_t |
| |
| # For other closed-over tensor state, fake-ify it as all dynamic with an |
| # ephemeral source. This avoids invalid specialization during view replay. |
| # If we find that in practice the usage of ephemeral sources isn't enough |
| # to guarantee that we don't have guards on these symbols, we may need to |
| # explicitly suppress guards (as is done for _base in the dense -> dense |
| # view case). |
| temp_source = EphemeralSource("tensor_visitor_fn") |
| return self.meta_tensor( |
| visited_t, |
| shape_env, |
| callback, |
| source=temp_source, |
| symbolic_context=all_dynamic_symbolic_context( |
| visited_t, temp_source, shape_env, callback |
| ), |
| ) |
| |
| # Replay the view, swapping out any non-symbolic SymInts or real tensors |
| # for symbolic SymInts or fake tensors. |
| fake_t = t._view_func_unsafe(base, symint_visitor_fn, tensor_visitor_fn) |
| |
| # Ensure the output has symbolic shapes according to the outer symbolic context. |
| # These checks should simplify out any symbols created for closed-over view func |
| # SymInts. |
| torch._check(sym_eq(fake_t.size(), sizes)) |
| torch._check(sym_eq(fake_t.stride(), strides)) |
| torch._check(sym_eq(fake_t.storage_offset(), storage_offset)) |
| return fake_t |
| |
| # see expired-storages |
| self.check_expired_count += 1 |
| if self.check_expired_count >= self.check_expired_frequency: |
| self.check_for_expired_weak_storages() |
| self.check_expired_count = 0 |
| |
| if self.get_tensor_memo(t) is None: |
| with torch.inference_mode(t.is_inference()): |
| if t.is_sparse: |
| is_leaf = safe_is_leaf(t) |
| |
| # The lambda function below is similar to |
| # `t.to(device='meta')` except the latter |
| # preserves nnz value |
| r = callback( |
| lambda: torch.ops.aten._sparse_coo_tensor_with_dims( |
| t.sparse_dim(), |
| t.dense_dim(), |
| t.shape, |
| dtype=t.dtype, |
| layout=torch.sparse_coo, |
| device="meta", |
| ) |
| ) |
| assert safe_is_leaf(r), "the callback you passed in doesn't detach" |
| # Note [is_coalesced is dispatched] |
| # Strangely enough, is_coalesced() is a dispatched operator, |
| # which means that it will get caught by fake tensor mode. |
| # Ordinarily this would error, but there's some logic in |
| # fake tensor ensure this doesn't happen. |
| r._coalesced_(t.is_coalesced()) |
| if t.requires_grad: |
| r.requires_grad = True |
| if t.requires_grad and not is_leaf: |
| with torch.enable_grad(): |
| r = r.clone() |
| r._coalesced_(t.is_coalesced()) |
| elif is_sparse_compressed(t): |
| is_leaf = safe_is_leaf(t) |
| |
| def mk_meta(): |
| nnz = 0 |
| batch_dim = t.ndim - t.sparse_dim() - t.dense_dim() |
| batch_size = t.shape[:batch_dim] |
| if t.layout in {torch.sparse_csr, torch.sparse_bsr}: |
| index_dtype = t.crow_indices().dtype |
| compressed_indices = torch.empty( |
| t.crow_indices().shape, device="meta", dtype=index_dtype |
| ) |
| plain_indices = torch.empty( |
| (*t.col_indices().shape[:-1], nnz), |
| device="meta", |
| dtype=index_dtype, |
| ) |
| else: |
| index_dtype = t.ccol_indices().dtype |
| compressed_indices = torch.empty( |
| t.ccol_indices().shape, device="meta", dtype=index_dtype |
| ) |
| plain_indices = torch.empty( |
| (*t.row_indices().shape[:-1], nnz), |
| device="meta", |
| dtype=index_dtype, |
| ) |
| values_shape = t.values().shape |
| values = torch.empty( |
| ( |
| *values_shape[:batch_dim], |
| nnz, |
| *values_shape[batch_dim + 1 :], |
| ), |
| dtype=t.dtype, |
| device="meta", |
| ) |
| return torch.ops.aten.sparse_compressed_tensor( |
| compressed_indices, |
| plain_indices, |
| values, |
| t.shape, |
| layout=t.layout, |
| dtype=t.dtype, |
| device="meta", |
| ) |
| |
| # `mk_meta()` is similar to `t.to(device='meta'))` |
| # except `to('meta')` preserves nnz value while |
| # `mk_meta` result has nnz == 0. |
| r = callback(mk_meta) |
| |
| assert safe_is_leaf(r), "the callback you passed in doesn't detach" |
| if t.requires_grad: |
| r.requires_grad = True |
| if t.requires_grad and not is_leaf: |
| with torch.enable_grad(): |
| r = r.clone() |
| elif t.is_nested and not is_traceable_wrapper_subclass(t): |
| # TODO: Handle this better in Dynamo? |
| # There are checks there now, but this can still be triggered by a dense |
| # tensor graph input that is a view of a strided NT. |
| from torch._dynamo.exc import unimplemented |
| |
| unimplemented( |
| "strided nested tensors are not supported by meta conversion" |
| ) |
| elif t.is_mkldnn: |
| is_leaf = safe_is_leaf(t) |
| sizes, strides, _storage_offset = sym_sizes_strides_storage_offset( |
| t, source |
| ) |
| r = callback( |
| lambda: torch.empty_strided( |
| sizes, strides, dtype=t.dtype, device="meta" |
| ) |
| ) |
| assert safe_is_leaf(r), "the callback you passed in doesn't detach" |
| if t.requires_grad: |
| r.requires_grad = True |
| if t.requires_grad and not is_leaf: |
| with torch.enable_grad(): |
| r = r.clone() |
| elif is_functorch_wrapped_tensor(t): |
| if t._is_view(): |
| from torch._dynamo.exc import unimplemented |
| |
| unimplemented( |
| "view functorch tensors are not supported by meta conversion" |
| ) |
| |
| # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor) |
| # in a FakeTensor |
| def _to_fake_tensor(t): |
| if is_batchedtensor(t): |
| ft = _to_fake_tensor(get_unwrapped(t)) |
| lvl = maybe_get_level(t) |
| bdim = maybe_get_bdim(t) |
| r = _add_batch_dim(ft, bdim, lvl) |
| elif is_gradtrackingtensor(t): |
| disable_functorch = torch._C._DisableFuncTorch |
| with disable_functorch(): |
| ft = _to_fake_tensor(get_unwrapped(t)) |
| lvl = torch._C._functorch.maybe_get_level(t) |
| r = torch._C._functorch._wrap_for_grad(ft, lvl) |
| |
| is_leaf = safe_is_leaf(t) |
| if t.requires_grad and safe_is_leaf(r): |
| r.requires_grad = True |
| elif t.requires_grad and not is_leaf: |
| with torch.enable_grad(): |
| r = r.clone() |
| else: |
| sizes = t.size() |
| strides = t.stride() |
| r = callback( |
| lambda: torch.empty_strided( |
| sizes, |
| strides, |
| dtype=t.dtype, |
| device="meta", |
| ) |
| ) |
| return r |
| |
| r = _to_fake_tensor(t) |
| |
| elif t._is_view(): |
| # Construct views in two steps: recursively meta-fy their |
| # base, and then create view(s) off that. NB: doing it |
| # directly from storage is WRONG because this won't cause |
| # version counters to get shared. |
| assert t._is_view() |
| |
| base_symbolic_context = None |
| if shape_env and symbolic_context is not None: |
| from torch.fx.experimental.symbolic_shapes import ( |
| StatelessSymbolicContext, |
| ) |
| |
| assert isinstance(symbolic_context, StatelessSymbolicContext) |
| # NB: This should generally be set when the input is a view, |
| # but the exception right now is for fake-ifying grads, which is |
| # a work in progress. |
| if symbolic_context.view_base_context is not None: |
| base_symbolic_context = symbolic_context.view_base_context |
| |
| base = self.meta_tensor( |
| t._base, |
| shape_env, |
| callback, |
| source=torch._dynamo.source.AttrSource(source, "_base"), |
| symbolic_context=base_symbolic_context, |
| ) |
| |
| def is_c_of_r(complex_dtype, real_dtype): |
| return ( |
| utils.is_complex_dtype(complex_dtype) |
| and utils.corresponding_real_dtype(complex_dtype) |
| == real_dtype |
| ) |
| |
| # In some situations, MetaConverter may be called in a |
| # context where autograd is disabled. For the _is_view |
| # assert to pass, we have to setup the autograd view |
| # metadata anyway. Do this by reenabling the |
| # ADInplaceOrView key. This is kind of a hack. |
| old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded( |
| torch._C.DispatchKey.ADInplaceOrView |
| ) |
| torch._C._dispatch_tls_set_dispatch_key_excluded( |
| torch._C.DispatchKey.ADInplaceOrView, False |
| ) |
| try: |
| if base.dtype == t.dtype: |
| pass |
| elif is_c_of_r(base.dtype, t.dtype): |
| base = torch.view_as_real(base) |
| elif is_c_of_r(t.dtype, base.dtype): |
| base = torch.view_as_complex(base) |
| else: |
| # This is not guaranteed to succeed. If it fails, it |
| # means there is another dtype-converting view function |
| # that hasn't been handled here |
| base = base.view(t.dtype) |
| |
| # This is very tricky. Naively, you might expect this |
| # to hold: |
| # |
| # if t.requires_grad and not safe_is_leaf(t) |
| # assert t._base.requires_grad |
| # |
| # But it's not true! As you can see in the following |
| # program: |
| # |
| # x = torch.zeros(4) |
| # y = x.view(1, 4) |
| # y.requires_grad = True |
| # z = y.view(1, 1, 4) |
| # assert z._base is x |
| # |
| # So we may have to do *two* views out of the base to |
| # recreate this situation. |
| if safe_is_leaf(t): |
| # Leaf views that track view metadata are created by |
| # creating a view inside a no_grad block |
| with torch.no_grad(), maybe_suppress(): |
| r = view_from_base(base, t) |
| # As it's a leaf, we can directly assign requires_grad |
| r.requires_grad = t.requires_grad |
| else: |
| if t._base.requires_grad == t.requires_grad: |
| # Easy case, just run the view op |
| with torch.enable_grad(), maybe_suppress(): |
| r = view_from_base(base, t) |
| |
| # NB: We don't actaully faithfully replicate |
| # autograd connectivity, but that doesn't matter |
| # today. See following for more info: |
| # https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913 |
| else: |
| # Obscure case. Create a leaf view and give it the |
| # correct requires_grad, then do the final view. |
| # NB: Can't have a non-leaf without requiring grad! |
| assert t.requires_grad |
| with torch.no_grad(): |
| mid = base.view(base.shape) |
| mid.requires_grad = t.requires_grad |
| with torch.enable_grad(), maybe_suppress(): |
| r = view_from_base(mid, t) |
| # The CreationMeta influences whether or not inplace |
| # mutation is an error or not. So we need to make |
| # sure we properly propagate this as well. |
| torch._C._autograd._set_creation_meta( |
| r, torch._C._autograd._get_creation_meta(t) |
| ) |
| finally: |
| torch._C._dispatch_tls_set_dispatch_key_excluded( |
| torch._C.DispatchKey.ADInplaceOrView, old_exclude |
| ) |
| |
| else: |
| is_leaf = safe_is_leaf(t) |
| |
| ( |
| sizes, |
| strides, |
| storage_offset, |
| ) = sym_sizes_strides_storage_offset(t, source, symbolic_context) |
| |
| # If we have a subclass that desugars into dense tensors, |
| # perform our callback on each inner tensor. |
| if is_traceable_wrapper_subclass(t): |
| r = empty_create_subclass( |
| t, outer_size=sizes, outer_stride=strides |
| ) |
| else: |
| r = callback( |
| lambda: torch.empty_strided( |
| sizes, |
| strides, |
| dtype=t.dtype, |
| device="meta", |
| ) |
| ) |
| |
| assert safe_is_leaf(r), "the callback you passed in doesn't detach" |
| if t.requires_grad: |
| r.requires_grad = t.requires_grad |
| if not is_leaf: |
| # Fake up some autograd history. |
| with torch.enable_grad(): |
| # preserve_format is the default, but we want to |
| # emphasize how important it is to preserve |
| # format here |
| r = r.clone(memory_format=torch.preserve_format) |
| |
| # Graph-Break for wrapped tensors |
| if not ( |
| is_batchedtensor(t) or is_gradtrackingtensor(t) |
| ) and torch._C._functorch.is_functorch_wrapped_tensor(t): |
| return NotImplemented |
| |
| s = t.untyped_storage() |
| swr = StorageWeakRef(s) |
| if swr not in self.storage_memo and ( |
| r.is_nested |
| or ( |
| r.stride() == strides |
| and r.storage_offset() == storage_offset |
| ) |
| ): |
| # You're normal and happy, install the fresh storage into the memo |
| self.storage_memo[swr] = r.untyped_storage() |
| else: |
| # You're in crazy town; somehow you gave us a tensor |
| # that wasn't a view, but had nonzero storage offset, |
| # nontrivial strides (such that clone() couldn't |
| # preserve them), or already aliases with another |
| # tensor's storage. The most typical way to end |
| # up here is with set_. So use set_ to bludgeon this |
| # in. |
| r_s = self.meta_storage(s, callback=callback) |
| # NB: In principle, this should always work, but there |
| # is some subtle difference in the autograd metadata |
| # that means we will backprop the set_ call, even if |
| # r is declared as an input to grad. |
| # See https://github.com/pytorch/pytorch/issues/87956 |
| # for the reproducer. |
| # NB: The in_kernel_invocation_manager here is necessary |
| # for fake tensor. If we run the set_ call with fake |
| # tensor on, r will improperly report that it is NOT a |
| # meta tensor but a cpu tensor, and then the set_ call |
| # will fail due to device mismatch. no_dispatch() is |
| # not enough, because the fake tensor will still claim |
| # to be a CPU tensor and you'll end up in the CPU |
| # kernel. Arguably this is a hack; a cleaner way to |
| # solve this is to have a FakeStorage concept which |
| # would report it's CPU device--no problem now! But |
| # this is difficult to do because we don't have storage |
| # subclasses. Relevant test is |
| # DynamicShapesFunctionTests::test_add_dynamic_shapes in |
| # test/dynamo/test_dynamic_shapes.py |
| maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext() |
| from torch._subclasses.fake_tensor import ( |
| in_kernel_invocation_manager, |
| maybe_get_fake_mode, |
| ) |
| |
| mb_fake_mode = maybe_get_fake_mode(r) |
| if mb_fake_mode is not None: |
| maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode) |
| with maybe_fake_mgr, torch.no_grad(): |
| r.set_(r_s, storage_offset, sizes, strides) |
| |
| if safe_grad(t) is not None: |
| from torch._dynamo.source import AttrSource |
| |
| # TODO: Use a valid grad-specific symbolic context instead of recycling |
| # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view(). |
| r.grad = self.meta_tensor( |
| safe_grad(t), |
| shape_env, |
| callback, |
| source=AttrSource(source, "grad"), |
| symbolic_context=symbolic_context, |
| ) |
| torch._C._set_conj(r, t.is_conj()) |
| torch._C._set_neg(r, t.is_neg()) |
| # This can be skipped if necessary for performance reasons |
| assert_metadata_eq(assert_eq, t, r, skip_symbolic=True) |
| self.set_tensor_memo(t, r) |
| |
| return self.get_tensor_memo(t) |
| |
| def __call__( |
| self, |
| t, |
| shape_env=None, |
| *, |
| callback=lambda t: t(), |
| source=None, |
| symbolic_context=None, |
| ): |
| # TODO: zero tensors? We appear to have eliminated them by |
| # excluding complex for now |
| |
| if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t): |
| if t.device.type != "xla" and any( |
| [ |
| t.is_quantized, |
| t._is_view() and t._base is not None and t._base.is_sparse, |
| torch._is_functional_tensor(t), |
| t.device.type in ("lazy"), |
| # We need a way to test if a tensor is batched but there |
| # is no official APi to do it |
| # torch._C._is_batched(t), |
| ] |
| ): |
| # TODO: sparse should support meta |
| # NB technically to('meta') does work but our logging |
| # instrumentation will see the meta conversions and the |
| # tests all break so we just exclude this. In any case |
| # the to conversion isn't really right anyhow. |
| |
| if torch._is_functional_tensor(t) and t.device.type != "lazy": |
| if t._is_view(): |
| raise RuntimeError( |
| "Cannot safely fakify a view because this process drops the view information right now." |
| ) |
| |
| st = peek_interpreter_stack() |
| assert ( |
| st is None or st.key() == TransformType.Functionalize |
| ), "Expect st to be either None or have Functionalize transform key." |
| if st is None: |
| # the case of AOTAutograd |
| torch._sync(t) |
| unwrap_t = torch._from_functional_tensor(t) |
| with torch._dispatch.python.suspend_functionalization(): |
| fake_t = self.meta_tensor( |
| unwrap_t, |
| shape_env=shape_env, |
| callback=callback, |
| source=source, |
| symbolic_context=symbolic_context, |
| ) |
| out = torch._to_functional_tensor(fake_t) |
| torch._mirror_autograd_meta_to(fake_t, out) |
| return out |
| else: |
| # torch.func.functionalize |
| reapply_views = torch._C._functionalization_reapply_views_tls() |
| unwrap_t = _unwrap_functional_tensor(t, reapply_views) |
| pop_st_ctx = ( |
| torch._functorch.pyfunctorch.temporarily_pop_interpreter_stack() |
| ) |
| with pop_st_ctx: |
| fake_t = self.meta_tensor( |
| unwrap_t, |
| shape_env=shape_env, |
| callback=callback, |
| source=source, |
| symbolic_context=symbolic_context, |
| ) |
| return _wrap_functional_tensor(fake_t, current_level()) |
| self.miss += 1 |
| return NotImplemented |
| else: |
| self.hit += 1 |
| |
| disable_functorch = torch._C._DisableFuncTorch |
| with disable_functorch(): |
| r = self.meta_tensor( |
| t, |
| shape_env=shape_env, |
| callback=callback, |
| source=source, |
| symbolic_context=symbolic_context, |
| ) |
| if type(t) is torch.nn.Parameter: |
| # NB: Cannot directly use Parameter constructor |
| # because that would force a detach, not desirable |
| r._is_param = True |
| return r |
| elif torch.overrides.is_tensor_like(t): |
| self.miss += 1 |
| return NotImplemented |
| else: |
| # non-Tensor types don't count as hit or miss |
| return t |
| |
| |
| import torch._prims_common as utils |