Raise type error message for `interpolate` if `size` contains non-integer elements (#99243)
Raise type error message for interpolate when output size is a tuple containing elements that are not `int`
Fixes #98287
Check is only performed if `size` is an instance of `list` or `tuple`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99243
Approved by: https://github.com/Skylion007, https://github.com/Neilblaze, https://github.com/MovsisyanM, https://github.com/albanD
diff --git a/test/test_nn.py b/test/test_nn.py
index 72bd45e..fe90348 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -6798,6 +6798,14 @@
def test_interpolate(self):
+ def _test_interpolate_non_integer_size_warning(in_t, out_size, dim, **kwargs):
+ test_sizes = [float(out_size),
+ torch.tensor(out_size, dtype=torch.float)]
+ for size in test_sizes:
+ self.assertRaisesRegex(TypeError,
+ "(expected size to be one of int or).*",
+ F.interpolate, in_t, size=(size,) * dim, **kwargs)
+
def _test_interpolate_helper(in_t, scale_factor, layer):
out_size = int(math.floor(in_t.shape[-1] * scale_factor))
dim = len(in_t.shape) - 2
@@ -6811,6 +6819,7 @@
F.interpolate(in_t, scale_factor=scale_factor, **kwargs))
gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [in_t], nondet_tol=GRADCHECK_NONDET_TOL)
gradgradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [in_t], nondet_tol=GRADCHECK_NONDET_TOL)
+ _test_interpolate_non_integer_size_warning(in_t, out_size, dim, **kwargs)
def _make_input(dim, device):
size = [1, 1]
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index a2cc4a0..d094b6c 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -3784,6 +3784,15 @@
upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes)
+def _is_integer(x) -> bool:
+ r"""Type check the input number is an integer.
+ Will return True for int, SymInt and Tensors with integer elements.
+ """
+ if isinstance(x, (int, torch.SymInt)):
+ return True
+ return isinstance(x, Tensor) and not x.is_floating_point()
+
+
@_overload # noqa: F811
def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950
pass
@@ -3913,8 +3922,13 @@
f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. "
"Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
"output size in (o1, o2, ...,oK) format."
-
)
+ if not torch.jit.is_scripting():
+ if not all(_is_integer(x) for x in size):
+ raise TypeError(
+ "expected size to be one of int or Tuple[int] or Tuple[int, int] or "
+ f"Tuple[int, int, int], but got size with types {[type(x) for x in size]}"
+ )
output_size = size
else:
output_size = [size for _ in range(dim)]