KV cache is a mechanism to cache repeated calculations during model inference. During inference, for each new token generation, we use a new set of key, value and query vectors. But given the fact that that the existing tokens stay the same, the vectors associated with their values do not need to be recalculated. The only thing new is the vectors associated with the new token, or more precisely, the new part of the existing vectors that have grown in size. Therefore, we cache these intermediate key and value vectors to accelerate inference speed. In this entry, we prove why we can cache key and value vectors in causal attention.

To illustrate this repeated calculations consider the following example:

import torch
import torch.nn as nn

First, we setup a simple causal attention mechanism implemented in a previous post. The input has 6 elements.

torch.manual_seed(123)
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],
        [0.55, 0.87, 0.66],
        [0.57, 0.85, 0.64],
        [0.22, 0.58, 0.33],
        [0.77, 0.25, 0.10],
        [0.05, 0.80, 0.55],
    ]
)
key_embeddings = nn.Parameter(torch.rand(inputs.shape[1], inputs.shape[1]))
query_embeddings = nn.Parameter(torch.rand(inputs.shape[1], inputs.shape[1]))
value_embeddings = nn.Parameter(torch.rand(inputs.shape[1], inputs.shape[1]))

We write a function to calculate the context vector:

def calculate_context_vector(
    inputs, key_embeddings, query_embeddings, value_embeddings
):
    keys = inputs @ key_embeddings
    queries = inputs @ query_embeddings
    values = inputs @ value_embeddings
    attn_scores = queries @ keys.T
    context_length = attn_scores.shape[0]
    mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
    masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
    attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=1)
    context_vec = attn_weights @ values
    return context_vec

We calculate the context vector for all the input elements

ctxt = calculate_context_vector(
    inputs, key_embeddings, query_embeddings, value_embeddings
)
ctxt
tensor([[0.4976, 0.9655, 0.7614],
        [0.7674, 1.2199, 1.2528],
        [0.8186, 1.2667, 1.3497],
        [0.7324, 1.1287, 1.2029],
        [0.6963, 1.0718, 1.1713],
        [0.6824, 1.0370, 1.1307]], grad_fn=<MmBackward0>)

Now we add a new token/row to the input vector

torch.manual_seed(123)
# assume now we have a new token
new_row = torch.rand(1, inputs.shape[1])
new_inputs = torch.cat([inputs, new_row], dim=0)
new_inputs
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500],
        [0.2961, 0.5166, 0.2517]])

We calculate a new context vector based on the new input and we use torch.allclose to see if the new context vector share the first 6 rows with the previous context vector:

new_ctxt = calculate_context_vector(
    new_inputs, key_embeddings, query_embeddings, value_embeddings
)
print(new_ctxt, torch.allclose(ctxt, new_ctxt[:-1]), sep="\n")
tensor([[0.4976, 0.9655, 0.7614],
        [0.7674, 1.2199, 1.2528],
        [0.8186, 1.2667, 1.3497],
        [0.7324, 1.1287, 1.2029],
        [0.6963, 1.0718, 1.1713],
        [0.6824, 1.0370, 1.1307],
        [0.6538, 0.9875, 1.0863]], grad_fn=<MmBackward0>)
True

So the new context vector only adds a single new row, given a new input element in the input sequence. We can leverage this fact to save lots of calculations that are repeated in calculate_context_vector. More specifically, to find the new addition of the context vector, its last row, we only need to perform the same procedure, now only on the last row of the query vector new_queries[-1:]:

new_keys = new_inputs @ key_embeddings
new_values = new_inputs @ value_embeddings
new_queries = new_inputs @ query_embeddings
 
attn_scores = new_queries[-1:] @ new_keys.T
attn_weights = torch.softmax(attn_scores / new_keys.shape[-1] ** 0.5, dim=-1)
context_vec_row = attn_weights @ new_values
context_vec_row
print(context_vec_row, torch.allclose(new_ctxt[-1], context_vec_row), sep="\n")
tensor([[0.6538, 0.9875, 1.0863]], grad_fn=<MmBackward0>)
True

As shown above, to get the new context vector row, we only need the last row of the query vector, and full key and value vectors. But the key and value vectors are not entirely new either. Only the last row of these new vectors are new, the rest are the same as previous iteration of these vectors:

keys = inputs @ key_embeddings
values = inputs @ value_embeddings
print(torch.allclose(new_keys[:-1], keys))
print(torch.allclose(new_values[:-1], values))
 
True
True

This forms the basis of kv cache.