blob: aa033f2c4a4b3092a758c4f413904412f5d821d5 [file] [log] [blame]
from __future__ import print_function
import os
import collections
import glob
import yaml
import re
import argparse
from ..autograd.utils import YamlLoader, CodeTemplate, write
from ..autograd.gen_python_functions import get_py_torch_functions, get_py_variable_methods
from ..autograd.gen_autograd import load_aten_declarations
"""
This module implements generation of type stubs for PyTorch,
enabling use of autocomplete in IDEs like PyCharm, which otherwise
don't understand C extension modules.
At the moment, this module only handles type stubs for torch and
torch.Tensor. It should eventually be expanded to cover all functions
which come are autogenerated.
Here's our general strategy:
- We start off with a hand-written __init__.pyi.in file. This
file contains type definitions for everything we cannot automatically
generate, including pure Python definitions directly in __init__.py
(the latter case should be pretty rare).
- We go through automatically bound functions based on the
type information recorded in Declarations.yaml and
generate type hints for them (generate_type_hints)
There are a number of type hints which we've special-cased;
read gen_pyi for the gory details.
"""
# TODO: Consider defining some aliases for our Union[...] types, to make
# the stubs to read on the human eye.
needed_modules = set()
FACTORY_PARAMS = "dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False"
# this could be more precise w.r.t list contents etc. How to do Ellipsis?
INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]"
blacklist = [
'__init_subclass__',
'__new__',
'__subclasshook__',
'cdist',
'clamp',
'clamp_',
'device',
'grad',
'requires_grad',
'range',
# defined in functional
'einsum',
# reduction argument; these bindings don't make sense
'binary_cross_entropy_with_logits',
'ctc_loss',
'cosine_embedding_loss',
'hinge_embedding_loss',
'kl_div',
'margin_ranking_loss',
'triplet_margin_loss',
# Somehow, these are defined in both _C and in functional. Ick!
'broadcast_tensors',
# type hints for named tensors are broken: https://github.com/pytorch/pytorch/issues/27846
'align_tensors',
'meshgrid',
'cartesian_prod',
'norm',
'chain_matmul',
'stft',
'tensordot',
'norm',
'split',
'unique_consecutive',
# These are handled specially by python_arg_parser.cpp
'add',
'add_',
'add_out',
'sub',
'sub_',
'sub_out',
'mul',
'mul_',
'mul_out',
'div',
'div_',
'div_out',
]
def type_to_python(typename, size=None):
"""type_to_python(typename: str, size: str) -> str
Transforms a Declarations.yaml type name into a Python type specification
as used for type hints.
"""
typename = typename.replace(' ', '') # normalize spaces, e.g., 'Generator *'
# Disambiguate explicitly sized int/tensor lists from implicitly
# sized ones. These permit non-list inputs too. (IntArrayRef[] and
# TensorList[] are not real types; this is just for convenience.)
if typename in {'IntArrayRef', 'TensorList'} and size is not None:
typename += '[]'
typename = {
'Device': 'Union[_device, str, None]',
'Generator*': 'Generator',
'IntegerTensor': 'Tensor',
'Scalar': 'Number',
'ScalarType': '_dtype',
'Storage': 'Storage',
'BoolTensor': 'Tensor',
'IndexTensor': 'Tensor',
'Tensor': 'Tensor',
'MemoryFormat': 'memory_format',
'IntArrayRef': '_size',
'IntArrayRef[]': 'Union[_int, _size]',
'TensorList': 'Union[Tuple[Tensor, ...], List[Tensor]]',
'TensorList[]': 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]',
'bool': '_bool',
'double': '_float',
'int64_t': '_int',
'accreal': 'Number',
'real': 'Number',
'void*': '_int', # data_ptr
'void': 'None',
'std::string': 'str',
'Dimname': 'Union[str, None]',
'DimnameList': 'List[Union[str, None]]',
'QScheme': '_qscheme',
}[typename]
return typename
def arg_to_type_hint(arg):
"""arg_to_type_hint(arg) -> str
This takes one argument in a Declarations and returns a string
representing this argument in a type hint signature.
"""
name = arg['name']
if name == 'from': # from is a Python keyword...
name += '_'
typename = type_to_python(arg['dynamic_type'], arg.get('size'))
if arg.get('is_nullable'):
typename = 'Optional[' + typename + ']'
if 'default' in arg:
default = arg['default']
if default == 'nullptr':
default = None
elif default == 'c10::nullopt':
default = None
elif isinstance(default, str) and default.startswith('{') and default.endswith('}'):
if arg['dynamic_type'] == 'Tensor' and default == '{}':
default = None
elif arg['dynamic_type'] == 'IntArrayRef':
default = '(' + default[1:-1] + ')'
else:
raise Exception("Unexpected default constructor argument of type {}".format(arg['dynamic_type']))
elif default == 'MemoryFormat::Contiguous':
default = 'contiguous_format'
elif default == 'QScheme::PER_TENSOR_AFFINE':
default = 'per_tensor_affine'
default = '={}'.format(default)
else:
default = ''
return name + ': ' + typename + default
binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv',
'matmul', 'floordiv',
'radd', 'rmul', 'rfloordiv', # reverse arithmetic
'and', 'or', 'xor', # logic
'iadd', 'iand', 'idiv', 'ilshift', 'imul',
'ior', 'irshift', 'isub', 'itruediv', 'ixor', # inplace ops
)
comparison_ops = ('eq', 'ne', 'ge', 'gt', 'lt', 'le')
unary_ops = ('neg', 'abs', 'invert')
to_py_type_ops = ('bool', 'float', 'long', 'index', 'int', 'nonzero')
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
def sig_for_ops(opname):
"""sig_for_ops(opname : str) -> List[str]
Returns signatures for operator special functions (__add__ etc.)"""
# we have to do this by hand, because they are hand-bound in Python
assert opname.endswith('__') and opname.startswith('__'), "Unexpected op {}".format(opname)
name = opname[2:-2]
if name in binary_ops:
return ['def {}(self, other: Any) -> Tensor: ...'.format(opname)]
elif name in comparison_ops:
# unsafe override https://github.com/python/mypy/issues/5704
return ['def {}(self, other: Any) -> Tensor: ... # type: ignore'.format(opname)]
elif name in unary_ops:
return ['def {}(self) -> Tensor: ...'.format(opname)]
elif name in to_py_type_ops:
if name in {'bool', 'float'}:
tname = name
elif name == 'nonzero':
tname = 'bool'
else:
tname = 'int'
if tname in {'float', 'int', 'bool'}:
tname = 'builtins.' + tname
return ['def {}(self) -> {}: ...'.format(opname, tname)]
else:
raise Exception("unknown op", opname)
def generate_type_hints(fname, decls, is_tensor=False):
"""generate_type_hints(fname, decls, is_tensor=False)
Generates type hints for the declarations pertaining to the function
:attr:`fname`. attr:`decls` are the declarations from the parsed
Declarations.yaml.
The :attr:`is_tensor` flag indicates whether we are parsing
members of the Tensor class (true) or functions in the
`torch` namespace (default, false).
This function currently encodes quite a bit about the semantics of
the translation C++ -> Python.
"""
if fname in blacklist:
return []
type_hints = []
dnames = ([d['name'] for d in decls])
has_out = fname + '_out' in dnames
if has_out:
decls = [d for d in decls if d['name'] != fname + '_out']
for decl in decls:
render_kw_only_separator = True # whether we add a '*' if we see a keyword only argument
python_args = []
has_tensor_options = 'TensorOptions' in (a['dynamic_type'] for a in decl['arguments'])
for a in decl['arguments']:
if a['dynamic_type'] != 'TensorOptions':
if a.get('kwarg_only', False) and render_kw_only_separator:
python_args.append('*')
render_kw_only_separator = False
try:
python_args.append(arg_to_type_hint(a))
except Exception:
print("Error while processing function {}".format(fname))
raise
if is_tensor:
if 'self: Tensor' in python_args:
python_args.remove('self: Tensor')
python_args = ['self'] + python_args
else:
raise Exception("method without self is unexpected")
if has_out:
if render_kw_only_separator:
python_args.append('*')
render_kw_only_separator = False
python_args.append('out: Optional[Tensor]=None')
if has_tensor_options:
if render_kw_only_separator:
python_args.append('*')
render_kw_only_separator = False
python_args += ["dtype: _dtype=None",
"layout: _layout=strided",
"device: Union[_device, str, None]=None",
"requires_grad:_bool=False"]
python_args_s = ', '.join(python_args)
python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']]
if len(python_returns) > 1:
python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']'
elif len(python_returns) == 1:
python_returns_s = python_returns[0]
else:
python_returns_s = 'None'
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
numargs = len(decl['arguments'])
vararg_pos = int(is_tensor)
have_vararg_version = (numargs > vararg_pos and
decl['arguments'][vararg_pos]['dynamic_type'] in {'IntArrayRef', 'TensorList'} and
(numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and
(not is_tensor or decl['arguments'][0]['name'] == 'self'))
type_hints.append(type_hint)
if have_vararg_version:
# Two things come into play here: PyTorch has the "magic" that if the first and only positional argument
# is an IntArrayRef or TensorList, it will be used as a vararg variant.
# The following outputs the vararg variant, the "pass a list variant" is output above.
# The other thing is that in Python, the varargs are annotated with the element type, not the list type.
typelist = decl['arguments'][vararg_pos]['dynamic_type']
if typelist == 'IntArrayRef':
vararg_type = '_int'
else:
vararg_type = 'Tensor'
# replace first argument and eliminate '*' if present
python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] +
': ' + vararg_type] + python_args[vararg_pos + 2:])
python_args_s = ', '.join(python_args)
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
type_hints.append(type_hint)
return type_hints
def gen_nn_modules(out):
def replace_forward(m):
# We instruct mypy to not emit errors for the `forward` and `__call__` declarations since mypy
# would otherwise correctly point out that Module's descendants' `forward` declarations
# conflict with `Module`s. Specificlaly, `Module` defines `forward(self, *args)` while the
# descandantes define more specific forms, such as `forward(self, input: Tensor)`, which
# violates Liskov substitutability. The 'mypy' team recommended this solution for now.
forward_def = m.group(0) + " # type: ignore"
call_def = re.sub(r'def forward', 'def __call__', forward_def)
new_def = "{}\n{}".format(forward_def, call_def)
return new_def
pattern = re.compile(r'^\s*def forward\(self.*$', re.MULTILINE)
for fname in glob.glob("torch/nn/modules/*.pyi.in"):
with open(fname, 'r') as f:
src = f.read()
res = pattern.sub(replace_forward, src)
fname_out = fname[:-3]
with open(os.path.join(out, fname_out), 'w') as f:
f.write(res)
def gen_nn_functional(out):
# Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
# through an `_add_docstr` call
imports = [
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc',
'avg_pool1d',
'relu_',
'selu_',
'celu_',
'rrelu_',
'pixel_shuffle',
'pdist',
'cosine_similarity',
]
# Functions generated by `torch._jit_internal.boolean_dispatch`
dispatches = [
'fractional_max_pool2d',
'fractional_max_pool3d',
'max_pool1d',
'max_pool2d',
'max_pool3d',
'adaptive_max_pool1d',
'adaptive_max_pool2d',
'adaptive_max_pool3d',
]
# Functions directly imported from `torch._C`
from_c = [
'avg_pool2d',
'avg_pool3d',
'hardtanh_',
'elu_',
'leaky_relu_',
'logsigmoid',
'softplus',
'softshrink',
'one_hot',
]
import_code = ["from .. import {0} as {0}".format(_) for _ in imports]
# TODO make these types more precise
dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
stubs = CodeTemplate.from_file(os.path.join('torch', 'nn', 'functional.pyi.in'))
env = {
'imported_hints': import_code,
'dispatched_hints': dispatch_code
}
write(out, 'torch/nn/functional.pyi', stubs, env)
def gen_nn_pyi(out):
gen_nn_functional(out)
gen_nn_modules(out)
def gen_pyi(declarations_path, out):
"""gen_pyi()
This function generates a pyi file for torch.
"""
# Some of this logic overlaps with generate_python_signature in
# tools/autograd/gen_python_functions.py; however, this
# function is all about generating mypy type signatures, whereas
# the other function generates are custom format for argument
# checking. If you are update this, consider if your change
# also needs to update the other file.
# Load information from YAML
declarations = load_aten_declarations(declarations_path)
# Generate type signatures for top-level functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
unsorted_function_hints = collections.defaultdict(list)
unsorted_function_hints.update({
'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'],
'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
'numel': ['def numel(self: Tensor) -> _int: ...'],
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."],
'get_num_threads': ['def get_num_threads() -> _int: ...'],
'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
'get_num_interop_threads': ['def get_num_interop_threads() -> _int: ...'],
'set_num_interop_threads': ['def set_num_interop_threads(num: _int) -> None: ...'],
# These functions are explicitly disabled by
# SKIP_PYTHON_BINDINGS because they are hand bound.
# Correspondingly, we must hand-write their signatures.
'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
'range': ['def range(start: Number, end: Number,'
' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
.format(FACTORY_PARAMS)],
'arange': ['def arange(start: Number, end: Number, step: Number, *,'
' out: Optional[Tensor]=None, {}) -> Tensor: ...'
.format(FACTORY_PARAMS),
'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
.format(FACTORY_PARAMS),
'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
.format(FACTORY_PARAMS)],
'randint': ['def randint(low: _int, high: _int, size: _size, *, {}) -> Tensor: ...'
.format(FACTORY_PARAMS),
'def randint(high: _int, size: _size, *, {}) -> Tensor: ...'
.format(FACTORY_PARAMS)],
})
for binop in ['add', 'sub', 'mul', 'div']:
unsorted_function_hints[binop].append(
'def {}(input: Union[Tensor, Number],'
' other: Union[Tensor, Number],'
' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
unsorted_function_hints[binop].append(
'def {}(input: Union[Tensor, Number],'
' value: Number,'
' other: Union[Tensor, Number],'
' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
function_declarations = get_py_torch_functions(declarations)
for name in sorted(function_declarations.keys()):
unsorted_function_hints[name] += generate_type_hints(name, function_declarations[name])
# Generate type signatures for deprecated functions
# TODO: Maybe we shouldn't generate type hints for deprecated
# functions :) However, examples like those addcdiv rely on these.
with open('tools/autograd/deprecated.yaml', 'r') as f:
deprecated = yaml.load(f, Loader=YamlLoader)
for d in deprecated:
name, sig = re.match(r"^([^\(]+)\(([^\)]*)", d['name']).groups()
sig = ['*' if p.strip() == '*' else p.split() for p in sig.split(',')]
sig = ['*' if p == '*' else (p[1] + ': ' + type_to_python(p[0])) for p in sig]
unsorted_function_hints[name].append("def {}({}) -> Tensor: ...".format(name, ', '.join(sig)))
function_hints = []
for name, hints in sorted(unsorted_function_hints.items()):
if len(hints) > 1:
hints = ['@overload\n' + h for h in hints]
function_hints += hints
# Generate type signatures for Tensor methods
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
unsorted_tensor_method_hints = collections.defaultdict(list)
unsorted_tensor_method_hints.update({
'size': ['def size(self) -> Size: ...',
'def size(self, _int) -> _int: ...'],
'stride': ['def stride(self) -> Tuple[_int]: ...',
'def stride(self, _int) -> _int: ...'],
'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'.
format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
# clamp has no default values in the Declarations
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
'clamp_': ["def clamp_(self, min: _float=-inf, max: _float=inf) -> Tensor: ..."],
'__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
'__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])"
" -> None: ...".format(INDICES)],
'tolist': ['def tolist(self) -> List: ...'],
'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
'element_size': ['def element_size(self) -> _int: ...'],
'dim': ['def dim(self) -> _int: ...'],
'numel': ['def numel(self) -> _int: ...'],
'ndimension': ['def ndimension(self) -> _int: ...'],
'nelement': ['def nelement(self) -> _int: ...'],
'cuda': ['def cuda(self, device: Optional[_device]=None, non_blocking: _bool=False) -> Tensor: ...'],
'numpy': ['def numpy(self) -> Any: ...'],
'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
'map_': ['def map_(tensor: Tensor, callable: Callable) -> Tensor: ...'],
'storage': ['def storage(self) -> Storage: ...'],
'type': ['def type(self, dtype: Union[None, str, _dtype]=None, non_blocking: _bool=False)'
' -> Union[str, Tensor]: ...'],
'get_device': ['def get_device(self) -> _int: ...'],
'contiguous': ['def contiguous(self) -> Tensor: ...'],
'is_contiguous': ['def is_contiguous(self) -> _bool: ...'],
'is_cuda': ['is_cuda: _bool'],
'is_leaf': ['is_leaf: _bool'],
'storage_offset': ['def storage_offset(self) -> _int: ...'],
'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
'non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
'def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
],
'item': ["def item(self) -> Number: ..."],
})
for binop in ['add', 'sub', 'mul', 'div']:
for inplace in [True, False]:
out_suffix = ', *, out: Optional[Tensor]=None'
if inplace:
name += '_'
out_suffix = ''
unsorted_tensor_method_hints[name].append(
'def {}(self, other: Union[Tensor, Number]{})'
' -> Tensor: ...'.format(name, out_suffix))
unsorted_tensor_method_hints[name].append(
'def {}(self, value: Number,'
' other: Union[Tensor, Number]{})'
' -> Tensor: ...'.format(name, out_suffix))
simple_conversions = ['byte', 'char', 'cpu', 'double', 'float',
'half', 'int', 'long', 'short', 'bool']
for name in simple_conversions:
unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))
tensor_method_declarations = get_py_variable_methods(declarations)
for name in sorted(tensor_method_declarations.keys()):
unsorted_tensor_method_hints[name] += \
generate_type_hints(name, tensor_method_declarations[name], is_tensor=True)
for op in all_ops:
name = '__{}__'.format(op)
unsorted_tensor_method_hints[name] += sig_for_ops(name)
tensor_method_hints = []
for name, hints in sorted(unsorted_tensor_method_hints.items()):
if len(hints) > 1:
hints = ['@overload\n' + h for h in hints]
tensor_method_hints += hints
# TODO: Missing type hints for nn
# Generate type signatures for legacy classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# TODO: These are deprecated, maybe we shouldn't type hint them
legacy_class_hints = []
for c in ('DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage'):
legacy_class_hints.append('class {}(Storage): ...'.format(c))
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
legacy_class_hints.append('class {}(Tensor): ...'.format(c))
# Generate type signatures for dtype classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# TODO: don't explicitly list dtypes here; get it from canonical
# source
dtype_class_hints = ['{}: dtype = ...'.format(n)
for n in
['float32', 'float', 'float64', 'double', 'float16', 'half',
'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
'complex32', 'complex64', 'complex128', 'quint8', 'qint8', 'qint32', 'bool']]
# Write out the stub
# ~~~~~~~~~~~~~~~~~~
env = {
'function_hints': function_hints,
'tensor_method_hints': tensor_method_hints,
'legacy_class_hints': legacy_class_hints,
'dtype_class_hints': dtype_class_hints,
}
TORCH_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '__init__.pyi.in'))
write(out, 'torch/__init__.pyi', TORCH_TYPE_STUBS, env)
gen_nn_pyi(out)
def main():
parser = argparse.ArgumentParser(
description='Generate type stubs for PyTorch')
parser.add_argument('--declarations-path', metavar='DECL',
default='torch/share/ATen/Declarations.yaml',
help='path to Declarations.yaml')
parser.add_argument('--out', metavar='OUT',
default='.',
help='path to output directory')
args = parser.parse_args()
gen_pyi(args.declarations_path, args.out)
if __name__ == '__main__':
main()