Self-Attention and the Key-Value Cache (In A Nutshell!)

The Transformer architecture underpins most modern large language models (LLMs). In the seminal paper "Attention is All You Need", Vaswani et al. propose a Transformer architecture that relies solely on the multi-head self-attention mechanism to learn global dependencies (i.e. relationships) between words in a sentence. In this post, I will first explain the multi-head self-attention mechanism that is used in LLMs such as the original ChatGPT model (which was derived from GPT-3.5), and go on to explain why a key-value cache is needed for efficient inference. For illustrative purposes, I use words to represent tokens.

ChatGPT uses a decoder-only Transformer architecture, and is trained to predict the next token (e.g. sub-word, punctuation) given a context (i.e. message). I start by illustrating the multi-head self-attention mechanism using a single sentence as an example. Assume that we are training an LLM with a context window of 10 words; a 10-word sentence in the training data set will be used to generate 10 samples of lengths 1 to 10 (Figure 1).

Figure 1. Example of how a sentence in the training dataset is used to generate 10 samples. The blue text indicates the words in the sample (i.e. the context), while the red text indicates the target word to be predicted. The <eos> tag represents the end of sequence token. The period punctuation as considered as a token/word as well.

I first describe how a single head from the multi-head self-attention mechanism works. Each single-head self-attention works by first transforming the embedding vector of each word into its query (Q), key (K) and value (V) embeddings. Subsequently, each sample (in Figure 1) is represented as a \(d_h\)-dimensional vector based on its Q, K, and V embeddings, regardless of the number of words it contains.

I illustrate how Sample 8 in Figure 1 is represented as a vector as an example. We first calculate the strength of the relationship (i.e. weight) between the 8th word and word \(i\) by taking the dot product of the query embedding of the 8th word, \(\mathbf{q}_8\), and the key embedding of word \(i\), \(\mathbf{k}_i\). This is repeated for words 1 to 8 to obtain a weight vector, \( \mathbf{w} = (\mathbf{q}_8 \mathbf{k}_1, \ldots, \mathbf{q}_8 \mathbf{k}_8) \). Subsequently, the weight vector is scaled by dividing by the square root of the dimension of the embeddings and transformed using softmax to obtain a probability vector, \( \mathbf{p} = \text{softmax}(\mathbf{w} / \sqrt{d_h}) \). These probabilities are used to calculate the the vector representation of the sample, which in essence is the weighted sum of the value embeddings of the words in the sample, \( \mathbf{z_j} = \mathbf{V}^\top \mathbf{p} \), where \(\mathbf{V} \in \mathbb{R}^{8 \times d_h} \) is the value embeddings of all 8 words in the sample.

The output of the multi-head self-attention mechanism \( \mathbf{z} \in \mathbb{R}^{d_e} \), is obtained by concatenating the output of each single head \(j\), denoted by \( \mathbf{z_j} \in \mathbb{R}^{d_h} \),

\[ \mathbf{z}  = \begin{bmatrix} \mathbf{z_1} \\ \vdots \\ \mathbf{z_{n_h}} \end{bmatrix} \]

where the dimension of each single head is obtained by dividing the embedding dimensions by the number of heads: \( d_h = d_e / n_h \).

Multi-head self-attention can be expressed mathematically as:
\[ \text {Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax} \left( \frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_e}} + \mathbf{M} \right) \mathbf{V} \]

where \(\mathbf{Q} \in \mathbb{R}^{n \times d_h} \) is the query embeddings matrix, \(\mathbf{K} \in \mathbb{R}^{n \times d_h} \) is the key embeddings matrix, \(\mathbf{K} \in \mathbb{R}^{n \times d_h} \) is the value embeddings matrix, and \(\mathbf{M} \in \mathbb{R}^{n \times n} \) is a triangular mask matrix with elements: \[ m_{ij} = \begin{cases} 0 & \text{if } j \leq i \\[6pt] -\infty & \text{if } j > 1 \end{cases} \]

Now that we know how the multi-head self-attention mechanism works, we can answer why a KV cache is needed. The KV cache serves to speed up inference, by caching the key and value embeddings calculated in the previous steps to avoid having to recompute them. Going back to our above example, the key and value embeddings of words 1 to 8 were computed to predict the 9th word. These values are required to predict the 10th word as well. With the KV cache, these values are simply retrieved and only the key and value embedding of the 9th word and the query embedding of the 10th word has to be computed. This reduces the computational costs and increases the speed of inference, at the expense of extra memory.

Comments

Popular posts from this blog

Training an LLM (In a Nutshell!)

Upcoming blog posts