| # mypy: allow-untyped-defs |
| from typing import Callable, Optional |
| |
| from .fake_impl import FakeImplHolder |
| from .utils import RegistrationHandle |
| |
| __all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"] |
| |
| |
| class SimpleLibraryRegistry: |
| """Registry for the "simple" torch.library APIs |
| |
| The "simple" torch.library APIs are a higher-level API on top of the |
| raw PyTorch DispatchKey registration APIs that includes: |
| - fake impl |
| |
| Registrations for these APIs do not go into the PyTorch dispatcher's |
| table because they may not directly involve a DispatchKey. For example, |
| the fake impl is a Python function that gets invoked by FakeTensor. |
| Instead, we manage them here. |
| |
| SimpleLibraryRegistry is a mapping from a fully qualified operator name |
| (including the overload) to SimpleOperatorEntry. |
| """ |
| |
| def __init__(self): |
| self._data = {} |
| |
| def find(self, qualname: str) -> "SimpleOperatorEntry": |
| if qualname not in self._data: |
| self._data[qualname] = SimpleOperatorEntry(qualname) |
| return self._data[qualname] |
| |
| |
| singleton: SimpleLibraryRegistry = SimpleLibraryRegistry() |
| |
| |
| class SimpleOperatorEntry: |
| """This is 1:1 to an operator overload. |
| |
| The fields of SimpleOperatorEntry are Holders where kernels can be |
| registered to. |
| """ |
| |
| def __init__(self, qualname: str): |
| self.qualname: str = qualname |
| self.fake_impl: FakeImplHolder = FakeImplHolder(qualname) |
| self.torch_dispatch_rules: GenericTorchDispatchRuleHolder = ( |
| GenericTorchDispatchRuleHolder(qualname) |
| ) |
| |
| # For compatibility reasons. We can delete this soon. |
| @property |
| def abstract_impl(self): |
| return self.fake_impl |
| |
| |
| class GenericTorchDispatchRuleHolder: |
| def __init__(self, qualname): |
| self._data = {} |
| self.qualname = qualname |
| |
| def register( |
| self, torch_dispatch_class: type, func: Callable |
| ) -> RegistrationHandle: |
| if self.find(torch_dispatch_class): |
| raise RuntimeError( |
| f"{torch_dispatch_class} already has a `__torch_dispatch__` rule registered for {self.qualname}" |
| ) |
| self._data[torch_dispatch_class] = func |
| |
| def deregister(): |
| del self._data[torch_dispatch_class] |
| |
| return RegistrationHandle(deregister) |
| |
| def find(self, torch_dispatch_class): |
| return self._data.get(torch_dispatch_class, None) |
| |
| |
| def find_torch_dispatch_rule(op, torch_dispatch_class: type) -> Optional[Callable]: |
| return singleton.find(op.__qualname__).torch_dispatch_rules.find( |
| torch_dispatch_class |
| ) |