| |
| from collections import namedtuple |
| |
| import torch |
| |
| from torch import Tensor |
| from typing import List, Sequence |
| |
| from . import Sequential, ModuleList, Linear |
| from .module import Module |
| from ..functional import log_softmax |
| |
| __all__ = ['AdaptiveLogSoftmaxWithLoss'] |
| |
| _ASMoutput = namedtuple('_ASMoutput', ['output', 'loss']) |
| |
| |
| class AdaptiveLogSoftmaxWithLoss(Module): |
| r"""Efficient softmax approximation. |
| |
| As described in |
| `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, |
| Moustapha Cissé, David Grangier, and Hervé Jégou |
| <https://arxiv.org/abs/1609.04309>`__. |
| |
| Adaptive softmax is an approximate strategy for training models with large |
| output spaces. It is most effective when the label distribution is highly |
| imbalanced, for example in natural language modelling, where the word |
| frequency distribution approximately follows the `Zipf's law`_. |
| |
| Adaptive softmax partitions the labels into several clusters, according to |
| their frequency. These clusters may contain different number of targets |
| each. |
| Additionally, clusters containing less frequent labels assign lower |
| dimensional embeddings to those labels, which speeds up the computation. |
| For each minibatch, only clusters for which at least one target is |
| present are evaluated. |
| |
| The idea is that the clusters which are accessed frequently |
| (like the first one, containing most frequent labels), should also be cheap |
| to compute -- that is, contain a small number of assigned labels. |
| |
| We highly recommend taking a look at the original paper for more details. |
| |
| * :attr:`cutoffs` should be an ordered Sequence of integers sorted |
| in the increasing order. |
| It controls number of clusters and the partitioning of targets into |
| clusters. For example setting ``cutoffs = [10, 100, 1000]`` |
| means that first `10` targets will be assigned |
| to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be |
| assigned to the first cluster, and targets `101, 102, ..., 1000` will be |
| assigned to the second cluster, while targets |
| `1001, 1002, ..., n_classes - 1` will be assigned |
| to the last, third cluster. |
| |
| * :attr:`div_value` is used to compute the size of each additional cluster, |
| which is given as |
| :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, |
| where :math:`idx` is the cluster index (with clusters |
| for less frequent words having larger indices, |
| and indices starting from :math:`1`). |
| |
| * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the |
| adaptive softmax. See paper for details. Set to False in the official |
| implementation. |
| |
| .. warning:: |
| Labels passed as inputs to this module should be sorted according to |
| their frequency. This means that the most frequent label should be |
| represented by the index `0`, and the least frequent |
| label should be represented by the index `n_classes - 1`. |
| |
| .. note:: |
| This module returns a ``NamedTuple`` with ``output`` |
| and ``loss`` fields. See further documentation for details. |
| |
| .. note:: |
| To compute log-probabilities for all classes, the ``log_prob`` |
| method can be used. |
| |
| Args: |
| in_features (int): Number of features in the input tensor |
| n_classes (int): Number of classes in the dataset |
| cutoffs (Sequence): Cutoffs used to assign targets to their buckets |
| div_value (float, optional): value used as an exponent to compute sizes |
| of the clusters. Default: 4.0 |
| head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the |
| adaptive softmax. Default: ``False`` |
| |
| Returns: |
| ``NamedTuple`` with ``output`` and ``loss`` fields: |
| * **output** is a Tensor of size ``N`` containing computed target |
| log probabilities for each example |
| * **loss** is a Scalar representing the computed negative |
| log likelihood loss |
| |
| Shape: |
| - input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})` |
| - target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}` |
| - output1: :math:`(N)` or :math:`()` |
| - output2: ``Scalar`` |
| |
| .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law |
| """ |
| |
| in_features: int |
| n_classes: int |
| cutoffs: List[int] |
| div_value: float |
| head_bias: bool |
| head: Linear |
| tail: ModuleList |
| |
| def __init__( |
| self, |
| in_features: int, |
| n_classes: int, |
| cutoffs: Sequence[int], |
| div_value: float = 4., |
| head_bias: bool = False, |
| device=None, |
| dtype=None |
| ) -> None: |
| factory_kwargs = {'device': device, 'dtype': dtype} |
| super().__init__() |
| |
| cutoffs = list(cutoffs) |
| |
| if (len(cutoffs) == 0): |
| raise ValueError("cutoffs should be a sequence of length larger than 0") |
| |
| if (cutoffs != sorted(cutoffs)) \ |
| or (min(cutoffs) <= 0) \ |
| or (max(cutoffs) > (n_classes - 1)) \ |
| or (len(set(cutoffs)) != len(cutoffs)) \ |
| or any(int(c) != c for c in cutoffs): |
| |
| raise ValueError("cutoffs should be a sequence of unique, positive " |
| "integers sorted in an increasing order, where " |
| "each value is between 1 and n_classes-1") |
| |
| self.in_features = in_features |
| self.n_classes = n_classes |
| self.cutoffs = cutoffs + [n_classes] |
| self.div_value = div_value |
| self.head_bias = head_bias |
| |
| self.shortlist_size = self.cutoffs[0] |
| self.n_clusters = len(self.cutoffs) - 1 |
| self.head_size = self.shortlist_size + self.n_clusters |
| |
| self.head = Linear(self.in_features, self.head_size, bias=self.head_bias, |
| **factory_kwargs) |
| self.tail = ModuleList() |
| |
| for i in range(self.n_clusters): |
| |
| hsz = int(self.in_features // (self.div_value ** (i + 1))) |
| osz = self.cutoffs[i + 1] - self.cutoffs[i] |
| |
| projection = Sequential( |
| Linear(self.in_features, hsz, bias=False, **factory_kwargs), |
| Linear(hsz, osz, bias=False, **factory_kwargs), |
| ) |
| |
| self.tail.append(projection) |
| |
| def reset_parameters(self) -> None: |
| self.head.reset_parameters() |
| for i2h, h2o in self.tail: |
| i2h.reset_parameters() |
| h2o.reset_parameters() |
| |
| def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: |
| targ_dim = target_.dim() |
| |
| if targ_dim == 1: |
| if input_.size(0) != target_.size(0): |
| raise RuntimeError('Input and target should have the same size ' |
| 'in the batch dimension.') |
| if input_.dim() != 2: |
| raise RuntimeError('1D target tensor expects 2D input tensors, ' |
| 'but found inputs with size', input_.size()) |
| elif targ_dim == 0: |
| if input_.dim() != 1: |
| raise RuntimeError('0D target tensor expects 1D input tensors, ' |
| 'but found inputs with size', input_.size()) |
| else: |
| raise RuntimeError('0D or 1D target tensor expected, ' |
| 'multi-target not supported') |
| |
| is_batched = targ_dim > 0 |
| input = input_ if is_batched else input_.unsqueeze(0) |
| target = target_ if is_batched else target_.unsqueeze(0) |
| |
| used_rows = 0 |
| batch_size = target.size(0) |
| |
| output = input.new_zeros(batch_size) |
| gather_inds = target.new_empty(batch_size) |
| |
| cutoff_values = [0] + self.cutoffs |
| for i in range(len(cutoff_values) - 1): |
| |
| low_idx = cutoff_values[i] |
| high_idx = cutoff_values[i + 1] |
| |
| target_mask = (target >= low_idx) & (target < high_idx) |
| row_indices = target_mask.nonzero().squeeze() |
| |
| if row_indices.numel() == 0: |
| continue |
| |
| if i == 0: |
| gather_inds.index_copy_(0, row_indices, target[target_mask]) |
| |
| else: |
| relative_target = target[target_mask] - low_idx |
| input_subset = input.index_select(0, row_indices) |
| |
| cluster_output = self.tail[i - 1](input_subset) |
| cluster_index = self.shortlist_size + i - 1 |
| |
| gather_inds.index_fill_(0, row_indices, cluster_index) |
| cluster_logprob = log_softmax(cluster_output, dim=1) |
| local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)) |
| output.index_copy_(0, row_indices, local_logprob.squeeze(1)) |
| |
| used_rows += row_indices.numel() |
| |
| if used_rows != batch_size: |
| raise RuntimeError(f"Target values should be in [0, {self.n_classes - 1}], " |
| f"but values in range [{target.min().item()}, {target.max().item()}] " |
| "were found. ") |
| |
| head_output = self.head(input) |
| head_logprob = log_softmax(head_output, dim=1) |
| output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze() |
| loss = (-output).mean() |
| |
| if not is_batched: |
| output = output.squeeze(0) |
| |
| return _ASMoutput(output, loss) |
| |
| def _get_full_log_prob(self, input, head_output): |
| """Given input tensor, and output of ``self.head``, compute the log of the full distribution.""" |
| out = input.new_empty((head_output.size(0), self.n_classes)) |
| head_logprob = log_softmax(head_output, dim=1) |
| |
| out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size] |
| |
| for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): |
| cluster_output = self.tail[i](input) |
| cluster_logprob = log_softmax(cluster_output, dim=1) |
| output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1) |
| |
| out[:, start_idx:stop_idx] = output_logprob |
| |
| return out |
| |
| def log_prob(self, input: Tensor) -> Tensor: |
| r"""Compute log probabilities for all :math:`\texttt{n\_classes}`. |
| |
| Args: |
| input (Tensor): a minibatch of examples |
| |
| Returns: |
| log-probabilities of for each class :math:`c` |
| in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a |
| parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. |
| |
| Shape: |
| - Input: :math:`(N, \texttt{in\_features})` |
| - Output: :math:`(N, \texttt{n\_classes})` |
| |
| """ |
| head_output = self.head(input) |
| return self._get_full_log_prob(input, head_output) |
| |
| def predict(self, input: Tensor) -> Tensor: |
| r"""Return the class with the highest probability for each example in the input minibatch. |
| |
| This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases. |
| |
| Args: |
| input (Tensor): a minibatch of examples |
| |
| Returns: |
| output (Tensor): a class with the highest probability for each example |
| |
| Shape: |
| - Input: :math:`(N, \texttt{in\_features})` |
| - Output: :math:`(N)` |
| """ |
| head_output = self.head(input) |
| output = torch.argmax(head_output, dim=1) |
| not_in_shortlist = (output >= self.shortlist_size) |
| all_in_shortlist = not (not_in_shortlist.any()) |
| |
| if all_in_shortlist: |
| return output |
| |
| elif not_in_shortlist.all(): |
| log_prob = self._get_full_log_prob(input, head_output) |
| return torch.argmax(log_prob, dim=1) |
| |
| else: |
| log_prob = self._get_full_log_prob(input[not_in_shortlist], |
| head_output[not_in_shortlist]) |
| output[not_in_shortlist] = torch.argmax(log_prob, dim=1) |
| return output |