| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| import torch |
| from torch import nn |
| from functorch.dim import dims, dimlists, softmax, cat |
| import math |
| |
| |
| class Linear(nn.Linear): |
| def forward(self, input): |
| ci, co = dims() |
| b = dimlists() |
| result = (input[b, ci] * self.weight[co, ci]).sum(ci) + self.bias[co] |
| return result.order(b, co) |
| |
| class BertSelfAttention(nn.Module): |
| def __init__(self, hidden_size, num_attention_heads, |
| attention_probs_dropout_prob, position_embedding_type=None, |
| max_position_embeddings=None, linear=Linear): |
| super().__init__() |
| if hidden_size % num_attention_heads != 0: |
| raise ValueError( |
| f"The hidden size ({hidden_size}) is not a multiple of the number of attention " |
| f"heads ({num_attention_heads})" |
| ) |
| |
| self.num_attention_heads = num_attention_heads |
| self.attention_head_size = int(hidden_size / num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attention_head_size |
| |
| self.query = linear(hidden_size, self.all_head_size) |
| self.key = linear(hidden_size, self.all_head_size) |
| self.value = linear(hidden_size, self.all_head_size) |
| |
| self.dropout_prob = attention_probs_dropout_prob |
| self.position_embedding_type = position_embedding_type |
| |
| if self.position_embedding_type is not None: |
| assert max_position_embeddings is not None |
| self.max_position_embeddings = max_position_embeddings |
| self.distance_embedding = nn.Embedding(2 * max_position_embeddings - 1, self.attention_head_size) |
| |
| def forward( |
| self, |
| hidden_states, |
| past_key_value=None, |
| ): |
| # first run the encoding linear layers for q, k, v normally |
| # the meaning of a linear layer is well understood, so no need to use explicit dimensions |
| q = self.query(hidden_states) |
| k = self.key(hidden_states) |
| v = self.value(hidden_states) |
| |
| # introduce values that represent each dimension. dimensions are 'first class' |
| # because they are actual python values introduced here |
| batch, query_sequence, key_sequence, heads, features = dims() |
| heads.size = self.num_attention_heads |
| |
| # bind the positional dimensions in k, q, and v against |
| # our values. the sizes of each dimension are determined by this binding |
| # and when a dimension is used twice (e.g. batch), its size against both |
| # uses is checked for consistency. |
| # The group (heads, features) splits apart a single positional dimension |
| # into two dimensions. Since heads.size*features.size == q.size(2) |
| # and we specified heads.size, features.size is inferred here. |
| q = q[batch, query_sequence, [heads, features]] |
| k = k[batch, key_sequence, [heads, features]] |
| v = v[batch, key_sequence, [heads, features]] |
| |
| |
| # this option allows the model to attend to not just the elements of the current sequence |
| # but the previous elements as well as additional tokens. |
| if past_key_value is not None: |
| extended_key_sequence = dims() |
| key_past = past_key_value[0][batch, heads, key_sequence, features] |
| value_past = past_key_value[1][batch, heads, key_sequence, features] |
| # cat introduces a new dimension extended_key_sequence, because it is twice as long |
| # as the original key_sequence |
| k = cat([key_past, k], key_sequence, extended_key_sequence) |
| v = cat([value_past, v], key_sequence, extended_key_sequence) |
| # for the rest of the function, we will just use extended_key_sequence in lieu of |
| # key_sequence |
| key_sequence = extended_key_sequence |
| |
| |
| # Take the dot product between "query" and "key" to get the raw attention scores. |
| # The actual outer-product and summation are explicitly represented here, |
| # and like einsum, will be pattern matched to an efficient matrix multiply op. |
| attention_scores = (q * k).sum(features) / math.sqrt(features.size) |
| |
| # relative positional embeddings gave a unique embedding based on the distance between |
| # key and value tokens in the sequence, e.g. |
| # 0 1 2 3 |
| # -1 0 1 2 |
| # -2 -1 0 1 |
| # -3 -2 -1 0 |
| if self.position_embedding_type is not None: |
| # the value of a dimension object when used as a tensor is the indices along its dimension |
| # so we can directly subtract the two dimensions to get a 2D tensor of (query_sequence x key_sequence) |
| # with the distance between them |
| distance = query_sequence - key_sequence |
| |
| assert key_sequence.size <= self.max_position_embeddings |
| |
| # we can then use that as an indirect index into the embedding table values to look up the features for that index |
| # this is just a `gather` primitive op. The resulting tensor will |
| # have all the dimensions of embeddeding_idx (query_sequence x key_sequence), |
| # plus all the dimensions of `embed` that were not indirectly accessed (`embedding_range`). |
| # this form of indirect indexing is more straightforward than either advanced indexing or torch.gather which both |
| # have a lot of dependencies on the positions of indexing tensors. |
| |
| positional_embedding = self.distance_embedding.weight[self.max_position_embeddings - 1 + distance, features] |
| |
| if self.position_embedding_type == "relative_key": |
| # these were einsum ops in the positional code because they are not easy to fit to existing matmul operators |
| # eventhough they are degenerate matmuls |
| relative_position_scores = (q * positional_embedding).sum(features) |
| attention_scores = attention_scores + relative_position_scores |
| elif self.position_embedding_type == "relative_key_query": |
| relative_position_scores_query = (q * positional_embedding).sum(features) |
| relative_position_scores_key = (k * positional_embedding).sum(features) |
| attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
| |
| attention_probs = attention_scores |
| # Normalize the attention scores to probabilities. |
| attention_probs = softmax(attention_scores, dim=key_sequence) |
| # # This is actually dropping out entire tokens to attend to, which might |
| # # seem a bit unusual, but is taken from the original Transformer paper. |
| attention_probs = torch.nn.functional.dropout(attention_probs, p=self.dropout_prob) |
| |
| # similarly, we can replace the matmul with a direct listing of the outer product, which makes it clear |
| # we are weighting the values v across all keys with the attention scores. |
| context_layer = (attention_probs * v).sum(key_sequence) |
| |
| # finally, we convert back to a standard tensor by describing the layout of dimensions. |
| # working in reverse to with_dims, the (heads, features) group flattens the dimensions into a single one. |
| return context_layer.order(batch, query_sequence, [heads, features]) |