Multi-head Latent Attention (In A Nutshell!)

In this post, I will dive into the Multi-head Latent Attention (MLA) mechanism, one of the innovations presented by the DeepSeek Team! This post assumes prior knowledge of the attention mechanism and the key-value cache. For a quick refresher on these topics, refer to my previous post on self-attention!

One of the main problems with multi-head self-attention is the memory cost associated with the size of the key-value cache. MLA reduces the size of the key-value cache and speeds up LLM inference. The core idea is to cache latent embeddings that are shared across all heads (and for both keys and values) instead of different key and value embeddings for each head like in multi-head self-attention (Figure 1). The latent embeddings are multiplied with different key and value up-projection matrices for each head to produce different key and value embeddings unique to each head. Having unique key and value embeddings for each head maintains the expressivity of the attention mechanism, which contrasts with other mechanisms such as Multi-Query Attention or Grouped-query Attention, where keys and values are shared across different heads.

Figure 1
Figure 1: Comparison between different attention mechanisms. Figure adapted from Deepseek-V2 paper via Wikimedia Commons.

Storing latent embeddings instead of key and values embeddings from all heads reduces the size of the KV-cache (Table 1), freeing up more space on the GPU VRAM for larger models with more parameters.

AttentionMemory footprint
MHA\(2 * n_h * d_h * n\)
MLA\(d_l * n\)

Table 1: Approximate memory footprint for different attention mechanisms.

However, one might wonder - doesn't up-projection of the latents result in increased computational cost? DeepSeek mitigates this problem by exploiting the associative property of matrix multiplication to avoid explicitly up-projecting the latents to their 1) key and 2) value embeddings, which helps to reduce computational costs.

Avoiding explicit up-projection of latents to key embeddings

The authors avoid up-projecting the latents to their key embeddings by re-expressing how the attention score is computed:

\begin{aligned} w_{ij} &= \mathbf{q}_i^\top \mathbf{k}_j \\  &= (\mathbf{W}_q \mathbf{z}_i)^\top (\mathbf{W}_k \mathbf{z}_j) \\ &= \mathbf{z}_i^\top (\mathbf{W}_q^\top \mathbf{W}_k) \mathbf{z}_j  \\ &= \tilde{\mathbf{z}}_i^{\top} \mathbf{z}_j \end{aligned}

Typically, the attention score \(w_{ij}\) is computed by taking the dot product of the query embedding of token \(i\), \( \mathbf{q}_i \), and the key embedding of token \(j\), \( \mathbf{k}_j \). Instead of up-projecting the latents by multiplying with their respective query and key up-projection matrices, \( \mathbf{W}_q, \mathbf{W}_k \in \mathbb{R}^{d_e \times d_l} \) to obtain both query and key embeddings, the authors take an alternative approach. Firstly, the two up-projection matrices are multiplied to obtain a square matrix. Subsequently, the square matrix is premultiplied by the transpose of \( \mathbf{z}_i \) to get the tranpose of the "latent query" vector \( \tilde{\mathbf{z}}_i \). During inference, attention scores have to be computed between token \( i \) all other tokens. By re-arranging the above expression, the attention score can be computed by taking the dot product of the latent query vector \( \tilde{\mathbf{z}}_i \) with the latent vector of other tokens \( \mathbf{z}_j \), which avoids having to up-project the latents to their key embeddings when computing attention scores.

Avoiding explicit up-projection of latents to value embeddings

By combining the up-projection of latents to value embeddings with the final up-projection layer of the LLM, the authors avoid explicit up-projection of latents to value embeddings when calculating the final output \(\mathbf{y}\):

\begin{aligned} \mathbf{y} &= \mathbf{W}_o ( \mathbf{V}^\top \mathbf{p} ) \\  &= \mathbf{W}_o  (\mathbf{Z}\mathbf{W}_v^\top)^\top \mathbf{p}  \\ &= \mathbf{W}_o  (\mathbf{W}_v \mathbf{Z}^\top) \mathbf{p} \\ &= ( \mathbf{W}_o  \mathbf{W}_v ) \mathbf{Z}^\top \mathbf{p} \end{aligned}

where \(\mathbf{W}_o \in \mathbb{R}^{d_e \times v} \) is the weight matrix of the final layer, \(\mathbf{V} = \mathbf{Z}\mathbf{W}_v^\top \in \mathbb{R}^{n \times d_e}\) is the value embeddings of all tokens in the context, \(\mathbf{Z} \in \mathbb{R}^{n \times d_l}\) is the latent embeddings of all tokens in the context,  \(\mathbf{W}_v\) is the value up-projection matrix, and \(\mathbf{p}\) is the attention probability vector. Thus, by first multiplying \(\mathbf{W}_o\) and  \(\mathbf{W}_v\), we can obtain a "combined" weights matrix that is able to project vectors from the latent space straight to the output space, avoiding explicit up-projection to the value embedding space.

In an upcoming section, we will explore how position encodings are implemented in MLA (concatenation instead of addition and decoupled RoPE). Stay tuned!

References

Comments

Popular posts from this blog

Training an LLM (In a Nutshell!)

Upcoming blog posts

Self-Attention and the Key-Value Cache (In A Nutshell!)