# Autoregressive Generation and KV Caching in Transformers Given a sequence, each token's representation is projected into three vectors: $Q = XW_Q \quad K = XW_K \quad V = XW_V$ The self-attention output is computed as: $\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$ ## Parallel Training with Full Sequence During training, the entire sequence is processed in parallel. **Shapes:** $Q \in \mathbb{R}^{(n \times d_k)}, \quad K \in \mathbb{R}^{(n \times d_k)}, \quad V \in \mathbb{R}^{(n \times d_v)}$ **Computation:** $QK^T \in \mathbb{R}^{(n \times n)}$ $\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \cdot V \in \mathbb{R}^{(n \times d_v)}$ This is a **matrix-matrix multiplication**. A causal mask is applied to $QK^T$ to prevent tokens from attending to future positions: $M_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases}$ The mask is applied to the scores _before_ softmax, not after. $A = \text{softmax}\left(\frac{QK^T + M}{\sqrt{d_k}}\right)$ The $-\infty$ values become 0 after softmax, and the remaining values in each row sum to 1. Note that you *add* the mask not multiply. ## Autoregressive Generation During generation however, tokens are produced one at a time. At step $t$: **Shapes:** $Q_t \in \mathbb{R}^{(1 \times d_k)}, \quad K_{1:t} \in \mathbb{R}^{(t \times d_k)}, \quad V_{1:t} \in \mathbb{R}^{(t \times d_v)}$ **Computation:** $Q_t K_{1:t}^T \in \mathbb{R}^{(1 \times t)}$ $\text{softmax}\left(\frac{Q_t K_{1:t}^T}{\sqrt{d_k}}\right) \cdot V_{1:t} \in \mathbb{R}^{(1 \times d_v)}$ This is a **vector-matrix multiplication**. No mask is needed because future tokens do not exist yet. ## KV Caching At each generation step $t$, the attention operation requires: | Component | What's Needed | Size | |-----------|--------------|------| | $Q$ | Only $Q_t$ (current token) | $(1, d_k)$ | | $K$ | All previous: $K_1, K_2, \ldots, K_t$ | $(t, d_k)$ | | $V$ | All previous: $V_1, V_2, \ldots, V_t$ | $(t, d_v)$ | The current token's query $Q_t$ compares against all previous keys $K_{1:t}$ to compute attention weights, then retrieves a weighted combination of all previous values $V_{1:t}$. Rather than recomputing $K$ and $V$ for all previous tokens at each step, we cache them: - At step $t$, compute $K_t$ and $V_t$ for the new token and append to the cache. - Reuse cached $K_{1:t-1}$ and $V_{1:t-1}$ from previous steps. - $Q_t$ is not cached because it is only used at step $t$ and never referenced again. Without caching, all K and V projections would need to be recomputed at every step, resulting in $O(n^2)$ computation for generating each token at position $n$ and overall $O(n^3)$. With caching, generating each token becomes $O(n)$, and the overall complexity is therefore $O(n^2)$.