Fix type annotation errors in torch.functional (#43446)
Summary:
Closes gh-42968
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43446
Reviewed By: albanD
Differential Revision: D23280962
Pulled By: malfet
fbshipit-source-id: de5386a95a20ecc814c39cbec3e4252112340b3a
diff --git a/mypy.ini b/mypy.ini
index b9c87ac..ce6bd2a 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -54,9 +54,6 @@
[mypy-torch.distributed.*]
ignore_errors = True
-[mypy-torch.functional.*]
-ignore_errors = True
-
[mypy-torch.testing._internal.*]
ignore_errors = True
diff --git a/test/onnx/expect/TestOperators.test_meshgrid.expect b/test/onnx/expect/TestOperators.test_meshgrid.expect
index 465b780..4107cc5 100644
--- a/test/onnx/expect/TestOperators.test_meshgrid.expect
+++ b/test/onnx/expect/TestOperators.test_meshgrid.expect
@@ -17,7 +17,7 @@
}
}
node {
- input: "x"
+ input: "0"
input: "3"
output: "4"
name: "Reshape_1"
@@ -38,7 +38,7 @@
}
}
node {
- input: "y"
+ input: "1"
input: "5"
output: "6"
name: "Reshape_3"
@@ -59,7 +59,7 @@
}
}
node {
- input: "z"
+ input: "2"
input: "7"
output: "8"
name: "Reshape_5"
@@ -221,7 +221,7 @@
}
name: "torch-jit-export"
input {
- name: "x"
+ name: "0"
type {
tensor_type {
elem_type: 1
@@ -234,7 +234,7 @@
}
}
input {
- name: "y"
+ name: "1"
type {
tensor_type {
elem_type: 1
@@ -247,7 +247,7 @@
}
}
input {
- name: "z"
+ name: "2"
type {
tensor_type {
elem_type: 1
diff --git a/torch/functional.py b/torch/functional.py
index b8c1d3c..8202940 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -1,7 +1,10 @@
-from typing import Tuple, Optional
+from typing import (
+ Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
+)
import torch
import torch.nn.functional as F
+from torch.types import _size
from ._lowrank import svd_lowrank, pca_lowrank
from .overrides import has_torch_function, handle_torch_function
from ._jit_internal import boolean_dispatch, List
@@ -65,7 +68,7 @@
if not torch.jit.is_scripting():
if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
return handle_torch_function(broadcast_tensors, tensors, *tensors)
- return _VF.broadcast_tensors(tensors)
+ return _VF.broadcast_tensors(tensors) # type: ignore
def split(tensor, split_size_or_sections, dim=0):
@@ -117,9 +120,15 @@
# call here.
return tensor.split(split_size_or_sections, dim)
+
+if TYPE_CHECKING:
+ _Indices = _size
+else:
+ _Indices = List[int]
+
+
# equivalent to itertools.product(indices)
-def _indices_product(indices):
- # type: (List[int]) -> (List[List[int]])
+def _indices_product(indices: _Indices) -> List[List[int]]:
empty_list = torch.jit.annotate(List[int], [])
result = [empty_list]
for idx in indices:
@@ -130,6 +139,7 @@
result = result_temp
return result
+
def _index_tensor_with_indices_list(tensor, indices):
# type: (Tensor, List[int]) -> Tensor
out = tensor
@@ -137,6 +147,7 @@
out = out[index]
return out
+
def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
# type: (Tensor, Tensor, bool, bool) -> (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]])
r"""Unpacks the data and pivots from a LU factorization of a tensor.
@@ -204,10 +215,12 @@
m, n = shape[-2:]
k = min(m, n)
if unpack_data:
- U = LU_data.triu()
+ U: Optional[Tensor] = LU_data.triu()
+ assert U is not None
if m != k:
U = U.narrow(-2, 0, k)
- L = LU_data.tril()
+ L: Optional[Tensor] = LU_data.tril()
+ assert L is not None
if k != n:
L = L.narrow(-1, 0, k)
L.diagonal(dim1=-2, dim2=-1).fill_(1)
@@ -217,9 +230,11 @@
if unpack_pivots:
LU_pivots_zero_idx = LU_pivots - 1
if LU_data.dim() > 2:
- P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype) \
- .expand(shape[:-1] + (m,)) \
- .clone(memory_format=torch.contiguous_format)
+ P: Optional[Tensor] = torch.eye(m, device=LU_data.device,
+ dtype=LU_data.dtype) \
+ .expand(shape[:-1] + (m,)) \
+ .clone(memory_format=torch.contiguous_format)
+ assert P is not None
# TODO: rewrite when TorchScript supports product and map as
# product(*map(lambda x: list(range(x)), shape[:-2])) when issue 33781 is fixed
@@ -321,15 +336,24 @@
return handle_torch_function(einsum, operands, equation, *operands)
if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
# the old interface of passing the operands as one list argument
- operands = operands[0]
+ _operands = operands[0]
# recurse incase operands contains value that has torch function
# in the original implementation this line is omitted
- return einsum(equation, *operands)
+ return einsum(equation, *_operands)
- return _VF.einsum(equation, operands)
+ return _VF.einsum(equation, operands) # type: ignore
-def meshgrid(*tensors):
+if TYPE_CHECKING:
+ # The JIT doesn't understand Union, so only add type annotation for mypy
+ def meshgrid(*tensors: Union[Tensor, List[Tensor]]) -> Tuple[Tensor, ...]:
+ return _meshgrid(*tensors)
+else:
+ def meshgrid(*tensors):
+ return _meshgrid(*tensors)
+
+
+def _meshgrid(*tensors):
r"""Take :math:`N` tensors, each of which can be either scalar or 1-dimensional
vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by
expanding the :math:`i` :sup:`th` input over dimensions defined by other inputs.
@@ -363,8 +387,8 @@
return handle_torch_function(meshgrid, tensors, *tensors)
if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
# the old interface of passing the operands as one list argument
- tensors = tensors[0]
- return _VF.meshgrid(tensors)
+ tensors = tensors[0] # type: ignore
+ return _VF.meshgrid(tensors) # type: ignore
def stft(input, n_fft, hop_length=None, win_length=None, window=None,
@@ -524,15 +548,24 @@
window=window, center=center, normalized=normalized, onesided=onesided,
length=length)
- return _VF.istft(
- input, n_fft, hop_length, win_length, window, center, normalized, onesided, length)
+ return _VF.istft(input, n_fft, hop_length, win_length, window, center, normalized, onesided, length) # type: ignore
del torch.unique_dim
-def _unique_impl(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
- # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor, Tensor]
+if TYPE_CHECKING:
+ # These _impl functions return a variable number of tensors as output with
+ # __torch_function__; tuple unpacking is done already rather than being
+ # done by the caller of the _impl function
+ _unique_impl_out = Any
+else:
+ _unique_impl_out = Tuple[Tensor, Tensor, Tensor]
+
+
+def _unique_impl(input: Tensor, sorted: bool = True,
+ return_inverse: bool = False, return_counts: bool = False,
+ dim: Optional[int] = None) -> _unique_impl_out:
r"""Returns the unique elements of the input tensor.
.. note:: This function is different from :func:`torch.unique_consecutive` in the sense that
@@ -598,7 +631,7 @@
return_counts=return_counts, dim=dim)
if dim is not None:
- output, inverse_indices, counts = _VF.unique_dim(
+ output, inverse_indices, counts = _VF.unique_dim( # type: ignore
input,
dim,
sorted=sorted,
@@ -615,8 +648,9 @@
return output, inverse_indices, counts
-def _unique_consecutive_impl(input, return_inverse=False, return_counts=False, dim=None):
- # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor, Tensor]
+def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
+ return_counts: bool = False,
+ dim: Optional[int] = None) -> _unique_impl_out:
r"""Eliminates all but the first element from every consecutive group of equivalent elements.
.. note:: This function is different from :func:`torch.unique` in the sense that this function
@@ -671,7 +705,7 @@
return handle_torch_function(
unique_consecutive, (input,), input, return_inverse=return_inverse,
return_counts=return_counts, dim=dim)
- output, inverse_indices, counts = _VF.unique_consecutive(
+ output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore
input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
return output, inverse_indices, counts
@@ -686,6 +720,7 @@
output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
return output, counts
+
def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
@@ -696,6 +731,7 @@
output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
return output
+
def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
@@ -706,6 +742,7 @@
output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
return output, inverse_indices
+
_return_inverse_false = boolean_dispatch(
arg_name='return_counts',
arg_index=3,
@@ -748,6 +785,7 @@
output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
return output, counts
+
def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):
# type: (Tensor, bool, bool, Optional[int]) -> Tensor
@@ -758,6 +796,7 @@
output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
return output
+
def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
@@ -768,6 +807,7 @@
output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
return output, inverse_indices
+
_consecutive_return_inverse_false = boolean_dispatch(
arg_name='return_counts',
arg_index=1,
@@ -857,7 +897,7 @@
raise RuntimeError("tensordot expects dims >= 0, but got dims={}".format(dims))
dims_a = list(range(-dims, 0))
dims_b = list(range(dims))
- return _VF.tensordot(a, b, dims_a, dims_b)
+ return _VF.tensordot(a, b, dims_a, dims_b) # type: ignore
def cartesian_prod(*tensors):
"""Do cartesian product of the given sequence of tensors. The behavior is similar to
@@ -890,7 +930,7 @@
if not torch.jit.is_scripting():
if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
return handle_torch_function(cartesian_prod, tensors, *tensors)
- return _VF.cartesian_prod(tensors)
+ return _VF.cartesian_prod(tensors) # type: ignore
def block_diag(*tensors):
"""Create a block diagonal matrix from provided tensors.
@@ -924,7 +964,7 @@
"""
if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
return handle_torch_function(block_diag, tensors, *tensors)
- return torch._C._VariableFunctions.block_diag(tensors)
+ return torch._C._VariableFunctions.block_diag(tensors) # type: ignore
def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
@@ -974,11 +1014,11 @@
return handle_torch_function(
cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
- return _VF.cdist(x1, x2, p, None)
+ return _VF.cdist(x1, x2, p, None) # type: ignore
elif compute_mode == 'use_mm_for_euclid_dist':
- return _VF.cdist(x1, x2, p, 1)
+ return _VF.cdist(x1, x2, p, 1) # type: ignore
elif compute_mode == 'donot_use_mm_for_euclid_dist':
- return _VF.cdist(x1, x2, p, 2)
+ return _VF.cdist(x1, x2, p, 2) # type: ignore
else:
raise ValueError("{} is not a valid value for compute_mode".format(compute_mode))
@@ -1014,7 +1054,7 @@
return handle_torch_function(atleast_1d, tensors, *tensors)
if len(tensors) == 1:
tensors = tensors[0]
- return _VF.atleast_1d(tensors)
+ return _VF.atleast_1d(tensors) # type: ignore
def atleast_2d(*tensors):
r"""
@@ -1049,7 +1089,7 @@
return handle_torch_function(atleast_2d, tensors, *tensors)
if len(tensors) == 1:
tensors = tensors[0]
- return _VF.atleast_2d(tensors)
+ return _VF.atleast_2d(tensors) # type: ignore
def atleast_3d(*tensors):
r"""
@@ -1093,28 +1133,44 @@
return handle_torch_function(atleast_3d, tensors, *tensors)
if len(tensors) == 1:
tensors = tensors[0]
- return _VF.atleast_3d(tensors)
+ return _VF.atleast_3d(tensors) # type: ignore
-# TODO: type dim as BroadcastingList when https://github.com/pytorch/pytorch/issues/33782 is fixed
-@overload # noqa: 749
-def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: 749
- # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
- pass
-@overload # noqa: 749
-def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: 749
- # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
+if TYPE_CHECKING:
pass
+ # There's no good way to use this type annotation; cannot rename norm() to
+ # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped
+ # for mypy for now.
+ # def norm(input: Tensor,
+ # p: Optional[Union[str, Number]] = "fro",
+ # dim: Optional[Union[int, List[int]]] = None,
+ # keepdim: bool = False,
+ # out: Optional[Tensor] = None,
+ # dtype: _dtype = None) -> Tensor:
+ # return _norm_impl(input, p, dim, keepdim, out, dtype)
+else:
+ # TODO: type dim as BroadcastingList when
+ # https://github.com/pytorch/pytorch/issues/33782 is fixed
+ @overload # noqa: 749
+ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: 749
+ # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
+ pass
-@overload # noqa: 749
-def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: 749
- # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
- pass
+ @overload # noqa: 749
+ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: 749
+ # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
+ pass
-@overload # noqa: 749
-def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: 749
- # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
- pass
+ @overload # noqa: 749
+ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: 749
+ # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
+ pass
+
+ @overload # noqa: 749
+ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: 749
+ # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
+ pass
+
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: 749
r"""Returns the matrix norm or vector norm of a given tensor.
@@ -1183,15 +1239,14 @@
ndim = input.dim()
-
# catch default case
if dim is None and out is None and dtype is None and p is not None:
if isinstance(p, str):
if p == "fro":
- return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
+ return _VF.frobenius_norm(input, dim=(), keepdim=keepdim) # type: ignore
if not isinstance(p, str):
_dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m))
- return _VF.norm(input, p, dim=_dim, keepdim=keepdim)
+ return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore
# TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
# remove the overloads where dim is an int and replace with BraodcastingList1
@@ -1202,7 +1257,7 @@
else:
_dim = dim
else:
- _dim = None
+ _dim = None # type: ignore
if isinstance(p, str):
if p == "fro":
@@ -1212,22 +1267,22 @@
if _dim is None:
_dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m))
if out is None:
- return _VF.frobenius_norm(input, _dim, keepdim=keepdim)
+ return _VF.frobenius_norm(input, _dim, keepdim=keepdim) # type: ignore
else:
- return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)
+ return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore
elif p == "nuc":
if dtype is not None:
raise ValueError("dtype argument is not supported in nuclear norm")
if _dim is None:
if out is None:
- return _VF.nuclear_norm(input, keepdim=keepdim)
+ return _VF.nuclear_norm(input, keepdim=keepdim) # type: ignore
else:
- return _VF.nuclear_norm(input, keepdim=keepdim, out=out)
+ return _VF.nuclear_norm(input, keepdim=keepdim, out=out) # type: ignore
else:
if out is None:
- return _VF.nuclear_norm(input, _dim, keepdim=keepdim)
+ return _VF.nuclear_norm(input, _dim, keepdim=keepdim) # type: ignore
else:
- return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)
+ return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore
raise RuntimeError("only valid string values are 'fro' and 'nuc', found {}".format(p))
else:
if _dim is None:
@@ -1235,14 +1290,14 @@
if out is None:
if dtype is None:
- return _VF.norm(input, p, _dim, keepdim=keepdim)
+ return _VF.norm(input, p, _dim, keepdim=keepdim) # type: ignore
else:
- return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)
+ return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype) # type: ignore
else:
if dtype is None:
- return _VF.norm(input, p, _dim, keepdim=keepdim, out=out)
+ return _VF.norm(input, p, _dim, keepdim=keepdim, out=out) # type: ignore
else:
- return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)
+ return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore
def chain_matmul(*matrices):
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
@@ -1276,7 +1331,7 @@
if not torch.jit.is_scripting():
if any(type(t) is not Tensor for t in matrices) and has_torch_function(matrices):
return handle_torch_function(chain_matmul, matrices, *matrices)
- return _VF.chain_matmul(matrices)
+ return _VF.chain_matmul(matrices) # type: ignore
def _lu_impl(A, pivot=True, get_infos=False, out=None):
@@ -1353,12 +1408,17 @@
# If get_infos is True, then we don't need to check for errors and vice versa
return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
-def _check_list_size(out_len, get_infos, out):
- # type: (int, bool, List[Tensor]) -> None
+
+if TYPE_CHECKING:
+ _ListOrSeq = Sequence[Tensor]
+else:
+ _ListOrSeq = List[Tensor]
+
+def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
get_infos_int = 1 if get_infos else 0
if out_len - get_infos_int != 2:
raise TypeError("expected tuple of {} elements but got {}"
- .format(2 + int(get_infos), len(out_len)))
+ .format(2 + int(get_infos), out_len))
if not isinstance(out, (tuple, list)):
raise TypeError("argument 'out' must be tuple of Tensors, not {}"
.format(type(out).__name__))
diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py
index 73dd008..18dc666 100644
--- a/torch/jit/annotations.py
+++ b/torch/jit/annotations.py
@@ -175,7 +175,7 @@
lines_with_type = list(filter(lambda line: 'type' in line[1], lines))
if len(type_lines) == 0:
- type_pattern = re.compile('#[\t ]*type[\t ]*:')
+ type_pattern = re.compile('#[\t ]*type[\t ]*(?!: ignore$):')
wrong_type_lines = list(filter(lambda line: type_pattern.search(line[1]), lines))
if len(wrong_type_lines) > 0:
raise RuntimeError("The annotation prefix in line " + str(wrong_type_lines[0][0])
diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in
index 40bf9e8..a3db0d1 100644
--- a/torch/nn/functional.pyi.in
+++ b/torch/nn/functional.pyi.in
@@ -1,6 +1,6 @@
from torch import Tensor
from torch.types import _size
-from typing import Any, Optional, Tuple, Dict, List, Callable
+from typing import Any, Optional, Tuple, Dict, List, Callable, Sequence
from .common_types import _ratio_any_t
# 'TypedDict' is a new accepted type that represents a dictionary with a fixed set of allowed keys.
@@ -300,7 +300,7 @@
def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[Any] = ...) -> Tensor: ...
-def pad(input: Tensor, pad: List[int], mode: str = ..., value: float = ...) -> Tensor: ...
+def pad(input: Tensor, pad: Sequence[int], mode: str = ..., value: float = ...) -> Tensor: ...
def pairwise_distance(x1: Tensor, x2: Tensor, p: float = ..., eps: float = ..., keepdim: bool = ...) -> Tensor: ...
diff --git a/torch/overrides.py b/torch/overrides.py
index a0a44ee..3d2dadd 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -180,6 +180,7 @@
Tensor.unflatten,
}
+
@functools.lru_cache(None)
def get_testing_overrides() -> Dict[Callable, Callable]:
"""Return a dict containing dummy overrides for all overridable functions
diff --git a/torch/quantization/_equalize.py b/torch/quantization/_equalize.py
index a2ed616..51cbe07 100644
--- a/torch/quantization/_equalize.py
+++ b/torch/quantization/_equalize.py
@@ -127,10 +127,10 @@
if curr_modules.keys() != prev_modules.keys():
raise ValueError("The keys to the given mappings must have the same set of names of modules")
- summed_norms = 0
+ summed_norms = torch.tensor(0.)
if None in prev_modules.values():
return False
for name in curr_modules.keys():
difference = curr_modules[name].weight.sub(prev_modules[name].weight)
summed_norms += torch.norm(difference)
- return summed_norms < threshold
+ return bool(summed_norms < threshold)