blob: 4de8550b37a87d0dc178678576686193fb59ad63 [file] [log] [blame]
from collections import namedtuple
import torch
from torch.autograd import Variable
PackedSequence_ = namedtuple('PackedSequence', ['data', 'batch_sizes'])
class PackedSequence(PackedSequence_):
r"""Holds the data and list of batch_sizes of a packed sequence.
All RNN modules accept packed sequences as inputs.
Note:
Instances of this class should never be created manually. They are meant
to be instantiated by functions like :func:`pack_padded_sequence`.
Batch sizes represent the number elements at each sequence step in
the batch, not the varying sequence lengths passed to
:func:`pack_padded_sequence`. For instance, given data ``abc`` and `d`
the ``PackedSequence`` would be ``adbc`` with ``batch_sizes=[2,1,1]``.
Attributes:
data (Variable): Variable containing packed sequence
batch_sizes (list[int]): list of integers holding information about
the batch size at each sequence step
"""
pass
def pack_padded_sequence(input, lengths, batch_first=False):
r"""Packs a Variable containing padded sequences of variable length.
Input can be of size ``TxBx*`` where T is the length of the longest sequence
(equal to ``lengths[0]``), B is the batch size, and * is any number of
dimensions (including 0). If ``batch_first`` is True ``BxTx*`` inputs are
expected.
The sequences should be sorted by length in a decreasing order, i.e.
``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the
shortest one.
Note:
This function accept any input that has at least two dimensions. You
can apply it to pack the labels, and use the output of the RNN with
them to compute the loss directly. A Variable can be retrieved from
a :class:`PackedSequence` object by accessing its ``.data`` attribute.
Arguments:
input (Variable): padded batch of variable length sequences.
lengths (list[int]): list of sequences lengths of each batch element.
batch_first (bool, optional): if ``True``, the input is expected in BxTx*
format.
Returns:
a :class:`PackedSequence` object
"""
if lengths[-1] <= 0:
raise ValueError("length of all samples has to be greater than 0, "
"but found an element in 'lengths' that is <=0")
if batch_first:
input = input.transpose(0, 1)
steps = []
batch_sizes = []
lengths_iter = reversed(lengths)
batch_size = input.size(1)
if len(lengths) != batch_size:
raise ValueError("lengths array has incorrect size")
prev_l = 0
for i, l in enumerate(lengths_iter):
if l > prev_l:
c_batch_size = batch_size - i
steps.append(input[prev_l:l, :c_batch_size].contiguous().view(-1, *input.size()[2:]))
batch_sizes.extend([c_batch_size] * (l - prev_l))
prev_l = l
elif prev_l > l: # remember that new_length is the preceding length in the array
raise ValueError("lengths array has to be sorted in decreasing order")
return PackedSequence(torch.cat(steps), batch_sizes)
def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0):
r"""Pads a packed batch of variable length sequences.
It is an inverse operation to :func:`pack_padded_sequence`.
The returned Variable's data will be of size TxBx*, where T is the length
of the longest sequence and B is the batch size. If ``batch_first`` is True,
the data will be transposed into BxTx* format.
Batch elements will be ordered decreasingly by their length.
Arguments:
sequence (PackedSequence): batch to pad
batch_first (bool, optional): if ``True``, the output will be in BxTx*
format.
padding_value (float, optional): values for padded elements
Returns:
Tuple of Variable containing the padded sequence, and a list of lengths
of each sequence in the batch.
"""
var_data, batch_sizes = sequence
max_batch_size = batch_sizes[0]
output = var_data.data.new(len(batch_sizes), max_batch_size, *var_data.size()[1:]).fill_(padding_value)
output = Variable(output)
lengths = []
data_offset = 0
prev_batch_size = batch_sizes[0]
prev_i = 0
for i, batch_size in enumerate(batch_sizes + [0]):
if batch_size != prev_batch_size:
l = prev_batch_size * (i - prev_i)
tmp = var_data[data_offset:data_offset + l]
output[prev_i:i, :prev_batch_size] = tmp.view(i - prev_i, prev_batch_size, *tmp.size()[1:])
data_offset += l
prev_i = i
dec = prev_batch_size - batch_size
if dec > 0:
lengths.extend((i,) * dec)
prev_batch_size = batch_size
lengths.reverse()
if batch_first:
output = output.transpose(0, 1)
return output, lengths
def pad_sequence(sequences, batch_first=False):
r"""Pad a list of variable length Variables with zero
``pad_sequence`` stacks a list of Variables along a new dimension,
and padds them to equal length. For example, if the input is list of
sequences with size ``Lx*`` and if batch_first is False, and ``TxBx*``
otherwise. The list of sequences should be sorted in the order of
decreasing length.
B is batch size. It's equal to the number of elements in ``sequences``.
T is length longest sequence.
L is length of the sequence.
* is any number of trailing dimensions, including none.
Example:
>>> from torch.nn.utils.rnn import pad_sequence
>>> a = Variable(torch.ones(25, 300))
>>> b = Variable(torch.ones(22, 300))
>>> c = Variable(torch.ones(15, 300))
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300])
Note:
This function returns a Variable of size TxBx* or BxTx* where T is the
length of longest sequence.
Function assumes trailing dimensions and type of all the Variables
in sequences are same.
Arguments:
sequences (list[Variable]): list of variable length sequences.
batch_first (bool, optional): output will be in BxTx* if True, or in
TxBx* otherwise
Returns:
Variable of size ``T x B x * `` if batch_first is False
Variable of size ``B x T x * `` otherwise
"""
# assuming trailing dimensions and type of all the Variables
# in sequences are same and fetching those from sequences[0]
max_size = sequences[0].size()
max_len, trailing_dims = max_size[0], max_size[1:]
prev_l = max_len
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_variable = Variable(sequences[0].data.new(*out_dims).zero_())
for i, variable in enumerate(sequences):
length = variable.size(0)
# temporary sort check, can be removed when we handle sorting internally
if prev_l < length:
raise ValueError("lengths array has to be sorted in decreasing order")
prev_l = length
# use index notation to prevent duplicate references to the variable
if batch_first:
out_variable[i, :length, ...] = variable
else:
out_variable[:length, i, ...] = variable
return out_variable
def pack_sequence(sequences):
r"""Packs a list of variable length Variables
``sequences`` should be a list of Variables of size ``Lx*``, where L is
the length of a sequence and * is any number of trailing dimensions,
including zero. They should be sorted in the order of decreasing length.
Example:
>>> from torch.nn.utils.rnn import pack_sequence
>>> a = Variable(torch.Tensor([1,2,3]))
>>> b = Variable(torch.Tensor([4,5]))
>>> c = Variable(torch.Tensor([6]))
>>> pack_sequence([a, b, c]])
PackedSequence(data=
1
4
6
2
5
3
[torch.FloatTensor of size 6]
, batch_sizes=[3, 2, 1])
Arguments:
sequences (list[Variable]): A list of sequences of decreasing length.
Returns:
a :class:`PackedSequence` object
"""
return pack_padded_sequence(pad_sequence(sequences), [v.size(0) for v in sequences])