| """ |
| Python implementation of __torch_function__ |
| |
| While most of the torch API and handling for __torch_function__ happens |
| at the C++ level, some of the torch API is written in Python so we need |
| python-level handling for __torch_function__ overrides as well. The main |
| developer-facing functionality in this file are handle_torch_function and |
| has_torch_function. See torch/functional.py and test/test_overrides.py |
| for usage examples. |
| |
| NOTE: heavily inspired by NumPy's ``__array_function__`` (see: |
| https://github.com/pytorch/pytorch/issues/24015 and |
| https://www.numpy.org/neps/nep-0018-array-function-protocol.html |
| ) |
| |
| """ |
| |
| |
| def _get_overloaded_args(relevant_args): |
| """Returns a list of arguments on which to call __torch_function__. |
| |
| Checks arguments in relevant_args for __torch_function__ implementations, |
| storing references to the arguments and their types in overloaded_args and |
| overloaded_types in order of calling precedence. Only distinct types are |
| considered. If a type is a subclass of another type it will have higher |
| precedence, otherwise the precedence order is the same as the order of |
| arguments in relevant_args, that is, from left-to-right in the argument list. |
| |
| The precedence-determining algorithm implemented in this function is |
| described in `NEP-0018`_. |
| |
| See torch::append_overloaded_arg for the equivalent function in the C++ |
| implementation. |
| |
| Parameters |
| ---------- |
| relevant_args : iterable of array-like |
| Iterable of array-like arguments to check for __torch_function__ |
| methods. |
| |
| Returns |
| ------- |
| overloaded_types : collection of types |
| Types of arguments from relevant_args with __torch_function__ methods. |
| overloaded_args : list |
| Arguments from relevant_args on which to call __torch_function__ |
| methods, in the order in which they should be called. |
| |
| .. _NEP-0018: |
| https://numpy.org/neps/nep-0018-array-function-protocol.html |
| |
| """ |
| # Runtime is O(num_arguments * num_unique_types) |
| overloaded_types = [] |
| overloaded_args = [] |
| for arg in relevant_args: |
| arg_type = type(arg) |
| # We only collect arguments if they have a unique type, which ensures |
| # reasonable performance even with a long list of possibly overloaded |
| # arguments. |
| if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__')): |
| # Create lists explicitly for the first type (usually the only one |
| # done) to avoid setting up the iterator for overloaded_args. |
| if overloaded_types: |
| overloaded_types.append(arg_type) |
| # By default, insert argument at the end, but if it is |
| # subclass of another argument, insert it before that argument. |
| # This ensures "subclasses before superclasses". |
| index = len(overloaded_args) |
| for i, old_arg in enumerate(overloaded_args): |
| if issubclass(arg_type, type(old_arg)): |
| index = i |
| break |
| overloaded_args.insert(index, arg) |
| else: |
| overloaded_types = [arg_type] |
| overloaded_args = [arg] |
| |
| return overloaded_args |
| |
| |
| def handle_torch_function( |
| public_api, relevant_args, *args, **kwargs): |
| """Implement a function with checks for __torch_function__ overrides. |
| |
| See torch::autograd::handle_torch_function for the equivalent of this |
| function in the C++ implementation. |
| |
| Arguments |
| --------- |
| public_api : function |
| Function exposed by the public torch API originally called like |
| ``public_api(*args, **kwargs)`` on which arguments are now being |
| checked. |
| relevant_args : iterable |
| Iterable of arguments to check for __torch_function__ methods. |
| args : tuple |
| Arbitrary positional arguments originally passed into ``public_api``. |
| kwargs : tuple |
| Arbitrary keyword arguments originally passed into ``public_api``. |
| |
| Returns |
| ------- |
| Result from calling `implementation()` or an `__torch_function__` |
| method, as appropriate. |
| |
| Raises |
| ------ |
| TypeError : if no implementation is found. |
| |
| """ |
| # Check for __torch_function__ methods. |
| overloaded_args = _get_overloaded_args(relevant_args) |
| |
| # Call overrides |
| for overloaded_arg in overloaded_args: |
| # Use `public_api` instead of `implementation` so __torch_function__ |
| # implementations can do equality/identity comparisons. |
| result = overloaded_arg.__torch_function__(public_api, args, kwargs) |
| |
| if result is not NotImplemented: |
| return result |
| |
| func_name = '{}.{}'.format(public_api.__module__, public_api.__name__) |
| raise TypeError("no implementation found for '{}' on types that implement " |
| '__torch_function__: {}' |
| .format(func_name, list(map(type, overloaded_args)))) |
| |
| def has_torch_function(relevant_args): |
| """Check for __torch_function__ implementations in the elements of an iterable |
| |
| Arguments |
| --------- |
| relevant_args : iterable |
| Iterable or aguments to check for __torch_function__ methods. |
| |
| Returns |
| ------- |
| True if any of the elements of relevant_args have __torch_function__ |
| implementations, False otherwise. |
| """ |
| return any(hasattr(a, '__torch_function__') for a in relevant_args) |