blob: 1f3a37ccbfedf3c74b7a9f3e1285e1ffe2bfcd85 [file] [log] [blame]
import torch
import torch._C as _C
from torch.utils._pytree import tree_flatten
"""
This is a dispatcher (in Python)
- You can define new operations (in Python) without schemas
- It interfaces with the PyTorch dispatcher
"""
class PyDispatcher:
# operator is a PyOperator
@staticmethod
def call(operator, *args, **kwargs):
dispatch_key_set = compute_keyset(args, kwargs)
kernel = operator.lookup(dispatch_key_set)
return kernel(*args, **kwargs)
# operator is a PyOperator
@staticmethod
def redispatch(operator, dispatch_key_set, *args, **kwargs):
kernel = operator.lookup(dispatch_key_set)
return kernel(*args, **kwargs)
def compute_keyset(args, kwargs):
tensors = get_tensors(args, kwargs)
return key_extractor(tensors)
# Note - this should maintain identical impl to the C++ dispatcher key extraction logic
# at ATen/core/dispatch/DispatchKeyExtractor.h
def key_extractor(tensors):
key_set = _C._dispatch_tls_local_include_set() # type: ignore[attr-defined]
for tensor in tensors:
key_set = key_set | _C._dispatch_keys(tensor) # type: ignore[attr-defined]
key_set = key_set - _C._dispatch_tls_local_exclude_set() # type: ignore[attr-defined]
return key_set
def to_flat_tuple(args, kwargs):
flat_args, _ = tree_flatten(args)
flat_kwargs, _ = tree_flatten(kwargs)
flat_all = flat_args + flat_kwargs
return flat_all
def get_tensors(args, kwargs):
flat_all = to_flat_tuple(args, kwargs)
tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
return tuple(tensor_args)