| import torch |
| from torch.nn.parameter import Parameter |
| |
| from .module import Module |
| from .. import functional as F |
| from .. import init |
| from torch._jit_internal import weak_module, weak_script_method |
| |
| |
| @weak_module |
| class Embedding(Module): |
| r"""A simple lookup table that stores embeddings of a fixed dictionary and size. |
| |
| This module is often used to store word embeddings and retrieve them using indices. |
| The input to the module is a list of indices, and the output is the corresponding |
| word embeddings. |
| |
| Args: |
| num_embeddings (int): size of the dictionary of embeddings |
| embedding_dim (int): the size of each embedding vector |
| padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` |
| (initialized to zeros) whenever it encounters the index. |
| max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` |
| is renormalized to have norm :attr:`max_norm`. |
| norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. |
| scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of |
| the words in the mini-batch. Default ``False``. |
| sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. |
| See Notes for more details regarding sparse gradients. |
| |
| Attributes: |
| weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) |
| initialized from :math:`\mathcal{N}(0, 1)` |
| |
| Shape: |
| - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract |
| - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` |
| |
| .. note:: |
| Keep in mind that only a limited number of optimizers support |
| sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), |
| :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) |
| |
| .. note:: |
| With :attr:`padding_idx` set, the embedding vector at |
| :attr:`padding_idx` is initialized to all zeros. However, note that this |
| vector can be modified afterwards, e.g., using a customized |
| initialization method, and thus changing the vector used to pad the |
| output. The gradient for this vector from :class:`~torch.nn.Embedding` |
| is always zero. |
| |
| Examples:: |
| |
| >>> # an Embedding module containing 10 tensors of size 3 |
| >>> embedding = nn.Embedding(10, 3) |
| >>> # a batch of 2 samples of 4 indices each |
| >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) |
| >>> embedding(input) |
| tensor([[[-0.0251, -1.6902, 0.7172], |
| [-0.6431, 0.0748, 0.6969], |
| [ 1.4970, 1.3448, -0.9685], |
| [-0.3677, -2.7265, -0.1685]], |
| |
| [[ 1.4970, 1.3448, -0.9685], |
| [ 0.4362, -0.4004, 0.9400], |
| [-0.6431, 0.0748, 0.6969], |
| [ 0.9124, -2.3616, 1.1151]]]) |
| |
| |
| >>> # example with padding_idx |
| >>> embedding = nn.Embedding(10, 3, padding_idx=0) |
| >>> input = torch.LongTensor([[0,2,0,5]]) |
| >>> embedding(input) |
| tensor([[[ 0.0000, 0.0000, 0.0000], |
| [ 0.1535, -2.0309, 0.9315], |
| [ 0.0000, 0.0000, 0.0000], |
| [-0.1655, 0.9897, 0.0635]]]) |
| """ |
| __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm', |
| 'norm_type', 'scale_grad_by_freq', 'sparse'] |
| |
| def __init__(self, num_embeddings, embedding_dim, padding_idx=None, |
| max_norm=None, norm_type=2., scale_grad_by_freq=False, |
| sparse=False, _weight=None): |
| super(Embedding, self).__init__() |
| self.num_embeddings = num_embeddings |
| self.embedding_dim = embedding_dim |
| if padding_idx is not None: |
| if padding_idx > 0: |
| assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' |
| elif padding_idx < 0: |
| assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' |
| padding_idx = self.num_embeddings + padding_idx |
| self.padding_idx = padding_idx |
| self.max_norm = max_norm |
| self.norm_type = norm_type |
| self.scale_grad_by_freq = scale_grad_by_freq |
| if _weight is None: |
| self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) |
| self.reset_parameters() |
| else: |
| assert list(_weight.shape) == [num_embeddings, embedding_dim], \ |
| 'Shape of weight does not match num_embeddings and embedding_dim' |
| self.weight = Parameter(_weight) |
| self.sparse = sparse |
| |
| def reset_parameters(self): |
| init.normal_(self.weight) |
| if self.padding_idx is not None: |
| with torch.no_grad(): |
| self.weight[self.padding_idx].fill_(0) |
| |
| @weak_script_method |
| def forward(self, input): |
| return F.embedding( |
| input, self.weight, self.padding_idx, self.max_norm, |
| self.norm_type, self.scale_grad_by_freq, self.sparse) |
| |
| def extra_repr(self): |
| s = '{num_embeddings}, {embedding_dim}' |
| if self.padding_idx is not None: |
| s += ', padding_idx={padding_idx}' |
| if self.max_norm is not None: |
| s += ', max_norm={max_norm}' |
| if self.norm_type != 2: |
| s += ', norm_type={norm_type}' |
| if self.scale_grad_by_freq is not False: |
| s += ', scale_grad_by_freq={scale_grad_by_freq}' |
| if self.sparse is not False: |
| s += ', sparse=True' |
| return s.format(**self.__dict__) |
| |
| @classmethod |
| def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, |
| max_norm=None, norm_type=2., scale_grad_by_freq=False, |
| sparse=False): |
| r"""Creates Embedding instance from given 2-dimensional FloatTensor. |
| |
| Args: |
| embeddings (Tensor): FloatTensor containing weights for the Embedding. |
| First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. |
| freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. |
| Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` |
| padding_idx (int, optional): See module initialization documentation. |
| max_norm (float, optional): See module initialization documentation. |
| norm_type (float, optional): See module initialization documentation. Default ``2``. |
| scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. |
| sparse (bool, optional): See module initialization documentation. |
| |
| Examples:: |
| |
| >>> # FloatTensor containing pretrained weights |
| >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) |
| >>> embedding = nn.Embedding.from_pretrained(weight) |
| >>> # Get embeddings for index 1 |
| >>> input = torch.LongTensor([1]) |
| >>> embedding(input) |
| tensor([[ 4.0000, 5.1000, 6.3000]]) |
| """ |
| assert embeddings.dim() == 2, \ |
| 'Embeddings parameter is expected to be 2-dimensional' |
| rows, cols = embeddings.shape |
| embedding = cls( |
| num_embeddings=rows, |
| embedding_dim=cols, |
| _weight=embeddings, |
| padding_idx=padding_idx, |
| max_norm=max_norm, |
| norm_type=norm_type, |
| scale_grad_by_freq=scale_grad_by_freq, |
| sparse=sparse) |
| embedding.weight.requires_grad = not freeze |
| return embedding |
| |
| |
| @weak_module |
| class EmbeddingBag(Module): |
| r"""Computes sums or means of 'bags' of embeddings, without instantiating the |
| intermediate embeddings. |
| |
| For bags of constant length and no :attr:`per_sample_weights`, this class |
| |
| * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=0)``, |
| * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=0)``, |
| * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=0)``. |
| |
| However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these |
| operations. |
| |
| EmbeddingBag also supports per-sample weights as an argument to the forward |
| pass. This scales the output of the Embedding before performing a weighted |
| reduction as specified by ``mode``. If :attr:`per_sample_weights`` is passed, the |
| only supported ``mode`` is ``"sum"``, which computes a weighted sum according to |
| :attr:`per_sample_weights`. |
| |
| Args: |
| num_embeddings (int): size of the dictionary of embeddings |
| embedding_dim (int): the size of each embedding vector |
| max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` |
| is renormalized to have norm :attr:`max_norm`. |
| norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. |
| scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of |
| the words in the mini-batch. Default ``False``. |
| Note: this option is not supported when ``mode="max"``. |
| mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. |
| ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights` |
| into consideration. ``"mean"`` computes the average of the values |
| in the bag, ``"max"`` computes the max value over each bag. |
| Default: ``"mean"`` |
| sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See |
| Notes for more details regarding sparse gradients. Note: this option is not |
| supported when ``mode="max"``. |
| |
| Attributes: |
| weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` |
| initialized from :math:`\mathcal{N}(0, 1)`. |
| |
| Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and |
| :attr:`per_index_weights` (Tensor, optional) |
| |
| - If :attr:`input` is 2D of shape `(B, N)`, |
| |
| it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and |
| this will return ``B`` values aggregated in a way depending on the :attr:`mode`. |
| :attr:`offsets` is ignored and required to be ``None`` in this case. |
| |
| - If :attr:`input` is 1D of shape `(N)`, |
| |
| it will be treated as a concatenation of multiple bags (sequences). |
| :attr:`offsets` is required to be a 1D tensor containing the |
| starting index positions of each bag in :attr:`input`. Therefore, |
| for :attr:`offsets` of shape `(B)`, :attr:`input` will be viewed as |
| having ``B`` bags. Empty bags (i.e., having 0-length) will have |
| returned vectors filled by zeros. |
| |
| per_sample_weights (Tensor, optional): a tensor of float / double weights, or None |
| to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights` |
| must have exactly the same shape as input and is treated as having the same |
| :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``. |
| |
| |
| Output shape: `(B, embedding_dim)` |
| |
| Examples:: |
| |
| >>> # an Embedding module containing 10 tensors of size 3 |
| >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum') |
| >>> # a batch of 2 samples of 4 indices each |
| >>> input = torch.LongTensor([1,2,4,5,4,3,2,9]) |
| >>> offsets = torch.LongTensor([0,4]) |
| >>> embedding_sum(input, offsets) |
| tensor([[-0.8861, -5.4350, -0.0523], |
| [ 1.1306, -2.5798, -1.0044]]) |
| """ |
| __constants__ = ['num_embeddings', 'embedding_dim', 'max_norm', 'norm_type', |
| 'scale_grad_by_freq', 'mode', 'sparse'] |
| |
| def __init__(self, num_embeddings, embedding_dim, |
| max_norm=None, norm_type=2., scale_grad_by_freq=False, |
| mode='mean', sparse=False, _weight=None): |
| super(EmbeddingBag, self).__init__() |
| self.num_embeddings = num_embeddings |
| self.embedding_dim = embedding_dim |
| self.max_norm = max_norm |
| self.norm_type = norm_type |
| self.scale_grad_by_freq = scale_grad_by_freq |
| if _weight is None: |
| self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) |
| self.reset_parameters() |
| else: |
| assert list(_weight.shape) == [num_embeddings, embedding_dim], \ |
| 'Shape of weight does not match num_embeddings and embedding_dim' |
| self.weight = Parameter(_weight) |
| self.mode = mode |
| self.sparse = sparse |
| |
| def reset_parameters(self): |
| init.normal_(self.weight) |
| |
| @weak_script_method |
| def forward(self, input, offsets=None, per_sample_weights=None): |
| # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor |
| return F.embedding_bag(input, self.weight, offsets, |
| self.max_norm, self.norm_type, |
| self.scale_grad_by_freq, self.mode, self.sparse, |
| per_sample_weights) |
| |
| def extra_repr(self): |
| s = '{num_embeddings}, {embedding_dim}' |
| if self.max_norm is not None: |
| s += ', max_norm={max_norm}' |
| if self.norm_type != 2: |
| s += ', norm_type={norm_type}' |
| if self.scale_grad_by_freq is not False: |
| s += ', scale_grad_by_freq={scale_grad_by_freq}' |
| s += ', mode={mode}' |
| return s.format(**self.__dict__) |
| |
| @classmethod |
| def from_pretrained(cls, embeddings, freeze=True, max_norm=None, |
| norm_type=2., scale_grad_by_freq=False, |
| mode='mean', sparse=False): |
| r"""Creates EmbeddingBag instance from given 2-dimensional FloatTensor. |
| |
| Args: |
| embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag. |
| First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'. |
| freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. |
| Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True`` |
| max_norm (float, optional): See module initialization documentation. Default: ``None`` |
| norm_type (float, optional): See module initialization documentation. Default ``2``. |
| scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. |
| mode (string, optional): See module initialization documentation. Default: ``"mean"`` |
| sparse (bool, optional): See module initialization documentation. Default: ``False``. |
| |
| Examples:: |
| |
| >>> # FloatTensor containing pretrained weights |
| >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) |
| >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight) |
| >>> # Get embeddings for index 1 |
| >>> input = torch.LongTensor([[1, 0]]) |
| >>> embeddingbag(input) |
| tensor([[ 2.5000, 3.7000, 4.6500]]) |
| """ |
| assert embeddings.dim() == 2, \ |
| 'Embeddings parameter is expected to be 2-dimensional' |
| rows, cols = embeddings.shape |
| embeddingbag = cls( |
| num_embeddings=rows, |
| embedding_dim=cols, |
| _weight=embeddings, |
| max_norm=max_norm, |
| norm_type=norm_type, |
| scale_grad_by_freq=scale_grad_by_freq, |
| mode=mode, |
| sparse=sparse) |
| embeddingbag.weight.requires_grad = not freeze |
| return embeddingbag |