Add `pad_sequence` as a native function (#57868)
Summary:
https://github.com/pytorch/pytorch/issues/56229
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57868
Reviewed By: mruberry
Differential Revision: D28334174
Pulled By: ngimel
fbshipit-source-id: f1647718ada596686117703b682c0af7e92e16f5
diff --git a/aten/src/ATen/native/PackedSequence.cpp b/aten/src/ATen/native/PackedSequence.cpp
index 34da156..798672c 100644
--- a/aten/src/ATen/native/PackedSequence.cpp
+++ b/aten/src/ATen/native/PackedSequence.cpp
@@ -184,4 +184,39 @@
return std::make_tuple(output, lengths_t);
}
+Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value) {
+ const int64_t sequences_size = sequences.size();
+ TORCH_CHECK(sequences_size > 0, "received an empty list of sequences");
+ IntArrayRef max_size = sequences[0].sizes();
+ IntArrayRef trailing_dims = max_size.slice(1);
+ int64_t max_len = std::max_element(
+ sequences.begin(),
+ sequences.end(),
+ [](const Tensor &a, const Tensor &b) {
+ return a.size(0) < b.size(0);
+ }
+ )->size(0);
+
+ DimVector out_dims;
+ if (batch_first) {
+ out_dims = {sequences_size, max_len};
+ } else {
+ out_dims = {max_len, sequences_size};
+ }
+ out_dims.insert(out_dims.end(), trailing_dims.begin(), trailing_dims.end());
+
+ Tensor out = at::full(out_dims, padding_value, sequences[0].options());
+ for (int64_t i = 0; i < sequences_size; i++) {
+ const Tensor currseq = sequences[i];
+ const int64_t length_i = currseq.size(0);
+ // use index notation to prevent duplicate references to the tensor
+ if (batch_first) {
+ out.select(0, i).narrow(0, 0, length_i).copy_(currseq);
+ } else {
+ out.narrow(0, 0, length_i).select(1, i).copy_(currseq);
+ }
+ }
+ return out;
+}
+
}} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 543910b..4279e2f 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9820,3 +9820,7 @@
variants: function
dispatch:
CPU, CUDA: segment_reduce_backward_kernel
+
+- func: pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0) -> Tensor
+ python_module: nn
+ variants: function
diff --git a/torch/_C/_nn.pyi.in b/torch/_C/_nn.pyi.in
index 518fbac..0f00a3d 100644
--- a/torch/_C/_nn.pyi.in
+++ b/torch/_C/_nn.pyi.in
@@ -23,3 +23,7 @@
@overload
def _parse_to(tensor: Tensor, non_blocking: _bool, copy: _bool, *,
memory_format: memory_format) -> Tuple[_device, _dtype, _bool, memory_format]: ...
+
+# Defined in aten/src/ATen/naitve/PadSequence.cpp
+def pad_sequence(sequences: List[Tensor], batch_first: bool = False,
+ padding_value: float = ...) -> Tensor: ...
diff --git a/torch/csrc/api/include/torch/nn/utils/rnn.h b/torch/csrc/api/include/torch/nn/utils/rnn.h
index ab67f42..19b0fce 100644
--- a/torch/csrc/api/include/torch/nn/utils/rnn.h
+++ b/torch/csrc/api/include/torch/nn/utils/rnn.h
@@ -280,38 +280,7 @@
ArrayRef<Tensor> sequences,
bool batch_first = false,
double padding_value = 0) {
- // assuming trailing dimensions and type of all the Tensors
- // in sequences are same and fetching those from sequences[0]
- auto max_size = sequences[0].sizes();
- auto trailing_dims = max_size.slice(1);
- auto max_len = std::max_element(
- sequences.begin(),
- sequences.end(),
- [](const Tensor& a, const Tensor& b) {
- return a.size(0) < b.size(0);
- }
- )->size(0);
-
- std::vector<int64_t> out_dims;
- if (batch_first) {
- out_dims = {(int64_t)sequences.size(), max_len};
- } else {
- out_dims = {max_len, (int64_t)sequences.size()};
- }
- out_dims.insert(out_dims.end(), trailing_dims.begin(), trailing_dims.end());
-
- auto out_tensor = torch::full({out_dims}, padding_value, sequences[0].options());
- for (size_t i = 0; i < sequences.size(); i++) {
- auto tensor = sequences[i];
- int64_t length = tensor.size(0);
- // use index notation to prevent duplicate references to the tensor
- if (batch_first) {
- out_tensor.select(0, i).narrow(0, 0, length).copy_(tensor);
- } else {
- out_tensor.narrow(0, 0, length).select(1, i).copy_(tensor);
- }
- }
- return out_tensor;
+ return at::pad_sequence(sequences, batch_first, padding_value);
}
/// Packs a list of variable length Tensors
@@ -338,7 +307,7 @@
lengths[i] = sequences[i].size(0);
}
return pack_padded_sequence(
- pad_sequence(sequences), lengths, /*batch_first=*/false, /*enforce_sorted=*/enforce_sorted);
+ at::pad_sequence(sequences), lengths, /*batch_first=*/false, /*enforce_sorted=*/enforce_sorted);
}
} // namespace rnn
diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py
index 5592385..affddfc 100644
--- a/torch/nn/utils/rnn.py
+++ b/torch/nn/utils/rnn.py
@@ -2,11 +2,15 @@
import warnings
import torch
+from torch import Tensor
from ... import _VF
from ..._jit_internal import Optional
+from typing import List, Tuple
-PackedSequence_ = namedtuple('PackedSequence',
+
+
+PackedSequence_ = namedtuple('PackedSequence_',
['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])
# type annotation for PackedSequence_ to make it compatible with TorchScript
@@ -356,24 +360,7 @@
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
- max_size = sequences[0].size()
- trailing_dims = max_size[1:]
- max_len = max([s.size(0) for s in sequences])
- if batch_first:
- out_dims = (len(sequences), max_len) + trailing_dims
- else:
- out_dims = (max_len, len(sequences)) + trailing_dims
-
- out_tensor = sequences[0].new_full(out_dims, padding_value)
- for i, tensor in enumerate(sequences):
- length = tensor.size(0)
- # use index notation to prevent duplicate references to the tensor
- if batch_first:
- out_tensor[i, :length, ...] = tensor
- else:
- out_tensor[:length, i, ...] = tensor
-
- return out_tensor
+ return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
def pack_sequence(sequences, enforce_sorted=True):
diff --git a/torch/nn/utils/rnn.pyi b/torch/nn/utils/rnn.pyi
index 50b6d6e..2c1c6c9 100644
--- a/torch/nn/utils/rnn.pyi
+++ b/torch/nn/utils/rnn.pyi
@@ -1,5 +1,5 @@
from collections import namedtuple
-from typing import Any, Optional, overload, Union, TypeVar, Tuple, Sequence
+from typing import Any, List, Optional, overload, Union, TypeVar, Tuple, Sequence
from torch import Tensor
from torch.types import _dtype, _device
@@ -65,7 +65,7 @@
total_length: Optional[int] = ...) -> Tuple[Tensor, ...]: ...
-def pad_sequence(sequences: Sequence[Tensor], batch_first: bool = ..., padding_value: int = ...) -> Tensor: ...
+def pad_sequence(sequences: List[Tensor], batch_first: bool = False, padding_value: float = ...) -> Tensor: ...
def pack_sequence(sequences: Sequence[Tensor], enforce_sorted: bool = ...) -> PackedSequence: ...