| import re |
| |
| from dataclasses import dataclass |
| from typing import List, Dict, Optional, Iterator, Tuple, Set, NoReturn |
| from enum import Enum |
| import itertools |
| |
| # A little trick from https://github.com/python/mypy/issues/6366 |
| # for getting mypy to do exhaustiveness checking |
| # TODO: put this somewhere else, maybe |
| def assert_never(x: NoReturn) -> NoReturn: |
| raise AssertionError("Unhandled type: {}".format(type(x).__name__)) |
| |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # |
| # |
| # DATA MODEL |
| # |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # |
| # |
| # Some general principles for our data model. |
| # |
| # - Stop using C++ data types as the internal data representation |
| # format. Instead, the internal data structures are centered |
| # around JIT schema representation. This avoid a big problem |
| # with the old codegen where we read in all the types from |
| # native_functions.yaml and then immediately had to retranslate |
| # them into C++ types. |
| # |
| # - More semantic data representation. Instead of representing |
| # everything as dicts and strings, we define dataclasses for |
| # every interesting entity the code generation has to deal with. |
| # These dataclasses have strong semantic invariants: for example, |
| # we generally require them to roundtrip losslessly into the |
| # form they were parsed from. These structures are immutable |
| # and you're expected to populate information once during |
| # construction. |
| |
| # Represent a source location; used for better error reporting |
| @dataclass(frozen=True) |
| class Location: |
| file: str |
| line: int |
| |
| def __str__(self) -> str: |
| return "{}:{}".format(self.file, self.line) |
| |
| # Valid values of the 'variants' field in native_functions.yaml |
| Variant = Enum('Variant', ('function', 'method')) |
| |
| UseC10Dispatcher = Enum('UseC10Dispatcher', ( |
| 'full', |
| 'with_codegenerated_unboxing_wrapper' |
| )) |
| |
| # The basic input to the code generation is native_functions.yaml. |
| # The name "native", BTW, comes from the distinction between native |
| # functions and legacy TH functions. The legacy TH functions are gone, |
| # but the "native" descriptor has stuck. |
| # |
| # NativeFunction models a single entry in native_functions.yaml. Its |
| # fields roughly correspond to what you would see in the YAML itself, |
| # but after canonicalization and parsing has occurred. |
| # |
| # You can see some of the overall design patterns for how we setup |
| # dataclasses in this class, but we will defer a complete discussion |
| # of this at FunctionSchema. |
| @dataclass(frozen=True) |
| class NativeFunction: |
| # The function schema of the operator in question. This schema |
| # has been parsed; see FunctionSchema for more about its structure. |
| # (This type is quoted as we are forward referencing a type |
| # defined later in the file. I opted for this ordering of the |
| # classes for expository clarity.) |
| func: 'FunctionSchema' |
| |
| # Corresponds to the 'use_c10_dispatcher' field. The default |
| # is 'with_codegenerated_unboxing_wrapper' |
| use_c10_dispatcher: UseC10Dispatcher |
| |
| # Whether or not to omit automatic generation of a DeviceGuard |
| device_guard: bool |
| |
| # What python module to put the function in |
| python_module: Optional[str] |
| |
| # TODO: figure out what this does |
| category_override: Optional[str] |
| |
| # If no variants are specified in native_functions.yaml, this is |
| # assumed to be {'function'}. |
| variants: Set[Variant] |
| |
| # Whether or not we should skip generating registrations for |
| # this kernel. This is a bit of a double-edged sword, as manual |
| # registrations don't participate in codegen-based selective build! |
| manual_kernel_registration: bool |
| |
| # Distinguish between a missing dispatch dict (historically, this |
| # means to register a catch-all kernel) and a present but empty |
| # dispatch dict (this means register nothing; arguably, this should |
| # subsume manual_kernel_registration). |
| # |
| # TODO: str key could be replaced with more explicit enum |
| dispatch: Optional[Dict[str, str]] |
| |
| # The location in the YAML file were this native function entry was |
| # defined. This is for conveniently reporting error messages! |
| loc: 'Location' |
| |
| # NB: The benefit of defining a dataclass is that we automatically get |
| # a constructor defined for all the fields we specify. No need |
| # to explicitly write it out. |
| |
| @staticmethod |
| def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': |
| """ |
| Parse a NativeFunction from a dictionary as directly parsed |
| from native_functions.yaml |
| """ |
| e = ei.copy() |
| |
| funcs = e.pop('func') |
| assert isinstance(funcs, str), f'not a str: {funcs}' |
| func = FunctionSchema.parse(funcs) |
| |
| use_c10_dispatcher_s = e.pop('use_c10_dispatcher', None) |
| if use_c10_dispatcher_s is None: |
| use_c10_dispatcher = UseC10Dispatcher.with_codegenerated_unboxing_wrapper |
| elif use_c10_dispatcher_s == 'full': |
| use_c10_dispatcher = UseC10Dispatcher.full |
| else: |
| raise AssertionError( |
| f'use_c10_dispatcher must be unset or set to full, got {use_c10_dispatcher}') |
| |
| variants_s = e.pop('variants', 'function') |
| assert isinstance(variants_s, str) |
| variants: Set[Variant] = set() |
| for v in variants_s.split(', '): |
| if v == 'function': |
| variants.add(Variant.function) |
| elif v == 'method': |
| variants.add(Variant.method) |
| else: |
| raise AssertionError(f'illegal variant {v}') |
| |
| manual_kernel_registration = e.pop('manual_kernel_registration', False) |
| assert isinstance(manual_kernel_registration, bool), f'not a bool: {manual_kernel_registration}' |
| |
| device_guard = e.pop('device_guard', True) |
| assert isinstance(device_guard, bool), f'not a bool: {device_guard}' |
| |
| python_module = e.pop('python_module', None) |
| assert python_module is None or isinstance(python_module, str), f'not a str: {python_module}' |
| |
| category_override = e.pop('category_override', None) |
| assert category_override is None or isinstance(category_override, str), f'not a str: {category_override}' |
| |
| raw_dispatch = e.pop('dispatch', None) |
| assert raw_dispatch is None or isinstance(raw_dispatch, dict), e |
| dispatch: Optional[Dict[str, str]] = None |
| if raw_dispatch is not None: |
| dispatch = {} |
| for ks, v in raw_dispatch.items(): |
| if ks == '__line__': |
| continue # not worth tracking line numbers for dispatch entries |
| assert isinstance(ks, str), e |
| assert isinstance(v, str), e |
| for k in ks.split(","): |
| dispatch[k.strip()] = v |
| |
| e.pop('__line__') |
| assert not e, f"leftover entries: {e}" |
| |
| return NativeFunction( |
| func=func, |
| use_c10_dispatcher=use_c10_dispatcher, |
| variants=variants, |
| manual_kernel_registration=manual_kernel_registration, |
| python_module=python_module, |
| category_override=category_override, |
| dispatch=dispatch, |
| device_guard=device_guard, |
| loc=loc, |
| ) |
| |
| # __post_init__ functions in dataclasses can be used to do extra |
| # validation after construction. |
| # |
| # Notice that we don't do any type validation here. In fact, we |
| # rely exclusively on mypy to check if you've done types correctly! |
| # Validation is for nontrivial invariants that cannot be (conveniently) |
| # encoded in the type system. |
| def __post_init__(self) -> None: |
| if self.func.out_arguments: |
| assert self.variants == {Variant.function}, "Native functions with out arguments MUST " \ |
| "be declared with only function variant; e.g., variants: function; " \ |
| "otherwise you will tickle a Python argument binding bug " \ |
| "(which usually manifests itself as the result variable being undefined.)" |
| |
| SchemaKind = Enum('SchemaKind', ('functional', 'inplace', 'out')) |
| |
| # The function schema is undoubtedly the most important data structure |
| # in all of the codegen, as it defines the type signature for operators, |
| # and most of the code generation we do is type directed (e.g., look at |
| # the types, decide what to do. Think about how we code generate |
| # C++ function stubs!) |
| # |
| # We will also see in this class the general structure for how we model |
| # data in this code generation. A few notable properties to point out |
| # ahead of time: |
| # |
| # - These dataclasses are a *lossless* representation of the strings |
| # they are parsed from. In fact, we assert that given the |
| # information stored in the dataclass, we can exactly reconstruct |
| # the string we parsed from (and assert this inside the parse |
| # definition). There are a few reasons for this: |
| # |
| # - If you find that it is difficult to reconstruct the string |
| # given a dataclass, that is a clue that you are data |
| # representation is wrong. |
| # |
| # - It helps ensure that all relevant information is present |
| # in the dataclass, so that downstream users aren't tempted |
| # to reparse the original string to get some information |
| # that was omitted. |
| # |
| # - It forces you to represent the data in-memory in the same way |
| # it is recorded textually, which makes the dataclasses easier |
| # to understand for someone who is familiar with the |
| # textual format. (As a tradeoff, it means you have to model |
| # the syntax, even when it is inconvenient. But maybe that means |
| # the syntax is bad!) If you don't understand the internal |
| # representation, go look at the printing code to see how |
| # it maps onto the surface syntax! |
| # |
| # - It makes it easy to test the parsing code, as parsing code |
| # that is inconsistent with the string code will fail early |
| # and loudly. (As a tradeoff, it makes the parsing code a bit |
| # brittle (in particular, with trivial whitespace changes you |
| # are likely to trigger an assert error). |
| # |
| # In general, try to make the __str__ code as simple as possible |
| # (even at the cost of more complex parsing logic.) Additionally, |
| # try to minimize redundancy in data representation. (Precomputed |
| # fields are OK though: they are defined as a simple function on |
| # the canonical representation in question.) |
| # |
| # - These dataclasses are all frozen; once constructed their |
| # values never change. This makes it easy to tell where any |
| # given data came from: just look to the constructor. As a |
| # tradeoff, you can't easily "decorate" a schema with extra |
| # information from a post-facto analysis. We impose this |
| # restriction to make these structures more understandable. |
| # |
| @dataclass(frozen=True) |
| class FunctionSchema: |
| # The name of the operator this function schema describes. |
| name: 'OperatorName' |
| |
| arguments: Tuple['Argument', ...] |
| kwarg_only_arguments: Tuple['Argument', ...] # but not including out args |
| # Unlike in the previous codegen, we have factored out 'out' arguments |
| # in the canonical representation, removing them from kwarg |
| # arguments. This choice is justified by numerous downstream |
| # transformations which treat out arguments specially; additionally, |
| # you can see that canonicity is not violated! |
| out_arguments: Tuple['Argument', ...] # these are also kwarg-only |
| |
| # TODO: Need to handle collisions with argument names at some point |
| returns: Tuple['Return', ...] |
| |
| def schema_order_arguments(self) -> Iterator['Argument']: |
| return itertools.chain(self.arguments, self.kwarg_only_arguments, self.out_arguments) |
| |
| @staticmethod |
| def parse(func: str) -> 'FunctionSchema': |
| # We should probably get a proper parser here |
| assert ' -> ' in func, "function schema missing return type (spaces are mandatory)" |
| func_decl, return_decl = [x.strip() for x in func.split(' -> ')] |
| ops, args = func_decl.split('(', 1) |
| assert args[-1] == ")", "Expecting closing )" |
| args = args[:-1] |
| name = OperatorName.parse(ops) |
| arguments, kwarg_only_arguments, out_arguments = parse_arguments(args) |
| returns = parse_returns(return_decl) |
| r = FunctionSchema( |
| name=name, |
| arguments=arguments, |
| kwarg_only_arguments=kwarg_only_arguments, |
| out_arguments=out_arguments, |
| returns=returns |
| ) |
| assert str(r) == func, f'{str(r)} != {func}' |
| return r |
| |
| def __post_init__(self) -> None: |
| for arg, ret in zip(self.out_arguments, self.returns): |
| assert arg.annotation == ret.annotation, \ |
| "Out arguments must have matching return Tensor; furthermore, " \ |
| "the ith-argument needs to correspond to the ith return" |
| if self.out_arguments: |
| assert len(self.out_arguments) == len(self.returns), \ |
| "Must return as many arguments as there are out arguments" |
| if self.name.name.inplace: |
| # TODO: fixme |
| if str(self.name) not in [ |
| '_amp_foreach_non_finite_check_and_unscale_', |
| '_foreach_add_scalar_list_', |
| '_foreach_sub_scalar_list_', |
| '_foreach_mul_scalar_list_', |
| '_foreach_div_scalar_list_', |
| '_foreach_add_.Scalar', |
| '_foreach_sub_.Scalar', |
| '_foreach_mul_.Scalar', |
| '_foreach_div_.Scalar', |
| '_foreach_add_.List', |
| '_foreach_sub_.List', |
| '_foreach_mul_.List', |
| '_foreach_div_.List', |
| '_foreach_exp_', |
| '_foreach_sqrt_', |
| '_foreach_addcmul_', |
| '_foreach_addcdiv_']: |
| assert len(self.returns) == 1 |
| |
| def is_out_fn(self) -> bool: |
| # Note [is_out_fn] |
| # |
| # out functions are the variants which take an explicit out= argument |
| # to populate into. We need to know if a schema corresponds to an |
| # out function for several reasons: |
| # |
| # - They codegen differently in C++ API |
| # - codegen to at::add_out rather than at::add |
| # - out argument is moved to front of C++ argument list |
| # |
| # out functions are DEFINED to be any function with a keyword-only |
| # argument that is mutable. In principle, this could lead to a |
| # false positive if you define a function that mutates a |
| # kwarg only argument, but this isn't the "true" output of this |
| # function. A more robust definition that would work in this |
| # case would also look at: |
| # |
| # - The output types. Out functions take in the arguments |
| # they mutate and then return them again; this is sort |
| # of "definitionally" what makes something an out function. |
| # Historically, we DO check this for consistency. |
| # - Correspondence with pure variant. An out function |
| # should have a signature equivalent to its pure variant, |
| # but just with extra kwargs for the output elements. This |
| # is difficult to actually check for and historically |
| # we only do this check in tools/ |
| return bool(self.out_arguments) |
| |
| def kind(self) -> SchemaKind: |
| """ |
| What kind of schema is this? A functional schema is one |
| that returns a newly allocated output; an inplace schema |
| modifies the self argument inplace; an out schema writes |
| the result into an explicitly provided out argument. |
| """ |
| is_inplace = self.name.name.inplace |
| is_out = bool(self.out_arguments) |
| assert not (is_inplace and is_out) |
| if is_inplace: |
| return SchemaKind.inplace |
| elif is_out: |
| return SchemaKind.out |
| else: |
| return SchemaKind.functional |
| |
| # WARNING: This method is not currently tested in any meaningful way |
| def signature(self) -> 'FunctionSchema': |
| """ |
| Certain schemas are 'related', in that they are simply |
| inplace/out/functional versions of the same function. This method |
| factors these schemas into the "core" functional signature which |
| is equal across all versions. |
| |
| Here is what normalization happens to the schema to convert |
| it to a signature: |
| - The overload name is stripped (name is retained, since |
| it expresses semantic content about what the function does) |
| - Inplace is set False |
| - Out arguments are stripped |
| - Mutability annotations are stripped (this is sound |
| because you cannot overload on mutability annotation) |
| |
| This function is based off of get_signature in |
| tools.autograd.load_derivatives |
| """ |
| |
| # dataclasses.replace could be used here, but it is less |
| # type safe so for now I've opted to type everything out |
| def strip_arg_annotation(a: Argument) -> Argument: |
| return Argument( |
| name=a.name, |
| type=a.type, |
| default=a.default, # hmmm |
| annotation=None, |
| ) |
| |
| def strip_ret_annotation(r: Return) -> Return: |
| return Return( |
| name=r.name, |
| type=r.type, |
| annotation=None, |
| ) |
| |
| return FunctionSchema( |
| name=OperatorName( |
| name=BaseOperatorName( |
| base=self.name.name.base, |
| inplace=False, |
| dunder_method=self.name.name.dunder_method, |
| ), |
| overload_name="", # stripped |
| ), |
| arguments=tuple(map(strip_arg_annotation, self.arguments)), |
| kwarg_only_arguments=tuple(map(strip_arg_annotation, self.kwarg_only_arguments)), |
| out_arguments=(), # stripped |
| returns=tuple(map(strip_ret_annotation, self.returns)), |
| ) |
| |
| def __str__(self) -> str: |
| all_arguments: List[str] = [] |
| all_arguments.extend(map(str, self.arguments)) |
| if self.kwarg_only_arguments or self.out_arguments: |
| all_arguments.append('*') |
| all_arguments.extend(map(str, self.kwarg_only_arguments)) |
| all_arguments.extend(map(str, self.out_arguments)) |
| all_arguments_str = ', '.join(all_arguments) |
| if len(self.returns) == 1: |
| returns = str(self.returns[0]) # omit parentheses |
| else: |
| returns = '(' + ', '.join(map(str, self.returns)) + ')' |
| return f'{self.name}({all_arguments_str}) -> {returns}' |
| |
| # Here is the rest of the data model, described more briefly. |
| |
| # Simplified version for what actually shows up in built-ins. |
| # Look at alias_info.h for expanded syntax. If you need the structure, |
| # you also need to make this structure recursive so it can be lined |
| # up with the type components too. For primitives this isn't really |
| # necessary |
| @dataclass(frozen=True) |
| class Annotation: |
| # Typically only has one element. Not actually a set so |
| # we can conveniently assume it is canonically ordered |
| alias_set: Tuple[str, ...] |
| is_write: bool |
| |
| @staticmethod |
| def parse(ann: str) -> 'Annotation': |
| m = re.match(r'^([a-z])(!?)$', ann) |
| assert m is not None, f'unrecognized alias annotation {ann}' |
| alias_set = (m.group(1),) |
| is_write = m.group(2) == '!' |
| r = Annotation(alias_set=alias_set, is_write=is_write) |
| assert str(r) == ann, f'{r} != {ann}' |
| return r |
| |
| def __str__(self) -> str: |
| alias_set = '|'.join(self.alias_set) |
| is_write = '!' if self.is_write else '' |
| return f'{alias_set}{is_write}' |
| |
| # The base class for the type system. This is also loosely modeled |
| # off of jit_type.h, but we've simplified the hierarchy to focus |
| # in on the aspects of the type system that matter for code generation |
| # (for example, there's no SingleElementType subclass anymore). |
| # You never actually construct a Type; usually it's going to be one |
| # of the subclasses. If Python had ADTs this would be one! |
| @dataclass(frozen=True) |
| class Type: |
| @staticmethod |
| def parse(t: str) -> 'Type': |
| r = Type._parse(t) |
| assert str(r) == t, f'{r} != {t}' |
| return r |
| |
| @staticmethod |
| def _parse(t: str) -> 'Type': |
| m = re.match(r'^(.+)\?$', t) |
| if m is not None: |
| return OptionalType(Type.parse(m.group(1))) |
| m = re.match(r'^(.+)\[([0-9]+)?\]$', t) |
| if m is not None: |
| size = int(m.group(2)) if m.group(2) is not None else None |
| return ListType(elem=Type.parse(m.group(1)), size=size) |
| try: |
| return BaseType(BaseTy[t]) |
| except KeyError: |
| raise RuntimeError(f"unrecognized type {t}") |
| |
| def __str__(self) -> str: |
| raise NotImplementedError |
| |
| # WARNING: These concepts are not very well-defined. For example, |
| # is "int?" nullable? How about "int?[]". They are defined |
| # so we can conveniently generate legacy Declarations.yaml but |
| # really we should probably just remove these at some point |
| |
| def is_tensor_like(self) -> bool: |
| raise NotImplementedError |
| |
| def is_nullable(self) -> bool: |
| raise NotImplementedError |
| |
| def is_list_like(self) -> Optional['ListType']: |
| raise NotImplementedError |
| |
| # Base types are simple, atomic types with no further structure |
| BaseTy = Enum('BaseTy', ( |
| 'Generator', |
| 'ScalarType', |
| 'Tensor', |
| 'int', |
| 'Dimname', |
| 'float', |
| 'str', |
| 'bool', |
| 'Layout', |
| 'Device', |
| 'Scalar', |
| 'MemoryFormat', |
| 'QScheme', |
| 'Storage', |
| 'ConstQuantizerPtr', # TODO: rename |
| )) |
| |
| @dataclass(frozen=True) |
| class BaseType(Type): |
| name: BaseTy |
| |
| def __str__(self) -> str: |
| return f'{self.name.name}' |
| |
| def is_tensor_like(self) -> bool: |
| return self.name == BaseTy.Tensor |
| |
| def is_nullable(self) -> bool: |
| return False |
| |
| def is_list_like(self) -> Optional['ListType']: |
| return None |
| |
| # Optional types may be specified, or may also be validly given None |
| @dataclass(frozen=True) |
| class OptionalType(Type): |
| elem: Type |
| |
| def __str__(self) -> str: |
| return f'{self.elem}?' |
| |
| def is_tensor_like(self) -> bool: |
| return self.elem.is_tensor_like() |
| |
| def is_nullable(self) -> bool: |
| return True |
| |
| def is_list_like(self) -> Optional['ListType']: |
| return self.elem.is_list_like() |
| |
| # List types specify that we may have multiples of an element. We |
| # also support explicit sizes on list types, but these have |
| # some nontrivial semantics! (However, for C++ API purposes, explicit |
| # sizes are mostly erased from the type system.) |
| # |
| # DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g., |
| # int[] elaborates differently than bool[3]! |
| @dataclass(frozen=True) |
| class ListType(Type): |
| elem: Type |
| size: Optional[int] |
| |
| def __str__(self) -> str: |
| size = f'{self.size}' if self.size else '' |
| return f'{self.elem}[{size}]' |
| |
| def is_tensor_like(self) -> bool: |
| return self.elem.is_tensor_like() |
| |
| def is_nullable(self) -> bool: |
| return self.elem.is_nullable() |
| |
| def is_list_like(self) -> Optional['ListType']: |
| return self |
| |
| @dataclass(frozen=True) |
| class Argument: |
| # NB: I didn't put kwarg_only as a boolean field here, unlike |
| # c10::Argument, so that printing works correctly |
| |
| name: str |
| type: Type |
| default: Optional[str] |
| |
| # The semantics of the annotation field are a little strange. |
| # |
| # Alias annotations parametrize Tensors (since Tensors are the only things |
| # that can alias.) This motivates why I write Tensor(a!)? (and not, for |
| # example, Tensor?(a!)), because the (a!) describes aliasing on the tensor, |
| # which may be optional (i.e., the alias annotation should bind first to |
| # Tensor, before the optional postfix annotation). |
| # |
| # However, despite being a property of Tensor, we (and c10::Argument) |
| # store the annotation at the top level of the Argument, rather than |
| # inside the embedded Tensor type. In the C++ version of this |
| # class, we then go through great lengths to mimic the type |
| # structure in the annotation structure so we can correlate |
| # annotations with types. |
| # |
| # Now, it turns out, in all applications in code generation, the |
| # structure of annotated types is very simple. So we just hard |
| # code it here. But if we ever do get anything more complex, this |
| # model will have to change! |
| annotation: Optional[Annotation] |
| |
| @staticmethod |
| def parse(arg: str) -> 'Argument': |
| name: str |
| default: Optional[str] |
| type_and_annot, name_and_default = arg.rsplit(' ', 1) |
| if '=' in name_and_default: |
| name, default = name_and_default.split('=') |
| else: |
| name = name_and_default |
| default = None |
| # TODO: deduplicate annotation matching with Return |
| match = re.match(r'Tensor\((.+)\)(.*)', type_and_annot) |
| annotation: Optional[Annotation] |
| if match: |
| # If you update this, make sure the __str__ still works too |
| assert match.group(2) in ['', '?', '[]'], 'unrecognized alias analysis form with Tensor' |
| type_s = 'Tensor' + match.group(2) |
| annotation = Annotation.parse(match.group(1)) |
| else: |
| type_s = type_and_annot |
| annotation = None |
| type = Type.parse(type_s) |
| r = Argument( |
| name=name, |
| type=type, |
| default=default, |
| annotation=annotation, |
| ) |
| assert str(r) == arg, f'{str(r)} != {arg}' |
| return r |
| |
| @property |
| def is_write(self) -> bool: |
| return self.annotation is not None and self.annotation.is_write |
| |
| def __str__(self) -> str: |
| type = f'{self.type}' |
| if self.annotation: |
| assert type in ['Tensor', 'Tensor?', 'Tensor[]'] |
| type = type.replace('Tensor', f'Tensor({self.annotation})') |
| if self.name is None: |
| return type |
| else: |
| mb_default = '' |
| if self.default: |
| mb_default = f'={self.default}' |
| return f"{type} {self.name}{mb_default}" |
| |
| |
| @dataclass(frozen=True) |
| class Return: |
| name: Optional[str] |
| type: Type |
| annotation: Optional[Annotation] |
| |
| @staticmethod |
| def parse(arg: str) -> 'Return': |
| name: Optional[str] |
| if ' ' in arg: |
| type_and_annot, name = arg.rsplit(' ', 1) |
| else: |
| type_and_annot = arg |
| name = None |
| match = re.match(r'Tensor\((.+)\)(.*)', type_and_annot) |
| annotation: Optional[Annotation] |
| if match: |
| # If you update this, make sure the __str__ still works too |
| assert match.group(2) in ['', '?', '[]'], 'unrecognized alias analysis form with Tensor' |
| type_s = 'Tensor' + match.group(2) |
| annotation = Annotation.parse(match.group(1)) |
| else: |
| type_s = type_and_annot |
| annotation = None |
| type = Type.parse(type_s) |
| r = Return( |
| name=name, |
| type=type, |
| annotation=annotation, |
| ) |
| assert str(r) == arg, f'{str(r)} != {arg}' |
| return r |
| |
| @property |
| def is_write(self) -> bool: |
| return self.annotation is not None and self.annotation.is_write |
| |
| def __str__(self) -> str: |
| type = f'{self.type}' |
| if self.annotation: |
| assert type in ['Tensor', 'Tensor?', 'Tensor[]'] |
| type = type.replace('Tensor', f'Tensor({self.annotation})') |
| if self.name is None: |
| return type |
| else: |
| return f"{type} {self.name}" |
| |
| |
| # Names that validly are __iXXX__ indicating inplace operations. |
| # Taken from https://www.python.org/dev/peps/pep-0203/#new-methods |
| # NB: PyTorch hasn't actually implemented all of these |
| AUGMENTED_ASSIGNMENT_NAMES = ['add', 'sub', 'mul', 'div', 'mod', 'pow', 'lshift', 'rshift', 'and', 'xor', 'or'] |
| |
| # A BaseOperatorName is what we think of the operator name, without |
| # the overload name. Unusually, we don't represent this as just a |
| # string; instead, we directly represent a few important semantic |
| # bits of information we derive from the string: namely whether |
| # or not it's inplace (add_) and whether or not it's a double-underscore |
| # method (__add__) |
| @dataclass(frozen=True) |
| class BaseOperatorName: |
| base: str |
| inplace: bool |
| dunder_method: bool |
| |
| @staticmethod |
| def parse(op: str) -> 'BaseOperatorName': |
| assert op != '' |
| assert not op.endswith('_out'), \ |
| "_out suffix is reserved and not permitted for operator names; " \ |
| "did you mean to specify an out overload name instead?" |
| m = re.match(r'^__([^_]+)__$', op) |
| if m is not None: |
| dunder_method = True |
| base = m.group(1) |
| if any(base == f'i{n}' for n in AUGMENTED_ASSIGNMENT_NAMES): |
| inplace = True |
| base = base[1:] |
| else: |
| inplace = False |
| # temporary, this is not intrinsically true but |
| # has been historically true for dunder methods |
| # we support (but, if we ever got, say, __int__, this would |
| # be wrong!) |
| assert base[0] != 'i' |
| else: |
| dunder_method = False |
| base = op |
| if base[-1] == '_': |
| inplace = True |
| base = base[:-1] |
| else: |
| inplace = False |
| r = BaseOperatorName(base=base, inplace=inplace, dunder_method=dunder_method) |
| assert str(r) == op, f'{str(r)} != {op}' |
| return r |
| |
| def __str__(self) -> str: |
| if self.dunder_method: |
| i = 'i' if self.inplace else '' |
| return f'__{i}{self.base}__' |
| else: |
| i = '_' if self.inplace else '' |
| return f'{self.base}{i}' |
| |
| # Operator name is the base operator name along with the (typically not |
| # user visible) overload string. |
| @dataclass(frozen=True) |
| class OperatorName: |
| name: BaseOperatorName |
| overload_name: str |
| |
| @staticmethod |
| def parse(op_name: str) -> 'OperatorName': |
| if '.' in op_name: |
| name, overload_name = op_name.split('.', 1) |
| else: |
| name = op_name |
| overload_name = '' |
| r = OperatorName( |
| name=BaseOperatorName.parse(name), |
| overload_name=overload_name |
| ) |
| assert str(r) == op_name, f'{str(r)} != {op_name}' |
| return r |
| |
| def __str__(self) -> str: |
| if self.overload_name: |
| return f"{self.name}.{self.overload_name}" |
| else: |
| return f"{self.name}" |
| |
| # Helper functions for parsing argument lists (both inputs and returns) |
| |
| def parse_returns(return_decl: str) -> Tuple[Return, ...]: |
| """ |
| Input: '()' |
| Output: [] |
| """ |
| if return_decl == '()': |
| return () |
| if return_decl[0] == '(' and return_decl[-1] == ')': |
| return_decl = return_decl[1:-1] |
| return tuple(Return.parse(arg) for arg in return_decl.split(', ')) |
| |
| def parse_arguments(args: str) -> Tuple[Tuple[Argument, ...], Tuple[Argument, ...], Tuple[Argument, ...]]: |
| """ |
| Input: 'int x, int y, int z' |
| Output: positional args, kwarg only args |
| """ |
| arguments: List[Argument] = [] |
| kwarg_only_arguments: List[Argument] = [] |
| out_arguments: List[Argument] = [] |
| arguments_acc = arguments |
| |
| # TODO: Use a real parser here; this will get bamboozled |
| # by signatures that contain things like std::array<bool, 2> (note the space) |
| for arg in args.split(', '): |
| if not arg: |
| continue |
| if arg == '*': |
| assert arguments_acc is arguments, "invalid syntax: kwarg-only specifier * can only occur once" |
| arguments_acc = kwarg_only_arguments |
| continue |
| parg = Argument.parse(arg) |
| # Currently, we rely directly on the invariant that there are NO |
| # kwarg-only mutating arguments. If you want to relax this, |
| # we will need a more semantic way of matching that takes |
| # into account return arguments. In that case, you will have |
| # to manage out_arguments computation a level up, in |
| # FunctionSchema. See Note [is_out_fn] |
| if parg.annotation is not None and parg.annotation.is_write: |
| if arguments_acc is arguments: |
| pass # do nothing |
| elif arguments_acc is kwarg_only_arguments: |
| arguments_acc = out_arguments |
| else: |
| assert arguments_acc is not out_arguments |
| arguments_acc.append(parg) |
| |
| return tuple(arguments), tuple(kwarg_only_arguments), tuple(out_arguments) |