pad_sequence: fix regression - support tensor (#72436)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/71365
Based on https://github.com/pytorch/pytorch/pull/72343
Thanks jbschlosser
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72436
Reviewed By: bdhirsh
Differential Revision: D34117724
Pulled By: jbschlosser
fbshipit-source-id: e5d6599d0791025e18ab36ae16c417a91554bf64
(cherry picked from commit ffe8a0e41b7906920e392a9588347215ac44f46f)
diff --git a/test/test_nn.py b/test/test_nn.py
index 28f44c9..fb7a172 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -138,6 +138,21 @@
RuntimeError,
lambda: rnn_utils.pack_padded_sequence(b_a, [22, 25], enforce_sorted=True))
+ def test_pad_sequence_with_tensor_sequences(self):
+ seq_tuple_input = torch.nn.utils.rnn.pad_sequence(
+ (torch.tensor([[7, 6]]), torch.tensor([[-7, -1]]))
+ )
+ seq_tensor_input = torch.nn.utils.rnn.pad_sequence(
+ torch.tensor([[[7, 6]], [[-7, -1]]])
+ )
+ self.assertEqual(seq_tuple_input, seq_tensor_input)
+ self.assertEqual(seq_tuple_input.shape, torch.Size([1, 2, 2]))
+
+ def test_pad_sequence_with_non_iterable_sequences(self):
+ msg = r"Expected iterable for input sequences, but got arg of type"
+ with self.assertRaisesRegex(RuntimeError, msg):
+ torch.nn.utils.rnn.pad_sequence(5)
+
def test_total_length(self):
padded, lengths = self._padded_sequence(torch.FloatTensor)
max_length = max(lengths)
diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py
index 981e8bb..9999f23 100644
--- a/torch/nn/utils/rnn.py
+++ b/torch/nn/utils/rnn.py
@@ -6,7 +6,7 @@
from ... import _VF
from ..._jit_internal import Optional
-from typing import List, Tuple
+from typing import List, Tuple, Union, Iterable
@@ -321,7 +321,7 @@
def pad_sequence(sequences, batch_first=False, padding_value=0.0):
- # type: (List[Tensor], bool, float) -> Tensor
+ # type: (Union[List[Tensor], Tensor], bool, float) -> Tensor
r"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
@@ -358,6 +358,21 @@
Tensor of size ``B x T x *`` otherwise
"""
+ if not (torch.jit.is_tracing() or torch.jit.is_scripting()):
+ # JIT doesn't support `Iterable`
+ if not isinstance(sequences, Iterable):
+ msg = ('pad_sequence: Expected iterable for input sequences, but got arg of type: '
+ f'{type(sequences)}')
+ raise RuntimeError(msg)
+
+ # In JIT context this leads to,
+ # RuntimeError: cannot statically infer the expected size of a list in this context
+ sequences = tuple(sequences)
+ else:
+ # For JIT, we only support Union[Tensor, Tuple[Tensor]]
+ if isinstance(sequences, torch.Tensor):
+ sequences = sequences.unbind(0)
+
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)