| import torch |
| from torch.nn.parameter import Parameter |
| |
| from .module import Module |
| from .. import functional as F |
| |
| |
| class Embedding(Module): |
| r"""A simple lookup table that stores embeddings of a fixed dictionary and size. |
| |
| This module is often used to store word embeddings and retrieve them using indices. |
| The input to the module is a list of indices, and the output is the corresponding |
| word embeddings. |
| |
| Args: |
| num_embeddings (int): size of the dictionary of embeddings |
| embedding_dim (int): the size of each embedding vector |
| padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` |
| (initialized to zeros) whenever it encounters the index. |
| max_norm (float, optional): If given, will renormalize the embedding vectors to have a norm lesser than |
| this before extracting. |
| norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default ``2``. |
| scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of |
| the words in the mini-batch. Default ``False``. |
| sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. |
| See Notes for more details regarding sparse gradients. |
| |
| Attributes: |
| weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) |
| |
| Shape: |
| |
| - Input: LongTensor of arbitrary shape containing the indices to extract |
| - Output: `(*, embedding_dim)`, where `*` is the input shape |
| |
| .. note:: |
| Keep in mind that only a limited number of optimizers support |
| sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), |
| :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) |
| |
| .. note:: |
| With :attr:`padding_idx` set, the embedding vector at |
| :attr:`padding_idx` is initialized to all zeros. However, note that this |
| vector can be modified afterwards, e.g., using a customized |
| initialization method, and thus changing the vector used to pad the |
| output. The gradient for this vector from :class:`~torch.nn.Embedding` |
| is always zero. |
| |
| Examples:: |
| |
| >>> # an Embedding module containing 10 tensors of size 3 |
| >>> embedding = nn.Embedding(10, 3) |
| >>> # a batch of 2 samples of 4 indices each |
| >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) |
| >>> embedding(input) |
| tensor([[[-0.0251, -1.6902, 0.7172], |
| [-0.6431, 0.0748, 0.6969], |
| [ 1.4970, 1.3448, -0.9685], |
| [-0.3677, -2.7265, -0.1685]], |
| |
| [[ 1.4970, 1.3448, -0.9685], |
| [ 0.4362, -0.4004, 0.9400], |
| [-0.6431, 0.0748, 0.6969], |
| [ 0.9124, -2.3616, 1.1151]]]) |
| |
| |
| >>> # example with padding_idx |
| >>> embedding = nn.Embedding(10, 3, padding_idx=0) |
| >>> input = torch.LongTensor([[0,2,0,5]]) |
| >>> embedding(input) |
| tensor([[[ 0.0000, 0.0000, 0.0000], |
| [ 0.1535, -2.0309, 0.9315], |
| [ 0.0000, 0.0000, 0.0000], |
| [-0.1655, 0.9897, 0.0635]]]) |
| """ |
| |
| def __init__(self, num_embeddings, embedding_dim, padding_idx=None, |
| max_norm=None, norm_type=2, scale_grad_by_freq=False, |
| sparse=False, _weight=None): |
| super(Embedding, self).__init__() |
| self.num_embeddings = num_embeddings |
| self.embedding_dim = embedding_dim |
| if padding_idx is not None: |
| if padding_idx > 0: |
| assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' |
| elif padding_idx < 0: |
| assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' |
| padding_idx = self.num_embeddings + padding_idx |
| self.padding_idx = padding_idx |
| self.max_norm = max_norm |
| self.norm_type = norm_type |
| self.scale_grad_by_freq = scale_grad_by_freq |
| if _weight is None: |
| self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) |
| self.reset_parameters() |
| else: |
| assert list(_weight.shape) == [num_embeddings, embedding_dim], \ |
| 'Shape of weight does not match num_embeddings and embedding_dim' |
| self.weight = Parameter(_weight) |
| self.sparse = sparse |
| |
| def reset_parameters(self): |
| self.weight.data.normal_(0, 1) |
| if self.padding_idx is not None: |
| self.weight.data[self.padding_idx].fill_(0) |
| |
| def forward(self, input): |
| return F.embedding( |
| input, self.weight, self.padding_idx, self.max_norm, |
| self.norm_type, self.scale_grad_by_freq, self.sparse) |
| |
| def extra_repr(self): |
| s = '{num_embeddings}, {embedding_dim}' |
| if self.padding_idx is not None: |
| s += ', padding_idx={padding_idx}' |
| if self.max_norm is not None: |
| s += ', max_norm={max_norm}' |
| if self.norm_type != 2: |
| s += ', norm_type={norm_type}' |
| if self.scale_grad_by_freq is not False: |
| s += ', scale_grad_by_freq={scale_grad_by_freq}' |
| if self.sparse is not False: |
| s += ', sparse=True' |
| return s.format(**self.__dict__) |
| |
| @classmethod |
| def from_pretrained(cls, embeddings, freeze=True, sparse=False): |
| r"""Creates Embedding instance from given 2-dimensional FloatTensor. |
| |
| Args: |
| embeddings (Tensor): FloatTensor containing weights for the Embedding. |
| First dimension is being passed to Embedding as 'num_embeddings', second as 'embedding_dim'. |
| freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. |
| Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` |
| sparse (bool, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor. |
| See Notes for more details regarding sparse gradients. |
| |
| Examples:: |
| |
| >>> # FloatTensor containing pretrained weights |
| >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) |
| >>> embedding = nn.Embedding.from_pretrained(weight) |
| >>> # Get embeddings for index 1 |
| >>> input = torch.LongTensor([1]) |
| >>> embedding(input) |
| tensor([[ 4.0000, 5.1000, 6.3000]]) |
| """ |
| assert embeddings.dim() == 2, \ |
| 'Embeddings parameter is expected to be 2-dimensional' |
| rows, cols = embeddings.shape |
| embedding = cls( |
| num_embeddings=rows, |
| embedding_dim=cols, |
| _weight=embeddings, |
| sparse=sparse, |
| ) |
| embedding.weight.requires_grad = not freeze |
| return embedding |
| |
| |
| class EmbeddingBag(Module): |
| r"""Computes sums or means of 'bags' of embeddings, without instantiating the |
| intermediate embeddings. |
| |
| For bags of constant length, this class |
| |
| * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=1)``, |
| * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=1)``, |
| * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=1)``. |
| |
| However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these |
| operations. |
| |
| Args: |
| num_embeddings (int): size of the dictionary of embeddings |
| embedding_dim (int): the size of each embedding vector |
| max_norm (float, optional): If given, will renormalize the embedding vectors to have a norm lesser than |
| this before extracting. |
| norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default ``2``. |
| scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of |
| the words in the mini-batch. Default ``False``. |
| Note: this option is not supported when ``mode="max"``. |
| mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. |
| Default: ``"mean"`` |
| sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See |
| Notes for more details regarding sparse gradients. Note: this option is not |
| supported when ``mode="max"``. |
| |
| Attributes: |
| weight (Tensor): the learnable weights of the module of shape ``(num_embeddings x embedding_dim)`` |
| |
| Inputs: :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional) |
| |
| - If :attr:`input` is 2D of shape ``B x N``, |
| |
| it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and |
| this will return ``B`` values aggregated in a way depending on the :attr:`mode`. |
| :attr:`offsets` is ignored and required to be ``None`` in this case. |
| |
| - If :attr:`input` is 1D of shape ``N``, |
| |
| it will be treated as a concatenation of multiple bags (sequences). |
| :attr:`offsets` is required to be a 1D tensor containing the |
| starting index positions of each bag in :attr:`input`. Therefore, |
| for :attr:`offsets` of shape ``B``, :attr:`input` will be viewed as |
| having ``B`` bags. Empty bags (i.e., having 0-length) will have |
| returned vectors filled by zeros. |
| |
| Output shape: ``B x embedding_dim`` |
| |
| Examples:: |
| |
| >>> # an Embedding module containing 10 tensors of size 3 |
| >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum') |
| >>> # a batch of 2 samples of 4 indices each |
| >>> input = torch.LongTensor([1,2,4,5,4,3,2,9]) |
| >>> offsets = torch.LongTensor([0,4]) |
| >>> embedding_sum(input, offsets) |
| tensor([[-0.8861, -5.4350, -0.0523], |
| [ 1.1306, -2.5798, -1.0044]]) |
| """ |
| |
| def __init__(self, num_embeddings, embedding_dim, |
| max_norm=None, norm_type=2, scale_grad_by_freq=False, |
| mode='mean', sparse=False): |
| super(EmbeddingBag, self).__init__() |
| self.num_embeddings = num_embeddings |
| self.embedding_dim = embedding_dim |
| self.max_norm = max_norm |
| self.norm_type = norm_type |
| self.scale_grad_by_freq = scale_grad_by_freq |
| self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) |
| self.mode = mode |
| self.sparse = sparse |
| |
| self.reset_parameters() |
| |
| def reset_parameters(self): |
| self.weight.data.normal_(0, 1) |
| |
| def forward(self, input, offsets=None): |
| return F.embedding_bag(input, self.weight, offsets, |
| self.max_norm, self.norm_type, |
| self.scale_grad_by_freq, self.mode, self.sparse) |
| |
| def extra_repr(self): |
| s = '{num_embeddings}, {embedding_dim}' |
| if self.max_norm is not None: |
| s += ', max_norm={max_norm}' |
| if self.norm_type != 2: |
| s += ', norm_type={norm_type}' |
| if self.scale_grad_by_freq is not False: |
| s += ', scale_grad_by_freq={scale_grad_by_freq}' |
| s += ', mode={mode}' |
| return s.format(**self.__dict__) |
| |
| # TODO: SparseLinear |