pad_sequence: fix regression - support tensor (#72436)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/71365

Based on https://github.com/pytorch/pytorch/pull/72343

Thanks jbschlosser

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72436

Reviewed By: bdhirsh

Differential Revision: D34117724

Pulled By: jbschlosser

fbshipit-source-id: e5d6599d0791025e18ab36ae16c417a91554bf64
(cherry picked from commit ffe8a0e41b7906920e392a9588347215ac44f46f)
diff --git a/test/test_nn.py b/test/test_nn.py
index 28f44c9..fb7a172 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -138,6 +138,21 @@
             RuntimeError,
             lambda: rnn_utils.pack_padded_sequence(b_a, [22, 25], enforce_sorted=True))
 
+    def test_pad_sequence_with_tensor_sequences(self):
+        seq_tuple_input = torch.nn.utils.rnn.pad_sequence(
+            (torch.tensor([[7, 6]]), torch.tensor([[-7, -1]]))
+        )
+        seq_tensor_input = torch.nn.utils.rnn.pad_sequence(
+            torch.tensor([[[7, 6]], [[-7, -1]]])
+        )
+        self.assertEqual(seq_tuple_input, seq_tensor_input)
+        self.assertEqual(seq_tuple_input.shape, torch.Size([1, 2, 2]))
+
+    def test_pad_sequence_with_non_iterable_sequences(self):
+        msg = r"Expected iterable for input sequences, but got arg of type"
+        with self.assertRaisesRegex(RuntimeError, msg):
+            torch.nn.utils.rnn.pad_sequence(5)
+
     def test_total_length(self):
         padded, lengths = self._padded_sequence(torch.FloatTensor)
         max_length = max(lengths)
diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py
index 981e8bb..9999f23 100644
--- a/torch/nn/utils/rnn.py
+++ b/torch/nn/utils/rnn.py
@@ -6,7 +6,7 @@
 from ... import _VF
 from ..._jit_internal import Optional
 
-from typing import List, Tuple
+from typing import List, Tuple, Union, Iterable
 
 
 
@@ -321,7 +321,7 @@
 
 
 def pad_sequence(sequences, batch_first=False, padding_value=0.0):
-    # type: (List[Tensor], bool, float) -> Tensor
+    # type: (Union[List[Tensor], Tensor], bool, float) -> Tensor
     r"""Pad a list of variable length Tensors with ``padding_value``
 
     ``pad_sequence`` stacks a list of Tensors along a new dimension,
@@ -358,6 +358,21 @@
         Tensor of size ``B x T x *`` otherwise
     """
 
+    if not (torch.jit.is_tracing() or torch.jit.is_scripting()):
+        # JIT doesn't support `Iterable`
+        if not isinstance(sequences, Iterable):
+            msg = ('pad_sequence: Expected iterable for input sequences, but got arg of type: '
+                   f'{type(sequences)}')
+            raise RuntimeError(msg)
+
+        # In JIT context this leads to,
+        # RuntimeError: cannot statically infer the expected size of a list in this context
+        sequences = tuple(sequences)
+    else:
+        # For JIT, we only support Union[Tensor, Tuple[Tensor]]
+        if isinstance(sequences, torch.Tensor):
+            sequences = sequences.unbind(0)
+
     # assuming trailing dimensions and type of all the Tensors
     # in sequences are same and fetching those from sequences[0]
     return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)