blob: 7c05a60d76684606ed3b057d08c82a0dc15d94cd [file] [log] [blame]
import torch
from ._utils import _range
def split(tensor, split_size, dim=0):
"""Splits the tensor into equally sized chunks (if possible).
Last chunk will be smaller if the tensor size along a given dimension
is not divisible by ``split_size``.
Arguments:
tensor (Tensor): tensor to split.
split_size (int): size of a single chunk.
dim (int): dimension along which to split the tensor.
"""
if dim < 0:
dim += tensor.dim()
dim_size = tensor.size(dim)
num_splits = (dim_size + split_size - 1) // split_size
last_split_size = split_size - (split_size * num_splits - dim_size)
def get_split_size(i):
return split_size if i < num_splits - 1 else last_split_size
return tuple(tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) for i
in _range(0, num_splits))
def chunk(tensor, chunks, dim=0):
"""Splits a tensor into a number of chunks along a given dimension.
Arguments:
tensor (Tensor): tensor to split.
chunks (int): number of chunks to return.
dim (int): dimension along which to split the tensor.
"""
if dim < 0:
dim += tensor.dim()
split_size = (tensor.size(dim) + chunks - 1) // chunks
return split(tensor, split_size, dim)
def stack(sequence, dim=0):
"""Concatenates sequence of tensors along a new dimension.
All tensors need to be of the same size.
Arguments:
sequence (Sequence): sequence of tensors to concatenate.
dim (int): dimension to insert. Has to be between 0 and the number
of dimensions of concatenated tensors (inclusive).
"""
if len(sequence) == 0:
raise TypeError("stack expects a non-empty sequence of tensors")
if dim < 0:
dim += sequence[0].dim()
return torch.cat(list(t.unsqueeze(dim) for t in sequence), dim)
def unbind(tensor, dim=0):
"""Removes a tensor dimension.
Returns a tuple of all slices along a given dimension, already without it.
Arguments:
tensor (Tensor): tensor to unbind.
dim (int): dimension to remove.
"""
return tuple(tensor.select(dim, i) for i in _range(tensor.size(dim)))
def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
"""Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor.
Returns a tuple indexed by:
0: The pivots.
1: The L tensor.
2: The U tensor.
Arguments:
LU_data (Tensor): The packed LU factorization data.
LU_pivots (Tensor): The packed LU factorization pivots.
unpack_data (bool): Flag indicating if the data should be unpacked.
unpack_pivots (bool): Flag indicating if the pivots should be unpacked.
"""
nBatch, sz, _ = LU_data.size()
if unpack_data:
I_U = torch.triu(torch.ones(sz, sz)).type_as(LU_data).byte().unsqueeze(0).expand(nBatch, sz, sz)
I_L = 1 - I_U
L = LU_data.new(LU_data.size()).zero_()
U = LU_data.new(LU_data.size()).zero_()
I_diag = torch.eye(sz).type_as(LU_data).byte().unsqueeze(0).expand(nBatch, sz, sz)
L[I_diag] = 1.0
L[I_L] = LU_data[I_L]
U[I_U] = LU_data[I_U]
else:
L = U = None
if unpack_pivots:
P = torch.eye(sz).type_as(LU_data).unsqueeze(0).repeat(nBatch, 1, 1)
for i in range(nBatch):
for j in range(sz):
k = LU_pivots[i, j] - 1
t = P[i, :, j].clone()
P[i, :, j] = P[i, :, k]
P[i, :, k] = t
else:
P = None
return P, L, U