Add `pad_sequence` as a native function (#57868)

Summary:
https://github.com/pytorch/pytorch/issues/56229

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57868

Reviewed By: mruberry

Differential Revision: D28334174

Pulled By: ngimel

fbshipit-source-id: f1647718ada596686117703b682c0af7e92e16f5
diff --git a/aten/src/ATen/native/PackedSequence.cpp b/aten/src/ATen/native/PackedSequence.cpp
index 34da156..798672c 100644
--- a/aten/src/ATen/native/PackedSequence.cpp
+++ b/aten/src/ATen/native/PackedSequence.cpp
@@ -184,4 +184,39 @@
   return std::make_tuple(output, lengths_t);
 }
 
+Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value) {
+  const int64_t sequences_size = sequences.size();
+  TORCH_CHECK(sequences_size > 0, "received an empty list of sequences");
+  IntArrayRef max_size = sequences[0].sizes();
+  IntArrayRef trailing_dims = max_size.slice(1);
+  int64_t max_len = std::max_element(
+    sequences.begin(),
+    sequences.end(),
+    [](const Tensor &a, const Tensor &b) {
+      return a.size(0) < b.size(0);
+    }
+  )->size(0);
+
+  DimVector out_dims;
+  if (batch_first) {
+    out_dims = {sequences_size, max_len};
+  } else {
+    out_dims = {max_len, sequences_size};
+  }
+  out_dims.insert(out_dims.end(), trailing_dims.begin(), trailing_dims.end());
+
+  Tensor out = at::full(out_dims, padding_value, sequences[0].options());
+  for (int64_t i = 0; i < sequences_size; i++) {
+    const Tensor currseq = sequences[i];
+    const int64_t length_i = currseq.size(0);
+    // use index notation to prevent duplicate references to the tensor
+    if (batch_first) {
+      out.select(0, i).narrow(0, 0, length_i).copy_(currseq);
+    } else {
+      out.narrow(0, 0, length_i).select(1, i).copy_(currseq);
+    }
+  }
+  return out;
+}
+
 }} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 543910b..4279e2f 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9820,3 +9820,7 @@
   variants: function
   dispatch:
     CPU, CUDA: segment_reduce_backward_kernel
+
+- func: pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0) -> Tensor
+  python_module: nn
+  variants: function
diff --git a/torch/_C/_nn.pyi.in b/torch/_C/_nn.pyi.in
index 518fbac..0f00a3d 100644
--- a/torch/_C/_nn.pyi.in
+++ b/torch/_C/_nn.pyi.in
@@ -23,3 +23,7 @@
 @overload
 def _parse_to(tensor: Tensor, non_blocking: _bool, copy: _bool, *,
               memory_format: memory_format) -> Tuple[_device, _dtype, _bool, memory_format]: ...
+
+# Defined in aten/src/ATen/naitve/PadSequence.cpp
+def pad_sequence(sequences: List[Tensor], batch_first: bool = False,
+                 padding_value: float = ...) -> Tensor: ...
diff --git a/torch/csrc/api/include/torch/nn/utils/rnn.h b/torch/csrc/api/include/torch/nn/utils/rnn.h
index ab67f42..19b0fce 100644
--- a/torch/csrc/api/include/torch/nn/utils/rnn.h
+++ b/torch/csrc/api/include/torch/nn/utils/rnn.h
@@ -280,38 +280,7 @@
     ArrayRef<Tensor> sequences,
     bool batch_first = false,
     double padding_value = 0) {
-  // assuming trailing dimensions and type of all the Tensors
-  // in sequences are same and fetching those from sequences[0]
-  auto max_size = sequences[0].sizes();
-  auto trailing_dims = max_size.slice(1);
-  auto max_len = std::max_element(
-    sequences.begin(),
-    sequences.end(),
-    [](const Tensor& a, const Tensor& b) {
-      return a.size(0) < b.size(0);
-    }
-  )->size(0);
-
-  std::vector<int64_t> out_dims;
-  if (batch_first) {
-    out_dims = {(int64_t)sequences.size(), max_len};
-  } else {
-    out_dims = {max_len, (int64_t)sequences.size()};
-  }
-  out_dims.insert(out_dims.end(), trailing_dims.begin(), trailing_dims.end());
-
-  auto out_tensor = torch::full({out_dims}, padding_value, sequences[0].options());
-  for (size_t i = 0; i < sequences.size(); i++) {
-    auto tensor = sequences[i];
-    int64_t length = tensor.size(0);
-    // use index notation to prevent duplicate references to the tensor
-    if (batch_first) {
-      out_tensor.select(0, i).narrow(0, 0, length).copy_(tensor);
-    } else {
-      out_tensor.narrow(0, 0, length).select(1, i).copy_(tensor);
-    }
-  }
-  return out_tensor;
+  return at::pad_sequence(sequences, batch_first, padding_value);
 }
 
 /// Packs a list of variable length Tensors
@@ -338,7 +307,7 @@
     lengths[i] = sequences[i].size(0);
   }
   return pack_padded_sequence(
-    pad_sequence(sequences), lengths, /*batch_first=*/false, /*enforce_sorted=*/enforce_sorted);
+    at::pad_sequence(sequences), lengths, /*batch_first=*/false, /*enforce_sorted=*/enforce_sorted);
 }
 
 } // namespace rnn
diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py
index 5592385..affddfc 100644
--- a/torch/nn/utils/rnn.py
+++ b/torch/nn/utils/rnn.py
@@ -2,11 +2,15 @@
 import warnings
 
 import torch
+from torch import Tensor
 from ... import _VF
 from ..._jit_internal import Optional
 
+from typing import List, Tuple
 
-PackedSequence_ = namedtuple('PackedSequence',
+
+
+PackedSequence_ = namedtuple('PackedSequence_',
                              ['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])
 
 # type annotation for PackedSequence_ to make it compatible with TorchScript
@@ -356,24 +360,7 @@
 
     # assuming trailing dimensions and type of all the Tensors
     # in sequences are same and fetching those from sequences[0]
-    max_size = sequences[0].size()
-    trailing_dims = max_size[1:]
-    max_len = max([s.size(0) for s in sequences])
-    if batch_first:
-        out_dims = (len(sequences), max_len) + trailing_dims
-    else:
-        out_dims = (max_len, len(sequences)) + trailing_dims
-
-    out_tensor = sequences[0].new_full(out_dims, padding_value)
-    for i, tensor in enumerate(sequences):
-        length = tensor.size(0)
-        # use index notation to prevent duplicate references to the tensor
-        if batch_first:
-            out_tensor[i, :length, ...] = tensor
-        else:
-            out_tensor[:length, i, ...] = tensor
-
-    return out_tensor
+    return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
 
 
 def pack_sequence(sequences, enforce_sorted=True):
diff --git a/torch/nn/utils/rnn.pyi b/torch/nn/utils/rnn.pyi
index 50b6d6e..2c1c6c9 100644
--- a/torch/nn/utils/rnn.pyi
+++ b/torch/nn/utils/rnn.pyi
@@ -1,5 +1,5 @@
 from collections import namedtuple
-from typing import Any, Optional, overload, Union, TypeVar, Tuple, Sequence
+from typing import Any, List, Optional, overload, Union, TypeVar, Tuple, Sequence
 from torch import Tensor
 from torch.types import _dtype, _device
 
@@ -65,7 +65,7 @@
                         total_length: Optional[int] = ...) -> Tuple[Tensor, ...]: ...
 
 
-def pad_sequence(sequences: Sequence[Tensor], batch_first: bool = ..., padding_value: int = ...) -> Tensor: ...
+def pad_sequence(sequences: List[Tensor], batch_first: bool = False, padding_value: float = ...) -> Tensor: ...
 
 
 def pack_sequence(sequences: Sequence[Tensor], enforce_sorted: bool = ...) -> PackedSequence: ...