The goal of self-attention is to calculate a context vector for each of the element in a sequence. To do so, we need to calculate attention score, which is a dot product, for each of the element with regard to other elements : . Attention scores measure the similarity of two vectors. In this case, how similar is the given element with all the other elements. Then, we normalize attention score to get attention weights . Finally we compute context vector of the given element by compute the weighted sum of the attention weights and its corresponding elements. By performing the weighted sum, we ask, how important is each of the other element in the sentence is to the given element?

In summary, we perform the following steps:

  1. compute attention score ()
  2. compute attention weights (), a normalized version of the attention score
  3. compute context vectors ()

Next, we walk through an example where we calculate context vectors for each of the element in an input sequence. Say the input sequence, input, has 6 elements. Attention score is the matrix product of the input sequence and its own transpose, which produces pairwise dot-product between all input vectors. The shape of the attention score is 6 x 6, which means for each of the input element, we have 6 scores, each for one input element. Note that this is a simplified version of the attention mechanism, where masks are not applied.

import torch
import torch.nn as nn
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],
        [0.55, 0.87, 0.66],
        [0.57, 0.85, 0.64],
        [0.22, 0.58, 0.33],
        [0.77, 0.25, 0.10],
        [0.05, 0.80, 0.55],
    ]
)
attn_scores = attn_scores = inputs @ inputs.T
print(attn_scores.shape)
torch.Size([6, 6])

Next, we apply softmax to attention score to get attention weights.

attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights.shape)
torch.Size([6, 6])

Finally, we calculate context vectors for each element by multiplying attention weights by input matrix. The context vector have the same shape as the input matrix.

all_context_vectors = attn_weights @ inputs
print(all_context_vectors.shape)
torch.Size([6, 3])

Self-attention with trainable weights

In the classic scaled dot-product attention or self-attention introduced in the Attention is All You Need paper, the attention function is the following:

This is an expanded version of our naive implementation above. The main difference is separation of the input vector into 3 different vectors: query, key and value vectors. By introducing three vectors, we are transforming an input vector into three learned subspaces. We are overparameterize the input vector. By doing so, each new vector undertakes a distinct role, taken from the world of information retrieval:

  • Q (query): the search request
  • K (key): the searchable index
  • V (value): the actual content retrieved
inputs.shape
torch.Size([6, 3])

Each of the 3 vectors are created by multiplying the given input token and corresponding matrix (, , and ).

Also, we introduce separate parameters for in and out dimensions, d_in and d_out. In GPT models, the two values are often the same, but here we use different values to illustrate the dimension changes.

The weight matrices are d_in x d_out, The key, value and query vectors are seq_len x d_out since it is the product of the input vector, seq_len x d_out, and weight matrices, d_in x d_out. The attention score/weight has shape seq_len x seq_len. This is because a given element needs a weight with each of the other elements. And finally the context vectors has shape seq_len x d_out, which is the product of attention weights and value vector.

Also, notice that we normalize attention scores by dividing it by the embedding dimension (keys.shape[-1]) and then apply the softmax.

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
 
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T  # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec
 
 
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(inputs.shape[1], 2)
print(sa_v1(inputs))
 
tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

Causal Attention

Finally, we complete the attention picture by applying mask to the attention mechanism. Masked attention, or causal attention, is a special form of self-attention where a model only have access to the previous and current tokens in a sequence as oppose to the entire sequence. In the attention weight matrix, the mask forms a diagonal matrix where the top half are zeros (positions beyond the diagonal are inaccessible).

queries = inputs @ sa_v1.W_query
keys = inputs @ sa_v1.W_key
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
print(attn_weights)
tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]],
       grad_fn=<SoftmaxBackward0>)

After we get the attention weights, we use torch.triu to create a diagonal mask, and we apply this mask to the attention weights and use negative infinity as the fill value. This then allow us to get a proper masked attention after applying softmax function.

context_length = attn_scores.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=1)
print(attn_weights)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3986, 0.6014, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2526, 0.3791, 0.3683, 0.0000, 0.0000, 0.0000],
        [0.2265, 0.2839, 0.2794, 0.2103, 0.0000, 0.0000],
        [0.1952, 0.2363, 0.2331, 0.1820, 0.1534, 0.0000],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]],
       grad_fn=<SoftmaxBackward0>)

Finally, we put everything in a CausalAttention class. In this attention implementation, we also include a dropout step.

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
        self.dropout = nn.Dropout(dropout)
        # use this to ensure buffer is moved to the same device as the model
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
 
    def forward(self, x):
        num_tokens, d_in = x.shape
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec
 
 
torch.manual_seed(123)
context_length = inputs.shape[0]
print(context_length)
ca = CausalAttention(inputs.shape[1], 2, context_length, 0.2)
context_vecs = ca(inputs)
print("context_vecs.shape:", context_vecs.shape)
6
context_vecs.shape: torch.Size([6, 2])