blob: 3f3ba7ff01f0dfaa0096ef322dcd7cae584127e7 [file] [log] [blame]
import dataclasses
import inspect
import sys
from typing import Any, Callable, Tuple
import torch
@dataclasses.dataclass
class Kernel:
"""Models a (function, source location)"""
func: Callable
source: str
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
class RegistrationHandle:
"""Does something when someone calls .destroy() on it"""
def __init__(self, on_destroy: Callable):
self._on_destroy = on_destroy
def destroy(self) -> None:
self._on_destroy()
def get_source(stacklevel: int) -> str:
"""Get a string that represents the caller.
Example: "/path/to/foo.py:42"
Use stacklevel=1 to get the caller's source
Use stacklevel=2 to get the caller's caller's source
etc.
"""
frame = inspect.getframeinfo(sys._getframe(stacklevel))
source = f"{frame.filename}:{frame.lineno}"
return source
def parse_namespace(qualname: str) -> Tuple[str, str]:
splits = qualname.split("::")
if len(splits) != 2:
raise ValueError(
f"Expected `qualname` to be of the form "
f'"namespace::name", but got {qualname}. '
f"The qualname passed to the torch.library APIs must consist "
f"of a namespace and a name, e.g. aten::sin"
)
return splits[0], splits[1]
def lookup_op(qualname: str) -> torch._ops.OpOverloadPacket:
namespace, name = parse_namespace(qualname)
if "." in name:
name, overload = name.split(".")
else:
overload = "default"
ns = getattr(torch.ops, namespace)
packet = getattr(ns, name)
return getattr(packet, overload)
def is_builtin(op: torch._ops.OpOverload) -> bool:
assert isinstance(op, torch._ops.OpOverload)
return op.namespace in {"aten", "prim", "prims"}
def is_functional_schema(schema: Any) -> bool:
"""Check if the schema is functional.
An operator is functional if:
- it does not mutate any of its inputs
- it does not return a view on any of its inputs
- it has at least one return
"""
# Lazy import because not all PyTorch builds have torchgen
from torchgen.model import FunctionSchema, SchemaKind
assert isinstance(schema, (str, FunctionSchema))
if isinstance(schema, str):
schema = FunctionSchema.parse(schema)
if schema.kind() != SchemaKind.functional:
return False
rets = schema.returns
is_non_mutating_view = len(rets) > 0 and any(
r.annotation is not None and not r.annotation.is_write for r in rets
)
if is_non_mutating_view:
return False
if not schema.returns:
return False
return True