| # mypy: ignore-errors |
| |
| import functools |
| import itertools |
| import math |
| import sys |
| from typing import Callable, Union |
| |
| import torch |
| import torch._custom_op |
| import torch._logging |
| |
| from torch._ops import OpOverload |
| from torch._prims_common import ( |
| elementwise_dtypes, |
| ELEMENTWISE_TYPE_PROMOTION_KIND, |
| is_boolean_dtype, |
| is_float_dtype, |
| is_integer_dtype, |
| ) |
| |
| from torch._subclasses.fake_tensor import ( |
| DataDependentOutputException, |
| DynamicOutputShapeException, |
| FakeTensor, |
| in_kernel_invocation_manager, |
| run_fallback_kernel, |
| UnsupportedOperatorException, |
| ) |
| from torch.fx.operator_schemas import normalize_function |
| |
| from torch.utils._stats import count_label |
| |
| pytree = torch.utils._pytree |
| |
| __all__ = [ |
| "op_implementations_checks", |
| "get_fast_op_impls", |
| "stride_incorrect_op", |
| "has_meta", |
| ] |
| |
| op_implementations_dict = {} |
| op_implementations_checks = [] |
| |
| |
| aten = torch._ops.ops.aten |
| |
| |
| def ordered_set(*items): |
| return dict.fromkeys(items, True) |
| |
| |
| # This function indicates if the backend device |
| # supports non-contiguous tensors |
| def is_noncontiguous_supported(device): |
| if device.type == "hpu": |
| return False |
| return True |
| |
| |
| _like_tensor_constructors = ordered_set( |
| aten.empty_like.default, |
| aten.empty_like.out, |
| aten.full_like.default, |
| aten.full_like.out, |
| aten.ones_like.default, |
| aten.ones_like.out, |
| aten.rand_like.default, |
| aten.rand_like.out, |
| aten.randn_like.default, |
| aten.randn_like.out, |
| aten.randint_like.default, |
| aten.randint_like.out, |
| aten.randint_like.low_dtype, |
| aten.randint_like.low_dtype_out, |
| aten.zeros_like.default, |
| aten.zeros_like.out, |
| aten.new_empty.default, |
| aten.new_empty.out, |
| aten.new_empty_strided.default, |
| aten.new_empty_strided.out, |
| aten.new_full.default, |
| aten.new_full.out, |
| aten.new_zeros.default, |
| aten.new_zeros.out, |
| aten.new_ones.default, |
| aten.new_ones.out, |
| ) |
| |
| |
| _device_not_kwarg_ops = ordered_set( |
| aten._resize_output_.default, |
| aten._nested_tensor_from_tensor_list.default, |
| aten._nested_tensor_from_tensor_list.out, |
| aten.pin_memory.default, |
| aten.is_pinned.default, |
| aten.to.device, |
| aten.to.prim_Device, |
| aten._pin_memory.default, |
| aten._pin_memory.out, |
| aten._resize_output.default, |
| aten._resize_output.out, |
| ) |
| |
| # this op is never actually used |
| _non_kwarg_device_constructors = (aten._list_to_tensor,) |
| |
| |
| def contains_tensor_types(type): |
| tensor_type = torch._C.TensorType.get() |
| return type.isSubtypeOf(tensor_type) or any( |
| contains_tensor_types(e) for e in type.containedTypes() |
| ) |
| |
| |
| @functools.lru_cache(None) |
| def _is_tensor_constructor(func: OpOverload): |
| assert isinstance(func, OpOverload) |
| schema = func._schema |
| if any(contains_tensor_types(arg.type) for arg in schema.arguments): |
| return False |
| # TODO: no real reason to restrict multiple outputs |
| return ( |
| len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get() |
| ) |
| |
| |
| def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]): |
| def impl_decorator(op_impl): |
| if isinstance(run_impl_check, OpOverload): |
| assert ( |
| run_impl_check not in op_implementations_dict |
| ), f"duplicate registration: {run_impl_check}" |
| op_implementations_dict[run_impl_check] = op_impl |
| elif isinstance(run_impl_check, (list, tuple)): |
| for op in run_impl_check: |
| register_op_impl(op)(op_impl) |
| else: |
| assert callable(run_impl_check) |
| op_implementations_checks.append((run_impl_check, op_impl)) |
| |
| return op_impl |
| |
| return impl_decorator |
| |
| |
| @register_op_impl(op_implementations_dict.__contains__) |
| def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs): |
| return op_implementations_dict[func](fake_mode, func, *args, **kwargs) |
| |
| |
| @register_op_impl(_is_tensor_constructor) |
| @register_op_impl([*_like_tensor_constructors]) |
| def constructors(fake_mode, func, *args, **kwargs): |
| assert func not in _non_kwarg_device_constructors |
| _, new_kwargs = normalize_function( |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| if "names" in kwargs: |
| raise UnsupportedOperatorException( |
| "torch.compile doesn't support named tensors" |
| ) |
| |
| if func in _like_tensor_constructors: |
| default_device = new_kwargs["input"].device |
| # TODO: file issue |
| args = (new_kwargs.pop("input"),) |
| else: |
| # cpu is default device if none is specified |
| default_device = torch.device("cpu") |
| args = () |
| out_device = new_kwargs.pop("device", None) |
| out_device = out_device if out_device is not None else default_device |
| new_kwargs["device"] = torch.device("meta") |
| # _like constructors have fake tensor inputs (maybe this causes the non-like |
| # to fail? hmmm) |
| with in_kernel_invocation_manager(fake_mode): |
| r = func(*args, **new_kwargs) |
| return FakeTensor(fake_mode, r, out_device) |
| |
| |
| @register_op_impl(aten.to.prim_Device) |
| @register_op_impl(aten.to.device) |
| def non_kwarg_to(fake_mode, func, *args, **kwargs): |
| _, new_kwargs = normalize_function( |
| func, args, kwargs, normalize_to_only_use_kwargs=True |
| ) |
| input_device = new_kwargs["device"] |
| out_device = input_device if input_device else new_kwargs["input"].device |
| new_kwargs["device"] = torch.device("meta") |
| inp = new_kwargs.pop("input") |
| with in_kernel_invocation_manager(fake_mode): |
| r = func(inp, **new_kwargs) |
| # TODO: I think this does the wrong thing if r is inp |
| return fake_mode.fake_tensor_converter.from_meta_and_device( |
| fake_mode, r, out_device |
| ) |
| |
| |
| def stride_incorrect_op(op): |
| if op.namespace not in ("aten", "prims"): |
| return False |
| if op is aten._fft_c2c.default: |
| return False |
| |
| op_name = op.name() |
| if "fft" in op_name: |
| return True |
| return False |
| |
| |
| # These operators have meta implementations with incorrect strides |
| @register_op_impl(stride_incorrect_op) |
| def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs): |
| # This is a workaround for meta implmentations with incorrect strides |
| |
| def is_symbolic(x): |
| if isinstance(x, FakeTensor): |
| return x._has_symbolic_sizes_strides |
| if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)): |
| return True |
| return False |
| |
| # For static shapes, we can fall back to eager for the real strides |
| if fake_mode.allow_fallback_kernels: |
| require_dynamic = any( |
| is_symbolic(x) for x in itertools.chain(args, kwargs.values()) |
| ) |
| if not require_dynamic: |
| flat_args, args_spec = pytree.tree_flatten((args, kwargs)) |
| return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None) |
| |
| raise UnsupportedOperatorException(func) |
| |
| |
| # Dont default to default device handling, |
| # since the device of `the_template` is ignored |
| @register_op_impl(aten.resize_as_.default) |
| def resize_as_(fake_mode, func, *args, **kwargs): |
| with in_kernel_invocation_manager(fake_mode): |
| return func(*args, **kwargs) |
| |
| |
| @register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default) |
| def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs): |
| # TODO: remove me |
| return constructors(fake_mode, func, *args, **kwargs) |
| |
| |
| # index.Tensor data-dependent in only some conditions |
| @register_op_impl( |
| lambda func: torch.Tag.dynamic_output_shape in func.tags |
| and func |
| not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor] |
| ) |
| def dyn_shape(fake_mode, func, *args, **kwargs): |
| raise DynamicOutputShapeException(func) |
| |
| |
| @register_op_impl(aten.repeat_interleave.Tensor) |
| def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): |
| if output_size is None: |
| if ( |
| fake_mode.shape_env is None |
| or not fake_mode.shape_env.allow_dynamic_output_shape_ops |
| ): |
| raise DynamicOutputShapeException(func) |
| |
| output_size = fake_mode.shape_env.create_unbacked_symint() |
| |
| # Avoid importing sympy at a module level |
| from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size |
| |
| _constrain_range_for_size(output_size) |
| # TODO: consider a memo |
| return repeats.new_empty(output_size) |
| |
| |
| @register_op_impl(torch.ops.aten._local_scalar_dense.default) |
| def local_scalar_dense(fake_mode, func, arg): |
| if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs: |
| # Without symints/symfloats, cannot handle this |
| raise DataDependentOutputException(func) |
| if is_float_dtype(arg.dtype): |
| return fake_mode.shape_env.create_unbacked_symfloat() |
| elif is_integer_dtype(arg.dtype): |
| return fake_mode.shape_env.create_unbacked_symint() |
| elif is_boolean_dtype(arg.dtype): |
| return fake_mode.shape_env.create_unbacked_symbool() |
| else: |
| raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}") |
| |
| |
| @register_op_impl(torch.ops.aten.nonzero.default) |
| def nonzero(fake_mode, func, arg): |
| if ( |
| fake_mode.shape_env is None |
| or not fake_mode.shape_env.allow_dynamic_output_shape_ops |
| ): |
| # Without symints/symfloats, cannot handle this |
| raise DynamicOutputShapeException(func) |
| |
| if arg.nonzero_memo is None: |
| nnz = fake_mode.shape_env.create_unbacked_symint() |
| |
| # This is unsound, but it works well in practice |
| # See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit# |
| # TODO: Add a config knob to turn off this unsound behavior |
| # |
| # NB: If numel < 2, the bounds here might be COMPLETELY |
| # disjoint with what can actually occur. But this is fine: |
| # remember, the hypothesis is that if your later code works |
| # with N >= 2, it will work with N = 1 and N = 0. |
| maxval = sys.maxsize - 1 |
| |
| # Avoid importing sympy at a module level |
| from torch.fx.experimental.symbolic_shapes import ( |
| _constrain_range_for_size, |
| has_free_symbols, |
| ) |
| |
| if not has_free_symbols(arg.numel()): |
| # Don't upgrade the range if numel is less than two, since we then |
| # have an empty range which makes things go explodey. We also |
| # don't allow for 2 because that would specialize the unbacked |
| # SymInt to 2, which is also likely to be buggy. |
| if arg.numel() > 2: |
| maxval = int(arg.numel()) |
| |
| _constrain_range_for_size(nnz, max=maxval) |
| |
| arg._nonzero_memo = nnz |
| arg._nonzero_memo_vc = arg._version |
| |
| return arg.new_empty((arg.nonzero_memo, arg.dim()), dtype=torch.int64) |
| |
| |
| @register_op_impl(torch.ops.aten.masked_select.default) |
| def masked_select(fake_mode, func, self, mask): |
| if ( |
| fake_mode.shape_env is None |
| or not fake_mode.shape_env.allow_dynamic_output_shape_ops |
| ): |
| # Without symints/symfloats, cannot handle this |
| raise DynamicOutputShapeException(func) |
| |
| nnz = fake_mode.shape_env.create_unbacked_symint() |
| |
| # see nonzero for commentary |
| maxval = sys.maxsize - 1 |
| |
| # Avoid importing sympy at a module level |
| from torch.fx.experimental.symbolic_shapes import ( |
| _constrain_range_for_size, |
| has_free_symbols, |
| ) |
| |
| if not has_free_symbols(self.numel()): |
| if self.numel() > 2: |
| maxval = int(self.numel()) |
| |
| _constrain_range_for_size(nnz, max=maxval) |
| |
| return self.new_empty((nnz,)) |
| |
| |
| # NB: this must be ordered after local_scalar_dense |
| @register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags) |
| def data_dep(fake_mode, func, *args, **kwargs): |
| raise DataDependentOutputException(func) |
| |
| |
| # Bool Indices get Expanded as Masks |
| # See: IndexingUtils.h:expandTensors |
| def check_no_bool_index_tensors(func, self, indices): |
| for index in indices: |
| if index is not None and index.dtype in (torch.bool, torch.uint8): |
| raise DynamicOutputShapeException(func) |
| |
| |
| def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs): |
| _, new_kwargs = normalize_function( |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| out_device = new_kwargs["input"].device |
| with in_kernel_invocation_manager(fake_mode): |
| out = func(*args, **kwargs) |
| if not is_noncontiguous_supported(out_device): |
| out = out.new_empty(out.shape) |
| |
| if out is new_kwargs["input"]: |
| return out # copy_ |
| return FakeTensor(fake_mode, out, out_device) |
| |
| |
| _is_builtin_namespaces = ordered_set("aten", "prims", "prim") |
| |
| |
| def is_builtin(op): |
| return op.namespace in _is_builtin_namespaces |
| |
| |
| def has_meta(func): |
| return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta") |
| |
| |
| @register_op_impl( |
| lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func) |
| ) |
| def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs): |
| tensor_lists = [] |
| for arg in itertools.chain(args, kwargs.values()): |
| if ( |
| isinstance(arg, (list, tuple)) |
| and len(arg) |
| and isinstance(arg[0], torch.Tensor) |
| ): |
| tensor_lists.append(arg) |
| |
| try: |
| with in_kernel_invocation_manager(fake_mode): |
| out_meta = func(*args, **kwargs) |
| except NotImplementedError as not_implemented_error: |
| return NotImplemented |
| |
| if not out_meta: |
| return out_meta |
| |
| assert tensor_lists |
| out_fake = [] |
| |
| for i, meta_t in enumerate(out_meta): |
| device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists]) |
| out_fake.append( |
| fake_mode.fake_tensor_converter.from_meta_and_device( |
| fake_mode, meta_t, device |
| ) |
| ) |
| |
| return out_fake |
| |
| |
| # Dont default to default device handling, |
| # Since op can take in non-zero sized cpu |
| # index tensors with cuda self |
| @register_op_impl(aten.index.Tensor) |
| def index_tensor(fake_mode, func, *args, **kwargs): |
| from torch._meta_registrations import meta_index_Tensor |
| |
| _, new_kwargs = normalize_function( |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| out_device = new_kwargs["input"].device |
| # ensure nonzero call goes to fake tensor |
| with fake_mode: |
| out = meta_index_Tensor(*args, **kwargs) |
| return out.to(out_device) |
| |
| |
| # Can take mixed meta/non-meta arguments; the meta registration |
| # will roughly do the right thing even when given real devices |
| @register_op_impl(aten._embedding_bag.default) |
| def embedding_bag(fake_mode, func, *args, **kwargs): |
| from torch._meta_registrations import meta_embedding_bag |
| |
| with fake_mode: |
| return meta_embedding_bag(*args, **kwargs) |
| |
| |
| # takes in multiple-devices, dont default to default device handling |
| @register_op_impl(aten._unsafe_index_put.default) |
| @register_op_impl(aten.copy.default) |
| @register_op_impl(aten.copy_.default) |
| @register_op_impl(aten.slice_scatter.default) |
| def multi_device_op_default(fake_mode, func, *args, **kwargs): |
| return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) |
| |
| |
| # same with multi_device_op_default, but return the input |
| @register_op_impl(aten.copy.out) |
| @register_op_impl(aten.slice_scatter.out) |
| def multi_device_op_out(fake_mode, func, *args, **kwargs): |
| with in_kernel_invocation_manager(fake_mode): |
| out = func(*args, **kwargs) |
| |
| _, new_kwargs = normalize_function( |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| return new_kwargs["input"] |
| |
| |
| @register_op_impl(aten.index_put.default) |
| @register_op_impl(aten.index_put_.default) |
| def index_put_impl(fake_mode, func, *args, **kwargs): |
| _, new_kwargs = normalize_function( |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| values = new_kwargs["values"] |
| self_device = new_kwargs["input"].fake_device |
| torch._check( |
| self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1), |
| lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})", |
| ) |
| |
| out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) |
| if func is aten.index_put_.default: |
| return new_kwargs["input"] |
| else: |
| return out |
| |
| |
| @register_op_impl(aten._nested_tensor_from_tensor_list.default) |
| @register_op_impl(aten._nested_tensor_from_tensor_list.out) |
| def nested_tensors_unsupported(fake_mode, func, *args, **kwargs): |
| raise UnsupportedOperatorException( |
| "torch.compile does not support strided NestedTensor" |
| ) |
| |
| |
| @register_op_impl( |
| [ |
| x |
| for x in _device_not_kwarg_ops |
| if x |
| not in ( |
| # these are already registered elsewhere |
| aten.to.device, |
| aten.to.prim_Device, |
| aten._nested_tensor_from_tensor_list.default, |
| aten._nested_tensor_from_tensor_list.out, |
| ) |
| ] |
| ) |
| def nyi(fake_mode, func, *args, **kwargs): |
| assert func not in _device_not_kwarg_ops, f"NYI: {func}" |
| |
| |
| @register_op_impl([aten.convolution.default, aten.convolution_backward.default]) |
| def conv(fake_mode, func, *args, **kwargs): |
| _, kwargs = normalize_function( |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| device = kwargs["input"].fake_device |
| # need to re-enable mode so the tensors report fake device |
| with fake_mode: |
| # if the input is unsqueezed is done in Convolution.cpp we get segfault |
| k = kwargs["weight"].ndim |
| batch = kwargs["input"].shape[0] |
| |
| # Avoid importing sympy at a module level |
| from torch.fx.experimental.symbolic_shapes import has_hint |
| |
| if not has_hint(batch): |
| # TODO: We can make this a little more faithful with best effort |
| # channels last detection (but only if it's statically obvious!) |
| mem_fmt = None |
| elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: |
| mem_fmt = None |
| else: |
| if func is aten.convolution.default: |
| conv_backend = torch._C._select_conv_backend(**kwargs) |
| else: |
| conv_backend = torch._C._select_conv_backend( |
| kwargs["input"], |
| kwargs["weight"], |
| bias=None, |
| stride=kwargs["stride"], |
| padding=kwargs["padding"], |
| dilation=kwargs["dilation"], |
| transposed=kwargs["transposed"], |
| output_padding=kwargs["output_padding"], |
| groups=kwargs["groups"], |
| bias_sizes=kwargs["bias_sizes"], |
| ) |
| mem_fmt = torch._C._conv_determine_backend_memory_format( |
| kwargs["input"], kwargs["weight"], conv_backend |
| ) |
| |
| def convert(t, mem_fmt): |
| if t is None: |
| return t |
| if mem_fmt is not None: |
| t = t.to(memory_format=mem_fmt) |
| return FakeTensor(fake_mode, t, device) |
| |
| with in_kernel_invocation_manager(fake_mode): |
| out = func(**kwargs) |
| |
| if func is aten.convolution.default: |
| return convert(out, mem_fmt) |
| else: |
| return ( |
| convert(out[0], mem_fmt), |
| convert(out[1], mem_fmt), |
| convert(out[2], None), |
| ) |
| |
| |
| @register_op_impl(aten._scaled_dot_product_flash_attention.default) |
| def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs): |
| _, kwargs = normalize_function( |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| query = kwargs["query"] |
| key = kwargs["key"] |
| return_debug_mask = kwargs["return_debug_mask"] |
| # unused: value, dropout_p, is_causal, scale |
| |
| def convert_tensor(t, device): |
| return FakeTensor(fake_mode, t, device) |
| |
| batch_size = query.size(0) |
| num_heads = query.size(1) |
| max_seqlen_batch_q = query.size(2) |
| head_dim = query.size(3) |
| max_seqlen_batch_k = key.size(2) |
| |
| query_t = query.transpose(1, 2) |
| # empty_like already returns a fake tensor so we don't need to convert it |
| attention = torch.empty_like(query_t).transpose(1, 2) |
| logsumexp = convert_tensor( |
| torch.empty( |
| (batch_size, num_heads, max_seqlen_batch_q), |
| dtype=torch.float, |
| device="meta", |
| ), |
| device=query.device, |
| ) |
| |
| if return_debug_mask: |
| blocksize_c = 128 if head_dim > 64 else 256 |
| max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) |
| if max_seqlen_batch_k <= 128: |
| max_seqlen_k = 128 |
| elif max_seqlen_batch_k <= 256: |
| max_seqlen_k = 256 |
| debug_mask = convert_tensor( |
| torch.empty( |
| (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), |
| dtype=query.dtype, |
| device="meta", |
| ), |
| device=query.device, |
| ) |
| else: |
| debug_mask = convert_tensor( |
| torch.empty(0, dtype=query.dtype, device="meta"), |
| query.device, |
| ) |
| |
| # Note [Seed and Offset]: device for seed and offset below depends on whether we are |
| # capturing or not, but at the time of tracing we don't know if we |
| # are going to use cudagraphs or not, so we return meta tensors here |
| # it's possible we'll need to have some special handling in inductor for sdpa |
| |
| return ( |
| attention, |
| logsumexp, |
| None, |
| None, |
| max_seqlen_batch_q, |
| max_seqlen_batch_k, |
| convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), |
| convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), |
| debug_mask, |
| ) |
| |
| |
| @register_op_impl(aten._scaled_dot_product_efficient_attention.default) |
| def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs): |
| _, kwargs = normalize_function( |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| query = kwargs["query"] |
| key = kwargs["key"] |
| value = kwargs["value"] |
| compute_log_sumexp = kwargs["compute_log_sumexp"] |
| # unused: attn_bias, dropout_p, is_causal, scale |
| |
| def convert_tensor(t, device): |
| return FakeTensor(fake_mode, t, device) |
| |
| query = query.transpose(1, 2) |
| key = key.transpose(1, 2) |
| value = value.transpose(1, 2) |
| |
| B = query.size(0) |
| M = query.size(1) |
| N = key.size(1) |
| num_heads = query.size(-2) |
| K = query.size(-1) |
| Kv = value.size(-1) |
| |
| res = convert_tensor( |
| torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), |
| query.device, |
| ) |
| |
| logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 |
| logsum_exp = convert_tensor( |
| torch.empty( |
| (B, num_heads, logsumexp_dim), |
| dtype=torch.float, |
| device="meta", |
| ), |
| query.device, |
| ) |
| |
| res = res.transpose(1, 2) |
| |
| # See Note [Seed and Offset]: |
| seed = convert_tensor( |
| torch.empty((), dtype=torch.long, device="meta"), query.device |
| ) |
| offset = convert_tensor( |
| torch.empty((), dtype=torch.long, device="meta"), query.device |
| ) |
| |
| return res, logsum_exp, seed, offset |
| |
| |
| @register_op_impl(aten._flash_attention_forward.default) |
| def meta__flash_attention_forward(fake_mode, func, *args, **kwargs): |
| _, kwargs = normalize_function( |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| query = kwargs["query"] |
| key = kwargs["key"] |
| cum_seq_q = kwargs["cum_seq_q"] |
| cum_seq_k = kwargs["cum_seq_k"] |
| max_q = kwargs["max_q"] |
| max_k = kwargs["max_k"] |
| return_debug_mask = kwargs["return_debug_mask"] |
| # unused: value, dropout_p, is_causal, scale |
| |
| def convert_tensor(t, device): |
| return FakeTensor(fake_mode, t, device) |
| |
| # NB: there are two underlying paths: |
| # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) |
| # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total |
| # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total |
| batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 |
| max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q |
| max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k |
| num_heads = query.size(-2) |
| head_dim = query.size(-1) |
| |
| # Cuda Path |
| # note: empty_like already returns a fake tensor, we don't need to wrap it |
| attention = torch.empty_like(query) |
| logsumexp = convert_tensor( |
| torch.empty( |
| (batch_size, num_heads, max_seqlen_batch_q), |
| dtype=torch.float, |
| device="meta", |
| ), |
| device=query.device, |
| ) |
| |
| if return_debug_mask: |
| blocksize_c = 128 if head_dim > 64 else 256 |
| max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) |
| if max_seqlen_batch_k <= 128: |
| max_seqlen_k = 128 |
| elif max_seqlen_batch_k <= 256: |
| max_seqlen_k = 256 |
| debug_mask = convert_tensor( |
| torch.empty( |
| (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), |
| dtype=query.dtype, |
| device="meta", |
| ), |
| query.device, |
| ) |
| else: |
| debug_mask = convert_tensor( |
| torch.empty(0, dtype=query.dtype, device="meta"), |
| query.device, |
| ) |
| |
| # See Note [Seed and Offset]: |
| return ( |
| attention, |
| logsumexp, |
| convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), |
| convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), |
| debug_mask, |
| ) |
| |
| |
| @register_op_impl(aten._efficient_attention_forward.default) |
| def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs): |
| _, kwargs = normalize_function( |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| query = kwargs["query"] |
| key = kwargs["key"] |
| value = kwargs["value"] |
| cu_seqlens_q = kwargs["cu_seqlens_q"] |
| max_seqlen_q = kwargs["max_seqlen_q"] |
| max_seqlen_k = kwargs["max_seqlen_k"] |
| compute_log_sumexp = kwargs["compute_log_sumexp"] |
| # unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k |
| |
| def convert_tensor(t, device): |
| return FakeTensor(fake_mode, t, device) |
| |
| B = query.size(0) |
| M = query.size(1) |
| N = key.size(1) |
| num_heads = query.size(-2) |
| K = query.size(-1) |
| Kv = value.size(-1) |
| |
| res = convert_tensor( |
| torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), |
| query.device, |
| ) |
| |
| logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B |
| actual_max_seqlen_q = M |
| if cu_seqlens_q is not None: |
| assert max_seqlen_q is not None |
| actual_max_seqlen_q = max_seqlen_q |
| actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N |
| logsumexp_dim = ( |
| math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 |
| ) |
| logsum_exp = convert_tensor( |
| torch.empty( |
| (logsumexp_batch_dim, num_heads, logsumexp_dim), |
| dtype=torch.float, |
| device="meta", |
| ), |
| query.device, |
| ) |
| |
| # See Note [Seed and Offset]: |
| seed = convert_tensor( |
| torch.empty((), dtype=torch.long, device="meta"), query.device |
| ) |
| offset = convert_tensor( |
| torch.empty((), dtype=torch.long, device="meta"), query.device |
| ) |
| |
| return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k |
| |
| |
| FAST_OP_IMPLEMENTATIONS = {} |
| |
| |
| # Unlike register_op_impl, these don't do the slow iteration for |
| # run_impl_check, and these run BEFORE decompositions |
| def register_fast_op_impl(func: OpOverload): |
| def impl_decorator(op_impl): |
| FAST_OP_IMPLEMENTATIONS[func] = op_impl |
| return op_impl |
| |
| return impl_decorator |
| |
| |
| # infer_size_impl in ExpandUtils |
| def infer_size(a, b): |
| from torch.fx.experimental.symbolic_shapes import guard_size_oblivious |
| |
| dimsA = len(a) |
| dimsB = len(b) |
| ndim = max(dimsA, dimsB) |
| expandedSizes = [0] * ndim |
| for i in range(ndim - 1, -1, -1): |
| offset = ndim - 1 - i |
| dimA = dimsA - 1 - offset |
| dimB = dimsB - 1 - offset |
| sizeA = a[dimA] if dimA >= 0 else 1 |
| sizeB = b[dimB] if dimB >= 0 else 1 |
| |
| # NB: It is very important to test for broadcasting, before testing |
| # sizeA == sizeB. This is because the broadcasting tests are likely |
| # to be statically known (in particular, if sizeA/sizeB is unbacked |
| # but size-like, we will unsoundly assume they never equal 1), but |
| # the sizeA == sizeB test may not be statically known. However, once |
| # we have established that no broadcasting is happening, the |
| # sizeA == sizeB is now expect_true and we can defer it as a runtime |
| # assert (this works because Python will return the terminal |
| # expression of an or statement as-is, without bool()'ing it; if this |
| # were not the case, we'd need to write this using torch.sym_or() or |
| # something like that). |
| torch._check( |
| guard_size_oblivious(sizeA == 1) |
| or guard_size_oblivious(sizeB == 1) |
| or sizeA == sizeB, |
| lambda: f"The size of tensor a ({sizeA}) " |
| f"must match the size of tensor b ({sizeB}) " |
| f"at non-singleton dimension {i})", |
| ) |
| expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA |
| return tuple(expandedSizes) |
| |
| |
| def make_fast_binary_impl(slow_ref): |
| def fast_binary_impl(mode, *args, **kwargs): |
| def slow(msg): |
| count_label(f"slow {msg}") |
| with mode: |
| return slow_ref(*args, **kwargs) |
| |
| count_label("attempt fast") |
| |
| # Fast path (based off of TensorIterator fast path). |
| # Unfortunately, there is no way to easily deduplicate |
| # this with either the TensorIterator C++ implementation |
| # (which we don't want to SymIntify, and also the algorithm |
| # here is slightly different from TensorIterator to allow |
| # for broadcasting), nor the PrimTorch implementation |
| # (which does not actually implement a fast path.) |
| |
| operands = args |
| |
| # compute_shape |
| has_scalars = False |
| has_tensors = False |
| final_shape = None |
| for op in operands: |
| shape = op.shape if isinstance(op, torch.Tensor) else () |
| if len(shape) == 0: |
| has_scalars = True |
| else: |
| has_tensors = True |
| if final_shape is None: |
| final_shape = shape |
| # TODO: Minor optimization: track if the shapes |
| # were equal so you can skip the equality check |
| # below if unnecessary |
| final_shape = infer_size(final_shape, shape) |
| assert final_shape is not None |
| |
| # Do some extra safety checks to see if the output |
| # stride is obvious |
| for op in operands: |
| if isinstance(op, torch.Tensor) and op.shape == final_shape: |
| break |
| else: |
| return slow("both tensors nontrivially broadcast") |
| |
| # compute_types |
| cpu = torch.device("cpu") |
| common_device = cpu |
| common_dtype = None |
| output_dtype = None |
| has_different_input_dtypes = False |
| for op in operands: |
| if not isinstance(op, torch.Tensor): |
| # Use elementwise_dtypes for the tricky case |
| has_different_input_dtypes = True |
| continue |
| if common_device == cpu and not op.device.type == "cpu": |
| common_device = op.device |
| # Slightly simplified here as target_dtype cannot vary |
| if common_dtype is None: |
| common_dtype = op.dtype |
| elif common_dtype != op.dtype: |
| has_different_input_dtypes = True |
| |
| if has_different_input_dtypes: |
| # compute promotion |
| # TODO: we don't need the compute type |
| _, common_dtype = elementwise_dtypes( |
| *operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT |
| ) |
| |
| # check all tensors on same device |
| # cpu scalars are assumed allow |
| current_cpu_scalars_on_non_cpu = 0 |
| max_cpu_scalars_on_non_cpu = 1 # hard coded atm |
| for op in operands: |
| if not isinstance(op, torch.Tensor): |
| continue |
| if common_device != cpu and op.dim() == 0 and op.device == cpu: |
| if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu: |
| return slow("error") |
| current_cpu_scalars_on_non_cpu += 1 |
| elif op.device != common_device: |
| return slow("error") |
| |
| # compute_fast_setup_type |
| is_contiguous = True |
| is_channels_last = True |
| # TODO: is_non-overlapping_and_dense (not bound from Python |
| # no inplace, no out, everything defined |
| |
| if is_noncontiguous_supported(common_device): |
| for op in operands: |
| if not isinstance(op, torch.Tensor): |
| continue |
| is_contiguous = is_contiguous and op.is_contiguous( |
| memory_format=torch.contiguous_format |
| ) |
| is_channels_last = is_channels_last and op.is_contiguous( |
| memory_format=torch.channels_last |
| ) |
| if is_contiguous: |
| # do contiguous |
| count_label("fast is_contiguous") |
| return FakeTensor( |
| mode, |
| torch.empty( |
| final_shape, |
| dtype=common_dtype, |
| device="meta", |
| memory_format=torch.contiguous_format, |
| ), |
| device=common_device, |
| ) |
| if is_channels_last: |
| count_label("fast channels_last") |
| # do channels last |
| return FakeTensor( |
| mode, |
| torch.empty( |
| final_shape, |
| dtype=common_dtype, |
| device="meta", |
| memory_format=torch.channels_last, |
| ), |
| device=common_device, |
| ) |
| |
| return slow("no contiguity match") |
| |
| return fast_binary_impl |
| |
| |
| @functools.lru_cache(None) |
| def get_fast_op_impls(): |
| import torch._refs |
| |
| register_fast_op_impl(torch.ops.aten.add.Tensor)( |
| make_fast_binary_impl(torch._refs.add) |
| ) |
| register_fast_op_impl(torch.ops.aten.sub.Tensor)( |
| make_fast_binary_impl(torch._refs.sub) |
| ) |
| register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type] |
| register_fast_op_impl(torch.ops.aten.div.Tensor)( |
| make_fast_binary_impl(torch._refs.div) |
| ) |
| return FAST_OP_IMPLEMENTATIONS |