KV cache is a mechanism to cache repeated calculations during model inference. During inference, for each new token generation, we use a new set of key, value and query vectors. But given the fact that that the existing tokens stay the same, the vectors associated with their values do not need to be recalculated. The only thing new is the vectors associated with the new token, or more precisely, the new part of the existing vectors that have grown in size. Therefore, we cache these intermediate key and value vectors to accelerate inference speed. In this entry, we prove why we can cache key and value vectors in causal attention.
To illustrate this repeated calculations consider the following example:
import torch
import torch.nn as nnFirst, we setup a simple causal attention mechanism implemented in a previous post. The input has 6 elements.
torch.manual_seed(123)
inputs = torch.tensor(
[
[0.43, 0.15, 0.89],
[0.55, 0.87, 0.66],
[0.57, 0.85, 0.64],
[0.22, 0.58, 0.33],
[0.77, 0.25, 0.10],
[0.05, 0.80, 0.55],
]
)
key_embeddings = nn.Parameter(torch.rand(inputs.shape[1], inputs.shape[1]))
query_embeddings = nn.Parameter(torch.rand(inputs.shape[1], inputs.shape[1]))
value_embeddings = nn.Parameter(torch.rand(inputs.shape[1], inputs.shape[1]))We write a function to calculate the context vector:
def calculate_context_vector(
inputs, key_embeddings, query_embeddings, value_embeddings
):
keys = inputs @ key_embeddings
queries = inputs @ query_embeddings
values = inputs @ value_embeddings
attn_scores = queries @ keys.T
context_length = attn_scores.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=1)
context_vec = attn_weights @ values
return context_vecWe calculate the context vector for all the input elements
ctxt = calculate_context_vector(
inputs, key_embeddings, query_embeddings, value_embeddings
)
ctxttensor([[0.4976, 0.9655, 0.7614],
[0.7674, 1.2199, 1.2528],
[0.8186, 1.2667, 1.3497],
[0.7324, 1.1287, 1.2029],
[0.6963, 1.0718, 1.1713],
[0.6824, 1.0370, 1.1307]], grad_fn=<MmBackward0>)
Now we add a new token/row to the input vector
torch.manual_seed(123)
# assume now we have a new token
new_row = torch.rand(1, inputs.shape[1])
new_inputs = torch.cat([inputs, new_row], dim=0)
new_inputstensor([[0.4300, 0.1500, 0.8900],
[0.5500, 0.8700, 0.6600],
[0.5700, 0.8500, 0.6400],
[0.2200, 0.5800, 0.3300],
[0.7700, 0.2500, 0.1000],
[0.0500, 0.8000, 0.5500],
[0.2961, 0.5166, 0.2517]])
We calculate a new context vector based on the new input and we use torch.allclose to see if the new context vector share the first 6 rows with the previous context vector:
new_ctxt = calculate_context_vector(
new_inputs, key_embeddings, query_embeddings, value_embeddings
)
print(new_ctxt, torch.allclose(ctxt, new_ctxt[:-1]), sep="\n")tensor([[0.4976, 0.9655, 0.7614],
[0.7674, 1.2199, 1.2528],
[0.8186, 1.2667, 1.3497],
[0.7324, 1.1287, 1.2029],
[0.6963, 1.0718, 1.1713],
[0.6824, 1.0370, 1.1307],
[0.6538, 0.9875, 1.0863]], grad_fn=<MmBackward0>)
True
So the new context vector only adds a single new row, given a new input element in the input sequence. We can leverage this fact to save lots of calculations that are repeated in calculate_context_vector. More specifically, to find the new addition of the context vector, its last row, we only need to perform the same procedure, now only on the last row of the query vector new_queries[-1:]:
new_keys = new_inputs @ key_embeddings
new_values = new_inputs @ value_embeddings
new_queries = new_inputs @ query_embeddings
attn_scores = new_queries[-1:] @ new_keys.T
attn_weights = torch.softmax(attn_scores / new_keys.shape[-1] ** 0.5, dim=-1)
context_vec_row = attn_weights @ new_values
context_vec_row
print(context_vec_row, torch.allclose(new_ctxt[-1], context_vec_row), sep="\n")tensor([[0.6538, 0.9875, 1.0863]], grad_fn=<MmBackward0>)
True
As shown above, to get the new context vector row, we only need the last row of the query vector, and full key and value vectors. But the key and value vectors are not entirely new either. Only the last row of these new vectors are new, the rest are the same as previous iteration of these vectors:
keys = inputs @ key_embeddings
values = inputs @ value_embeddings
print(torch.allclose(new_keys[:-1], keys))
print(torch.allclose(new_values[:-1], values))
True
True
This forms the basis of kv cache.