| import argparse |
| import collections |
| from pprint import pformat |
| from typing import Dict, List, Sequence |
| |
| from torchgen.api.python import ( |
| PythonSignatureGroup, |
| PythonSignatureNativeFunctionPair, |
| returns_named_tuple_pyi, |
| ) |
| from torchgen.gen import parse_native_yaml, parse_tags_yaml |
| |
| from torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant |
| from torchgen.utils import FileManager |
| |
| from tools.autograd.gen_python_functions import ( |
| group_overloads, |
| load_signatures, |
| should_generate_py_binding, |
| ) |
| |
| """ |
| This module implements generation of type stubs for PyTorch, |
| enabling use of autocomplete in IDEs like PyCharm, which otherwise |
| don't understand C extension modules. |
| |
| At the moment, this module only handles type stubs for torch and |
| torch.Tensor. It should eventually be expanded to cover all functions |
| which come are autogenerated. |
| |
| Here's our general strategy: |
| |
| - We start off with a hand-written __init__.pyi.in file. This |
| file contains type definitions for everything we cannot automatically |
| generate, including pure Python definitions directly in __init__.py |
| (the latter case should be pretty rare). |
| |
| - We go through automatically bound functions based on the |
| type information recorded in native_functions.yaml and |
| generate type hints for them (generate_type_hints) |
| |
| There are a number of type hints which we've special-cased; |
| read gen_pyi for the gory details. |
| """ |
| |
| |
| def get_py_torch_functions( |
| python_funcs: Sequence[PythonSignatureNativeFunctionPair], |
| method: bool = False, |
| ) -> Sequence[PythonSignatureGroup]: |
| """ |
| Get declarations (grouped by name) which should be generated |
| as either functions in the "torch" module or methods on Tensor. |
| """ |
| |
| def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool: |
| return ( |
| should_generate_py_binding(python_func.function) |
| and not python_func.function.python_module |
| and Variant.function in python_func.function.variants |
| ) |
| |
| def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool: |
| return ( |
| should_generate_py_binding(python_func.function) |
| and not python_func.function.python_module |
| and Variant.method in python_func.function.variants |
| ) |
| |
| should_bind = should_bind_method if method else should_bind_function |
| return group_overloads([f for f in python_funcs if should_bind(f)]) |
| |
| |
| # TODO: Consider defining some aliases for our Union[...] types, to make |
| # the stubs to read on the human eye. |
| |
| DEVICE_PARAM = "device: Device = None" |
| FACTORY_PARAMS = f"dtype: Optional[_dtype] = None, {DEVICE_PARAM}, requires_grad: _bool = False, pin_memory: _bool = False" |
| |
| # NOTE: specifying indices for Tensor.__getitem__ |
| # We can imitate numpy's definition of ndarray.__getitem__ found in numpy/__init__.pyi: |
| # |
| # key: ( |
| # None |
| # | slice |
| # | ellipsis |
| # | SupportsIndex |
| # | _ArrayLikeInt_co |
| # | tuple[None | slice | ellipsis | _ArrayLikeInt_co | SupportsIndex, ...] |
| # ) |
| # |
| # where: |
| # |
| # _ArrayLikeInt_co = _DualArrayLike[ |
| # dtype[Union[bool_, integer[Any]]], |
| # Union[bool, int], |
| # ] |
| # |
| # and |
| # |
| # _DualArrayLike = Union[ |
| # _SupportsArray[_DType], |
| # _NestedSequence[_SupportsArray[_DType]], |
| # _T, |
| # _NestedSequence[_T], |
| # ] |
| # |
| # Moreover, _NestedSequence is a Protocol that matches arbitrary nesting of list/tuple. |
| # We can substitute and simplify: |
| # _SupportsArray -> Tensor |
| # _ArrayLikeInt_co -> [bool | int | | Tensor | NestedSequence[bool | int] | NestedSequence[Tensor]] |
| # which leaves us with key: T | tuple[T, ...], where T is: |
| # T = ( |
| # None | bool | int | slice | ellipsis | SupportsIndex |
| # | Tensor | _NestedSequence[Tensor] | _NestedSequence[bool | int] |
| # ) |
| |
| # NOTE: ellipsis is equal to type[Ellipsis] in stub files. |
| _leaf_types = "Union[None, _bool, _int, slice, ellipsis, Tensor]" # not SupportsIndex! |
| _index = f"Union[SupportsIndex, {_leaf_types}, _NestedSequence[{_leaf_types}]]" |
| INDICES = f"indices: Union[{_index}, tuple[{_index}, ...]]" |
| |
| blocklist = [ |
| "__init_subclass__", |
| "__new__", |
| "__subclasshook__", |
| "cdist", |
| "device", |
| "grad", |
| "requires_grad", |
| "range", |
| # defined in functional |
| "einsum", |
| # Somehow, these are defined in both _C and in functional. Ick! |
| "broadcast_tensors", |
| # Manually define named tensor type stubs in __init__.pyi.in |
| "align_tensors", |
| "meshgrid", |
| "cartesian_prod", |
| "block_diag", |
| "norm", |
| "chain_matmul", |
| "stft", |
| "tensordot", |
| "split", |
| "unique_consecutive", |
| "atleast_1d", |
| "atleast_2d", |
| "atleast_3d", |
| # These are handled specially by python_arg_parser.cpp |
| "add", |
| "add_", |
| "add_out", |
| "sub", |
| "sub_", |
| "sub_out", |
| "mul", |
| "mul_", |
| "mul_out", |
| "div", |
| "div_", |
| "div_out", |
| "true_divide", |
| "true_divide_", |
| "true_divide_out", |
| "floor_divide", |
| "floor_divide_", |
| "floor_divide_out", |
| "to", |
| "_to_copy", |
| "copy_", |
| ] |
| |
| binary_ops = ( |
| "add", |
| "sub", |
| "mul", |
| "div", |
| "pow", |
| "lshift", |
| "rshift", |
| "mod", |
| "truediv", |
| "matmul", |
| "floordiv", |
| "radd", |
| "rsub", |
| "rmul", |
| "rtruediv", |
| "rfloordiv", |
| "rpow", # reverse arithmetic |
| "and", |
| "or", |
| "xor", |
| "rand", |
| "ror", |
| "rxor", # logic |
| "iadd", |
| "iand", |
| "idiv", |
| "ilshift", |
| "imul", |
| "ior", |
| "irshift", |
| "isub", |
| "ixor", |
| "ifloordiv", |
| "imod", # inplace ops |
| ) |
| symmetric_comparison_ops = ("eq", "ne") |
| asymmetric_comparison_ops = ("ge", "gt", "lt", "le") |
| comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops |
| |
| unary_ops = ("neg", "abs", "invert") |
| to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero") |
| all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops |
| |
| |
| def sig_for_ops(opname: str) -> List[str]: |
| """sig_for_ops(opname : str) -> List[str] |
| |
| Returns signatures for operator special functions (__add__ etc.)""" |
| |
| # we have to do this by hand, because they are hand-bound in Python |
| |
| assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}" |
| |
| name = opname[2:-2] |
| if name in binary_ops: |
| return [f"def {opname}(self, other: Any) -> Tensor: ..."] |
| elif name in comparison_ops: |
| sig = f"def {opname}(self, other: Any) -> Tensor: ..." |
| if name in symmetric_comparison_ops: |
| # unsafe override https://github.com/python/mypy/issues/5704 |
| sig += " # type: ignore[override]" |
| return [sig] |
| elif name in unary_ops: |
| return [f"def {opname}(self) -> Tensor: ..."] |
| elif name in to_py_type_ops: |
| if name in {"bool", "float", "complex"}: |
| tname = name |
| elif name == "nonzero": |
| tname = "bool" |
| else: |
| tname = "int" |
| if tname in {"float", "int", "bool", "complex"}: |
| tname = "builtins." + tname |
| return [f"def {opname}(self) -> {tname}: ..."] |
| else: |
| raise Exception("unknown op", opname) |
| |
| |
| def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]: |
| type_hints: List[str] = [] |
| |
| # Some deprecated ops that are on the blocklist are still included in pyi |
| if sig_group.signature.name in blocklist and not sig_group.signature.deprecated: |
| return type_hints |
| |
| # deprecated signatures have separate entries for their functional and out variants |
| # (as opposed to the native ops, which fuse the two into a single signature). |
| # generate the functional variant here, if an out variant exists. |
| if sig_group.signature.deprecated and sig_group.outplace is not None: |
| type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True) |
| type_hints.append(type_hint) |
| |
| # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument |
| # Generates the out variant if one exists. Otherwise, generate the functional variant |
| type_hint = sig_group.signature.signature_str_pyi( |
| skip_outputs=sig_group.outplace is None |
| ) |
| type_hints.append(type_hint) |
| |
| # Some operators also additionally have a vararg variant of their signature |
| type_hint_vararg = sig_group.signature.signature_str_pyi_vararg( |
| skip_outputs=sig_group.outplace is None |
| ) |
| if type_hint_vararg: |
| type_hints.append(type_hint_vararg) |
| |
| return type_hints |
| |
| |
| def get_max_pool_dispatch(name: str, arg_list: List[str]) -> Dict[str, List[str]]: |
| flag_pos = arg_list.index("{return_indices}") |
| # If return_indices is positional arg, everything before should have no default |
| arg_list_positional = ( |
| [ |
| ", ".join(single_arg.split(" = ")[0] for single_arg in arg.split(", ")) |
| for arg in arg_list[: flag_pos + 1] |
| ] |
| + ["/"] |
| + arg_list[flag_pos + 1 :] |
| ) |
| # Otherwise force return_indices to be kwarg |
| arg_list_keyword = arg_list.copy() |
| arg_list_keyword.insert(flag_pos, "*") |
| tmpl = "def {name}({args}) -> {{return_type}}: ..." |
| return { |
| name: [ |
| tmpl.format(name=name, args=", ".join(arg_list)).format( |
| return_indices="return_indices: Literal[False] = False", |
| return_type="Tensor", |
| ), |
| tmpl.format(name=name, args=", ".join(arg_list_positional)).format( |
| return_indices="return_indices: Literal[True]", |
| return_type="Tuple[Tensor, Tensor]", |
| ), |
| tmpl.format(name=name, args=", ".join(arg_list_keyword)).format( |
| return_indices="return_indices: Literal[True]", |
| return_type="Tuple[Tensor, Tensor]", |
| ), |
| ] |
| } |
| |
| |
| def gen_nn_functional(fm: FileManager) -> None: |
| INPUT = "input: Tensor" |
| KERNEL_SIZE = "kernel_size: Union[_int, _size]" |
| STRIDE_PADDING = ", ".join( |
| [ |
| "stride: Optional[Union[_int, _size]] = None", |
| "padding: Union[_int, _size] = 0", |
| ] |
| ) |
| |
| # TODO the list for `torch._C._nn` is nonexhaustive |
| unsorted_c_nn_function_hints: Dict[str, List[str]] = {} |
| |
| for d in (2, 3): |
| unsorted_c_nn_function_hints.update( |
| { |
| f"avg_pool{d}d": [ |
| f"def avg_pool{d}d({{}}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| f"{INPUT}", |
| f"{KERNEL_SIZE}", |
| f"{STRIDE_PADDING}", |
| "ceil_mode: bool = False", |
| "count_include_pad: bool = True", |
| "divisor_override: Optional[int] = None", |
| ] |
| ) |
| ) |
| ], |
| f"fractional_max_pool{d}d": [ |
| f"def fractional_max_pool{d}d({{}}) -> {{}}: ...".format( |
| ", ".join( |
| [ |
| f"{INPUT}", |
| f"{KERNEL_SIZE}", |
| "output_size: Union[_int, _size]", |
| "_random_samples: Tensor", |
| ] |
| ), |
| "Tuple[Tensor, Tensor]", |
| ) |
| ], |
| f"adaptive_max_pool{d}d": [ |
| f"def adaptive_max_pool{d}d({{}}) -> {{}}: ...".format( |
| ", ".join([f"{INPUT}", "output_size: Union[_int, _size]"]), |
| "Tuple[Tensor, Tensor]", |
| ) |
| ], |
| } |
| ) |
| |
| unsorted_c_nn_function_hints.update( |
| { |
| "hardtanh": [ |
| "def hardtanh({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "input: Tensor", |
| "min_val: float = ...", |
| "max_val: float = ...", |
| "*", |
| "out: Optional[Tensor] = None", |
| ] |
| ) |
| ) |
| ], |
| "hardtanh_": [ |
| "def hardtanh_({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "input: Tensor", |
| "min_val: float = ...", |
| "max_val: float = ...", |
| ] |
| ) |
| ) |
| ], |
| "elu_": ["def elu_(input: Tensor, alpha: float = ...) -> Tensor: ..."], |
| "leaky_relu": [ |
| "def leaky_relu({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "input: Tensor", |
| "negative_slope: float = ...", |
| "*", |
| "out: Optional[Tensor] = None", |
| ] |
| ) |
| ) |
| ], |
| "leaky_relu_": [ |
| f"def leaky_relu_({', '.join(['input: Tensor', 'negative_slope: float = ...'])}) -> Tensor: ..." |
| ], |
| "log_sigmoid": ["def log_sigmoid(input: Tensor) -> Tensor: ..."], |
| "gelu": ["def gelu(input: Tensor, approximate: str = ...) -> Tensor: ..."], |
| "softplus": [ |
| "def softplus({}) -> Tensor: ...".format( |
| ", ".join( |
| ["input: Tensor", "beta: int = ...", "threshold: int = ..."] |
| ) |
| ) |
| ], |
| "softshrink": [ |
| "def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ..." |
| ], |
| "hardsigmoid": [ |
| f"def hardsigmoid({', '.join(['input: Tensor', '*', 'out: Optional[Tensor] = None'])}) -> Tensor: ..." |
| ], |
| "linear": [ |
| "def linear({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "input: Tensor", |
| "weight: Tensor", |
| "bias: Optional[Tensor] = None", |
| ] |
| ) |
| ) |
| ], |
| "pad": [ |
| "def pad({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "input: Tensor", |
| "pad: Sequence[int]", |
| "mode: str = ...", |
| "value: Optional[float] = None", |
| ] |
| ) |
| ) |
| ], |
| "one_hot": [ |
| "def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ..." |
| ], |
| "scaled_dot_product_attention": [ |
| "def scaled_dot_product_attention({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "query: Tensor", |
| "key: Tensor", |
| "value: Tensor", |
| "attn_mask: Optional[Tensor] = None", |
| "dropout_p: float = 0.0", |
| "is_causal: bool = False", |
| "scale: Optional[float] = None", |
| ] |
| ) |
| ) |
| ], |
| } |
| ) |
| |
| c_nn_function_hints: List[str] = [] |
| for _, hints in sorted(unsorted_c_nn_function_hints.items()): |
| if len(hints) > 1: |
| hints = ["@overload\n" + h for h in hints] |
| c_nn_function_hints += hints |
| |
| # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered |
| # through an `_add_docstr` call |
| torch_imports = [ |
| "conv1d", |
| "conv2d", |
| "conv3d", |
| "conv_transpose1d", |
| "conv_transpose2d", |
| "conv_transpose3d", |
| "conv_tbc", |
| "avg_pool1d", |
| "adaptive_avg_pool1d", |
| "relu_", |
| "selu_", |
| "celu_", |
| "prelu", |
| "rrelu_", |
| "hardshrink", |
| "bilinear", |
| "pixel_shuffle", |
| "pixel_unshuffle", |
| "channel_shuffle", |
| "native_channel_shuffle", |
| "pairwise_distance", |
| "pdist", |
| "cosine_similarity", |
| ] |
| imported_hints = [f"from .. import {_} as {_}" for _ in torch_imports] |
| |
| # Functions imported into `torch.nn.functional` from `torch._C._nn` |
| c_nn_imports = [ |
| "avg_pool2d", |
| "avg_pool3d", |
| "hardtanh_", |
| "elu_", |
| "leaky_relu_", |
| "gelu", |
| "softplus", |
| "softshrink", |
| "linear", |
| "pad", |
| "one_hot", |
| "scaled_dot_product_attention", |
| ] |
| imported_hints += [f"from .._C._nn import {_} as {_}" for _ in c_nn_imports] |
| # This is from `torch._C._nn` but renamed |
| imported_hints.append("from .._C._nn import log_sigmoid\nlogsigmoid = log_sigmoid") |
| |
| # Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional` |
| unsorted_dispatched_hints: Dict[str, List[str]] = {} |
| |
| for d in (1, 2, 3): |
| unsorted_dispatched_hints.update( |
| **get_max_pool_dispatch( |
| f"max_pool{d}d", |
| [ |
| f"{INPUT}", |
| f"{KERNEL_SIZE}", |
| f"{STRIDE_PADDING}", |
| "dilation: Union[_int, _size] = 1", |
| "ceil_mode: bool = False", |
| "{return_indices}", |
| ], |
| ), |
| **get_max_pool_dispatch( |
| f"fractional_max_pool{d}d", |
| [ |
| f"{INPUT}", |
| f"{KERNEL_SIZE}", |
| "output_size: Optional[Union[_int, _size]] = None", |
| "output_ratio: Optional[_ratio_any_t] = None", |
| "{return_indices}", |
| "_random_samples: Optional[Tensor] = None", |
| ], |
| ), |
| **get_max_pool_dispatch( |
| f"adaptive_max_pool{d}d", |
| [f"{INPUT}", "output_size: Union[_int, _size]", "{return_indices}"], |
| ), |
| ) |
| |
| # There's no fractional_max_pool1d |
| del unsorted_dispatched_hints["fractional_max_pool1d"] |
| |
| dispatched_hints: List[str] = [] |
| for _, hints in sorted(unsorted_dispatched_hints.items()): |
| if len(hints) > 1: |
| hints = ["@overload\n" + h for h in hints] |
| dispatched_hints += hints |
| |
| fm.write_with_template( |
| "torch/nn/functional.pyi", |
| "torch/nn/functional.pyi.in", |
| lambda: { |
| "imported_hints": imported_hints, |
| "dispatched_hints": dispatched_hints, |
| }, |
| ) |
| fm.write_with_template( |
| "torch/_C/_nn.pyi", |
| "torch/_C/_nn.pyi.in", |
| lambda: { |
| "c_nn_function_hints": c_nn_function_hints, |
| }, |
| ) |
| |
| |
| def gen_pyi( |
| native_yaml_path: str, |
| tags_yaml_path: str, |
| deprecated_yaml_path: str, |
| fm: FileManager, |
| ) -> None: |
| """gen_pyi() |
| |
| This function generates a pyi file for torch. |
| """ |
| |
| # Some of this logic overlaps with generate_python_signature in |
| # tools/autograd/gen_python_functions.py; however, this |
| # function is all about generating mypy type signatures, whereas |
| # the other function generates are custom format for argument |
| # checking. If you are update this, consider if your change |
| # also needs to update the other file. |
| |
| # Dictionary for NamedTuple definitions |
| namedtuples: Dict[str, str] = {} |
| |
| # Generate type signatures for top-level functions |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list) |
| |
| for n, n1, n2 in [ |
| ("csr", "crow", "col"), |
| ("csc", "ccol", "row"), |
| ("bsr", "crow", "col"), |
| ("bsc", "ccol", "row"), |
| ]: |
| unsorted_function_hints.update( |
| { |
| f"sparse_{n}_tensor": [ |
| f"def sparse_{n}_tensor({{}}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| f"{n1}_indices: Union[Tensor, List]", |
| f"{n2}_indices: Union[Tensor, List]", |
| "values: Union[Tensor, List]", |
| "size: Optional[_size] = None", |
| "*", |
| "dtype: Optional[_dtype] = None", |
| "device: Union[_device, str, None] = None", |
| "requires_grad: _bool = False", |
| "check_invariants: Optional[_bool] = None", |
| ] |
| ), |
| ) |
| ], |
| } |
| ) |
| |
| unsorted_function_hints.update( |
| { |
| "set_flush_denormal": ["def set_flush_denormal(mode: _bool) -> _bool: ..."], |
| "get_default_dtype": ["def get_default_dtype() -> _dtype: ..."], |
| "asarray": [ |
| "def asarray({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "obj: Any", |
| "*", |
| "dtype: Optional[_dtype] = None", |
| "device: Union[_device, str, None] = None", |
| "copy: Optional[_bool] = None", |
| "requires_grad: _bool = False", |
| ] |
| ) |
| ) |
| ], |
| "from_numpy": ["def from_numpy(ndarray) -> Tensor: ..."], |
| "frombuffer": [ |
| "def frombuffer({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "buffer: Any", |
| "*", |
| "dtype: _dtype", |
| "count: int = -1", |
| "offset: int = 0", |
| "device: Union[_device, str, None] = None", |
| "requires_grad: _bool = False", |
| ] |
| ) |
| ) |
| ], |
| "numel": ["def numel(self: Tensor) -> _int: ..."], |
| "as_tensor": [ |
| "def as_tensor({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "data: Any", |
| "dtype: Optional[_dtype] = None", |
| DEVICE_PARAM, |
| ] |
| ) |
| ) |
| ], |
| "get_num_threads": ["def get_num_threads() -> _int: ..."], |
| "set_num_threads": ["def set_num_threads(num: _int) -> None: ..."], |
| "init_num_threads": ["def init_num_threads() -> None: ..."], |
| "get_num_interop_threads": ["def get_num_interop_threads() -> _int: ..."], |
| "set_num_interop_threads": [ |
| "def set_num_interop_threads(num: _int) -> None: ..." |
| ], |
| # These functions are explicitly disabled by |
| # SKIP_PYTHON_BINDINGS because they are hand bound. |
| # Correspondingly, we must hand-write their signatures. |
| "tensor": [f"def tensor(data: Any, {FACTORY_PARAMS}) -> Tensor: ..."], |
| "sparse_coo_tensor": [ |
| "def sparse_coo_tensor({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "indices: Tensor", |
| "values: Union[Tensor, List]", |
| "size: Optional[_size] = None", |
| "*", |
| "dtype: Optional[_dtype] = None", |
| "device: Union[_device, str, None] = None", |
| "requires_grad: _bool = False", |
| "check_invariants: Optional[_bool] = None", |
| "is_coalesced: Optional[_bool] = None", |
| ] |
| ) |
| ) |
| ], |
| "sparse_compressed_tensor": [ |
| "def sparse_compressed_tensor({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "compressed_indices: Union[Tensor, List]", |
| "plain_indices: Union[Tensor, List]", |
| "values: Union[Tensor, List]", |
| "size: Optional[_size] = None", |
| "*", |
| "dtype: Optional[_dtype] = None", |
| "layout: Optional[_layout] = None", |
| "device: Union[_device, str, None] = None", |
| "requires_grad: _bool = False", |
| "check_invariants: Optional[_bool] = None", |
| ] |
| ) |
| ) |
| ], |
| "_sync": ["def _sync(t: Tensor) -> None: ..."], |
| "_is_functional_tensor": [ |
| "def _is_functional_tensor(t: Tensor) -> _bool: ..." |
| ], |
| "_from_functional_tensor": [ |
| "def _from_functional_tensor(t: Tensor) -> Tensor: ..." |
| ], |
| "_to_functional_tensor": [ |
| "def _to_functional_tensor(t: Tensor) -> Tensor: ..." |
| ], |
| "_functionalize_replace": [ |
| "def _functionalize_replace(self_: Tensor, other: Tensor) -> None: ..." |
| ], |
| "_functionalize_commit_update": [ |
| "def _functionalize_commit_update(t: Tensor) -> None: ..." |
| ], |
| "_functionalize_sync": ["def _functionalize_sync(t: Tensor) -> None: ..."], |
| "_enable_functionalization": [ |
| "def _enable_functionalization(*, reapply_views: _bool = False): ..." |
| ], |
| "_disable_functionalization": ["def _disable_functionalization(): ..."], |
| "range": [ |
| "def range({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "start: Number", |
| "end: Number", |
| "step: Number = 1", |
| "*", |
| "out: Optional[Tensor] = None", |
| FACTORY_PARAMS, |
| ] |
| ) |
| ) |
| ], |
| "arange": [ |
| "def arange({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "start: Number", |
| "end: Number", |
| "step: Number", |
| "*", |
| "out: Optional[Tensor] = None", |
| FACTORY_PARAMS, |
| ] |
| ) |
| ), |
| "def arange({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "start: Number", |
| "end: Number", |
| "*", |
| "out: Optional[Tensor] = None", |
| FACTORY_PARAMS, |
| ] |
| ) |
| ), |
| "def arange({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "end: Number", |
| "*", |
| "out: Optional[Tensor] = None", |
| FACTORY_PARAMS, |
| ] |
| ) |
| ), |
| ], |
| "linspace": [ |
| "def linspace({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "start: Number", |
| "end: Number", |
| "steps: Optional[_int] = None", |
| "*", |
| "out: Optional[Tensor] = None", |
| FACTORY_PARAMS, |
| ] |
| ) |
| ) |
| ], |
| "logspace": [ |
| "def logspace({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "start: Number", |
| "end: Number", |
| "steps: Optional[_int] = None", |
| "base: _float = 10.0", |
| "*", |
| "out: Optional[Tensor] = None", |
| FACTORY_PARAMS, |
| ] |
| ) |
| ) |
| ], |
| "randint": [ |
| "def randint({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "low: _int", |
| "high: _int", |
| "size: _size", |
| "*", |
| "generator: Optional[Generator] = None", |
| FACTORY_PARAMS, |
| ] |
| ) |
| ), |
| "def randint({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "high: _int", |
| "size: _size", |
| "*", |
| "generator: Optional[Generator] = None", |
| FACTORY_PARAMS, |
| ] |
| ) |
| ), |
| ], |
| "full": [ |
| "def full({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "size: _size", |
| "fill_value: Union[Number, _complex]", |
| "*", |
| "out: Optional[Tensor] = None", |
| "layout: _layout = strided", |
| FACTORY_PARAMS, |
| ] |
| ) |
| ), |
| "def full({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "size: _size", |
| "fill_value: Union[Number, _complex]", |
| "*", |
| "names: List[Union[str, None]]", |
| "layout: _layout = strided", |
| FACTORY_PARAMS, |
| ] |
| ) |
| ), |
| ], |
| "is_grad_enabled": ["def is_grad_enabled() -> _bool: ..."], |
| "is_inference_mode_enabled": [ |
| "def is_inference_mode_enabled() -> _bool: ..." |
| ], |
| "nonzero": [ |
| "def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: ...", |
| "def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", |
| ], |
| "dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], |
| "hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], |
| "saddmm": [ |
| "def saddmm({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "input: Tensor", |
| "mat1: Tensor", |
| "mat2: Tensor", |
| "*", |
| "beta: Number = 1", |
| "alpha: Number = 1", |
| "out: Optional[Tensor] = None", |
| ] |
| ) |
| ) |
| ], |
| "spmm": ["def spmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], |
| "div": [ |
| "def div({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "input: Union[Tensor, Number]", |
| "other: Union[Tensor, Number]", |
| "*", |
| "rounding_mode: Optional[str] = None", |
| "out: Optional[Tensor] = None", |
| ] |
| ) |
| ) |
| ], |
| } |
| ) |
| for binop in ["mul", "true_divide", "floor_divide"]: |
| unsorted_function_hints[binop].append( |
| f"def {binop}(input: Union[Tensor, Number], other: Union[Tensor, Number], " |
| "*, out: Optional[Tensor] = None) -> Tensor: ..." |
| ) |
| for binop in ["add", "sub"]: |
| unsorted_function_hints[binop].append( |
| f"def {binop}(input: Union[Tensor, Number], other: Union[Tensor, Number], " |
| "*, alpha: Optional[Number] = 1, out: Optional[Tensor] = None) -> Tensor: ..." |
| ) |
| |
| native_functions = parse_native_yaml( |
| native_yaml_path, tags_yaml_path |
| ).native_functions |
| native_functions = list(filter(should_generate_py_binding, native_functions)) |
| |
| function_signatures = load_signatures( |
| native_functions, deprecated_yaml_path, method=False, pyi=True |
| ) |
| sig_groups = get_py_torch_functions(function_signatures) |
| for group in sorted(sig_groups, key=lambda g: g.signature.name): |
| name = group.signature.name |
| unsorted_function_hints[name] += generate_type_hints(group) |
| |
| named_tuple = returns_named_tuple_pyi(group.signature) |
| if named_tuple is not None and not group.signature.deprecated: |
| # deprecated namedtuples are currently not included for torch functions |
| tuple_name, tuple_def = named_tuple |
| if tuple_name in namedtuples: |
| assert namedtuples[tuple_name] == tuple_def |
| else: |
| namedtuples[tuple_name] = tuple_def |
| |
| def replace_special_case(hint: str) -> str: |
| # NB: Keep this in sync with enum in aten/src/ATen/core/Reduction.h |
| hint = hint.replace("at::Reduction::Mean", "1") |
| hint = hint.replace(": Tensor = None", ": Optional[Tensor] = None") |
| # Match both: |
| # ": Union[Tensor, Tuple[Tensor, ...], List[Tensor]] = None" |
| # ": Union[Tuple[Tensor, ...], List[Tensor]] = None" |
| hint = hint.replace( |
| "Tuple[Tensor, ...], List[Tensor]] = None", |
| "Tuple[Tensor, ...], List[Tensor], None] = None", |
| ) |
| return hint |
| |
| function_hints = [] |
| for name, hints in sorted(unsorted_function_hints.items()): |
| hints = [replace_special_case(h) for h in hints] |
| if len(hints) > 1: |
| hints = ["@overload\n" + h for h in hints] |
| function_hints += hints |
| |
| # Generate type signatures for Tensor methods |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list) |
| unsorted_tensor_method_hints.update( |
| { |
| "size": [ |
| "def size(self) -> Size: ...", |
| "def size(self, dim: _int) -> _int: ...", |
| ], |
| "stride": [ |
| "def stride(self) -> Tuple[_int, ...]: ...", |
| "def stride(self, _int) -> _int: ...", |
| ], |
| "new_ones": [ |
| f"def new_ones(self, size: _size, {FACTORY_PARAMS}) -> Tensor: ..." |
| ], |
| "new_tensor": [ |
| f"def new_tensor(self, data: Any, {FACTORY_PARAMS}) -> Tensor: ..." |
| ], |
| # new and __init__ have the same signatures differ only in return type |
| # Adapted from legacy_tensor_ctor and legacy_tensor_new |
| "new": [ |
| f"def new(self, *args: Any, {DEVICE_PARAM}) ->Tensor: ...", |
| "def new(self, storage: Storage) -> Tensor: ...", |
| "def new(self, other: Tensor) -> Tensor: ...", |
| f"def new(self, size: _size, *, {DEVICE_PARAM}) -> Tensor: ...", |
| ], |
| "__init__": [ |
| f"def __init__(self, *args: Any, {DEVICE_PARAM}) -> None: ...", |
| "def __init__(self, storage: Storage) -> None: ...", |
| "def __init__(self, other: Tensor) -> None: ...", |
| f"def __init__(self, size: _size, *, {DEVICE_PARAM}) -> None: ...", |
| ], |
| "as_subclass": ["def as_subclass(self, cls: Type[S]) -> S: ..."], |
| "_make_subclass": [ |
| "@staticmethod \ndef _make_subclass({}) -> S: ...".format( |
| ", ".join( |
| [ |
| "cls: Type[S]", |
| "data: Tensor", |
| "require_grad: _bool = False", |
| "dispatch_strides: _bool = False", |
| "dispatch_device: _bool = False", |
| "device_for_backend_keys: Optional[_device] = None", |
| ] |
| ) |
| ) |
| ], |
| "__getitem__": [f"def __getitem__(self, {INDICES}) -> Tensor: ..."], |
| "__setitem__": [ |
| f"def __setitem__(self, {INDICES}, val: Union[Tensor, Number]) -> None: ..." |
| ], |
| "tolist": ["def tolist(self) -> List: ..."], |
| "requires_grad_": [ |
| "def requires_grad_(self, mode: _bool = True) -> Tensor: ..." |
| ], |
| "element_size": ["def element_size(self) -> _int: ..."], |
| "data_ptr": ["def data_ptr(self) -> _int: ..."], |
| "dim": ["def dim(self) -> _int: ..."], |
| "nonzero": [ |
| "def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ...", |
| "def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", |
| ], |
| "numel": ["def numel(self) -> _int: ..."], |
| "ndimension": ["def ndimension(self) -> _int: ..."], |
| "nelement": ["def nelement(self) -> _int: ..."], |
| "cuda": [ |
| "def cuda({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "self", |
| "device: Optional[Union[_device, _int, str]] = None", |
| "non_blocking: _bool = False", |
| ] |
| ) |
| ) |
| ], |
| "numpy": ["def numpy(self, *, force: _bool = False) -> Any: ..."], |
| "apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."], |
| "map_": [ |
| "def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..." |
| ], |
| "map2_": [ |
| "def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..." |
| ], |
| "storage": ["def untyped_storage(self) -> UntypedStorage: ..."], |
| "storage_type": ["def storage_type(self) -> Storage: ..."], |
| "type": [ |
| "def type(self, dtype: None = None, non_blocking: _bool = False) -> str: ...", |
| "def type(self, dtype: Union[str, _dtype], non_blocking: _bool = False) -> Tensor: ...", |
| ], |
| "get_device": ["def get_device(self) -> _int: ..."], |
| "contiguous": [ |
| "def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ..." |
| ], |
| "has_names": ["def has_names(self) -> _bool: ..."], |
| "is_contiguous": [ |
| "def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..." |
| ], |
| "_is_view": ["def _is_view(self) -> _bool: ..."], |
| "is_cpu": ["is_cpu: _bool"], |
| "is_cuda": ["is_cuda: _bool"], |
| "is_leaf": ["is_leaf: _bool"], |
| "is_nested": ["is_nested: _bool"], |
| "is_sparse": ["is_sparse: _bool"], |
| "is_sparse_csr": ["is_sparse_csr: _bool"], |
| "is_quantized": ["is_quantized: _bool"], |
| "is_meta": ["is_meta: _bool"], |
| "is_mps": ["is_mps: _bool"], |
| "is_mtia": ["is_mtia: _bool"], |
| "is_ort": ["is_ort: _bool"], |
| "is_mkldnn": ["is_mkldnn: _bool"], |
| "is_vulkan": ["is_vulkan: _bool"], |
| "is_ipu": ["is_ipu: _bool"], |
| "storage_offset": ["def storage_offset(self) -> _int: ..."], |
| "to": [ |
| "def to(self, dtype: _dtype, non_blocking: _bool = False, copy: _bool = False) -> Tensor: ...", |
| "def to({}) -> Tensor: ...".format( |
| ", ".join( |
| [ |
| "self", |
| "device: Optional[Union[_device, str]] = None", |
| "dtype: Optional[_dtype] = None", |
| "non_blocking: _bool = False", |
| "copy: _bool = False", |
| ] |
| ) |
| ), |
| "def to(self, other: Tensor, non_blocking: _bool = False, copy: _bool = False) -> Tensor: ...", |
| ], |
| "item": ["def item(self) -> Number: ..."], |
| "copy_": [ |
| "def copy_(self, src: Tensor, non_blocking: _bool = False) -> Tensor: ..." |
| ], |
| "set_": [ |
| "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], " |
| "offset: _int, size: _size, stride: _size) -> Tensor: ...", |
| "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...", |
| ], |
| "split": [ |
| "def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ...", |
| "def split(self, split_size: Tuple[_int, ...], dim: _int = 0) -> Sequence[Tensor]: ...", |
| ], |
| "div": [ |
| "def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." |
| ], |
| "div_": [ |
| "def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." |
| ], |
| } |
| ) |
| for binop in ["mul", "true_divide", "floor_divide"]: |
| for inplace in [False, True]: |
| out_suffix = ", *, out: Optional[Tensor] = None" |
| if inplace: |
| binop += "_" |
| out_suffix = "" |
| unsorted_tensor_method_hints[binop].append( |
| f"def {binop}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{out_suffix})" |
| " -> Tensor: ..." |
| ) |
| for binop in ["add", "sub"]: |
| for inplace in [False, True]: |
| out_suffix = ", out: Optional[Tensor] = None" |
| if inplace: |
| binop += "_" |
| out_suffix = "" |
| unsorted_tensor_method_hints[binop].append( |
| f"def {binop}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], " |
| f"*, alpha: Optional[Number] = 1{out_suffix})" |
| " -> Tensor: ..." |
| ) |
| simple_conversions = [ |
| "byte", |
| "char", |
| "cpu", |
| "double", |
| "float", |
| "half", |
| "int", |
| "long", |
| "short", |
| "bool", |
| "bfloat16", |
| ] |
| for name in simple_conversions: |
| unsorted_tensor_method_hints[name].append(f"def {name}(self) -> Tensor: ...") |
| |
| # pyi tensor methods don't currently include deprecated signatures for some reason |
| # TODO: we should probably add them in |
| tensor_method_signatures = load_signatures( |
| native_functions, |
| deprecated_yaml_path, |
| method=True, |
| skip_deprecated=True, |
| pyi=True, |
| ) |
| tensor_method_sig_groups = get_py_torch_functions( |
| tensor_method_signatures, method=True |
| ) |
| |
| for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name): |
| name = group.signature.name |
| unsorted_tensor_method_hints[name] += generate_type_hints(group) |
| |
| named_tuple = returns_named_tuple_pyi(group.signature) |
| if named_tuple is not None and not group.signature.deprecated: |
| # deprecated namedtuples are currently not included for torch functions |
| tuple_name, tuple_def = named_tuple |
| if tuple_name in namedtuples: |
| assert namedtuples[tuple_name] == tuple_def |
| else: |
| namedtuples[tuple_name] = tuple_def |
| |
| for op in all_ops: |
| name = f"__{op}__" |
| unsorted_tensor_method_hints[name] += sig_for_ops(name) |
| |
| tensor_method_hints = [] |
| for name, hints in sorted(unsorted_tensor_method_hints.items()): |
| if len(hints) > 1: |
| hints = ["@overload\n" + h for h in hints] |
| tensor_method_hints += hints |
| |
| # TODO: Missing type hints for nn |
| |
| # Generate namedtuple definitions |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| namedtuple_defs = [f"{defn}\n" for defn in namedtuples.values()] |
| |
| # Generate type signatures for legacy classes |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| legacy_storage_base_hints = ["class StorageBase(object): ..."] |
| |
| legacy_class_hints = [] |
| for c in ( |
| "DoubleTensor", |
| "FloatTensor", |
| "LongTensor", |
| "IntTensor", |
| "ShortTensor", |
| "HalfTensor", |
| "CharTensor", |
| "ByteTensor", |
| "BoolTensor", |
| ): |
| legacy_class_hints.append(f"class {c}(Tensor): ...") |
| |
| # Generate type signatures for dtype classes |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| # TODO: don't explicitly list dtypes here; get it from canonical |
| # source |
| dtype_class_hints = [ |
| f"{n}: dtype = ..." |
| for n in [ |
| "float32", |
| "float", |
| "float64", |
| "double", |
| "float16", |
| "bfloat16", |
| "float8_e4m3fn", |
| "float8_e5m2", |
| "half", |
| "uint8", |
| "int8", |
| "int16", |
| "short", |
| "int32", |
| "int", |
| "int64", |
| "long", |
| "complex32", |
| "complex64", |
| "chalf", |
| "cfloat", |
| "complex128", |
| "cdouble", |
| "quint8", |
| "qint8", |
| "qint32", |
| "bool", |
| "quint4x2", |
| "quint2x4", |
| ] |
| ] |
| |
| # Generate __all__ directive |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| # Include only the functions that contain hints, to prevent undefined |
| # symbols to be included in the `__all__` directive. |
| hinted_function_names = [ |
| name for name, hint in unsorted_function_hints.items() if hint |
| ] |
| all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names) |
| all_directive = pformat(all_symbols, width=100, compact=True).split("\n") |
| all_directive[0] = f"__all__ = {all_directive[0]}" |
| |
| # Dispatch key hints |
| # ~~~~~~~~~~~~~~~~~~ |
| dispatch_key_hints = [f"{d.name}: DispatchKey = ..." for d in DispatchKey] |
| torch_dispatch_mode_key_hints = [ |
| f"{k.name}: _TorchDispatchModeKey = ..." for k in _TorchDispatchModeKey |
| ] |
| |
| # Tags Enum type hints |
| # ~~~~~~~~~~~~~~~~~~~~ |
| |
| tag_names = sorted(parse_tags_yaml(tags_yaml_path)) |
| tag_attributes = "\n".join( |
| f"{name}: _int = {index}" for index, name in enumerate(tag_names) |
| ) |
| |
| # Write out the stub |
| # ~~~~~~~~~~~~~~~~~~ |
| |
| env = { |
| "namedtuple_defs": namedtuple_defs, |
| "function_hints": function_hints, |
| "tensor_method_hints": tensor_method_hints, |
| "legacy_class_hints": legacy_class_hints, |
| "legacy_storage_base_hints": legacy_storage_base_hints, |
| "dtype_class_hints": dtype_class_hints, |
| "dispatch_key_hints": dispatch_key_hints, |
| "torch_dispatch_mode_key_hints": torch_dispatch_mode_key_hints, |
| "all_directive": all_directive, |
| "tag_attributes": tag_attributes, |
| } |
| fm.write_with_template( |
| "torch/_C/__init__.pyi", |
| "torch/_C/__init__.pyi.in", |
| lambda: { |
| "generated_comment": "@" + "generated from torch/_C/__init__.pyi.in", |
| **env, |
| }, |
| ) |
| fm.write_with_template( |
| "torch/_C/_VariableFunctions.pyi", |
| "torch/_C/_VariableFunctions.pyi.in", |
| lambda: { |
| "generated_comment": "@" |
| + "generated from torch/_C/_VariableFunctions.pyi.in", |
| **env, |
| }, |
| ) |
| fm.write_with_template( |
| "torch/_VF.pyi", |
| "torch/_C/_VariableFunctions.pyi.in", |
| lambda: { |
| "generated_comment": "@" |
| + "generated from torch/_C/_VariableFunctions.pyi.in", |
| **env, |
| }, |
| ) |
| fm.write_with_template( |
| "torch/return_types.pyi", |
| "torch/_C/return_types.pyi.in", |
| lambda: { |
| "generated_comment": "@" + "generated from torch/_C/return_types.pyi", |
| **env, |
| }, |
| ) |
| gen_nn_functional(fm) |
| |
| |
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Generate type stubs for PyTorch") |
| parser.add_argument( |
| "--native-functions-path", |
| metavar="NATIVE", |
| default="aten/src/ATen/native/native_functions.yaml", |
| help="path to native_functions.yaml", |
| ) |
| parser.add_argument( |
| "--tags-path", |
| metavar="TAGS", |
| default="aten/src/ATen/native/tags.yaml", |
| help="path to tags.yaml", |
| ) |
| parser.add_argument( |
| "--deprecated-functions-path", |
| metavar="DEPRECATED", |
| default="tools/autograd/deprecated.yaml", |
| help="path to deprecated.yaml", |
| ) |
| parser.add_argument( |
| "--out", metavar="OUT", default=".", help="path to output directory" |
| ) |
| args = parser.parse_args() |
| fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False) |
| gen_pyi( |
| args.native_functions_path, args.tags_path, args.deprecated_functions_path, fm |
| ) |
| |
| |
| if __name__ == "__main__": |
| main() |