blob: 161252fdb3737905453508994d4bc8924058d13d [file] [log] [blame]
import torch
import inspect
import numbers
import typing
import enum
from typing import Any, Callable, Dict, List, Optional
_manual_overrides : Dict[Callable, List[inspect.Signature]] = {}
def _nonzero_schemas():
signatures = []
def nonzero(self):
pass
signatures.append(inspect.signature(nonzero))
def nonzero(self, *, as_tuple : bool): # type: ignore
pass
signatures.append(inspect.signature(nonzero))
return signatures
_manual_overrides[torch.nonzero] = _nonzero_schemas()
class _FakeGlobalNamespace:
def __getattr__(self, name):
if name == 'torch':
return torch
raise RuntimeError('Expected a torch namespace lookup')
_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout,
'number' : numbers.Number, 'Future' : torch.jit.Future,
'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme,
'__torch__': _FakeGlobalNamespace(),
't': typing.TypeVar('t')} # type: ignore
for k in dir(typing):
_type_eval_globals[k] = getattr(typing, k)
def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any:
"""
Convert a TorchScript type to a Python type (including subtypes) via
eval'ing the annotation_str. _type_eval_globals sets up expressions
like "List" and "Future" to map to actual types (typing.List and jit.Future)
"""
return eval(ts_type.annotation_str, _type_eval_globals)
def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
parameters : List[inspect.Parameter] = []
for arg in ts_schema.arguments:
arg_type = _torchscript_type_to_python_type(arg.type)
default = arg.default_value if arg.has_default_value() else inspect.Parameter.empty
# TODO: Figure out if this is safe. It seems like when generating the type signatures for
# PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor
# argument name. Downstream, if someone converts that positional argument to a keyword
# argument, the name mismatch will break things, so here we're going to normalize the
# name to "input"
name = arg.name if arg.name != 'self' else 'input'
kind = inspect.Parameter.KEYWORD_ONLY if arg.kwarg_only else inspect.Parameter.POSITIONAL_OR_KEYWORD
parameters.append(inspect.Parameter(name=name, kind=kind, default=default, annotation=arg_type))
return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns]
if len(return_types) == 0:
return_type = None
elif len(return_types) == 1:
return_type = return_types[0]
else:
return_type = tuple(return_types)
return inspect.Signature(parameters, return_annotation=return_type)
def get_signature_for_torch_op(op : Callable) -> Optional[List[inspect.Signature]]:
"""
Given an operator on the `torch` namespace, return a list of `inspect.Signature`
objects corresponding to the overloads of that op.. May return `None` if a signature
could not be retrieved.
Args:
op (Callable): An operator on the `torch` namespace to look up a signature for
Returns:
Optional[List[inspect.Signature]]: A list of signatures for the overloads of this
operator, or None if the operator signatures could not be retrieved.
"""
override = _manual_overrides.get(op)
if override:
return override
aten_fn = torch.jit._builtins._find_builtin(op)
if aten_fn is None:
return None
schemas = torch._C._jit_get_schemas_for_operator(aten_fn)
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
return signatures