| import dataclasses |
| import inspect |
| import sys |
| from typing import Any, Callable, Tuple |
| |
| import torch |
| |
| |
| @dataclasses.dataclass |
| class Kernel: |
| """Models a (function, source location)""" |
| |
| func: Callable |
| source: str |
| |
| def __call__(self, *args, **kwargs): |
| return self.func(*args, **kwargs) |
| |
| |
| class RegistrationHandle: |
| """Does something when someone calls .destroy() on it""" |
| |
| def __init__(self, on_destroy: Callable): |
| self._on_destroy = on_destroy |
| |
| def destroy(self) -> None: |
| self._on_destroy() |
| |
| |
| def get_source(stacklevel: int) -> str: |
| """Get a string that represents the caller. |
| |
| Example: "/path/to/foo.py:42" |
| |
| Use stacklevel=1 to get the caller's source |
| Use stacklevel=2 to get the caller's caller's source |
| etc. |
| """ |
| frame = inspect.getframeinfo(sys._getframe(stacklevel)) |
| source = f"{frame.filename}:{frame.lineno}" |
| return source |
| |
| |
| def parse_namespace(qualname: str) -> Tuple[str, str]: |
| splits = qualname.split("::") |
| if len(splits) != 2: |
| raise ValueError( |
| f"Expected `qualname` to be of the form " |
| f'"namespace::name", but got {qualname}. ' |
| f"The qualname passed to the torch.library APIs must consist " |
| f"of a namespace and a name, e.g. aten::sin" |
| ) |
| return splits[0], splits[1] |
| |
| |
| def lookup_op(qualname: str) -> torch._ops.OpOverloadPacket: |
| namespace, name = parse_namespace(qualname) |
| if "." in name: |
| name, overload = name.split(".") |
| else: |
| overload = "default" |
| ns = getattr(torch.ops, namespace) |
| packet = getattr(ns, name) |
| return getattr(packet, overload) |
| |
| |
| def is_builtin(op: torch._ops.OpOverload) -> bool: |
| assert isinstance(op, torch._ops.OpOverload) |
| return op.namespace in {"aten", "prim", "prims"} |
| |
| |
| def is_functional_schema(schema: Any) -> bool: |
| """Check if the schema is functional. |
| |
| An operator is functional if: |
| - it does not mutate any of its inputs |
| - it does not return a view on any of its inputs |
| - it has at least one return |
| """ |
| |
| # Lazy import because not all PyTorch builds have torchgen |
| from torchgen.model import FunctionSchema, SchemaKind |
| |
| assert isinstance(schema, (str, FunctionSchema)) |
| if isinstance(schema, str): |
| schema = FunctionSchema.parse(schema) |
| |
| if schema.kind() != SchemaKind.functional: |
| return False |
| rets = schema.returns |
| is_non_mutating_view = len(rets) > 0 and any( |
| r.annotation is not None and not r.annotation.is_write for r in rets |
| ) |
| if is_non_mutating_view: |
| return False |
| if not schema.returns: |
| return False |
| return True |