| # mypy: ignore-errors |
| |
| import functools |
| import warnings |
| from typing import Callable, Union |
| |
| import torch |
| import torch.utils._pytree as pytree |
| from torch._ops import OpOverload |
| from torch._subclasses.fake_tensor import ( |
| FakeTensorMode, |
| tree_flatten_only, |
| UnsupportedFakeTensorException, |
| ) |
| from torch.utils._python_dispatch import TorchDispatchMode |
| |
| |
| aten = torch._ops.ops.aten |
| |
| |
| def outputs_alias_inputs(outputs, inputs): |
| input_storages = { |
| inp._typed_storage()._cdata |
| for inp in tree_flatten_only(torch.Tensor, inputs) |
| if torch._C._has_storage(inp) |
| } |
| return any( |
| torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages |
| for out in tree_flatten_only(torch.Tensor, outputs) |
| ) |
| |
| |
| def outputs_are_inputs(outputs, inputs): |
| input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)} |
| return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs)) |
| |
| |
| def output_alias_each_other(outputs): |
| storages = set() |
| for out in tree_flatten_only(torch.Tensor, outputs): |
| if not torch._C._has_storage(out): |
| continue |
| stor = out._typed_storage()._cdata |
| if stor in storages: |
| return True |
| storages.add(stor) |
| return False |
| |
| |
| def is_sdpa_error(func, idx, e): |
| if ( |
| ( |
| func is aten._scaled_dot_product_flash_attention.default |
| or func is aten._flash_attention_forward.default |
| ) |
| and idx in (6, 7) |
| and "Devices" in repr(e) |
| ): |
| return True |
| if ( |
| ( |
| func is aten._scaled_dot_product_efficient_attention.default |
| or func is aten._efficient_attention_forward.default |
| ) |
| and idx in (2, 3) |
| and "Devices" in repr(e) |
| ): |
| return True |
| return False |
| |
| |
| class CrossRefFakeMode(TorchDispatchMode): |
| def __init__( |
| self, |
| ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None, |
| *, |
| check_strides=True, |
| check_aliasing=True, |
| ): |
| self.ignore_op_fn = ( |
| ignore_op_fn if ignore_op_fn is not None else lambda fn: False |
| ) |
| self.check_strides = check_strides |
| self.check_aliasing = check_aliasing |
| |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| kwargs = kwargs or {} |
| |
| fake_r = None |
| |
| # empty_like excluded for now due to sparse complex |
| # aten._to_dense.default this one is getting called with csc |
| if ( |
| func |
| not in ( |
| aten.lift_fresh.default, |
| aten.lift_fresh_copy.default, |
| aten.set_.source_Storage_storage_offset, |
| ) |
| and not self.ignore_op_fn(func) |
| and torch.Tag.dynamic_output_shape not in func.tags |
| and torch.Tag.inplace_view not in func.tags |
| and torch.Tag.data_dependent_output not in func.tags |
| ): |
| # Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow |
| from torch.fx.experimental.symbolic_shapes import ShapeEnv |
| |
| try: |
| # TODO: enable_python_dispatcher() here |
| with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode: |
| fake_args, fake_kwargs = pytree.tree_map_only( |
| torch.Tensor, |
| functools.partial(fake_mode.from_tensor, static_shapes=True), |
| (args, kwargs), |
| ) |
| with warnings.catch_warnings(): |
| fake_r = func(*fake_args, **fake_kwargs) |
| except UnsupportedFakeTensorException: |
| pass |
| |
| context = ( |
| f"When comparing the output of {func} on FakeTensor and concrete Tensors, " |
| f"found" |
| ) |
| r = func(*args, **kwargs) |
| if fake_r is not None: |
| r_flat = pytree.tree_leaves(r) |
| f_flat = pytree.tree_leaves(fake_r) |
| assert len(f_flat) == len( |
| r_flat |
| ), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}" |
| |
| if self.check_aliasing: |
| r_aliasing = outputs_alias_inputs(r, (args, kwargs)) |
| f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs)) |
| assert ( |
| r_aliasing == f_aliasing |
| ), f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}" |
| |
| r_identity_eq = outputs_are_inputs(r, (args, kwargs)) |
| f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs)) |
| assert ( |
| r_identity_eq == f_identity_eq |
| ), f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}" |
| |
| r_output_alias_each_other = output_alias_each_other(r) |
| f_output_alias_each_other = output_alias_each_other(fake_r) |
| assert r_output_alias_each_other == f_output_alias_each_other, ( |
| f"{context} mismatch in outputs_alias_each_other check " |
| f"{f_output_alias_each_other} != {r_output_alias_each_other}" |
| ) |
| |
| for idx, (r_out, fake_out) in enumerate( |
| zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r)) |
| ): |
| r_is_ten = isinstance(r_out, torch.Tensor) |
| assert r_is_ten == isinstance( |
| fake_out, torch.Tensor |
| ), f"{context} mismatched number of tensor outputs" |
| if r_is_ten: |
| assert r_out.requires_grad == fake_out.requires_grad, ( |
| f"{context} mismatched requires_grad-ness of outputs. " |
| f"This usually means that you have added autograd support " |
| f"for your operator at a dispatch key other than Autograd, " |
| f"which will lead to problems" |
| ) |
| if torch._C._has_storage(r_out): |
| r_offset = r_out.storage_offset() |
| f_offset = fake_out.storage_offset() |
| assert ( |
| r_offset == f_offset |
| ), f"{context} mismatched storage offset" |
| |
| try: |
| torch._prims.utils.compare_tensor_meta( |
| r_out, |
| fake_out, |
| check_strides=self.check_strides, |
| allow_rhs_unbacked=True, |
| ) |
| except Exception as e: |
| if is_sdpa_error(func, idx, e): |
| continue |
| error_message = ( |
| f"{context} mismatched tensor metadata: {e}" |
| if len(r_flat) == 1 |
| else f"{context} mismatched tensor metadata for output[{idx}]: {e}" |
| ) |
| raise RuntimeError(error_message) from e |
| return r |