| """Module for handling symbolic function registration.""" |
| |
| import warnings |
| from typing import ( |
| Callable, |
| Collection, |
| Dict, |
| Generic, |
| Optional, |
| Sequence, |
| Set, |
| TypeVar, |
| Union, |
| ) |
| |
| from torch.onnx import _constants, errors |
| from torch.onnx._internal import _beartype |
| |
| OpsetVersion = int |
| |
| |
| def _dispatch_opset_version( |
| target: OpsetVersion, registered_opsets: Collection[OpsetVersion] |
| ) -> Optional[OpsetVersion]: |
| """Finds the registered opset given a target opset version and the available opsets. |
| |
| Args: |
| target: The target opset version. |
| available_opsets: The available opsets. |
| |
| Returns: |
| The registered opset version. |
| """ |
| if not registered_opsets: |
| return None |
| |
| descending_registered_versions = sorted(registered_opsets, reverse=True) |
| # Linear search for the opset version, which is fine since the number of opset |
| # versions is small. |
| |
| if target >= _constants.ONNX_BASE_OPSET: |
| # Always look down toward opset 1 when the target is >= ONNX_BASE_OPSET (opset 9). |
| # When a custom op is register at opset 1, we want to be able to discover it as a |
| # fallback for all opsets >= ONNX_BASE_OPSET. |
| for version in descending_registered_versions: |
| if version <= target: |
| return version |
| return None |
| |
| # target < opset 9. This is the legacy behavior to support opset 7 and opset 8. |
| # for caffe2 support. We search up toward opset 9. |
| for version in reversed(descending_registered_versions): |
| # Count back up until _constants.ONNX_BASE_OPSET |
| if target <= version <= _constants.ONNX_BASE_OPSET: |
| return version |
| |
| return None |
| |
| |
| _K = TypeVar("_K") |
| _V = TypeVar("_V") |
| |
| |
| class OverrideDict(Generic[_K, _V], Collection[_K]): |
| """A dictionary that merges built-in and custom symbolic functions. |
| |
| It supports overriding and un-overriding built-in symbolic functions with custom |
| ones. |
| """ |
| |
| def __init__(self): |
| self._base: Dict[_K, _V] = {} |
| self._overrides: Dict[_K, _V] = {} |
| self._merged: Dict[_K, _V] = {} |
| |
| def set_base(self, key: _K, value: _V) -> None: |
| self._base[key] = value |
| if key not in self._overrides: |
| self._merged[key] = value |
| |
| def in_base(self, key: _K) -> bool: |
| """Checks if a key is in the base dictionary.""" |
| return key in self._base |
| |
| def override(self, key: _K, value: _V) -> None: |
| """Overrides a base key-value with a new pair.""" |
| self._overrides[key] = value |
| self._merged[key] = value |
| |
| def remove_override(self, key: _K) -> None: |
| """Un-overrides a key-value pair.""" |
| self._overrides.pop(key, None) # type: ignore[arg-type] |
| self._merged.pop(key, None) # type: ignore[arg-type] |
| if key in self._base: |
| self._merged[key] = self._base[key] |
| |
| def overridden(self, key: _K) -> bool: |
| """Checks if a key-value pair is overridden.""" |
| return key in self._overrides |
| |
| def __getitem__(self, key: _K) -> _V: |
| return self._merged[key] |
| |
| def get(self, key: _K, default: Optional[_V] = None): |
| return self._merged.get(key, default) |
| |
| def __contains__(self, key: object) -> bool: |
| return key in self._merged |
| |
| def __iter__(self): |
| return iter(self._merged) |
| |
| def __len__(self) -> int: |
| return len(self._merged) |
| |
| def __repr__(self) -> str: |
| return f"OverrideDict(base={self._base}, overrides={self._overrides})" |
| |
| def __bool__(self) -> bool: |
| return bool(self._merged) |
| |
| |
| class _SymbolicFunctionGroup: |
| """Different versions of symbolic functions registered to the same name. |
| |
| O(number of registered versions of an op) search is performed to find the most |
| recent version of the op. |
| |
| The registration is delayed until op is used to improve startup time. |
| |
| Function overloads with different arguments are not allowed. |
| Custom op overrides are supported. |
| """ |
| |
| def __init__(self, name: str) -> None: |
| self._name = name |
| # A dictionary of functions, keyed by the opset version. |
| self._functions: OverrideDict[OpsetVersion, Callable] = OverrideDict() |
| |
| def __repr__(self) -> str: |
| return f"_SymbolicFunctionGroup({self._name}, registered={self._functions})" |
| |
| def __getitem__(self, key: OpsetVersion) -> Callable: |
| result = self.get(key) |
| if result is None: |
| raise KeyError(key) |
| return result |
| |
| # TODO(justinchuby): Add @functools.lru_cache(maxsize=None) if lookup time becomes |
| # a problem. |
| def get(self, opset: OpsetVersion) -> Optional[Callable]: |
| """Find the most recent version of the function.""" |
| version = _dispatch_opset_version(opset, self._functions) |
| if version is None: |
| return None |
| |
| return self._functions[version] |
| |
| def add(self, func: Callable, opset: OpsetVersion) -> None: |
| """Adds a symbolic function. |
| |
| Args: |
| func: The function to add. |
| opset: The opset version of the function to add. |
| """ |
| if self._functions.in_base(opset): |
| warnings.warn( |
| f"Symbolic function '{self._name}' already registered for opset {opset}. " |
| f"Replacing the existing function with new function. This is unexpected. " |
| f"Please report it on {_constants.PYTORCH_GITHUB_ISSUES_URL}.", |
| errors.OnnxExporterWarning, |
| ) |
| self._functions.set_base(opset, func) |
| |
| def add_custom(self, func: Callable, opset: OpsetVersion) -> None: |
| """Adds a custom symbolic function. |
| |
| Args: |
| func: The symbolic function to register. |
| opset: The corresponding opset version. |
| """ |
| self._functions.override(opset, func) |
| |
| def remove_custom(self, opset: OpsetVersion) -> None: |
| """Removes a custom symbolic function. |
| |
| Args: |
| opset: The opset version of the custom function to remove. |
| """ |
| if not self._functions.overridden(opset): |
| warnings.warn( |
| f"No custom function registered for '{self._name}' opset {opset}" |
| ) |
| return |
| self._functions.remove_override(opset) |
| |
| def get_min_supported(self) -> OpsetVersion: |
| """Returns the lowest built-in opset version supported by the function.""" |
| return min(self._functions) |
| |
| |
| class SymbolicRegistry: |
| """Registry for symbolic functions. |
| |
| The registry maintains a mapping from qualified names to symbolic functions. |
| It is used to register new symbolic functions and to dispatch calls to |
| the appropriate function. |
| """ |
| |
| def __init__(self) -> None: |
| self._registry: Dict[str, _SymbolicFunctionGroup] = {} |
| |
| def register( |
| self, name: str, opset: OpsetVersion, func: Callable, custom: bool = False |
| ) -> None: |
| """Registers a symbolic function. |
| |
| Args: |
| name: The qualified name of the function to register. In the form of 'domain::op'. |
| E.g. 'aten::add'. |
| opset: The opset version of the function to register. |
| func: The symbolic function to register. |
| custom: Whether the function is a custom function that overrides existing ones. |
| |
| Raises: |
| ValueError: If the separator '::' is not in the name. |
| """ |
| if "::" not in name: |
| raise ValueError( |
| f"The name must be in the form of 'domain::op', not '{name}'" |
| ) |
| symbolic_functions = self._registry.setdefault( |
| name, _SymbolicFunctionGroup(name) |
| ) |
| if custom: |
| symbolic_functions.add_custom(func, opset) |
| else: |
| symbolic_functions.add(func, opset) |
| |
| def unregister(self, name: str, opset: OpsetVersion) -> None: |
| """Unregisters a symbolic function. |
| |
| Args: |
| name: The qualified name of the function to unregister. |
| opset: The opset version of the function to unregister. |
| """ |
| if name not in self._registry: |
| return |
| self._registry[name].remove_custom(opset) |
| |
| def get_function_group(self, name: str) -> Optional[_SymbolicFunctionGroup]: |
| """Returns the function group for the given name.""" |
| return self._registry.get(name) |
| |
| def is_registered_op(self, name: str, version: int) -> bool: |
| """Returns whether the given op is registered for the given opset version.""" |
| functions = self.get_function_group(name) |
| if functions is None: |
| return False |
| return functions.get(version) is not None |
| |
| def all_functions(self) -> Set[str]: |
| """Returns the set of all registered function names.""" |
| return set(self._registry) |
| |
| |
| @_beartype.beartype |
| def onnx_symbolic( |
| name: str, |
| opset: Union[OpsetVersion, Sequence[OpsetVersion]], |
| decorate: Optional[Sequence[Callable]] = None, |
| custom: bool = False, |
| ) -> Callable: |
| """Registers a symbolic function. |
| |
| Usage:: |
| |
| ``` |
| @onnx_symbolic("aten::symbolic_b", opset=10, decorate=[quantized_aten_handler(scale=1/128, zero_point=0)]) |
| @symbolic_helper.parse_args("v", "v", "b") |
| def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value: |
| ... |
| ``` |
| |
| Args: |
| name: The qualified name of the function in the form of 'domain::op'. |
| E.g. 'aten::add'. |
| opset: The opset versions of the function to register at. |
| decorate: A sequence of decorators to apply to the function. |
| custom: Whether the function is a custom symbolic function. |
| |
| Raises: |
| ValueError: If the separator '::' is not in the name. |
| """ |
| |
| def wrapper(func: Callable) -> Callable: |
| decorated = func |
| if decorate is not None: |
| for decorate_func in decorate: |
| decorated = decorate_func(decorated) |
| |
| global registry |
| nonlocal opset |
| if isinstance(opset, OpsetVersion): |
| opset = (opset,) |
| for opset_version in opset: |
| registry.register(name, opset_version, decorated, custom=custom) |
| |
| # Return the original function because the decorators in "decorate" are only |
| # specific to the instance being registered. |
| return func |
| |
| return wrapper |
| |
| |
| @_beartype.beartype |
| def custom_onnx_symbolic( |
| name: str, |
| opset: Union[OpsetVersion, Sequence[OpsetVersion]], |
| decorate: Optional[Sequence[Callable]] = None, |
| ) -> Callable: |
| """Registers a custom symbolic function. |
| |
| Args: |
| name: the qualified name of the function. |
| opset: the opset version of the function. |
| decorate: a sequence of decorators to apply to the function. |
| |
| Returns: |
| The decorator. |
| |
| Raises: |
| ValueError: If the separator '::' is not in the name. |
| """ |
| return onnx_symbolic(name, opset, decorate, custom=True) |
| |
| |
| # The registry for all symbolic functions. |
| registry = SymbolicRegistry() |