| import torch |
| from ._utils import _range |
| from operator import mul |
| from functools import reduce |
| import math |
| |
| __all__ = [ |
| 'split', 'chunk', 'stack', 'unbind', 'btriunpack', 'matmul', 'det', 'stft', |
| 'hann_window', 'hamming_window', 'bartlett_window', 'where', |
| ] |
| |
| |
| def split(tensor, split_size, dim=0): |
| r"""Splits the tensor into chunks all of size :attr:`split_size` (if possible). |
| |
| Last chunk will be smaller if the tensor size along a given dimension |
| is not divisible by :attr`split_size`. |
| |
| Arguments: |
| tensor (Tensor): the tensor to split |
| split_size (int): size of a single chunk |
| dim (int): dimension along which to split the tensor |
| """ |
| if dim < 0: |
| dim += tensor.dim() |
| dim_size = tensor.size(dim) |
| num_splits = (dim_size + split_size - 1) // split_size |
| last_split_size = split_size - (split_size * num_splits - dim_size) |
| |
| def get_split_size(i): |
| return split_size if i < num_splits - 1 else last_split_size |
| return tuple(tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) for i |
| in _range(0, num_splits)) |
| |
| |
| def chunk(tensor, chunks, dim=0): |
| r"""Splits a tensor into a specific number of chunks. |
| |
| Arguments: |
| tensor (Tensor): the tensor to split |
| chunks (int): number of chunks to return |
| dim (int): dimension along which to split the tensor |
| """ |
| if dim < 0: |
| dim += tensor.dim() |
| split_size = (tensor.size(dim) + chunks - 1) // chunks |
| return split(tensor, split_size, dim) |
| |
| |
| def stack(sequence, dim=0, out=None): |
| r"""Concatenates sequence of tensors along a new dimension. |
| |
| All tensors need to be of the same size. |
| |
| Arguments: |
| sequence (Sequence): sequence of tensors to concatenate |
| dim (int): dimension to insert. Has to be between 0 and the number |
| of dimensions of concatenated tensors (inclusive) |
| """ |
| if len(sequence) == 0: |
| raise ValueError("stack expects a non-empty sequence of tensors") |
| if dim < 0: |
| dim += sequence[0].dim() + 1 |
| inputs = [t.unsqueeze(dim) for t in sequence] |
| if out is None: |
| return torch.cat(inputs, dim) |
| else: |
| return torch.cat(inputs, dim, out=out) |
| |
| |
| def unbind(tensor, dim=0): |
| r"""Removes a tensor dimension. |
| |
| Returns a tuple of all slices along a given dimension, already without it. |
| |
| Arguments: |
| tensor (Tensor): the tensor to unbind |
| dim (int): dimension to remove |
| """ |
| return tuple(tensor.select(dim, i) for i in _range(tensor.size(dim))) |
| |
| |
| def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): |
| r"""Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor. |
| |
| Returns a tuple indexed by: |
| 0: The pivots. |
| 1: The L tensor. |
| 2: The U tensor. |
| |
| Arguments: |
| LU_data (Tensor): the packed LU factorization data |
| LU_pivots (Tensor): the packed LU factorization pivots |
| unpack_data (bool): flag indicating if the data should be unpacked |
| unpack_pivots (bool): tlag indicating if the pivots should be unpacked |
| |
| Example:: |
| |
| >>> A = torch.randn(2, 3, 3) |
| >>> A_LU, pivots = A.btrifact() |
| >>> P, a_L, a_U = torch.btriunpack(A_LU, pivots) |
| >>> |
| >>> # test that (P, A_L, A_U) gives LU factorization |
| >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U)) |
| >>> assert torch.equal(A_, A) == True # can recover A |
| """ |
| |
| nBatch, sz, _ = LU_data.size() |
| |
| if unpack_data: |
| I_U = torch.triu(torch.ones(sz, sz)).type_as(LU_data).byte().unsqueeze(0).expand(nBatch, sz, sz) |
| I_L = 1 - I_U |
| L = LU_data.new(LU_data.size()).zero_() |
| U = LU_data.new(LU_data.size()).zero_() |
| I_diag = torch.eye(sz).type_as(LU_data).byte().unsqueeze(0).expand(nBatch, sz, sz) |
| L[I_diag] = 1.0 |
| L[I_L] = LU_data[I_L] |
| U[I_U] = LU_data[I_U] |
| else: |
| L = U = None |
| |
| if unpack_pivots: |
| P = torch.eye(sz).type_as(LU_data).unsqueeze(0).repeat(nBatch, 1, 1) |
| for i in range(nBatch): |
| for j in range(sz): |
| k = LU_pivots[i, j] - 1 |
| t = P[i, :, j].clone() |
| P[i, :, j] = P[i, :, k] |
| P[i, :, k] = t |
| else: |
| P = None |
| |
| return P, L, U |
| |
| |
| def matmul(tensor1, tensor2, out=None): |
| r"""Matrix product of two tensors. |
| |
| The behavior depends on the dimensionality of the tensors as follows: |
| |
| - If both tensors are 1-dimensional, the dot product (scalar) is returned. |
| - If both arguments are 2-dimensional, the matrix-matrix product is returned. |
| - If the first argument is 1-dimensional and the second argument is 2-dimensional, |
| a 1 is prepended to its dimension for the purpose of the matrix multiply. |
| After the matrix multiply, the prepended dimension is removed. |
| - If the first argument is 2-dimensional and the second argument is 1-dimensional, |
| the matrix-vector product is returned. |
| - If both arguments are at least 1-dimensional and at least one argument is |
| N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first |
| argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the |
| batched matrix multiply and removed after. If the second argument is 1-dimensional, a |
| 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. |
| The non-matrix (i.e. batch) dimensions are :ref:`broadcasted <broadcasting-semantics>` (and thus |
| must be broadcastable). For example, if :attr:`tensor1` is a |
| :math:`(j \times 1 \times n \times m)` tensor and :attr:`tensor2` is a :math:`(k \times m \times p)` |
| tensor, :attr:`out` will be an :math:`(j \times k \times n \times p)` tensor. |
| |
| .. note:: |
| |
| The 1-dimensional dot product version of this function does not support an :attr:`out` parameter. |
| |
| Arguments: |
| tensor1 (Tensor): the first tensor to be multiplied |
| tensor2 (Tensor): the second tensor to be multiplied |
| out (Tensor, optional): the output tensor |
| """ |
| dim_tensor1 = tensor1.dim() |
| dim_tensor2 = tensor2.dim() |
| if dim_tensor1 == 1 and dim_tensor2 == 1: |
| if out is None: |
| return torch.dot(tensor1, tensor2) |
| else: |
| raise ValueError("out must be None for 1-d tensor matmul, returns a scalar") |
| if dim_tensor1 == 2 and dim_tensor2 == 1: |
| if out is None: |
| return torch.mv(tensor1, tensor2) |
| else: |
| return torch.mv(tensor1, tensor2, out=out) |
| elif dim_tensor1 == 1 and dim_tensor2 == 2: |
| if out is None: |
| return torch.mm(tensor1.unsqueeze(0), tensor2).squeeze_(0) |
| else: |
| return torch.mm(tensor1.unsqueeze(0), tensor2, out=out).squeeze_(0) |
| elif dim_tensor1 == 2 and dim_tensor2 == 2: |
| if out is None: |
| return torch.mm(tensor1, tensor2) |
| else: |
| return torch.mm(tensor1, tensor2, out=out) |
| elif dim_tensor1 >= 3 and (dim_tensor2 == 1 or dim_tensor2 == 2): |
| # optimization: use mm instead of bmm by folding tensor1's batch into |
| # its leading matrix dimension. |
| |
| if dim_tensor2 == 1: |
| tensor2 = tensor2.unsqueeze(-1) |
| |
| size1 = tensor1.size() |
| size2 = tensor2.size() |
| output_size = size1[:-1] + size2[-1:] |
| |
| # fold the batch into the first dimension |
| tensor1 = tensor1.contiguous().view(-1, size1[-1]) |
| |
| if out is None or not out.is_contiguous(): |
| output = torch.mm(tensor1, tensor2) |
| else: |
| output = torch.mm(tensor1, tensor2, out=out) |
| |
| output = output.view(output_size) |
| |
| if dim_tensor2 == 1: |
| output = output.squeeze(-1) |
| |
| if out is not None: |
| out.set_(output) |
| return out |
| |
| return output |
| elif (dim_tensor1 >= 1 and dim_tensor2 >= 1) and (dim_tensor1 >= 3 or dim_tensor2 >= 3): |
| # ensure each tensor size is at least 3-dimensional |
| tensor1_exp_size = torch.Size((1,) * max(3 - tensor1.dim(), 0) + tensor1.size()) |
| # rhs needs to be a separate case since we can't freely expand 1s on the rhs, but can on lhs |
| if dim_tensor2 == 1: |
| tensor2 = tensor2.unsqueeze(1) |
| tensor2_exp_size = torch.Size((1,) * max(3 - tensor2.dim(), 0) + tensor2.size()) |
| |
| # expand the batch portion (i.e. cut off matrix dimensions and expand rest) |
| expand_batch_portion = torch._C._infer_size(tensor1_exp_size[:-2], tensor2_exp_size[:-2]) |
| |
| # flatten expanded batches |
| tensor1_expanded = tensor1.expand(*(expand_batch_portion + tensor1_exp_size[-2:])) \ |
| .contiguous().view(reduce(mul, expand_batch_portion), *tensor1_exp_size[-2:]) |
| tensor2_expanded = tensor2.expand(*(expand_batch_portion + tensor2_exp_size[-2:])) \ |
| .contiguous().view(reduce(mul, expand_batch_portion), *tensor2_exp_size[-2:]) |
| |
| # reshape batches back into result |
| total_expansion = expand_batch_portion + (tensor1_exp_size[-2], tensor2_exp_size[-1]) |
| |
| def maybeSqueeze(tensor): |
| if dim_tensor1 == 1: |
| return tensor.squeeze(-2) |
| elif dim_tensor2 == 1: |
| return tensor.squeeze(-1) |
| else: |
| return tensor |
| |
| if out is None or not out.is_contiguous(): |
| output = torch.bmm(tensor1_expanded, tensor2_expanded) |
| else: |
| output = torch.bmm(tensor1_expanded, tensor2_expanded, out=out) |
| |
| output = maybeSqueeze(output.view(total_expansion)) |
| |
| if out is not None: |
| out.set_(output) |
| return out |
| |
| return output |
| |
| raise ValueError("both arguments to __matmul__ need to be at least 1D, " |
| "but they are {}D and {}D".format(dim_tensor1, dim_tensor2)) |
| |
| |
| def det(var): |
| r"""Calculates determinant of a 2D square Variable. |
| |
| .. note:: |
| Backward through `det` internally uses SVD results. So double backward |
| through `det` will need to backward through :meth:`~Tensor.svd`. This |
| can be unstable in certain cases. Please see :meth:`~torch.svd` for |
| details. |
| |
| Arguments: |
| var (Variable): The input 2D square Variable. |
| """ |
| if torch.is_tensor(var): |
| raise ValueError("det is currently only supported on Variable") |
| return var.det() |
| |
| |
| def stft(var, frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0): |
| r"""Short-time Fourier transform (STFT). |
| |
| Ignoring the batch dimension, this method computes the following expression: |
| |
| .. math:: |
| X[m, \omega] = \sum_{k = 0}^{frame\_length}% |
| window[k]\ signal[m \times hop + k]\ e^{- j \frac{2 \pi \cdot \omega k}{frame\_length}} |
| |
| , where :math:`m` is the index of the sliding window, and :math:`\omega` is |
| the frequency that :math:`0 \leq \omega < fft\_size`. When |
| :attr:`return_onsesided` is the default value True, only values for |
| :math:`\omega` in range :math:`[0, 1, 2, \dots, \lfloor \frac{fft\_size}{2} \rfloor + 1]` |
| are returned because the real-to-complex transform satisfies the Hermitian |
| symmetry, i.e., :math:`X[m, \omega] = X[m, fft\_length - \omega]^*`. |
| |
| The input :attr:`signal` must be 1-D sequence :math:`(T)` or 2-D a batch of |
| sequences :math:`(N \times T)`. If :attr:`fft_size` is ``None``, it is |
| default to same value as :attr:``frame_length``. :attr:`window` can be a |
| 1-D tensor of size :math:`(frame\_length)`, e.g., see |
| :meth:`torch.hann_window`. If :attr:`window` is the default value ``None``, |
| it is treated as if having :math:`1` everywhere in the frame. |
| :attr:`pad_end` indicates the amount of zero padding at the end of |
| :attr:`signal` before STFT. |
| |
| Returns the real and the imaginary parts together as one tensor of size |
| :math:`(* \times N \times 2)`, where :math:`*` is the shape of input :attr:`signal`, |
| :math:`N` is the number of :math:`\omega`s considered depending on |
| :attr:`fft_size` and :attr:`return_onesided`, and each pair in the last |
| dimension represents a complex number as real part and imaginary part. |
| |
| Arguments: |
| signal (Tensor): the input tensor |
| frame_length (int): the size of window frame and STFT filter |
| hop (int): the distance between neighboring sliding window frames |
| fft_size (int, optional): size of Fourier transform |
| return_onesided (bool, optional): controls whether to avoid redundancy in the return value |
| window (Tensor, optional): the optional window function |
| pad_end (int, optional): implicit zero padding at the end of :attr:`signal` |
| |
| Returns: |
| Tensor: A tensor containing the STFT result |
| """ |
| if torch.is_tensor(var): |
| raise ValueError("stft is currently only supported on Variable") |
| return var.stft(frame_length, hop, fft_size, return_onesided, window, pad_end) |
| |
| |
| def hann_window(window_length, periodic=True): |
| r"""Hann window function. |
| |
| This method computes the Hann window function: |
| |
| .. math:: |
| w[n] = \frac{1}{2}\ [1 - \cos \left( \frac{2 \pi n}{N - 1} \right)] = \sin^2 \left( \frac{\pi n}{N - 1} \right) |
| |
| , where :math:`N` is the full window size. |
| |
| The input :attr:`window_length` is a positive integer controlling the |
| returned window size. :attr:`periodic` flag determines whether the returned |
| window trims off the last duplicate value from the symmetric window and is |
| ready to be used as a periodic window with functions like |
| :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in |
| above formula is in fact :math:`window\_length + 1`. Also, we always have |
| ``torch.hann_window(L, periodic=True)`` equal to |
| ``torch.hann_window(L + 1, periodic=False)[:-1])``. |
| |
| .. note:: |
| If :attr:`window_length` :math:`\leq 2`, the returned window contains a single value 1. |
| |
| Arguments: |
| window_length (int): the size of returned window |
| periodic (bool, optional): If True, returns a window to be used as periodic |
| function. If False, return a symmetric window. |
| |
| Returns: |
| Tensor: A 1-D tensor of size :math:`(window\_length)` containing the window |
| """ |
| if window_length <= 0: |
| raise ValueError('window_length must be positive') |
| return hamming_window(window_length, periodic=periodic, alpha=0.5, beta=0.5) |
| |
| |
| def hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46): |
| r"""Hamming window function. |
| |
| This method computes the Hamming window function: |
| |
| .. math:: |
| w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right) |
| |
| , where :math:`N` is the full window size. |
| |
| The input :attr:`window_length` is a positive integer controlling the |
| returned window size. :attr:`periodic` flag determines whether the returned |
| window trims off the last duplicate value from the symmetric window and is |
| ready to be used as a periodic window with functions like |
| :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in |
| above formula is in fact :math:`window\_length + 1`. Also, we always have |
| ``torch.hamming_window(L, periodic=True)`` equal to |
| ``torch.hamming_window(L + 1, periodic=False)[:-1])``. |
| |
| .. note:: |
| If :attr:`window_length` :math:`\leq 2`, the returned window contains a single value 1. |
| |
| .. note:: |
| This is a generalized version of :meth:`torch.hann_window`. |
| |
| Arguments: |
| window_length (int): the size of returned window |
| periodic (bool, optional): If True, returns a window to be used as periodic |
| function. If False, return a symmetric window. |
| |
| Returns: |
| Tensor: A 1-D tensor of size :math:`(window\_length)` containing the window |
| """ |
| if window_length <= 0: |
| raise ValueError('window_length must be positive') |
| if window_length == 1: |
| return torch.ones(window_length) |
| window_length += int(periodic) |
| window = torch.arange(window_length).mul_(math.pi * 2 / (window_length - 1)).cos_().mul_(-beta).add_(alpha) |
| if periodic: |
| return window[:-1] |
| else: |
| return window |
| |
| |
| def bartlett_window(window_length, periodic=True): |
| r"""Bartlett window function. |
| |
| This method computes the Bartlett window function: |
| |
| .. math:: |
| w[n] = 1 - \lvert \frac{2n}{N-1} - 1 \rvert = \begin{cases} |
| \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ |
| 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ |
| \end{cases} |
| |
| , where :math:`N` is the full window size. |
| |
| The input :attr:`window_length` is a positive integer controlling the |
| returned window size. :attr:`periodic` flag determines whether the returned |
| window trims off the last duplicate value from the symmetric window and is |
| ready to be used as a periodic window with functions like |
| :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in |
| above formula is in fact :math:`window\_length + 1`. Also, we always have |
| ``torch.bartlett_window(L, periodic=True)`` equal to |
| ``torch.bartlett_window(L + 1, periodic=False)[:-1])``. |
| |
| .. note:: |
| If :attr:`window_length` :math:`\leq 2`, the returned window contains a single value 1. |
| |
| Arguments: |
| window_length (int): the size of returned window |
| periodic (bool, optional): If True, returns a window to be used as periodic |
| function. If False, return a symmetric window. |
| |
| Returns: |
| Tensor: A 1-D tensor of size :math:`(window\_length)` containing the window |
| """ |
| if window_length <= 0: |
| raise ValueError('window_length must be positive') |
| if window_length == 1: |
| return torch.ones(window_length) |
| window_length += int(periodic) |
| window = torch.arange(window_length).mul_(2.0 / (window_length - 1)) |
| first_half_size = ((window_length - 1) >> 1) + 1 |
| window.narrow(0, first_half_size, window_length - first_half_size).mul_(-1).add_(2) |
| if periodic: |
| return window[:-1] |
| else: |
| return window |
| |
| |
| def where(condition, x, y): |
| r"""Return a tensor of elements selected from either :attr:`x` or :attr:`y`, depending on :attr:`condition`. |
| |
| defined as:: |
| |
| out_i = x_i if condition_i |
| y_i else |
| |
| .. note:: |
| This function only works with ``Variables``. |
| |
| .. note:: |
| The tensors :attr:`condition`, :attr:`x`, :attr:`y` must be :ref:`broadcastable <broadcasting-semantics>`. |
| |
| Arguments: |
| condition (ByteTensor): When True (nonzero), yield x, otherwise yield y. |
| x (Tensor): values selected at indices where :attr:`condition` is True. |
| y (Tensor): values selected at indices where :attr:`condition` is False. |
| |
| Returns: |
| Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`x`, :attr:`y` |
| """ |
| # the parameter order is changed here; the functional order is the same as numpy; the |
| # method follows the usual torch mask semantics of x.fn(mask, y) |
| return torch._C._VariableBase.where(x, condition, y) |