|  | import dataclasses | 
|  | import itertools | 
|  | import re | 
|  |  | 
|  | from dataclasses import dataclass | 
|  | from enum import auto, Enum | 
|  | from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union | 
|  |  | 
|  | from torchgen.utils import assert_never, NamespaceHelper, OrderedSet | 
|  |  | 
|  | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # | 
|  | # | 
|  | #                           DATA MODEL | 
|  | # | 
|  | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # | 
|  | # | 
|  | # Some general principles for our data model. | 
|  | # | 
|  | # - Stop using C++ data types as the internal data representation | 
|  | #   format.  Instead, the internal data structures are centered | 
|  | #   around JIT schema representation.  This avoid a big problem | 
|  | #   with the old codegen where we read in all the types from | 
|  | #   native_functions.yaml and then immediately had to retranslate | 
|  | #   them into C++ types. | 
|  | # | 
|  | # - More semantic data representation.  Instead of representing | 
|  | #   everything as dicts and strings, we define dataclasses for | 
|  | #   every interesting entity the code generation has to deal with. | 
|  | #   These dataclasses have strong semantic invariants: for example, | 
|  | #   we generally require them to roundtrip losslessly into the | 
|  | #   form they were parsed from.  These structures are immutable | 
|  | #   and you're expected to populate information once during | 
|  | #   construction. | 
|  |  | 
|  |  | 
|  | # Represent a source location; used for better error reporting | 
|  | @dataclass(frozen=True) | 
|  | class Location: | 
|  | file: str | 
|  | line: int | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | return f"{self.file}:{self.line}" | 
|  |  | 
|  |  | 
|  | # Valid values of the 'variants' field in native_functions.yaml | 
|  | class Variant(Enum): | 
|  | function = auto() | 
|  | method = auto() | 
|  |  | 
|  |  | 
|  | # Default kernel namespace | 
|  | DEFAULT_KERNEL_NAMESPACE = "at::native" | 
|  |  | 
|  | # NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h | 
|  | BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split() | 
|  | FUNCTIONALITY_KEYS = ["", "Quantized", "Sparse", "NestedTensor", "Autograd"] | 
|  |  | 
|  | # This list guards dispatches that can be used in derivatives.yaml | 
|  | # For now we omit AutogradFunctionality and AutogradOther | 
|  | AUTOGRAD_KEYS = ["AutogradNestedTensor"] + [ | 
|  | "Autograd" + component for component in BACKEND_COMPONENTS | 
|  | ] | 
|  |  | 
|  | FRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"} | 
|  |  | 
|  |  | 
|  | # This doesn't have to be in sync with the header, it only needs to contain | 
|  | # entries that we actually use in the codegen or want pyi entries for | 
|  | class DispatchKey(Enum): | 
|  | Undefined = 0 | 
|  | CatchAll = Undefined | 
|  |  | 
|  | FPGA = auto() | 
|  | ORT = auto() | 
|  | Vulkan = auto() | 
|  | Metal = auto() | 
|  | MKLDNN = auto() | 
|  | OpenGL = auto() | 
|  | OpenCL = auto() | 
|  | IDEEP = auto() | 
|  | CustomRNGKeyId = auto() | 
|  | MkldnnCPU = auto() | 
|  | Sparse = auto() | 
|  | SparseCsrCPU = auto() | 
|  | SparseCsrCUDA = auto() | 
|  | NestedTensor = auto() | 
|  | Dense = auto() | 
|  |  | 
|  | Python = auto() | 
|  | FuncTorchDynamicLayerBackMode = auto() | 
|  | ZeroTensor = auto() | 
|  | Conjugate = auto() | 
|  | Negative = auto() | 
|  | BackendSelect = auto() | 
|  | Named = auto() | 
|  | AutogradOther = auto() | 
|  | AutogradFunctionality = auto() | 
|  | AutogradNestedTensor = auto() | 
|  | Tracer = auto() | 
|  | Autocast = auto() | 
|  | Batched = auto() | 
|  | VmapMode = auto() | 
|  | FuncTorchGradWrapper = auto() | 
|  | FuncTorchBatched = auto() | 
|  | BatchedNestedTensor = auto() | 
|  | FuncTorchVmapMode = auto() | 
|  | FuncTorchDynamicLayerFrontMode = auto() | 
|  | Functionalize = auto() | 
|  | TESTING_ONLY_GenericWrapper = auto() | 
|  | TESTING_ONLY_GenericMode = auto() | 
|  |  | 
|  | ADInplaceOrView = auto() | 
|  | Autograd = auto() | 
|  | CompositeImplicitAutograd = auto() | 
|  | CompositeImplicitAutogradNestedTensor = auto() | 
|  | CompositeExplicitAutograd = auto() | 
|  | CompositeExplicitAutogradNonFunctional = auto() | 
|  | FuncTorchBatchedDecomposition = auto() | 
|  |  | 
|  | # BEGIN autogenerated | 
|  | CPU = auto() | 
|  | CUDA = auto() | 
|  | HIP = auto() | 
|  | XLA = auto() | 
|  | MTIA = auto() | 
|  | MPS = auto() | 
|  | IPU = auto() | 
|  | XPU = auto() | 
|  | HPU = auto() | 
|  | VE = auto() | 
|  | Lazy = auto() | 
|  | Meta = auto() | 
|  | PrivateUse1 = auto() | 
|  | PrivateUse2 = auto() | 
|  | PrivateUse3 = auto() | 
|  | QuantizedCPU = auto() | 
|  | QuantizedCUDA = auto() | 
|  | QuantizedHIP = auto() | 
|  | QuantizedXLA = auto() | 
|  | QuantizedMTIA = auto() | 
|  | QuantizedMPS = auto() | 
|  | QuantizedIPU = auto() | 
|  | QuantizedXPU = auto() | 
|  | QuantizedHPU = auto() | 
|  | QuantizedVE = auto() | 
|  | QuantizedLazy = auto() | 
|  | QuantizedMeta = auto() | 
|  | QuantizedPrivateUse1 = auto() | 
|  | QuantizedPrivateUse2 = auto() | 
|  | QuantizedPrivateUse3 = auto() | 
|  | SparseCPU = auto() | 
|  | SparseCUDA = auto() | 
|  | SparseHIP = auto() | 
|  | SparseXLA = auto() | 
|  | SparseMTIA = auto() | 
|  | SparseMPS = auto() | 
|  | SparseIPU = auto() | 
|  | SparseXPU = auto() | 
|  | SparseHPU = auto() | 
|  | SparseVE = auto() | 
|  | SparseLazy = auto() | 
|  | SparseMeta = auto() | 
|  | SparsePrivateUse1 = auto() | 
|  | SparsePrivateUse2 = auto() | 
|  | SparsePrivateUse3 = auto() | 
|  | NestedTensorCPU = auto() | 
|  | NestedTensorCUDA = auto() | 
|  | NestedTensorHIP = auto() | 
|  | NestedTensorXLA = auto() | 
|  | NestedTensorMTIA = auto() | 
|  | NestedTensorMPS = auto() | 
|  | NestedTensorIPU = auto() | 
|  | NestedTensorXPU = auto() | 
|  | NestedTensorHPU = auto() | 
|  | NestedTensorVE = auto() | 
|  | NestedTensorLazy = auto() | 
|  | NestedTensorMeta = auto() | 
|  | NestedTensorPrivateUse1 = auto() | 
|  | NestedTensorPrivateUse2 = auto() | 
|  | NestedTensorPrivateUse3 = auto() | 
|  | AutogradCPU = auto() | 
|  | AutogradCUDA = auto() | 
|  | AutogradHIP = auto() | 
|  | AutogradXLA = auto() | 
|  | AutogradMTIA = auto() | 
|  | AutogradMPS = auto() | 
|  | AutogradIPU = auto() | 
|  | AutogradXPU = auto() | 
|  | AutogradHPU = auto() | 
|  | AutogradVE = auto() | 
|  | AutogradLazy = auto() | 
|  | AutogradMeta = auto() | 
|  | AutogradPrivateUse1 = auto() | 
|  | AutogradPrivateUse2 = auto() | 
|  | AutogradPrivateUse3 = auto() | 
|  | # END autogenerated | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | return self.name | 
|  |  | 
|  | def lower(self) -> str: | 
|  | return str(self).lower() | 
|  |  | 
|  | @staticmethod | 
|  | def parse(value: str) -> "DispatchKey": | 
|  | for k, v in DispatchKey.__members__.items(): | 
|  | if k == value: | 
|  | return v | 
|  | raise AssertionError(f"unknown dispatch key {value}") | 
|  |  | 
|  |  | 
|  | class _TorchDispatchModeKey(Enum): | 
|  | FAKE = auto() | 
|  | PROXY = auto() | 
|  | FUNCTIONAL = auto() | 
|  |  | 
|  |  | 
|  | def codegen_per_backend_entries() -> str: | 
|  | r = [] | 
|  | for fk in FUNCTIONALITY_KEYS: | 
|  | for bc in BACKEND_COMPONENTS: | 
|  | r.append(f"    {fk}{bc} = auto()") | 
|  | return "\n".join(r) | 
|  |  | 
|  |  | 
|  | for fk in FUNCTIONALITY_KEYS: | 
|  | for bc in BACKEND_COMPONENTS: | 
|  | if not hasattr(DispatchKey, fk + bc): | 
|  | r = codegen_per_backend_entries() | 
|  | print(r) | 
|  | raise RuntimeError( | 
|  | f"Missing {fk}{bc} from DispatchKey enum.  Here is the autogenerated list we expect to have:\n\n{r}" | 
|  | ) | 
|  |  | 
|  |  | 
|  | STRUCTURED_DISPATCH_KEYS = {DispatchKey.MPS, DispatchKey.CUDA, DispatchKey.CPU} | 
|  | UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU} | 
|  |  | 
|  | # Set of supported dispatch keys | 
|  | dispatch_keys = [ | 
|  | DispatchKey.CPU, | 
|  | DispatchKey.SparseCPU, | 
|  | DispatchKey.SparseCsrCPU, | 
|  | DispatchKey.MkldnnCPU, | 
|  | DispatchKey.CUDA, | 
|  | DispatchKey.MPS, | 
|  | DispatchKey.SparseCUDA, | 
|  | DispatchKey.SparseCsrCUDA, | 
|  | DispatchKey.QuantizedCPU, | 
|  | DispatchKey.QuantizedCUDA, | 
|  | DispatchKey.CompositeImplicitAutograd, | 
|  | DispatchKey.CompositeImplicitAutogradNestedTensor, | 
|  | DispatchKey.CompositeExplicitAutograd, | 
|  | DispatchKey.CompositeExplicitAutogradNonFunctional, | 
|  | DispatchKey.NestedTensorCPU, | 
|  | DispatchKey.NestedTensorCUDA, | 
|  | # Meta is a magic key: it is automatically generated for structured | 
|  | # kernels | 
|  | DispatchKey.Meta, | 
|  | DispatchKey.SparseMeta, | 
|  | DispatchKey.QuantizedMeta, | 
|  | DispatchKey.NestedTensorMeta, | 
|  | DispatchKey.ZeroTensor, | 
|  | ] | 
|  |  | 
|  |  | 
|  | # Dispatch keys that "support all backends".  These codegen slightly differently | 
|  | # then backend specific keys. | 
|  | def is_generic_dispatch_key(dk: DispatchKey) -> bool: | 
|  | return dk in { | 
|  | DispatchKey.CompositeExplicitAutograd, | 
|  | DispatchKey.CompositeExplicitAutogradNonFunctional, | 
|  | DispatchKey.CompositeImplicitAutograd, | 
|  | DispatchKey.CompositeImplicitAutogradNestedTensor, | 
|  | } | 
|  |  | 
|  |  | 
|  | # CUDA specific dispatch keys | 
|  | def is_cuda_dispatch_key(dk: DispatchKey) -> bool: | 
|  | return dk in { | 
|  | DispatchKey.CUDA, | 
|  | DispatchKey.QuantizedCUDA, | 
|  | DispatchKey.SparseCUDA, | 
|  | DispatchKey.SparseCsrCUDA, | 
|  | DispatchKey.NestedTensorCUDA, | 
|  | DispatchKey.AutogradCUDA, | 
|  | } | 
|  |  | 
|  |  | 
|  | # Structured kernel generation is only supported for certain key types; | 
|  | # otherwise use old-style | 
|  | def is_structured_dispatch_key(dk: DispatchKey) -> bool: | 
|  | return dk in STRUCTURED_DISPATCH_KEYS | 
|  |  | 
|  |  | 
|  | def is_ufunc_dispatch_key(dk: DispatchKey) -> bool: | 
|  | # For now, ufunc dispatch keys coincide with structured keys | 
|  | return dk in UFUNC_DISPATCH_KEYS | 
|  |  | 
|  |  | 
|  | # This is oddly named ScalarType and not DType for symmetry with C++ | 
|  | class ScalarType(Enum): | 
|  | Byte = auto() | 
|  | Char = auto() | 
|  | Short = auto() | 
|  | Int = auto() | 
|  | Long = auto() | 
|  | Half = auto() | 
|  | Float = auto() | 
|  | Double = auto() | 
|  | ComplexHalf = auto() | 
|  | ComplexFloat = auto() | 
|  | ComplexDouble = auto() | 
|  | Bool = auto() | 
|  | BFloat16 = auto() | 
|  | Float8_e5m2 = auto() | 
|  | Float8_e5m2fnuz = auto() | 
|  | Float8_e4m3fn = auto() | 
|  | Float8_e4m3fnuz = auto() | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | return self.name | 
|  |  | 
|  | @staticmethod | 
|  | def maybe_parse(value: str) -> Optional["ScalarType"]: | 
|  | for k, v in ScalarType.__members__.items(): | 
|  | if k == value: | 
|  | return v | 
|  | return None | 
|  |  | 
|  | @staticmethod | 
|  | def parse(value: str) -> "ScalarType": | 
|  | mb_r = ScalarType.maybe_parse(value) | 
|  | assert mb_r is not None, f"unknown dtype {value}" | 
|  | return mb_r | 
|  |  | 
|  | @staticmethod | 
|  | def parse_set(values: str) -> OrderedSet["ScalarType"]: | 
|  | dtypes: OrderedSet[ScalarType] = OrderedSet() | 
|  | for value in values.split(", "): | 
|  | if value in DTYPE_CLASSES: | 
|  | dtypes.update(DTYPE_CLASSES[value]) | 
|  | else: | 
|  | dtypes.add(ScalarType.parse(value)) | 
|  | return dtypes | 
|  |  | 
|  |  | 
|  | DTYPE_CLASSES: Dict[str, OrderedSet[ScalarType]] = {} | 
|  | # NB: Integral doesn't include boolean | 
|  | DTYPE_CLASSES["Integral"] = OrderedSet( | 
|  | [ | 
|  | ScalarType.Byte, | 
|  | ScalarType.Char, | 
|  | ScalarType.Int, | 
|  | ScalarType.Long, | 
|  | ScalarType.Short, | 
|  | ] | 
|  | ) | 
|  | # NB: Floating doesn't include low precision types | 
|  | DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double]) | 
|  | DTYPE_CLASSES["Complex"] = OrderedSet( | 
|  | [ScalarType.ComplexFloat, ScalarType.ComplexDouble] | 
|  | ) | 
|  | DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"] | 
|  | DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"] | 
|  | DTYPE_CLASSES["FloatingAndComplex"] = ( | 
|  | DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"] | 
|  | ) | 
|  |  | 
|  |  | 
|  | # Represents the valid entries for ufunc_inner_loop in native_functions.yaml. | 
|  | # NB: if you add a new UfuncKey, you will teach torchgen.dest.ufunc how | 
|  | # to process it.  Most logic will ignore keys they don't understand, so your | 
|  | # new key will get silently ignored until you hook in logic to deal with it. | 
|  | class UfuncKey(Enum): | 
|  | # These are low level keys that represent exactly one particular | 
|  | # instantiation of the kernel produced by codegen | 
|  | CUDAFunctor = auto() | 
|  | CUDAFunctorOnOther = auto() | 
|  | CUDAFunctorOnSelf = auto() | 
|  |  | 
|  | CPUScalar = auto() | 
|  | CPUVector = auto() | 
|  |  | 
|  | # These are the ones users will usually specify, and | 
|  | # implicitly "fill in" the low level keys | 
|  | ScalarOnly = auto()  # CUDA*, CPUScalar | 
|  | Generic = auto()  # CUDA*, CPU* | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | return self.name | 
|  |  | 
|  | @staticmethod | 
|  | def parse(value: str) -> "UfuncKey": | 
|  | for k, v in UfuncKey.__members__.items(): | 
|  | if k == value: | 
|  | return v | 
|  | raise AssertionError(f"unknown ufunc key {value}") | 
|  |  | 
|  |  | 
|  | class DeviceCheckType(Enum): | 
|  | NoCheck = 0 | 
|  | ExactSame = 1 | 
|  |  | 
|  |  | 
|  | class ViewSchemaKind(Enum): | 
|  | aliasing = auto() | 
|  | aliasing_inplace = auto() | 
|  | non_aliasing = auto() | 
|  |  | 
|  |  | 
|  | # The basic input to the code generation is native_functions.yaml. | 
|  | # The name "native", BTW, comes from the distinction between native | 
|  | # functions and legacy TH functions.  The legacy TH functions are gone, | 
|  | # but the "native" descriptor has stuck. | 
|  | # | 
|  | # NativeFunction models a single entry in native_functions.yaml.  Its | 
|  | # fields roughly correspond to what you would see in the YAML itself, | 
|  | # but after canonicalization and parsing has occurred. | 
|  | # | 
|  | # You can see some of the overall design patterns for how we setup | 
|  | # dataclasses in this class, but we will defer a complete discussion | 
|  | # of this at FunctionSchema. | 
|  | @dataclass(frozen=True) | 
|  | class NativeFunction: | 
|  | # The namespace for this operator. For example, if we have "at::add" | 
|  | # then the namespace would be "at". This enables ops to be registered | 
|  | # through the same DSL with a custom namespace. If not specified, the | 
|  | # default namespace would be "at". | 
|  | namespace: str | 
|  |  | 
|  | # The function schema of the operator in question.  This schema | 
|  | # has been parsed; see FunctionSchema for more about its structure. | 
|  | # (This type is quoted as we are forward referencing a type | 
|  | # defined later in the file.  I opted for this ordering of the | 
|  | # classes for expository clarity.) | 
|  | func: "FunctionSchema" | 
|  |  | 
|  | # Whether or not to generate mutable tensor arguments like regular | 
|  | # ones | 
|  | use_const_ref_for_mutable_tensors: bool | 
|  |  | 
|  | # Whether or not to omit automatic generation of a DeviceGuard | 
|  | device_guard: bool | 
|  |  | 
|  | # How to emit automatic generation of device check | 
|  | device_check: DeviceCheckType | 
|  |  | 
|  | # What python module to put the function in | 
|  | python_module: Optional[str] | 
|  |  | 
|  | # TODO: figure out what this does | 
|  | category_override: Optional[str] | 
|  |  | 
|  | # If no variants are specified in native_functions.yaml, this is | 
|  | # assumed to be {'function'}. | 
|  | variants: Set[Variant] | 
|  |  | 
|  | # Whether or not we should skip generating registrations for | 
|  | # this kernel.  This is a bit of a double-edged sword, as manual | 
|  | # registrations don't participate in codegen-based selective build! | 
|  | manual_kernel_registration: bool | 
|  |  | 
|  | # Whether or not to skip generating TensorMethod/Functions bindings | 
|  | # for this kernel.  Technically, this doesn't actually skip generating | 
|  | # the binding; instead, the binding gets generated to __dispatch_{funcname} | 
|  | # so you can make use of the normal binding if you need it. | 
|  | manual_cpp_binding: bool | 
|  |  | 
|  | # The location in the YAML file were this native function entry was | 
|  | # defined.  This is for conveniently reporting error messages! | 
|  | loc: "Location" | 
|  |  | 
|  | # A list of operators that are expected to be auto-generated for this NativeFunction. | 
|  | # Note: This list isn't actually directly used by the codegen to generate anything. | 
|  | # Instead, the codegen figures out what operators to generate purely based off of | 
|  | # function schema, and uses the autogen declarations to error check. | 
|  | # We expect every NativeFunction that gets auto-generated be explicitly called out | 
|  | # in native_functions.yaml | 
|  | autogen: List["OperatorName"] | 
|  |  | 
|  | # If non-empty, this kernel is subject to ufunc codegen. | 
|  | # Sorted by ufunc_key | 
|  | ufunc_inner_loop: Dict[UfuncKey, "UfuncInnerLoop"] | 
|  |  | 
|  | # Whether or not this out functions is a "structured kernel".  Structured | 
|  | # kernels are defined a little differently from normal kernels; in | 
|  | # particular, their shape checking logic is defined separately from | 
|  | # the kernel.  Only out functions can be structured; other functions | 
|  | # delegate to the out function using the structured_delegate keyword. | 
|  | # Every structured kernel must have at least an out and a functional | 
|  | # variant. | 
|  | structured: bool | 
|  |  | 
|  | # Whether or not this non-out function is a structured kernel, defined | 
|  | # in terms of the out kernel referenced by the string here. | 
|  | structured_delegate: Optional["OperatorName"] | 
|  |  | 
|  | # Only valid for structured kernels.  Specifies alternative of what | 
|  | # to inherit from when defining the meta class for the structured | 
|  | # operator.  This will usually be TensorIteratorBase.  This also | 
|  | # changes the semantics of set_output to call the parent class. | 
|  | structured_inherits: Optional[str] | 
|  |  | 
|  | # Structured kernels can declare elements as "precomputed". These elements | 
|  | # are returned by the meta function in one struct and passed to the impl | 
|  | # function in lieu of certain kernel arguments that these precomputed | 
|  | # elements supersede. Information about the names and types of these | 
|  | # precomputed elements and how they correspond to kernel arguments is stored | 
|  | # in this member, if applicable. | 
|  | precomputed: Optional["Precompute"] | 
|  |  | 
|  | # Argument names whose default  should be excluded from the C++ interface. | 
|  | # Intended for resolving overload ambiguities between signatures. | 
|  | cpp_no_default_args: Set[str] | 
|  |  | 
|  | # Note [Abstract ATen methods] | 
|  | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | 
|  | # An abstract ATen method is one whose dispatch differs between | 
|  | # types.  These are implemented in derived types (with a | 
|  | # standard (throwing) definition in Type).  A concrete ATen | 
|  | # method is one which has the same dispatch for all types; | 
|  | # we just implement it in the base Type.  This is exposed | 
|  | # in Declarations.yaml via a field named 'abstract'. | 
|  | is_abstract: bool | 
|  |  | 
|  | # Whether or not the NativeFunction contains a backend-agnostic kernel | 
|  | has_composite_implicit_autograd_kernel: bool | 
|  | has_composite_implicit_autograd_nested_tensor_kernel: bool | 
|  | has_composite_explicit_autograd_kernel: bool | 
|  | has_composite_explicit_autograd_non_functional_kernel: bool | 
|  |  | 
|  | # Tags are used to describe semantic information about (groups of) operators, | 
|  | # That aren't easily inferrable directly from the operator's schema. | 
|  | tags: Set[str] | 
|  |  | 
|  | # NB: The benefit of defining a dataclass is that we automatically get | 
|  | # a constructor defined for all the fields we specify.  No need | 
|  | # to explicitly write it out. | 
|  |  | 
|  | # We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex. | 
|  | @staticmethod | 
|  | def from_yaml( | 
|  | ei: Dict[str, object], | 
|  | loc: "Location", | 
|  | valid_tags: Set[str], | 
|  | ignore_keys: Optional[Set[DispatchKey]] = None, | 
|  | ) -> Tuple[ | 
|  | "NativeFunction", Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]] | 
|  | ]: | 
|  | """ | 
|  | Parse a NativeFunction from a dictionary as directly parsed | 
|  | from native_functions.yaml | 
|  | """ | 
|  | e = ei.copy() | 
|  |  | 
|  | funcs = e.pop("func") | 
|  | assert isinstance(funcs, str), f"not a str: {funcs}" | 
|  | # only support one level of namespace. E.g., aten::add | 
|  | namespace_helper = NamespaceHelper.from_namespaced_entity( | 
|  | namespaced_entity=funcs, max_level=1 | 
|  | ) | 
|  | namespace = namespace_helper.get_cpp_namespace(default="aten") | 
|  | func = FunctionSchema.parse(namespace_helper.entity_name) | 
|  |  | 
|  | cpp_no_default_args_list = e.pop("cpp_no_default_args", []) | 
|  | assert isinstance(cpp_no_default_args_list, list) | 
|  | cpp_no_default_args = set(cpp_no_default_args_list) | 
|  |  | 
|  | use_const_ref_for_mutable_tensors = e.pop( | 
|  | "use_const_ref_for_mutable_tensors", False | 
|  | ) | 
|  | assert isinstance(use_const_ref_for_mutable_tensors, bool) | 
|  |  | 
|  | variants_s = e.pop("variants", "function") | 
|  | assert isinstance(variants_s, str) | 
|  | variants: Set[Variant] = set() | 
|  | for v in variants_s.split(", "): | 
|  | if v == "function": | 
|  | variants.add(Variant.function) | 
|  | elif v == "method": | 
|  | variants.add(Variant.method) | 
|  | else: | 
|  | raise AssertionError(f"illegal variant {v}") | 
|  |  | 
|  | manual_kernel_registration = e.pop("manual_kernel_registration", False) | 
|  | assert isinstance( | 
|  | manual_kernel_registration, bool | 
|  | ), f"not a bool: {manual_kernel_registration}" | 
|  |  | 
|  | manual_cpp_binding = e.pop("manual_cpp_binding", False) | 
|  | assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}" | 
|  |  | 
|  | device_guard = e.pop("device_guard", True) | 
|  | assert isinstance(device_guard, bool), f"not a bool: {device_guard}" | 
|  |  | 
|  | device_check_s = e.pop("device_check", None) | 
|  | assert device_check_s is None or isinstance( | 
|  | device_check_s, str | 
|  | ), f"not a str: {device_check_s}" | 
|  | device_check: DeviceCheckType | 
|  | if device_check_s is None: | 
|  | device_check = DeviceCheckType.ExactSame | 
|  | else: | 
|  | device_check = DeviceCheckType[device_check_s] | 
|  |  | 
|  | structured = e.pop("structured", False) | 
|  | assert isinstance(structured, bool), f"not a bool: {structured}" | 
|  |  | 
|  | structured_delegate_s = e.pop("structured_delegate", None) | 
|  | assert structured_delegate_s is None or isinstance( | 
|  | structured_delegate_s, str | 
|  | ), f"not a str: {structured_delegate_s}" | 
|  | assert structured_delegate_s is None or "::" not in structured_delegate_s, ( | 
|  | "namespace is not supported in structured delegate," | 
|  | " using the same namespace as the native function" | 
|  | ) | 
|  | structured_delegate: Optional[OperatorName] = None | 
|  | if structured_delegate_s is not None: | 
|  | structured_delegate = OperatorName.parse(structured_delegate_s) | 
|  |  | 
|  | structured_inherits = e.pop("structured_inherits", None) | 
|  | assert structured_inherits is None or isinstance( | 
|  | structured_inherits, str | 
|  | ), f"not a str: {structured_inherits}" | 
|  | assert structured_inherits is None or "::" not in structured_inherits, ( | 
|  | "namespace is not supported in structured inherits," | 
|  | " using the same namespace as the native function" | 
|  | ) | 
|  |  | 
|  | python_module = e.pop("python_module", None) | 
|  | assert python_module is None or isinstance( | 
|  | python_module, str | 
|  | ), f"not a str: {python_module}" | 
|  | assert ( | 
|  | python_module is None or Variant.method not in variants | 
|  | ), "functions in modules cannot be methods" | 
|  |  | 
|  | category_override = e.pop("category_override", None) | 
|  | assert category_override is None or isinstance( | 
|  | category_override, str | 
|  | ), f"not a str: {category_override}" | 
|  |  | 
|  | precomputed_dict = e.pop("precomputed", None) | 
|  | assert precomputed_dict is None or structured is True | 
|  | precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None | 
|  |  | 
|  | tags_inp = e.pop("tags", []) | 
|  | if isinstance(tags_inp, str): | 
|  | tags_inp = [tags_inp] | 
|  | assert isinstance(tags_inp, list) | 
|  |  | 
|  | # All aten ops generated by torchgen receive the pt2_compliant tag. | 
|  | if namespace == "aten" and "pt2_compliant_tag" in valid_tags: | 
|  | tags_inp.append("pt2_compliant_tag") | 
|  |  | 
|  | tags: Set[str] = set() | 
|  | for t in tags_inp: | 
|  | assert len(valid_tags) > 0 | 
|  | # TODO: verify that the tag is valid and has an entry in tags.yaml | 
|  | if t in valid_tags: | 
|  | tags.add(t) | 
|  | else: | 
|  | raise AssertionError(f"illegal tag {t}") | 
|  |  | 
|  | from torchgen.api import cpp | 
|  |  | 
|  | raw_dispatch = e.pop("dispatch", None) | 
|  | assert raw_dispatch is None or isinstance(raw_dispatch, dict), e | 
|  | dispatch: Dict[DispatchKey, BackendMetadata] = {} | 
|  | num_dispatch_keys: int = 0 | 
|  | if raw_dispatch is not None: | 
|  | assert not manual_kernel_registration, ( | 
|  | "cannot specify both manual_kernel_registration and dispatch; with " | 
|  | "manual registration, dispatch has no effect!" | 
|  | ) | 
|  | redundant_composite_implicit_autograd = False | 
|  | for ks, v in raw_dispatch.items(): | 
|  | if ks == "__line__": | 
|  | continue  # not worth tracking line numbers for dispatch entries | 
|  | assert isinstance(ks, str), e | 
|  | for k in ks.split(","): | 
|  | dispatch_key = DispatchKey.parse(k.strip()) | 
|  | num_dispatch_keys += 1 | 
|  |  | 
|  | if ignore_keys and dispatch_key in ignore_keys: | 
|  | continue | 
|  | assert dispatch_key in dispatch_keys, ( | 
|  | f"Dispatch key {dispatch_key} of kernel {v} " | 
|  | "is not a supported dispatch key." | 
|  | ) | 
|  | # We only allow at most 3 levels of namespace for kernels. | 
|  | # We will append "native" to a custom kernel namespace. | 
|  | namespace_helper = NamespaceHelper.from_namespaced_entity( | 
|  | v, max_level=3 | 
|  | ) | 
|  | kernel_namespace = namespace_helper.get_cpp_namespace(default="at") | 
|  | # Why is 'structured' included? External backends (e.g. | 
|  | # XLA) opt into which ops are structured independently | 
|  | # of which in-tree ops are structured | 
|  | dispatch[dispatch_key] = BackendMetadata( | 
|  | kernel=namespace_helper.entity_name, | 
|  | structured=structured | 
|  | and is_structured_dispatch_key(dispatch_key), | 
|  | cpp_namespace=(kernel_namespace + "::native"), | 
|  | ) | 
|  | if ( | 
|  | dispatch_key is DispatchKey.CompositeImplicitAutograd | 
|  | and v == cpp.name(func) | 
|  | ): | 
|  | redundant_composite_implicit_autograd = True | 
|  |  | 
|  | # We count the number of dispatch keys which have not been ignored to prevent a dispatch table | 
|  | # in which all backend keys are ignored but necessarily kept, remaining compositeimplicit, | 
|  | # from being treated as redundant. | 
|  | assert not ( | 
|  | num_dispatch_keys == 1 and redundant_composite_implicit_autograd | 
|  | ), ( | 
|  | "unnecessary dispatch table for this function; just delete the dispatch " | 
|  | "key entirely" | 
|  | ) | 
|  | # if a function is a structured delegate, deleting the dispatch | 
|  | # table is NOT semantics preserving | 
|  | assert ( | 
|  | structured_delegate | 
|  | or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} | 
|  | or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint() | 
|  | or num_dispatch_keys != 1 | 
|  | ), ( | 
|  | f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} " | 
|  | f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}.  Rename your implementation to the expected " | 
|  | "name, then delete the dispatch table" | 
|  | ) | 
|  | elif not structured and structured_delegate is None: | 
|  | name = str(func.name.name) | 
|  | assert not ( | 
|  | name.startswith("new_") | 
|  | or name.endswith("_like") | 
|  | # TODO: maybe it's better to test the return | 
|  | or ( | 
|  | func.arguments.tensor_options | 
|  | and not func.arguments.has_tensor_arg() | 
|  | ) | 
|  | ), ( | 
|  | f"expected {name} to have a CompositeExplicitAutograd " | 
|  | "dispatch entry, but there was no dispatch table.  Factory functions " | 
|  | "should not have implicit dispatch as they should not be decomposed " | 
|  | "for __torch_dispatch__" | 
|  | ) | 
|  | dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata( | 
|  | cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE | 
|  | ) | 
|  |  | 
|  | composites_in_dispatch = [ | 
|  | d | 
|  | for d in dispatch | 
|  | if d == DispatchKey.CompositeExplicitAutograd | 
|  | or d == DispatchKey.CompositeExplicitAutogradNonFunctional | 
|  | or d == DispatchKey.CompositeImplicitAutograd | 
|  | or d == DispatchKey.CompositeImplicitAutogradNestedTensor | 
|  | ] | 
|  |  | 
|  | assert len(composites_in_dispatch) <= 1 or ( | 
|  | len(composites_in_dispatch) == 2 | 
|  | and ( | 
|  | DispatchKey.CompositeExplicitAutogradNonFunctional | 
|  | not in composites_in_dispatch | 
|  | ) | 
|  | and ( | 
|  | DispatchKey.CompositeImplicitAutogradNestedTensor | 
|  | in composites_in_dispatch | 
|  | ) | 
|  | ), ( | 
|  | "cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, " | 
|  | "or CompositeImplicitAutograd on a single kernel; each " | 
|  | "strictly subsumes the other.  If you wanted to provide an explicit autograd " | 
|  | "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only" | 
|  | ) | 
|  |  | 
|  | autogen_str = e.pop("autogen", "") | 
|  | assert isinstance(autogen_str, str) | 
|  | autogen = ( | 
|  | [] | 
|  | if autogen_str == "" | 
|  | else [OperatorName.parse(x) for x in autogen_str.split(", ")] | 
|  | ) | 
|  |  | 
|  | raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {}) | 
|  | ufunc_inner_loop = {} | 
|  | if isinstance(raw_ufunc_inner_loop, str): | 
|  | ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse( | 
|  | raw_ufunc_inner_loop, UfuncKey.Generic | 
|  | ) | 
|  | elif isinstance(raw_ufunc_inner_loop, dict): | 
|  | for k, vo in raw_ufunc_inner_loop.items(): | 
|  | if k == "__line__": | 
|  | continue | 
|  | assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}" | 
|  | assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}" | 
|  | ufunc_key = UfuncKey.parse(k) | 
|  | ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key) | 
|  | else: | 
|  | raise AssertionError( | 
|  | f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}" | 
|  | ) | 
|  | # Program the BackendIndex for the implicit dispatch entry from ufunc | 
|  | if ufunc_inner_loop: | 
|  | assert structured, "ufunc must be structured" | 
|  |  | 
|  | # Delay import ufunc here to avoid circular import issue | 
|  | # See: https://github.com/pytorch/pytorch/issues/81294 | 
|  | import torchgen.api.ufunc as ufunc | 
|  |  | 
|  | for dispatch_key in UFUNC_DISPATCH_KEYS: | 
|  | assert ( | 
|  | dispatch_key not in dispatch | 
|  | ), f"ufunc should not have explicit dispatch entry for {dispatch_key}" | 
|  | dispatch[dispatch_key] = BackendMetadata( | 
|  | kernel=ufunc.schema_kernel_name(func, dispatch_key), | 
|  | structured=True, | 
|  | cpp_namespace=DEFAULT_KERNEL_NAMESPACE, | 
|  | ) | 
|  |  | 
|  | if structured_delegate: | 
|  | # Structured functions MUST have a dispatch table | 
|  | is_abstract = True | 
|  | else: | 
|  | is_abstract = ( | 
|  | dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} | 
|  | and dispatch.keys() | 
|  | != {DispatchKey.CompositeImplicitAutogradNestedTensor} | 
|  | and dispatch.keys() | 
|  | != { | 
|  | DispatchKey.CompositeImplicitAutograd, | 
|  | DispatchKey.CompositeImplicitAutogradNestedTensor, | 
|  | } | 
|  | ) | 
|  |  | 
|  | has_composite_implicit_autograd_kernel = ( | 
|  | DispatchKey.CompositeImplicitAutograd in dispatch.keys() | 
|  | ) | 
|  | has_composite_implicit_autograd_nested_tensor_kernel = ( | 
|  | DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch.keys() | 
|  | ) | 
|  | has_composite_explicit_autograd_kernel = ( | 
|  | DispatchKey.CompositeExplicitAutograd in dispatch.keys() | 
|  | ) | 
|  | has_composite_explicit_autograd_non_functional_kernel = ( | 
|  | DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch.keys() | 
|  | ) | 
|  |  | 
|  | # We aren't going to store dispatch metadata inline in NativeFunctions; | 
|  | # instead it is separately indexed by backend (so other backends can | 
|  | # add more dispatch entries after the fact).  Reindex the individual | 
|  | # metadata by OperatorName! | 
|  | backend_metadata = {k: {func.name: v} for k, v in dispatch.items()} | 
|  |  | 
|  | # don't care if it exists or not; make it easier to use this function | 
|  | # with other yaml parsers that aren't setting __line__ in the dict | 
|  | e.pop("__line__", None) | 
|  | assert not e, f"leftover entries: {e}" | 
|  |  | 
|  | # Asserts that we can't do in post_init, because they rely on backend-specific info | 
|  | if structured_delegate is not None: | 
|  | for key in STRUCTURED_DISPATCH_KEYS: | 
|  | assert key not in dispatch, ( | 
|  | f"if structured_delegate, then must not have {key} in dispatch dictionary " | 
|  | "(it is delegated!)" | 
|  | ) | 
|  |  | 
|  | return ( | 
|  | NativeFunction( | 
|  | func=func, | 
|  | use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors, | 
|  | variants=variants, | 
|  | structured=structured, | 
|  | structured_delegate=structured_delegate, | 
|  | structured_inherits=structured_inherits, | 
|  | precomputed=precomputed, | 
|  | autogen=autogen, | 
|  | ufunc_inner_loop=ufunc_inner_loop, | 
|  | manual_kernel_registration=manual_kernel_registration, | 
|  | manual_cpp_binding=manual_cpp_binding, | 
|  | python_module=python_module, | 
|  | category_override=category_override, | 
|  | device_guard=device_guard, | 
|  | device_check=device_check, | 
|  | loc=loc, | 
|  | cpp_no_default_args=cpp_no_default_args, | 
|  | is_abstract=is_abstract, | 
|  | has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel, | 
|  | has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel, | 
|  | has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel, | 
|  | has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel, | 
|  | tags=tags, | 
|  | namespace=namespace, | 
|  | ), | 
|  | backend_metadata, | 
|  | ) | 
|  |  | 
|  | def validate_unstructured(self) -> None: | 
|  | # TODO: probably better to accumulate these errors and report them all | 
|  | # at once | 
|  | assert not self.structured, ( | 
|  | "This function is structured, but there was " | 
|  | "no valid functional variant of it." | 
|  | ) | 
|  | assert self.structured_delegate, ( | 
|  | "This function delegates to another structured out function, " | 
|  | "but no valid function was found (the delegate may not exist, or it has the wrong type)" | 
|  | ) | 
|  |  | 
|  | # __post_init__ functions in dataclasses can be used to do extra | 
|  | # validation after construction. | 
|  | # | 
|  | # Notice that we don't do any type validation here.  In fact, we | 
|  | # rely exclusively on mypy to check if you've done types correctly! | 
|  | # Validation is for nontrivial invariants that cannot be (conveniently) | 
|  | # encoded in the type system. | 
|  | def __post_init__(self) -> None: | 
|  | if self.func.arguments.out: | 
|  | assert self.variants == {Variant.function}, ( | 
|  | "Native functions with out arguments MUST " | 
|  | "be declared with only function variant; e.g., variants: function; " | 
|  | "otherwise you will tickle a Python argument binding bug " | 
|  | "(which usually manifests itself as the result variable being undefined.)" | 
|  | ) | 
|  | if self.structured: | 
|  | assert self.func.kind() == SchemaKind.out, ( | 
|  | "Put structured field on the out= " | 
|  | "variant of a function; did you mean structured_delegate?" | 
|  | ) | 
|  | assert ( | 
|  | self.device_guard | 
|  | ), "device_guard: False is not respected by structured kernels" | 
|  | if self.structured_delegate: | 
|  | assert self.func.kind() != SchemaKind.out, ( | 
|  | "structured_delegate field not allowed " | 
|  | "on out= functions; did you mean structured?" | 
|  | ) | 
|  | assert ( | 
|  | self.device_guard | 
|  | ), "device_guard: False is not respected by structured kernels" | 
|  | # Technically, with the asserts above, this assert is impossible to | 
|  | # happen | 
|  | assert not ( | 
|  | self.structured and self.structured_delegate | 
|  | ), "Cannot have both structured and structured_delegate on function" | 
|  | defaulted_arguments = { | 
|  | a.name for a in self.func.schema_order_arguments() if a.default is not None | 
|  | } | 
|  | invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments) | 
|  | assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}" | 
|  | if self.structured_inherits is not None: | 
|  | assert ( | 
|  | self.structured | 
|  | ), "structured_inherits must also imply structured: True" | 
|  | if str(self.func.name).startswith("_foreach"): | 
|  | assert self.device_check == DeviceCheckType.NoCheck, ( | 
|  | "foreach kernels fall back to slow path when tensor are on different devices, " | 
|  | "device_check not allowed to be enabled" | 
|  | ) | 
|  |  | 
|  | # NB: if your function accidentally has rand/dropout/... in its name | 
|  | # but is not actually random, feel free to amend this to special case | 
|  | if ( | 
|  | "rand" in str(self.func.name) | 
|  | or ( | 
|  | ( | 
|  | "dropout" in str(self.func.name) | 
|  | or any( | 
|  | "dropout" in arg.name for arg in self.func.arguments.flat_all | 
|  | ) | 
|  | ) | 
|  | # Backwards of dropout is typically deterministic | 
|  | and "backward" not in str(self.func.name) | 
|  | and str(self.func.name.name) not in ["_cudnn_init_dropout_state"] | 
|  | ) | 
|  | or self.func.arguments.has_generator_arg() | 
|  | ): | 
|  | assert "nondeterministic_seeded" in self.tags, str(self.func.name) | 
|  |  | 
|  | @property | 
|  | def has_composite_kernel(self) -> bool: | 
|  | return ( | 
|  | self.has_composite_implicit_autograd_kernel | 
|  | or self.has_composite_explicit_autograd_kernel | 
|  | or self.has_composite_explicit_autograd_non_functional_kernel | 
|  | ) or ( | 
|  | self.has_composite_implicit_autograd_kernel | 
|  | and self.has_composite_implicit_autograd_nested_tensor_kernel | 
|  | ) | 
|  |  | 
|  | @property | 
|  | def is_view_op(self) -> bool: | 
|  | rets = self.func.returns | 
|  | is_non_mutating_view = len(rets) > 0 and any( | 
|  | r.annotation is not None and not r.annotation.is_write for r in rets | 
|  | ) | 
|  | # See Note [resize_ in Functionalization] for more dtails | 
|  | is_inplace_view = ( | 
|  | "inplace_view" in self.tags | 
|  | and str(self.func.name) != "resize_" | 
|  | and str(self.func.name) != "resize_as_" | 
|  | ) | 
|  | is_wildcard_view = any( | 
|  | inp.annotation is not None and "*" in inp.annotation.alias_set_after | 
|  | for inp in self.func.schema_order_arguments() | 
|  | ) | 
|  | return is_non_mutating_view or is_inplace_view or is_wildcard_view | 
|  |  | 
|  | @property | 
|  | def view_schema_kind(self) -> ViewSchemaKind: | 
|  | if self.is_view_op and self.func.name.name.inplace: | 
|  | assert "inplace_view" in self.tags | 
|  | return ViewSchemaKind.aliasing_inplace | 
|  | if self.is_view_op: | 
|  | return ViewSchemaKind.aliasing | 
|  | else: | 
|  | return ViewSchemaKind.non_aliasing | 
|  |  | 
|  | @property | 
|  | def root_name(self) -> str: | 
|  | return self.func.name.name.base | 
|  |  | 
|  | @property | 
|  | def part_of_structured_group(self) -> bool: | 
|  | return self.structured or self.structured_delegate is not None | 
|  |  | 
|  |  | 
|  | class SchemaKind(Enum): | 
|  | functional = auto() | 
|  | inplace = auto() | 
|  | out = auto() | 
|  | mutable = auto() | 
|  | scratch = auto() | 
|  |  | 
|  |  | 
|  | # A structured kernel is guaranteed to have a functional and out variant, and | 
|  | # optionally an inplace variant. | 
|  | # | 
|  | # NB: we create NativeFunctionsGroup *even if* the function is not | 
|  | # actually annotated structured.  Test the structured boolean to see if it | 
|  | # actually is structured or not. | 
|  | @dataclass(frozen=True) | 
|  | class NativeFunctionsGroup: | 
|  | functional: NativeFunction | 
|  | inplace: Optional[NativeFunction] | 
|  | mutable: Optional[NativeFunction] | 
|  | out: NativeFunction | 
|  |  | 
|  | @property | 
|  | def structured(self) -> bool: | 
|  | # Whether or not the operator has a meta() function. This information is backend-agnostic. | 
|  | return self.out.structured | 
|  |  | 
|  | def __post_init__(self) -> None: | 
|  | test_sig: FunctionSchema = self.functional.func.signature() | 
|  | for f in self.functions(): | 
|  | if test_sig != f.func.signature(): | 
|  | raise AssertionError( | 
|  | "NativeFunctionsGroup constructed from two NativeFunctions " | 
|  | f"that don't have matching signatures: {test_sig} != {f.func.signature()}" | 
|  | ) | 
|  |  | 
|  | if self.structured != f.part_of_structured_group: | 
|  | raise AssertionError( | 
|  | "NativeFunctionsGroup constructed from structured and unstructured " | 
|  | f"functions: {self.out.func.name} and {f.func.name}" | 
|  | ) | 
|  | assert self.functional.func.kind() == SchemaKind.functional | 
|  | assert self.out.func.kind() == SchemaKind.out | 
|  | assert self.functional.namespace == self.out.namespace | 
|  | if self.inplace is not None: | 
|  | assert self.inplace.func.kind() == SchemaKind.inplace | 
|  | assert self.inplace.namespace == self.functional.namespace | 
|  |  | 
|  | if self.mutable is not None: | 
|  | assert self.mutable.func.kind() == SchemaKind.mutable | 
|  | assert self.mutable.namespace == self.functional.namespace | 
|  | # See Note [Overload Ambiguity With Functional Variants] | 
|  | assert self.functional.func.name.name.functional_overload | 
|  |  | 
|  | if self.structured: | 
|  | # For now, structured composite kernels are not supported (need some | 
|  | # design work to figure out how to make the composite case work) | 
|  | assert ( | 
|  | not self.out.has_composite_implicit_autograd_kernel | 
|  | and not self.out.has_composite_implicit_autograd_nested_tensor_kernel | 
|  | ) | 
|  |  | 
|  | assert self.functional.structured_delegate == self.out.func.name, ( | 
|  | f"{self.functional.func.name} delegates to {self.functional.structured_delegate} " | 
|  | f"but its actual delegate is {self.out.func.name}" | 
|  | ) | 
|  | if self.inplace is not None: | 
|  | assert self.inplace.structured_delegate == self.out.func.name | 
|  |  | 
|  | generated_fns = sorted( | 
|  | [str(f.func.name) for f in self.functions() if "generated" in f.tags] | 
|  | ) | 
|  | generated_fns_str = ", ".join(str(x) for x in generated_fns) | 
|  | expected_generated_fns: Set[str] = set() | 
|  | for f in self.functions(): | 
|  | expected_generated_fns.update(str(op) for op in f.autogen) | 
|  | expected_generated_fns_str = ", ".join( | 
|  | str(x) for x in sorted(expected_generated_fns) | 
|  | ) | 
|  | if len(expected_generated_fns) == 0 and len(generated_fns) > 0: | 
|  | raise RuntimeError( | 
|  | f"The codegen expects to be able to generate '{generated_fns_str}'." | 
|  | " In order to generate them however, we expect them to be called out explicitly in the yaml." | 
|  | f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}" | 
|  | ) | 
|  | if expected_generated_fns_str != generated_fns_str: | 
|  | raise RuntimeError( | 
|  | f"The codegen expects to be able to generate '{generated_fns_str}'." | 
|  | f" To do so, it expects a line: 'autogen: {generated_fns_str}'." | 
|  | f" Instead, it found 'autogen: {expected_generated_fns_str}'" | 
|  | ) | 
|  |  | 
|  | def signature(self) -> "FunctionSchema": | 
|  | return self.out.func.signature() | 
|  |  | 
|  | def functions(self) -> Iterator[NativeFunction]: | 
|  | yield self.functional | 
|  | yield self.out | 
|  | if self.inplace is not None: | 
|  | yield self.inplace | 
|  | if self.mutable is not None: | 
|  | yield self.mutable | 
|  |  | 
|  | @property | 
|  | def root_name(self) -> str: | 
|  | return self.functional.root_name | 
|  |  | 
|  | @staticmethod | 
|  | def from_dict( | 
|  | d: Dict[SchemaKind, NativeFunction] | 
|  | ) -> Optional["NativeFunctionsGroup"]: | 
|  | assert d | 
|  | if len(d) == 1: | 
|  | return None | 
|  | d = dict(d)  # non-destructive updates please | 
|  | functional = d.pop(SchemaKind.functional, None) | 
|  | inplace = d.pop(SchemaKind.inplace, None) | 
|  | mutable = d.pop(SchemaKind.mutable, None) | 
|  | out = d.pop(SchemaKind.out, None) | 
|  | assert not d | 
|  | assert functional is not None | 
|  | # There are a few operators which only have functional/inplace variants; | 
|  | # these don't count as structured for our purposes here | 
|  | if out is None: | 
|  | return None | 
|  | # assuming all variants have the same namespace | 
|  | return NativeFunctionsGroup( | 
|  | functional=functional, | 
|  | inplace=inplace, | 
|  | mutable=mutable, | 
|  | out=out, | 
|  | ) | 
|  |  | 
|  |  | 
|  | @dataclass(frozen=True) | 
|  | class BackendMetadata: | 
|  | # The name of the backend kernel, for a given operator | 
|  | # for in-tree backends. These names come directly from the 'dispatch" field | 
|  | # in native_functions.yaml. The dispatch entry is optional; in that | 
|  | # case, that is equivalent to having written: | 
|  | # | 
|  | #   dispatch: | 
|  | #       CompositeImplicitAutograd: $operator_name | 
|  | kernel: str | 
|  | # Whether or not the operator has a structured kernel implemented, for this particular backend. | 
|  | # For in-tree backends, they all have the same value for structured- this is listed | 
|  | # in native_functions.yaml. | 
|  | # However, external backends like XLA can indendently toggle which ops are structured. | 
|  | structured: bool | 
|  |  | 
|  | # The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE | 
|  | cpp_namespace: str | 
|  |  | 
|  | def supports_symint(self) -> bool: | 
|  | return "_symint" in self.kernel | 
|  |  | 
|  |  | 
|  | @dataclass(frozen=True) | 
|  | class UfuncInnerLoop: | 
|  | name: str | 
|  | supported_dtypes: OrderedSet[ScalarType] | 
|  | # key is stored here because it affects the semantics of name, | 
|  | # so its helpful to have them together for further processing | 
|  | ufunc_key: UfuncKey | 
|  |  | 
|  | @staticmethod | 
|  | def parse(value: str, ufunc_key: UfuncKey) -> "UfuncInnerLoop": | 
|  | name, supported_dtypes_str = value.split(" ", 1) | 
|  | assert supported_dtypes_str[0] == "(" | 
|  | assert supported_dtypes_str[-1] == ")" | 
|  | supported_dtypes: OrderedSet[ScalarType] = OrderedSet() | 
|  | for k in supported_dtypes_str[1:-1].split(", "): | 
|  | supported_dtypes |= ScalarType.parse_set(k) | 
|  | return UfuncInnerLoop( | 
|  | name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key | 
|  | ) | 
|  |  | 
|  |  | 
|  | # BackendIndex represents a backend. | 
|  | # The BackendIndex encodes per-operator information that is potentially different | 
|  | # for each backend. The most obvious example is the name of the kernel | 
|  | # (the 'dispatch' entry in native_functions.yaml). | 
|  | # However, there can be other examples of different backends having different information. | 
|  | # External backends can choose to opt their kernels to be structured independently from in-tree backends, | 
|  | # which means that this information isn't inherently tied to a NativeFunction- it's different per backend. | 
|  | @dataclass(frozen=True) | 
|  | class BackendIndex: | 
|  | dispatch_key: DispatchKey | 
|  | # Mainly important for structured kernels, this determines which variant in the operator group is used to implement the others. | 
|  | # All in-tree ops use out kernels, while XLA uses functional kernels. | 
|  | use_out_as_primary: bool | 
|  | # Whether the backend requires a device guard, and device checks. | 
|  | # For in-tree backends, this is currently just CUDA/HIP | 
|  | # For out-of-tree backends, this is currently just Intel XPU | 
|  | device_guard: bool | 
|  | # Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA) | 
|  | external: bool | 
|  | # Other backend-specific information that is on a per-operator basis | 
|  | index: Dict["OperatorName", BackendMetadata] | 
|  |  | 
|  | @staticmethod | 
|  | def grow_index( | 
|  | parent_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]], | 
|  | child_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]], | 
|  | ) -> None: | 
|  | for k, v in child_index.items(): | 
|  | for op_name, metadata in v.items(): | 
|  | assert ( | 
|  | op_name not in parent_index[k] | 
|  | ), f"duplicate operator {op_name} for dispatch key {k}" | 
|  | parent_index[k][op_name] = metadata | 
|  |  | 
|  | def primary(self, g: NativeFunctionsGroup) -> NativeFunction: | 
|  | if self.use_out_as_primary: | 
|  | return g.out | 
|  | else: | 
|  | return g.functional | 
|  |  | 
|  | def has_kernel(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool: | 
|  | m = self.get_kernel(g) | 
|  | return m is not None | 
|  |  | 
|  | def get_kernel( | 
|  | self, g: Union[NativeFunction, NativeFunctionsGroup] | 
|  | ) -> Optional[BackendMetadata]: | 
|  | if isinstance(g, NativeFunction): | 
|  | f = g | 
|  | elif isinstance(g, NativeFunctionsGroup): | 
|  | f = self.primary(g) | 
|  | else: | 
|  | assert_never(g) | 
|  | if f.func.name not in self.index: | 
|  | return None | 
|  | return self.index[f.func.name] | 
|  |  | 
|  | def native_function_class_name(self) -> Optional[str]: | 
|  | if self.external: | 
|  | return f"{str(self.dispatch_key)}NativeFunctions" | 
|  | else: | 
|  | # TODO: This discrepancy isn't required; we could also generated | 
|  | # a class for in-tree kernels. It'll just require carefully | 
|  | # updating every kernel definition + callsite of every in-tree aten kernel. | 
|  | return None | 
|  |  | 
|  |  | 
|  | # The function schema is undoubtedly the most important data structure | 
|  | # in all of the codegen, as it defines the type signature for operators, | 
|  | # and most of the code generation we do is type directed (e.g., look at | 
|  | # the types, decide what to do.  Think about how we code generate | 
|  | # C++ function stubs!) | 
|  | # | 
|  | # We will also see in this class the general structure for how we model | 
|  | # data in this code generation.  A few notable properties to point out | 
|  | # ahead of time: | 
|  | # | 
|  | #   - These dataclasses are a *lossless* representation of the strings | 
|  | #     they are parsed from.  In fact, we assert that given the | 
|  | #     information stored in the dataclass, we can exactly reconstruct | 
|  | #     the string we parsed from (and assert this inside the parse | 
|  | #     definition).  There are a few reasons for this: | 
|  | # | 
|  | #       - If you find that it is difficult to reconstruct the string | 
|  | #         given a dataclass, that is a clue that you are data | 
|  | #         representation is wrong. | 
|  | # | 
|  | #       - It helps ensure that all relevant information is present | 
|  | #         in the dataclass, so that downstream users aren't tempted | 
|  | #         to reparse the original string to get some information | 
|  | #         that was omitted. | 
|  | # | 
|  | #       - It forces you to represent the data in-memory in the same way | 
|  | #         it is recorded textually, which makes the dataclasses easier | 
|  | #         to understand for someone who is familiar with the | 
|  | #         textual format.  (As a tradeoff, it means you have to model | 
|  | #         the syntax, even when it is inconvenient.  But maybe that means | 
|  | #         the syntax is bad!)  If you don't understand the internal | 
|  | #         representation, go look at the printing code to see how | 
|  | #         it maps onto the surface syntax! | 
|  | # | 
|  | #       - It makes it easy to test the parsing code, as parsing code | 
|  | #         that is inconsistent with the string code will fail early | 
|  | #         and loudly.  (As a tradeoff, it makes the parsing code a bit | 
|  | #         brittle (in particular, with trivial whitespace changes you | 
|  | #         are likely to trigger an assert error). | 
|  | # | 
|  | #     In general, try to make the __str__ code as simple as possible | 
|  | #     (even at the cost of more complex parsing logic.)  Additionally, | 
|  | #     try to minimize redundancy in data representation.  (Precomputed | 
|  | #     fields are OK though: they are defined as a simple function on | 
|  | #     the canonical representation in question.) | 
|  | # | 
|  | #   - These dataclasses are all frozen; once constructed their | 
|  | #     values never change.  This makes it easy to tell where any | 
|  | #     given data came from: just look to the constructor.  As a | 
|  | #     tradeoff, you can't easily "decorate" a schema with extra | 
|  | #     information from a post-facto analysis.  We impose this | 
|  | #     restriction to make these structures more understandable. | 
|  | # | 
|  | @dataclass(frozen=True) | 
|  | class FunctionSchema: | 
|  | # The name of the operator this function schema describes. | 
|  | name: "OperatorName" | 
|  |  | 
|  | arguments: "Arguments" | 
|  |  | 
|  | # TODO: Need to handle collisions with argument names at some point | 
|  | returns: Tuple["Return", ...] | 
|  |  | 
|  | def schema_order_arguments(self) -> Iterator["Argument"]: | 
|  | return itertools.chain( | 
|  | self.arguments.flat_positional, | 
|  | self.arguments.flat_kwarg_only, | 
|  | self.arguments.out, | 
|  | ) | 
|  |  | 
|  | decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)") | 
|  |  | 
|  | @staticmethod | 
|  | def parse(func: str) -> "FunctionSchema": | 
|  | # We should probably get a proper parser here | 
|  | decls = FunctionSchema.decl_re.findall(func) | 
|  | assert len(decls) == 1, f"Invalid function schema: {func}" | 
|  | ops, args, return_decl = decls[0] | 
|  | name = OperatorName.parse(ops) | 
|  | arguments = Arguments.parse(args) | 
|  | returns = parse_returns(return_decl) | 
|  | r = FunctionSchema(name=name, arguments=arguments, returns=returns) | 
|  | assert str(r) == func, f"{str(r)} != {func}" | 
|  | return r | 
|  |  | 
|  | def returns_are_aliased(self) -> bool: | 
|  | # We assert earlier that schemas can't have a mix of aliased and non-aliased returns | 
|  | return any( | 
|  | r | 
|  | for r in self.returns | 
|  | if r.annotation is not None and r.annotation.is_write | 
|  | ) | 
|  |  | 
|  | def __post_init__(self) -> None: | 
|  | for arg, ret in zip(self.arguments.out, self.returns): | 
|  | assert arg.annotation == ret.annotation, ( | 
|  | "Out arguments must have matching return Tensor; furthermore, " | 
|  | "the ith-argument needs to correspond to the ith return" | 
|  | ) | 
|  | # We also enforce that if you have any mutable, positional args, then they are not returned. | 
|  | # This makes it easier to group these functions properly with their functional/out= counterparts. | 
|  | for a in self.arguments.post_self_positional_mutable: | 
|  | assert not any( | 
|  | a.annotation == r.annotation for r in self.returns | 
|  | ), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}" | 
|  | # Invariant: we expect out arguments to appear as keyword arguments in the schema. | 
|  | # This means that all mutable returns should be aliased to a keyword argument | 
|  | # (except for "self", which we explicitly don't treat as an out argument because of its use in methods) | 
|  | # See Note [is_out_fn] | 
|  | out_and_self = list(self.arguments.out) + [ | 
|  | arg for arg in self.arguments.flat_positional if arg.name == "self" | 
|  | ] | 
|  | mutable_returns = [ | 
|  | ret | 
|  | for ret in self.returns | 
|  | if ret.annotation is not None and ret.annotation.is_write | 
|  | ] | 
|  | immutable_returns = [ | 
|  | ret | 
|  | for ret in self.returns | 
|  | if ret.annotation is None or not ret.annotation.is_write | 
|  | ] | 
|  | # Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)", | 
|  | # because: | 
|  | # (1) It's more annoying to handle properly | 
|  | # (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple. | 
|  | # Instead, we expect the (a!) argument to not be returned. | 
|  | assert ( | 
|  | len(mutable_returns) == 0 or len(immutable_returns) == 0 | 
|  | ), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}" | 
|  | for ret in mutable_returns: | 
|  | assert any(ret.annotation == arg.annotation for arg in out_and_self), ( | 
|  | 'All mutable returns must be aliased either to a keyword argument, or to "self". ' | 
|  | "Did you forget to mark an out argument as keyword-only?" | 
|  | ) | 
|  | if self.arguments.out: | 
|  | # out= ops that return their mutable inputs are only really useful for method chaining. | 
|  | # And method chaining is only really useful if the thing you're returning is a plain Tensor. | 
|  | # So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor, | 
|  | # and all other types of out= op schemas should return void. | 
|  | # There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that. | 
|  | if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out): | 
|  | assert ( | 
|  | len(self.returns) == 0 | 
|  | ), "out= ops that accept tensor lists as out arguments " | 
|  | "are expected to have no return type (since you can't do method chaining on them)" | 
|  | else: | 
|  | # mutable keyword arguments whose name has _scratch_ prefix are | 
|  | # scratch tensors for memory planning and should not be returned | 
|  | assert len( | 
|  | [ | 
|  | arg | 
|  | for arg in self.arguments.out | 
|  | if not arg.name.startswith("_scratch_") | 
|  | ] | 
|  | ) == len( | 
|  | self.returns | 
|  | ), "Must return as many arguments as there are out arguments, or no return at all" | 
|  |  | 
|  | if self.name.name.inplace: | 
|  | self_a = self.arguments.self_arg | 
|  | assert ( | 
|  | self_a | 
|  | and self_a.argument.annotation | 
|  | and self_a.argument.annotation.is_write | 
|  | ) | 
|  | if self_a.argument.type == BaseType(BaseTy.Tensor): | 
|  | # All inplace ops with an ordinary `Tensor self` argument should return self, | 
|  | # to allow for method chaining. | 
|  | assert ( | 
|  | len(self.returns) == 1 | 
|  | and self.returns[0].annotation == self_a.argument.annotation | 
|  | ) | 
|  | else: | 
|  | # You can't method chain on non-tensor self arguments though (like a List[Tensor]) | 
|  | # so in all other cases we expect the return type to be none. | 
|  | assert len(self.returns) == 0 | 
|  |  | 
|  | if self.arguments.tensor_options is not None: | 
|  | assert self.kind() == SchemaKind.functional, ( | 
|  | "Found an operator that is not functional or out variant, but has tensor options arguments." | 
|  | "This is not allowed- tensor options arguments are only allowed for factory functions." | 
|  | f"schema: {str(self)}" | 
|  | ) | 
|  | if self.is_functional_fn(): | 
|  | assert self.kind() == SchemaKind.functional, ( | 
|  | "Found an operator that is not functional, but its overload contains the string 'functional'." | 
|  | "This is a special keyword in the codegen, please use a different overload name." | 
|  | f"schema: {str(self)}" | 
|  | ) | 
|  |  | 
|  | def is_functional_fn(self) -> bool: | 
|  | return "functional" in self.name.overload_name | 
|  |  | 
|  | def is_out_fn(self) -> bool: | 
|  | # Note [is_out_fn] | 
|  | # | 
|  | # out functions are the variants which take an explicit out= argument | 
|  | # to populate into.  We need to know if a schema corresponds to an | 
|  | # out function for several reasons: | 
|  | # | 
|  | #   - They codegen differently in C++ API | 
|  | #       - codegen to at::add_out rather than at::add | 
|  | #       - out argument is moved to front of C++ argument list | 
|  | # | 
|  | # out functions are DEFINED to be any function with a keyword-only | 
|  | # argument that is mutable.  In principle, this could lead to a | 
|  | # false positive if you define a function that mutates a | 
|  | # kwarg only argument, but this isn't the "true" output of this | 
|  | # function.  A more robust definition that would work in this | 
|  | # case would also look at: | 
|  | # | 
|  | #   - The output types.  Out functions take in the arguments | 
|  | #     they mutate and then return them again; this is sort | 
|  | #     of "definitionally" what makes something an out function. | 
|  | #     Historically, we DO check this for consistency. | 
|  | #   - Correspondence with pure variant.  An out function | 
|  | #     should have a signature equivalent to its pure variant, | 
|  | #     but just with extra kwargs for the output elements.  This | 
|  | #     is difficult to actually check for and historically | 
|  | #     we only do this check in tools/ | 
|  | return bool(self.arguments.out) | 
|  |  | 
|  | def kind(self) -> SchemaKind: | 
|  | """ | 
|  | What kind of schema is this?  A functional schema is one | 
|  | that returns a newly allocated output; an inplace schema | 
|  | modifies the self argument inplace; an out schema writes | 
|  | the result into an explicitly provided out argument. | 
|  | """ | 
|  | is_out = bool(self.arguments.out) | 
|  | is_scratch = bool( | 
|  | [arg for arg in self.arguments.out if arg.name.startswith("_scratch_")] | 
|  | ) | 
|  | is_inplace = self.name.name.inplace | 
|  | is_mutable = any( | 
|  | a.annotation is not None and a.annotation.is_write | 
|  | for a in self.arguments.post_self_positional | 
|  | ) | 
|  | assert not (is_out and is_inplace) | 
|  | # out= and inplace schemas can also have post_self_positional mutable args, | 
|  | # but we give precedence to out= and inplace when deciding the schema kind. | 
|  | # Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops | 
|  | # to also worry about mutable post_self_positional arguments, | 
|  | # but it seems like a much bigger lift to classify them has having a new schema kind. | 
|  | # The number of ops that fit in this strange category is small enough that | 
|  | # we can probably manually write code for them instead of forcing the codegen to handle them. | 
|  | if is_inplace: | 
|  | return SchemaKind.inplace | 
|  | elif is_scratch: | 
|  | assert ( | 
|  | is_out | 
|  | ), "invariant: all scratch operators are expected to be out= operators too" | 
|  | return SchemaKind.scratch | 
|  | elif is_out: | 
|  | assert ( | 
|  | not is_scratch | 
|  | ), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!" | 
|  | return SchemaKind.out | 
|  | elif is_mutable: | 
|  | return SchemaKind.mutable | 
|  | else: | 
|  | return SchemaKind.functional | 
|  |  | 
|  | # For every return: | 
|  | # - If the return aliases an input, we return the input name | 
|  | # - Otherwise, we return None. | 
|  | # If return names were enforced to be consistent with aliasing information, then we wouldn't need this. | 
|  | def aliased_return_names(self) -> List[Optional[str]]: | 
|  | outs: List[Optional[str]] = [] | 
|  | for r in self.returns: | 
|  | aliased_args = [ | 
|  | a | 
|  | for a in self.arguments.flat_all | 
|  | if a.annotation is not None and a.annotation == r.annotation | 
|  | ] | 
|  | if len(aliased_args) == 0: | 
|  | outs.append(None) | 
|  | elif len(aliased_args) == 1: | 
|  | outs.append(aliased_args[0].name) | 
|  | else: | 
|  | aliased_names = ", ".join(a.name for a in aliased_args) | 
|  | raise AssertionError( | 
|  | f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})" | 
|  | ) | 
|  | return outs | 
|  |  | 
|  | def signature( | 
|  | self, | 
|  | *, | 
|  | strip_default: bool = False, | 
|  | strip_view_copy_name: bool = False, | 
|  | keep_return_names: bool = False, | 
|  | ) -> "FunctionSchema": | 
|  | """ | 
|  | Certain schemas are 'related', in that they are simply | 
|  | inplace/out/functional versions of the same function.  This method | 
|  | factors these schemas into the "core" functional signature which | 
|  | is equal across all versions. | 
|  |  | 
|  | Here is what normalization happens to the schema to convert | 
|  | it to a signature: | 
|  | - The overload name is stripped (name is retained, since | 
|  | it expresses semantic content about what the function does) | 
|  | - Inplace is set False | 
|  | - Out arguments are stripped | 
|  | - Mutable post_self_positional args are converted to returns | 
|  | - Mutability annotations are stripped  (this is sound | 
|  | because you cannot overload on mutability annotation) | 
|  | - Return names are stripped since they are not overloadable and | 
|  | some variants have return names but some not | 
|  | - TensorOptions are dropped | 
|  | because out= variants of factory functions don't include them | 
|  | (and we want to be able to pair up factory functions with their out variants) | 
|  |  | 
|  | Finally, we want to be able to pair up related "view" and their | 
|  | corresponding "view_copy" operators. We do this by optionally | 
|  | stripping the trailing "_copy" from the base name. | 
|  |  | 
|  | Example of a mutable op before and after: | 
|  |  | 
|  | f.func (Mutable operator): | 
|  | _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)  # noqa: B950 | 
|  |  | 
|  | f.func (Corresponding functional operator): | 
|  | _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out)  # noqa: B950 | 
|  |  | 
|  | f.func.signature() output: | 
|  | _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)  # noqa: B950 | 
|  | """ | 
|  |  | 
|  | def strip_ret_annotation(r: Return) -> Return: | 
|  | return Return( | 
|  | name=r.name if keep_return_names else None, | 
|  | type=r.type, | 
|  | annotation=None, | 
|  | ) | 
|  |  | 
|  | base_name = self.name.name.base | 
|  | if strip_view_copy_name and base_name.endswith("_copy"): | 
|  | base_name = base_name.replace("_copy", "") | 
|  |  | 
|  | # find mutable inputs that are not originally returned, and convert them to returns | 
|  | returns_from_mutable_inputs = tuple( | 
|  | # When we're grouping functions we strip the return names, | 
|  | # but when we're generating the actual functional variants then we follow | 
|  | # a convention for what to name the returns | 
|  | Return( | 
|  | name=f"{a.name}_out" if keep_return_names else None, | 
|  | type=a.type, | 
|  | annotation=None, | 
|  | ) | 
|  | for a in itertools.chain( | 
|  | # Order is important here (otherwise e.g. inplace with mutable args | 
|  | # and out= with mutable args won't have the same signature) | 
|  | [self.arguments.self_arg.argument] | 
|  | if self.arguments.self_arg is not None | 
|  | else [], | 
|  | self.arguments.out, | 
|  | self.arguments.post_self_positional, | 
|  | ) | 
|  | if a.annotation is not None | 
|  | and a.annotation.is_write | 
|  | and not any(a.annotation == r.annotation for r in self.returns) | 
|  | ) | 
|  | original_returns = tuple(map(strip_ret_annotation, self.returns)) | 
|  | # Ordering is important here. We expect the "mutable input" returns to come last. | 
|  | returns = original_returns + returns_from_mutable_inputs | 
|  |  | 
|  | args_sig = self.arguments.signature(strip_default=strip_default) | 
|  | # See Note [bernoulli.p schema] | 
|  | if str(self.name) == "bernoulli.p": | 
|  | args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5")) | 
|  |  | 
|  | return FunctionSchema( | 
|  | name=OperatorName( | 
|  | name=BaseOperatorName( | 
|  | base=base_name, | 
|  | inplace=False, | 
|  | dunder_method=self.name.name.dunder_method, | 
|  | ), | 
|  | overload_name="",  # stripped | 
|  | ), | 
|  | arguments=args_sig, | 
|  | returns=returns, | 
|  | ) | 
|  |  | 
|  | def view_signature(self) -> "FunctionSchema": | 
|  | return self.signature(strip_view_copy_name=True) | 
|  |  | 
|  | def with_name(self, name: "OperatorName") -> "FunctionSchema": | 
|  | return FunctionSchema( | 
|  | name=name, | 
|  | arguments=self.arguments, | 
|  | returns=self.returns, | 
|  | ) | 
|  |  | 
|  | @property | 
|  | def modifies_arguments(self) -> bool: | 
|  | return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable] | 
|  |  | 
|  | def has_symint(self) -> bool: | 
|  | return self.arguments.has_symint_arg() | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | all_arguments_str = str(self.arguments) | 
|  | if len(self.returns) == 1: | 
|  | returns = str(self.returns[0])  # omit parentheses | 
|  | else: | 
|  | returns = "(" + ", ".join(map(str, self.returns)) + ")" | 
|  | return f"{self.name}({all_arguments_str}) -> {returns}" | 
|  |  | 
|  |  | 
|  | # Here is the rest of the data model, described more briefly. | 
|  |  | 
|  |  | 
|  | # Simplified version for what actually shows up in built-ins. | 
|  | # Look at alias_info.h for expanded syntax.  If you need the structure, | 
|  | # you also need to make this structure recursive so it can be lined | 
|  | # up with the type components too.  For primitives this isn't really | 
|  | # necessary | 
|  | @dataclass(frozen=True) | 
|  | class Annotation: | 
|  | # Typically only has one element.  Not actually a set so | 
|  | # we can conveniently assume it is canonically ordered | 
|  | alias_set: Tuple[str, ...] | 
|  | is_write: bool | 
|  | alias_set_after: Tuple[str, ...] | 
|  |  | 
|  | @staticmethod | 
|  | def parse(ann: str) -> "Annotation": | 
|  | # TODO: implement a proper parser if this gets more ugly | 
|  | # Regex Explanation: | 
|  | # Example: "a! -> a|b" | 
|  | # Group #1: alias before optional '|', required. Matches the first | 
|  | #   character 'a' in the example | 
|  | # Group #2: optional alias set after optional '|', matches empty string | 
|  | #   in the example | 
|  | # Group #3: optional "is write" flag, matches '!' in the example. | 
|  | # Group #4: optional section containing arrow, matches " -> a|b" in the | 
|  | #   example. | 
|  | # Group #5: optional alias after set, supports wildcard, matches "a|b" | 
|  | #   in the example. | 
|  | # Group #6: optional sub-section of alias after set, matches "|b" in the | 
|  | #   example. | 
|  | m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann) | 
|  |  | 
|  | assert m is not None, f"unrecognized alias annotation {ann}" | 
|  | before_alias = m.group(1) + (m.group(2) if m.group(2) else "") | 
|  | alias_set = tuple(before_alias.split("|")) | 
|  | is_write = m.group(3) == "!" | 
|  | assert not ( | 
|  | is_write and len(alias_set) > 1 | 
|  | ), f"alias set larger than 1 is not mutable, got {ann} instead." | 
|  | after_set = tuple(m.group(5).split("|")) if m.group(5) else tuple() | 
|  | assert not ( | 
|  | len(before_alias) > 1 and len(after_set) > 1 | 
|  | ), f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead." | 
|  | r = Annotation( | 
|  | alias_set=alias_set, is_write=is_write, alias_set_after=after_set | 
|  | ) | 
|  | assert str(r) == ann, f"{r} != {ann}" | 
|  | return r | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | alias_set = "|".join(self.alias_set) | 
|  | if self.is_write: | 
|  | alias_set = f"{alias_set}!" | 
|  | alias_set_after = "|".join(self.alias_set_after) | 
|  | if alias_set_after: | 
|  | alias_set = f'{alias_set}{" -> "}{alias_set_after}' | 
|  | return alias_set | 
|  |  | 
|  |  | 
|  | # The base class for the type system.  This is also loosely modeled | 
|  | # off of jit_type.h, but we've simplified the hierarchy to focus | 
|  | # in on the aspects of the type system that matter for code generation | 
|  | # (for example, there's no SingleElementType subclass anymore). | 
|  | # You never actually construct a Type; usually it's going to be one | 
|  | # of the subclasses.  If Python had ADTs this would be one! | 
|  | @dataclass(frozen=True) | 
|  | class Type: | 
|  | @staticmethod | 
|  | def parse(t: str) -> "Type": | 
|  | r = Type._parse(t) | 
|  | assert str(r) == t, f"{r} != {t}" | 
|  | return r | 
|  |  | 
|  | @staticmethod | 
|  | def _parse(t: str) -> "Type": | 
|  | m = re.match(r"^(.+)\?$", t) | 
|  | if m is not None: | 
|  | return OptionalType(Type.parse(m.group(1))) | 
|  | m = re.match(r"^(.+)\[([0-9]+)?\]$", t) | 
|  | if m is not None: | 
|  | size = int(m.group(2)) if m.group(2) is not None else None | 
|  | return ListType(elem=Type.parse(m.group(1)), size=size) | 
|  |  | 
|  | # '__torch__.torch.classes.' is the prefix for custom class | 
|  | m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t) | 
|  | if m is not None: | 
|  | return CustomClassType(m.group(1)) | 
|  | try: | 
|  | return BaseType(BaseTy[t]) | 
|  | except KeyError as e: | 
|  | raise RuntimeError(f"unrecognized type {t}") from e | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | raise NotImplementedError | 
|  |  | 
|  | # WARNING: These concepts are not very well-defined.  For example, | 
|  | # is "int?" nullable? How about "int?[]".  They are defined | 
|  | # so we can conveniently generate legacy Declarations.yaml but | 
|  | # really we should probably just remove these at some point | 
|  |  | 
|  | def is_base_ty_like(self, base_ty: "BaseTy") -> bool: | 
|  | raise NotImplementedError | 
|  |  | 
|  | def is_tensor_like(self) -> bool: | 
|  | return self.is_base_ty_like(BaseTy.Tensor) | 
|  |  | 
|  | def is_generator_like(self) -> bool: | 
|  | return self.is_base_ty_like(BaseTy.Generator) | 
|  |  | 
|  | def is_symint_like(self) -> bool: | 
|  | return self.is_base_ty_like(BaseTy.SymInt) | 
|  |  | 
|  | def is_nullable(self) -> bool: | 
|  | raise NotImplementedError | 
|  |  | 
|  | def is_list_like(self) -> Optional["ListType"]: | 
|  | raise NotImplementedError | 
|  |  | 
|  |  | 
|  | # Base types are simple, atomic types with no further structure | 
|  | class BaseTy(Enum): | 
|  | Generator = auto() | 
|  | ScalarType = auto() | 
|  | Tensor = auto() | 
|  | int = auto() | 
|  | Dimname = auto() | 
|  | DimVector = auto() | 
|  | float = auto() | 
|  | str = auto() | 
|  | bool = auto() | 
|  | Layout = auto() | 
|  | Device = auto() | 
|  | DeviceIndex = auto() | 
|  | Scalar = auto() | 
|  | MemoryFormat = auto() | 
|  | QScheme = auto() | 
|  | Storage = auto() | 
|  | Stream = auto() | 
|  | SymInt = auto() | 
|  | ConstQuantizerPtr = auto()  # TODO: rename | 
|  |  | 
|  |  | 
|  | @dataclass(frozen=True) | 
|  | class BaseType(Type): | 
|  | name: BaseTy | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | return f"{self.name.name}" | 
|  |  | 
|  | def is_base_ty_like(self, base_ty: BaseTy) -> bool: | 
|  | return self.name == base_ty | 
|  |  | 
|  | def is_nullable(self) -> bool: | 
|  | return False | 
|  |  | 
|  | def is_list_like(self) -> Optional["ListType"]: | 
|  | return None | 
|  |  | 
|  | def is_symint_like(self) -> bool: | 
|  | return self.name == BaseTy.SymInt | 
|  |  | 
|  |  | 
|  | # Optional types may be specified, or may also be validly given None | 
|  | @dataclass(frozen=True) | 
|  | class OptionalType(Type): | 
|  | elem: Type | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | return f"{self.elem}?" | 
|  |  | 
|  | def is_base_ty_like(self, base_ty: BaseTy) -> bool: | 
|  | return self.elem.is_base_ty_like(base_ty) | 
|  |  | 
|  | def is_symint_like(self) -> bool: | 
|  | return self.elem.is_symint_like() | 
|  |  | 
|  | def is_nullable(self) -> bool: | 
|  | return True | 
|  |  | 
|  | def is_list_like(self) -> Optional["ListType"]: | 
|  | return self.elem.is_list_like() | 
|  |  | 
|  |  | 
|  | # A type representing a PyTorch custom class | 
|  | @dataclass(frozen=True) | 
|  | class CustomClassType(Type): | 
|  | class_name: str | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | """ | 
|  | Return the class name will prefix __torch__.torch.classes | 
|  | """ | 
|  | return f"__torch__.torch.classes.{self.class_name}" | 
|  |  | 
|  | def is_base_ty_like(self, base_ty: BaseTy) -> bool: | 
|  | return False | 
|  |  | 
|  | def is_symint_like(self) -> bool: | 
|  | return False | 
|  |  | 
|  | def is_nullable(self) -> bool: | 
|  | """ | 
|  | Assume a custom class is not nullable. | 
|  | """ | 
|  | return False | 
|  |  | 
|  | def is_list_like(self) -> Optional["ListType"]: | 
|  | return None | 
|  |  | 
|  |  | 
|  | # List types specify that we may have multiples of an element.  We | 
|  | # also support explicit sizes on list types, but these have | 
|  | # some nontrivial semantics!  (However, for C++ API purposes, explicit | 
|  | # sizes are mostly erased from the type system.) | 
|  | # | 
|  | # DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g., | 
|  | # int[] elaborates differently than bool[3]! | 
|  | @dataclass(frozen=True) | 
|  | class ListType(Type): | 
|  | elem: Type | 
|  | size: Optional[int] | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | size = f"{self.size}" if self.size else "" | 
|  | return f"{self.elem}[{size}]" | 
|  |  | 
|  | def is_base_ty_like(self, base_ty: BaseTy) -> bool: | 
|  | return self.elem.is_base_ty_like(base_ty) | 
|  |  | 
|  | def is_symint_like(self) -> bool: | 
|  | return self.elem.is_symint_like() | 
|  |  | 
|  | def is_nullable(self) -> bool: | 
|  | return self.elem.is_nullable() | 
|  |  | 
|  | def is_list_like(self) -> Optional["ListType"]: | 
|  | return self | 
|  |  | 
|  |  | 
|  | @dataclass(frozen=True) | 
|  | class Argument: | 
|  | # NB: I didn't put kwarg_only as a boolean field here, unlike | 
|  | # c10::Argument, so that printing works correctly | 
|  |  | 
|  | name: str | 
|  | type: Type | 
|  | default: Optional[str] | 
|  |  | 
|  | # The semantics of the annotation field are a little strange. | 
|  | # | 
|  | # Alias annotations parametrize Tensors (since Tensors are the only things | 
|  | # that can alias.)  This motivates why I write Tensor(a!)?  (and not, for | 
|  | # example, Tensor?(a!)), because the (a!) describes aliasing on the tensor, | 
|  | # which may be optional (i.e., the alias annotation should bind first to | 
|  | # Tensor, before the optional postfix annotation). | 
|  | # | 
|  | # However, despite being a property of Tensor, we (and c10::Argument) | 
|  | # store the annotation at the top level of the Argument, rather than | 
|  | # inside the embedded Tensor type.  In the C++ version of this | 
|  | # class, we then go through great lengths to mimic the type | 
|  | # structure in the annotation structure so we can correlate | 
|  | # annotations with types. | 
|  | # | 
|  | # Now, it turns out, in all applications in code generation, the | 
|  | # structure of annotated types is very simple.  So we just hard | 
|  | # code it here.  But if we ever do get anything more complex, this | 
|  | # model will have to change! | 
|  | annotation: Optional[Annotation] | 
|  |  | 
|  | @staticmethod | 
|  | def parse(arg: str) -> "Argument": | 
|  | name: str | 
|  | default: Optional[str] | 
|  | type_and_annot, name_and_default = arg.rsplit(" ", 1) | 
|  | if "=" in name_and_default: | 
|  | name, default = name_and_default.split("=") | 
|  | else: | 
|  | name = name_and_default | 
|  | default = None | 
|  | # TODO: deduplicate annotation matching with Return | 
|  | match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) | 
|  | annotation: Optional[Annotation] | 
|  | if match: | 
|  | # If you update this, make sure the __str__ still works too | 
|  | assert match.group(2) in [ | 
|  | "", | 
|  | "?", | 
|  | "[]", | 
|  | ], "unrecognized alias analysis form with Tensor" | 
|  | type_s = "Tensor" + match.group(2) | 
|  | annotation = Annotation.parse(match.group(1)) | 
|  | else: | 
|  | type_s = type_and_annot | 
|  | annotation = None | 
|  | type = Type.parse(type_s) | 
|  | r = Argument( | 
|  | name=name, | 
|  | type=type, | 
|  | default=default, | 
|  | annotation=annotation, | 
|  | ) | 
|  | assert str(r) == arg, f"{str(r)} != {arg}" | 
|  | return r | 
|  |  | 
|  | @property | 
|  | def is_write(self) -> bool: | 
|  | return self.annotation is not None and self.annotation.is_write | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | type = f"{self.type}" | 
|  | if self.annotation: | 
|  | assert type in ["Tensor", "Tensor?", "Tensor[]"] | 
|  | type = type.replace("Tensor", f"Tensor({self.annotation})") | 
|  | if self.name is None: | 
|  | return type | 
|  | else: | 
|  | mb_default = "" | 
|  | if self.default: | 
|  | mb_default = f"={self.default}" | 
|  | return f"{type} {self.name}{mb_default}" | 
|  |  | 
|  |  | 
|  | @dataclass(frozen=True) | 
|  | class Return: | 
|  | name: Optional[str] | 
|  | type: Type | 
|  | annotation: Optional[Annotation] | 
|  |  | 
|  | @staticmethod | 
|  | def parse(arg: str) -> "Return": | 
|  | name: Optional[str] | 
|  | if " " in arg: | 
|  | type_and_annot, name = arg.rsplit(" ", 1) | 
|  | else: | 
|  | type_and_annot = arg | 
|  | name = None | 
|  | match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) | 
|  | annotation: Optional[Annotation] | 
|  | if match: | 
|  | # If you update this, make sure the __str__ still works too | 
|  | assert match.group(2) in [ | 
|  | "", | 
|  | "?", | 
|  | "[]", | 
|  | ], "unrecognized alias analysis form with Tensor" | 
|  | type_s = "Tensor" + match.group(2) | 
|  | annotation = Annotation.parse(match.group(1)) | 
|  | else: | 
|  | type_s = type_and_annot | 
|  | annotation = None | 
|  | type = Type.parse(type_s) | 
|  | r = Return( | 
|  | name=name, | 
|  | type=type, | 
|  | annotation=annotation, | 
|  | ) | 
|  | assert str(r) == arg, f"{str(r)} != {arg}" | 
|  | return r | 
|  |  | 
|  | @property | 
|  | def is_write(self) -> bool: | 
|  | return self.annotation is not None and self.annotation.is_write | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | type = f"{self.type}" | 
|  | if self.annotation: | 
|  | assert type in ["Tensor", "Tensor?", "Tensor[]"] | 
|  | type = type.replace("Tensor", f"Tensor({self.annotation})") | 
|  | if self.name is None: | 
|  | return type | 
|  | else: | 
|  | return f"{type} {self.name}" | 
|  |  | 
|  |  | 
|  | # Represents the self argument for functions that may be methods | 
|  | @dataclass(frozen=True) | 
|  | class SelfArgument: | 
|  | argument: Argument | 
|  |  | 
|  |  | 
|  | # Bundle of arguments that represent a TensorOptions.  This is mostly | 
|  | # relevant for the public C++ API but we bake it into the core data | 
|  | # model because other APIs often have to interact with it | 
|  | @dataclass(frozen=True) | 
|  | class TensorOptionsArguments: | 
|  | dtype: Argument | 
|  | layout: Argument | 
|  | device: Argument | 
|  | pin_memory: Argument | 
|  |  | 
|  | def all(self) -> Sequence[Argument]: | 
|  | return [self.dtype, self.layout, self.device, self.pin_memory] | 
|  |  | 
|  |  | 
|  | @dataclass(frozen=True) | 
|  | class Arguments: | 
|  | # pre_self_positional is usually empty, but is notably non-empty | 
|  | # for where.self, where the condition argument comes before the | 
|  | # self argument | 
|  | pre_self_positional: Tuple[Argument, ...] | 
|  | self_arg: Optional[SelfArgument] | 
|  | post_self_positional: Tuple[Argument, ...] | 
|  |  | 
|  | pre_tensor_options_kwarg_only: Tuple[Argument, ...] | 
|  | tensor_options: Optional[TensorOptionsArguments] | 
|  | # post_tensor_options is typically memory format, which should be | 
|  | # part of tensor options but isn't right now, and is usually | 
|  | # placed after the tensor options arguments | 
|  | post_tensor_options_kwarg_only: Tuple[Argument, ...] | 
|  |  | 
|  | # Unlike in the previous codegen, we have factored out 'out' arguments | 
|  | # in the canonical representation, removing them from kwarg | 
|  | # arguments.  This choice is justified by numerous downstream | 
|  | # transformations which treat out arguments specially; additionally, | 
|  | # you can see that canonicity is not violated! | 
|  | out: Tuple[Argument, ...]  # these are also kwarg-only | 
|  |  | 
|  | @property | 
|  | def flat_non_out(self) -> Sequence[Argument]: | 
|  | ret: List[Argument] = [] | 
|  | ret.extend(self.flat_positional) | 
|  | ret.extend(self.flat_kwarg_only) | 
|  | return ret | 
|  |  | 
|  | @property | 
|  | def flat_positional(self) -> Sequence[Argument]: | 
|  | ret: List[Argument] = [] | 
|  | ret.extend(self.pre_self_positional) | 
|  | if self.self_arg is not None: | 
|  | ret.append(self.self_arg.argument) | 
|  | ret.extend(self.post_self_positional) | 
|  | return ret | 
|  |  | 
|  | @property | 
|  | def post_self_positional_mutable(self) -> Sequence[Argument]: | 
|  | return [a for a in self.post_self_positional if a.is_write] | 
|  |  | 
|  | # NB: doesn't contain out arguments | 
|  | @property | 
|  | def flat_kwarg_only(self) -> Sequence[Argument]: | 
|  | ret: List[Argument] = [] | 
|  | ret.extend(self.pre_tensor_options_kwarg_only) | 
|  | if self.tensor_options is not None: | 
|  | ret.extend(self.tensor_options.all()) | 
|  | ret.extend(self.post_tensor_options_kwarg_only) | 
|  | return ret | 
|  |  | 
|  | @property | 
|  | def flat_all(self) -> Sequence[Argument]: | 
|  | ret: List[Argument] = [] | 
|  | ret.extend(self.flat_positional) | 
|  | ret.extend(self.flat_kwarg_only) | 
|  | ret.extend(self.out) | 
|  | return ret | 
|  |  | 
|  | @property | 
|  | def non_out( | 
|  | self, | 
|  | ) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]: | 
|  | ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = [] | 
|  | ret.extend(self.positional) | 
|  | ret.extend(self.kwarg_only) | 
|  | return ret | 
|  |  | 
|  | @property | 
|  | def positional(self) -> Sequence[Union[Argument, SelfArgument]]: | 
|  | ret: List[Union[Argument, SelfArgument]] = [] | 
|  | ret.extend(self.pre_self_positional) | 
|  | if self.self_arg is not None: | 
|  | ret.append(self.self_arg) | 
|  | ret.extend(self.post_self_positional) | 
|  | return ret | 
|  |  | 
|  | @property | 
|  | def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]: | 
|  | ret: List[Union[Argument, TensorOptionsArguments]] = [] | 
|  | ret.extend(self.pre_tensor_options_kwarg_only) | 
|  | if self.tensor_options is not None: | 
|  | ret.append(self.tensor_options) | 
|  | ret.extend(self.post_tensor_options_kwarg_only) | 
|  | return ret | 
|  |  | 
|  | @property | 
|  | def all(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]: | 
|  | ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = [] | 
|  | ret.extend(self.positional) | 
|  | ret.extend(self.kwarg_only) | 
|  | ret.extend(self.out) | 
|  | return ret | 
|  |  | 
|  | def mutable_arg_names(self) -> List[str]: | 
|  | return [ | 
|  | a.name | 
|  | for a in self.flat_all | 
|  | if a.annotation is not None and a.annotation.is_write | 
|  | ] | 
|  |  | 
|  | def has_tensor_arg(self) -> bool: | 
|  | return any(a.type.is_tensor_like() for a in self.flat_non_out) | 
|  |  | 
|  | def has_symint_arg(self) -> bool: | 
|  | return any(a.type.is_symint_like() for a in self.flat_non_out) | 
|  |  | 
|  | def has_generator_arg(self) -> bool: | 
|  | return any(a.type.is_generator_like() for a in self.flat_non_out) | 
|  |  | 
|  | def signature(self, *, strip_default: bool = False) -> "Arguments": | 
|  | # dataclasses.replace could be used here, but it is less | 
|  | # type safe so for now I've opted to type everything out | 
|  | def strip_arg_annotation(a: Argument) -> Argument: | 
|  | return Argument( | 
|  | name=a.name, | 
|  | type=a.type, | 
|  | default=a.default if not strip_default else None, | 
|  | annotation=None, | 
|  | ) | 
|  |  | 
|  | return Arguments( | 
|  | pre_self_positional=tuple( | 
|  | map(strip_arg_annotation, self.pre_self_positional) | 
|  | ), | 
|  | self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument)) | 
|  | if self.self_arg is not None | 
|  | else None, | 
|  | post_self_positional=tuple( | 
|  | map(strip_arg_annotation, self.post_self_positional) | 
|  | ), | 
|  | # Since TensorOptions are dropped, the post_tensor_options_kwargs are | 
|  | # converted to pre_tensor_options_kwargs | 
|  | pre_tensor_options_kwarg_only=tuple( | 
|  | map(strip_arg_annotation, self.pre_tensor_options_kwarg_only) | 
|  | ) | 
|  | + tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)), | 
|  | # TensorOptions are dropped in signature, | 
|  | # so we can pair factory functions with their out= variants. | 
|  | tensor_options=None, | 
|  | post_tensor_options_kwarg_only=tuple(), | 
|  | # out arguments are dropped in signature | 
|  | out=(), | 
|  | ) | 
|  |  | 
|  | def remove_self_annotation(self) -> "Arguments": | 
|  | assert self.self_arg is not None | 
|  | return dataclasses.replace( | 
|  | self, | 
|  | self_arg=SelfArgument( | 
|  | dataclasses.replace(self.self_arg.argument, annotation=None) | 
|  | ), | 
|  | ) | 
|  |  | 
|  | def with_out_args(self, outs: List[Argument]) -> "Arguments": | 
|  | assert len(self.out) == 0 | 
|  | return dataclasses.replace( | 
|  | self, | 
|  | out=tuple(outs), | 
|  | ) | 
|  |  | 
|  | @staticmethod | 
|  | def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]: | 
|  | positional: List[Argument] = [] | 
|  | kwarg_only: List[Argument] = [] | 
|  | out: List[Argument] = [] | 
|  | arguments_acc = positional | 
|  |  | 
|  | # TODO: Use a real parser here; this will get bamboozled | 
|  | # by signatures that contain things like std::array<bool, 2> (note the space) | 
|  | for arg in args.split(", "): | 
|  | if not arg: | 
|  | continue | 
|  | if arg == "*": | 
|  | assert ( | 
|  | arguments_acc is positional | 
|  | ), "invalid syntax: kwarg-only specifier * can only occur once" | 
|  | arguments_acc = kwarg_only | 
|  | continue | 
|  | parg = Argument.parse(arg) | 
|  | # Currently, we rely directly on the invariant that there are NO | 
|  | # kwarg-only mutating arguments.  If you want to relax this, | 
|  | # we will need a more semantic way of matching that takes | 
|  | # into account return arguments.  In that case, you will have | 
|  | # to manage out computation a level up, in FunctionSchema.  See Note | 
|  | # [is_out_fn] | 
|  | if parg.annotation is not None and parg.annotation.is_write: | 
|  | if arguments_acc is positional: | 
|  | pass  # do nothing | 
|  | elif arguments_acc is kwarg_only: | 
|  | arguments_acc = out | 
|  | else: | 
|  | assert arguments_acc is not out | 
|  | arguments_acc.append(parg) | 
|  |  | 
|  | return positional, kwarg_only, out | 
|  |  | 
|  | @staticmethod | 
|  | def parse(args: str) -> "Arguments": | 
|  | """ | 
|  | Input: 'int x, int y, int z' | 
|  | """ | 
|  |  | 
|  | # We do this in two phases.  First we parse into three | 
|  | # main categories: positional, kwarg_only, out. | 
|  | # Then, we reparse positional and kwarg_only to separate | 
|  | # out the self argument and tensor options arguments. | 
|  |  | 
|  | positional, kwarg_only, out = Arguments._preparse(args) | 
|  |  | 
|  | # Split self argument | 
|  | self_ix = None | 
|  | for i, a in enumerate(positional): | 
|  | if a.name == "self": | 
|  | self_ix = i | 
|  | break | 
|  | pre_self_positional: List[Argument] | 
|  | self_arg: Optional[SelfArgument] | 
|  | post_self_positional: List[Argument] | 
|  | if self_ix is not None: | 
|  | pre_self_positional = positional[:self_ix] | 
|  | self_arg = SelfArgument(positional[self_ix]) | 
|  | post_self_positional = positional[self_ix + 1 :] | 
|  | else: | 
|  | pre_self_positional = [] | 
|  | self_arg = None | 
|  | post_self_positional = positional | 
|  |  | 
|  | # Group tensor options arguments | 
|  | pre_tensor_options_kwarg_only: List[Argument] = [] | 
|  | tensor_options: Optional[TensorOptionsArguments] = None | 
|  | post_tensor_options_kwarg_only: List[Argument] = [] | 
|  | kwarg_only_acc = pre_tensor_options_kwarg_only | 
|  |  | 
|  | def pred(name: str, ty: Type) -> Callable[[Argument], bool]: | 
|  | return lambda a: a.name == name and a.type in [ty, OptionalType(ty)] | 
|  |  | 
|  | predicates = [  # order matters | 
|  | pred("dtype", Type.parse("ScalarType")), | 
|  | pred("layout", Type.parse("Layout")), | 
|  | pred("device", Type.parse("Device")), | 
|  | pred("pin_memory", Type.parse("bool")), | 
|  | ] | 
|  |  | 
|  | i = 0 | 
|  | while i < len(kwarg_only): | 
|  | # If there is enough space... | 
|  | if i <= len(kwarg_only) - len(predicates): | 
|  | # And the next len(predicates) arguments look like TensorOptions arguments | 
|  | if all( | 
|  | p(a) | 
|  | for p, a in zip(predicates, kwarg_only[i : i + len(predicates)]) | 
|  | ): | 
|  | assert kwarg_only_acc is pre_tensor_options_kwarg_only | 
|  | # Group them together as one argument | 
|  | tensor_options = TensorOptionsArguments( | 
|  | dtype=kwarg_only[i], | 
|  | layout=kwarg_only[i + 1], | 
|  | device=kwarg_only[i + 2], | 
|  | pin_memory=kwarg_only[i + 3], | 
|  | ) | 
|  | i += len(predicates) | 
|  | kwarg_only_acc = post_tensor_options_kwarg_only | 
|  | continue | 
|  | kwarg_only_acc.append(kwarg_only[i]) | 
|  | i += 1 | 
|  |  | 
|  | return Arguments( | 
|  | pre_self_positional=tuple(pre_self_positional), | 
|  | self_arg=self_arg, | 
|  | post_self_positional=tuple(post_self_positional), | 
|  | pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only), | 
|  | tensor_options=tensor_options, | 
|  | post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only), | 
|  | out=tuple(out), | 
|  | ) | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | all_arguments: List[str] = [] | 
|  | all_arguments.extend(map(str, self.flat_positional)) | 
|  | if self.flat_kwarg_only or self.out: | 
|  | all_arguments.append("*") | 
|  | all_arguments.extend(map(str, self.flat_kwarg_only)) | 
|  | all_arguments.extend(map(str, self.out)) | 
|  | return ", ".join(all_arguments) | 
|  |  | 
|  | def __post_init__(self) -> None: | 
|  | # TODO: These invariants are weirdly asymmetric? | 
|  | # TODO: Fancier types? | 
|  | if self.self_arg is None: | 
|  | assert not self.pre_self_positional | 
|  | if self.tensor_options is None: | 
|  | assert not self.post_tensor_options_kwarg_only | 
|  |  | 
|  | # We don't allow any of the following to have argument annotations, | 
|  | # to keep things simple. | 
|  | mutable_pre_self_positionals = [ | 
|  | a | 
|  | for a in self.pre_self_positional | 
|  | if a.annotation is not None and a.annotation.is_write | 
|  | ] | 
|  | assert ( | 
|  | len(mutable_pre_self_positionals) == 0 | 
|  | ), "mutable pre_self_positional arguments are not currently supported in the schema" | 
|  |  | 
|  |  | 
|  | # Names that validly are __iXXX__ indicating inplace operations. | 
|  | # Taken from https://www.python.org/dev/peps/pep-0203/#new-methods | 
|  | # NB: PyTorch hasn't actually implemented all of these | 
|  | AUGMENTED_ASSIGNMENT_NAMES = [ | 
|  | "add", | 
|  | "sub", | 
|  | "mul", | 
|  | "div", | 
|  | "mod", | 
|  | "pow", | 
|  | "lshift", | 
|  | "rshift", | 
|  | "and", | 
|  | "xor", | 
|  | "or", | 
|  | ] | 
|  |  | 
|  |  | 
|  | # A BaseOperatorName is what we think of the operator name, without | 
|  | # the overload name.  Unusually, we don't represent this as just a | 
|  | # string; instead, we directly represent a few important semantic | 
|  | # bits of information we derive from the string: namely whether | 
|  | # or not it's inplace (add_) and whether or not it's a double-underscore | 
|  | # method (__add__) | 
|  | @dataclass(frozen=True) | 
|  | class BaseOperatorName: | 
|  | base: str | 
|  | inplace: bool | 
|  | dunder_method: bool | 
|  | # Note [Overload Ambiguity With Functional Variants] | 
|  | # A handful of operators have both a "mutable" and a "functional" variant. | 
|  | # (native_batch_norm is a good example, although this isn't the case today). | 
|  | # For those operators, the mutable and functional variant take in the same set of | 
|  | # arguments, but have different alias annotations. | 
|  | # this makes it ambiguous when you try to resolve an OverloadPacket into an overload, | 
|  | # given a set of input arguments. | 
|  | # | 
|  | # So instead of making the "functional" variant in this case a real overload, e.g: | 
|  | #   native_batch_norm (mutable variant) | 
|  | #   native_batch_norm.functional (functional variant) | 
|  | # we make it a new base operator, | 
|  | #   native_batch_norm_functional (functional variant) | 
|  | # | 
|  | # In an ideal world, we would probably invert this so the operators were: | 
|  | #   native_batch_norm.mutable (mutable variant) | 
|  | #   native_batch_norm (functional variant) | 
|  | # | 
|  | # Doing that is BC-breaking though, so we're stuck with the above modeling. | 
|  | functional_overload: bool = False | 
|  |  | 
|  | @staticmethod | 
|  | def parse(op: str) -> "BaseOperatorName": | 
|  | assert op != "" | 
|  | assert not op.endswith("_out"), ( | 
|  | "_out suffix is reserved and not permitted for operator names; " | 
|  | "did you mean to specify an out overload name instead?" | 
|  | ) | 
|  | m = re.match(r"^__([^_]+)__$", op) | 
|  | if m is not None: | 
|  | dunder_method = True | 
|  | base = m.group(1) | 
|  | if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES): | 
|  | inplace = True | 
|  | base = base[1:] | 
|  | else: | 
|  | inplace = False | 
|  | # temporary, this is not intrinsically true but | 
|  | # has been historically true for dunder methods | 
|  | # we support  (but, if we ever got, say, __int__, this would | 
|  | # be wrong!) | 
|  | assert base[0] != "i" | 
|  | else: | 
|  | dunder_method = False | 
|  | base = op | 
|  | if base[-1] == "_": | 
|  | inplace = True | 
|  | base = base[:-1] | 
|  | else: | 
|  | inplace = False | 
|  |  | 
|  | # See Note [Overload Ambiguity With Functional Variants] | 
|  | functional_suffix = "_functional" | 
|  | if base.endswith(functional_suffix): | 
|  | functional_overload = True | 
|  | base = base[: -len(functional_suffix)] | 
|  | # This seems complicated and unnecessary, so banning dunder methods | 
|  | # for now on ops that have a functional + mutable variant (like native_batch_norm). | 
|  | assert not dunder_method and not inplace | 
|  | else: | 
|  | functional_overload = False | 
|  |  | 
|  | r = BaseOperatorName( | 
|  | base=base, | 
|  | inplace=inplace, | 
|  | dunder_method=dunder_method, | 
|  | functional_overload=functional_overload, | 
|  | ) | 
|  | assert str(r) == op, f"{str(r)} != {op}" | 
|  | return r | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | if self.dunder_method: | 
|  | i = "i" if self.inplace else "" | 
|  | return f"__{i}{self.base}__" | 
|  | else: | 
|  | i = ( | 
|  | "_" | 
|  | if self.inplace | 
|  | else "_functional" | 
|  | if self.functional_overload | 
|  | else "" | 
|  | ) | 
|  | return f"{self.base}{i}" | 
|  |  | 
|  |  | 
|  | # Operator name is the base operator name along with the (typically not | 
|  | # user visible) overload string. | 
|  | @dataclass(frozen=True) | 
|  | class OperatorName: | 
|  | name: BaseOperatorName | 
|  | overload_name: str | 
|  |  | 
|  | @staticmethod | 
|  | def parse(op_name: str) -> "OperatorName": | 
|  | if "." in op_name: | 
|  | name, overload_name = op_name.split(".", 1) | 
|  | else: | 
|  | name = op_name | 
|  | overload_name = "" | 
|  | r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name) | 
|  | assert str(r) == op_name, f"{str(r)} != {op_name}" | 
|  | return r | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | if self.overload_name: | 
|  | return f"{self.name}.{self.overload_name}" | 
|  | else: | 
|  | return f"{self.name}" | 
|  |  | 
|  | # NB: This must be synchronized with the naming scheme in | 
|  | # aten/src/ATen/templates/Operators.h | 
|  | # Given a function schema "aten::op.overload(...)", | 
|  | # If there is no overload name, this returns f"{op}" | 
|  | # If there is an overload name, this returns f"{op}_{overload}" | 
|  | def unambiguous_name(self) -> str: | 
|  | if self.overload_name: | 
|  | return f"{self.name}_{self.overload_name}" | 
|  | else: | 
|  | return f"{self.name}" | 
|  |  | 
|  | def remove_inplace(self) -> "OperatorName": | 
|  | return OperatorName( | 
|  | name=BaseOperatorName( | 
|  | base=self.name.base, | 
|  | inplace=False, | 
|  | dunder_method=self.name.dunder_method, | 
|  | ), | 
|  | overload_name=self.overload_name, | 
|  | ) | 
|  |  | 
|  | def with_overload(self, overload: str) -> "OperatorName": | 
|  | return OperatorName( | 
|  | name=BaseOperatorName( | 
|  | base=self.name.base, | 
|  | inplace=False, | 
|  | dunder_method=self.name.dunder_method, | 
|  | ), | 
|  | overload_name=overload, | 
|  | ) | 
|  |  | 
|  |  | 
|  | def gets_generated_out_inplace_wrapper( | 
|  | f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex | 
|  | ) -> bool: | 
|  | return ( | 
|  | f.func.kind() is not SchemaKind.functional | 
|  | and not b.has_kernel(f) | 
|  | and b.has_kernel(g.functional) | 
|  | ) | 
|  |  | 
|  |  | 
|  | # NativeFunction objects that are views (f.is_view_op returns True) | 
|  | # are added into a `NativeFunctionsViewGroup`, which we can use to | 
|  | # easily access the generated (optional) view_copy NativeFunction. | 
|  | # It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup. | 
|  | # See Note [Codegen'd {view}_copy Operators] | 
|  | # | 
|  | # One property of this representation is that in order for a view-like op to be part of | 
|  | # a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist. | 
|  | # There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op, | 
|  | # but don't have corresponding aliasing `narrow.out` op. | 
|  | # This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup. | 
|  | @dataclass(frozen=True) | 
|  | class NativeFunctionsViewGroup: | 
|  | view: NativeFunction | 
|  | # Note: the {view}_copy operator is optional because we currently don't generate copy variants | 
|  | # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views | 
|  | # (we already get them "for free" through decomposition) | 
|  | view_copy: Optional[NativeFunction] | 
|  | # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant. | 
|  | view_inplace: Optional[NativeFunction] | 
|  |  | 
|  | def __post_init__(self) -> None: | 
|  | assert self.view.is_view_op | 
|  | if self.view_copy is None: | 
|  | assert not gets_generated_view_copy(self.view), ( | 
|  | f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs." | 
|  | " The codegen expects you to add a corresponding operator to native_functions.yaml:" | 
|  | f" {get_view_copy_name(self.view)!s}." | 
|  | " See Note [view_copy NativeFunctions] for details." | 
|  | ) | 
|  | else: | 
|  | assert self.view_copy.func.name.name.base.endswith("_copy") | 
|  | assert self.view.func.signature() == self.view_copy.func.signature( | 
|  | strip_view_copy_name=True | 
|  | ) | 
|  | assert "view_copy" in self.view_copy.tags, ( | 
|  | f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects" | 
|  | " view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml." | 
|  | " See Note [view_copy NativeFunction] for details." | 
|  | ) | 
|  | if self.view_inplace is not None: | 
|  | assert self.view.func.signature() == self.view_inplace.func.signature() | 
|  |  | 
|  | if self.view.has_composite_implicit_autograd_kernel: | 
|  | if self.view_inplace is not None: | 
|  | assert self.view_inplace.has_composite_implicit_autograd_kernel, ( | 
|  | f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either" | 
|  | " both have CompositeImplicitAutograd kernels, or both not have composite kernels." | 
|  | ) | 
|  | if self.view.has_composite_implicit_autograd_nested_tensor_kernel: | 
|  | if self.view_inplace is not None: | 
|  | assert ( | 
|  | self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel | 
|  | ), ( | 
|  | f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either" | 
|  | " both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels." | 
|  | ) | 
|  |  | 
|  | def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]: | 
|  | yield self.view | 
|  | if self.view_inplace is not None: | 
|  | yield self.view_inplace | 
|  | if self.view_copy is not None and include_copy: | 
|  | yield self.view_copy | 
|  |  | 
|  | @property | 
|  | def root_name(self) -> str: | 
|  | return self.view.root_name | 
|  |  | 
|  | @property | 
|  | def composite(self) -> bool: | 
|  | # We currently assert that the "group" is consistent. | 
|  | # If the view op is composite, then its view_inplace op is too. | 
|  | return self.view.has_composite_implicit_autograd_kernel | 
|  |  | 
|  |  | 
|  | def gets_generated_view_copy(f: NativeFunction) -> bool: | 
|  | # Only aliasing (view) operators get a copy variant. | 
|  | if not f.is_view_op: | 
|  | return False | 
|  | # We don't need to bother generating copy variants for CompositeImplicitAutograd ops, | 
|  | # because we can let them decompose into base view ops. | 
|  | if f.has_composite_implicit_autograd_kernel: | 
|  | return False | 
|  | # We also don't need to generate copy variants for inplace views. | 
|  | if "inplace_view" in f.tags: | 
|  | return False | 
|  | return True | 
|  |  | 
|  |  | 
|  | # Given a NativeFunction that corresponds to a view op, | 
|  | # returns the OperatorName of the corresponding "copy" variant of the op. | 
|  | def get_view_copy_name(f: NativeFunction) -> "OperatorName": | 
|  | # Right now, when asking for a view op's corresponding "view_copy" name | 
|  | # we assert for sanity that the op is allowed to have a generated view_copy variant. | 
|  | # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op). | 
|  | # However, narrow_copy() already exists as an op directly in native_functions.yaml. | 
|  | # I'm hardcoding narrow_copy here for now to maintain the assert, | 
|  | # But we could also just get rid of the assert. | 
|  | list_of_ops_with_explicit_view_copy_operators = ["narrow"] | 
|  | if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators: | 
|  | assert gets_generated_view_copy(f) | 
|  |  | 
|  | base_name = f"{f.func.name.name.base}_copy" | 
|  | view_copy_name = OperatorName( | 
|  | name=BaseOperatorName( | 
|  | base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method | 
|  | ), | 
|  | overload_name=f.func.name.overload_name, | 
|  | ) | 
|  | return view_copy_name | 
|  |  | 
|  |  | 
|  | # Helper functions for parsing argument lists (both inputs and returns) | 
|  |  | 
|  |  | 
|  | def parse_returns(return_decl: str) -> Tuple[Return, ...]: | 
|  | """ | 
|  | Input: '()' | 
|  | Output: [] | 
|  | """ | 
|  | if return_decl == "()": | 
|  | return () | 
|  | if return_decl[0] == "(" and return_decl[-1] == ")": | 
|  | return_decl = return_decl[1:-1] | 
|  | return tuple(Return.parse(arg) for arg in return_decl.split(", ")) | 
|  |  | 
|  |  | 
|  | # A Precompute instance consists of a map from kernel argument name | 
|  | # to the list of Argument instances that should replace that | 
|  | # kernel argument in the impl function. | 
|  | @dataclass(frozen=True) | 
|  | class Precompute: | 
|  | # A map from kernel argument name -> a list of precomputed | 
|  | # elements that replaces/supersedes it. | 
|  | replace: Dict[str, List[Argument]] | 
|  | # List of precomputed args added without replacement | 
|  | add: List[Argument] | 
|  |  | 
|  | @staticmethod | 
|  | def parse(src: object) -> "Precompute": | 
|  | assert isinstance(src, list) | 
|  |  | 
|  | # src is a list of strings of the format: | 
|  | #   {kernel param name} -> {replacement decl}[, {replacement decl}, ...] | 
|  | #   [{add decl}[, {add decl}, ...]] | 
|  | # The last line is optional and contains the precomputed parameters that are | 
|  | # added without replacement. | 
|  | # The other lines are parsed to get the names of which precomputed elements | 
|  | # should replace which kernel arguments. | 
|  | add_args = [] | 
|  | if " -> " not in src[-1]: | 
|  | add_list = src[-1].split(",") | 
|  | add_args = [Argument.parse(name.strip()) for name in add_list] | 
|  | src = src[:-1] | 
|  |  | 
|  | replace = {} | 
|  | for raw_replace_item in src: | 
|  | assert isinstance(raw_replace_item, str) | 
|  | assert " -> " in raw_replace_item, ( | 
|  | "precomputed parameters without replacement" | 
|  | " are allowed only in the last line" | 
|  | ) | 
|  |  | 
|  | arg, with_list_raw = raw_replace_item.split(" -> ") | 
|  | with_list = with_list_raw.split(",") | 
|  | with_list_args = [Argument.parse(name.strip()) for name in with_list] | 
|  | replace[arg] = with_list_args | 
|  |  | 
|  | r = Precompute(replace=replace, add=add_args) | 
|  | assert r.to_list() == src, "r.to_list() != src" | 
|  | return r | 
|  |  | 
|  | def __post_init__(self) -> None: | 
|  | # the template parameters are upper so if these are the | 
|  | # same then it is ambiguous | 
|  | for a in self.add: | 
|  | assert a.name.upper() != a.name | 
|  | for args in self.replace.values(): | 
|  | for a in args: | 
|  | assert a.name.upper() != a.name | 
|  |  | 
|  | def to_list(self) -> List[str]: | 
|  | replace_list = [] | 
|  | for kernel_param, replacement_params in self.replace.items(): | 
|  | replacements = ", ".join(str(param) for param in replacement_params) | 
|  | replace_list.append(f"{kernel_param} -> {replacements}") | 
|  |  | 
|  | return replace_list |