# Autoregressive Generation and KV Caching in Transformers
Given a sequence, each token's representation is projected into three vectors: $Q = XW_Q \quad K = XW_K \quad V = XW_V$ The self-attention output is computed as: $\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$
## Parallel Training with Full Sequence
During training, the entire sequence is processed in parallel.
**Shapes:** $Q \in \mathbb{R}^{(n \times d_k)}, \quad K \in \mathbb{R}^{(n \times d_k)}, \quad V \in \mathbb{R}^{(n \times d_v)}$ **Computation:** $QK^T \in \mathbb{R}^{(n \times n)}$ $\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \cdot V \in \mathbb{R}^{(n \times d_v)}$
This is a **matrix-matrix multiplication**. A causal mask is applied to $QK^T$ to prevent tokens from attending to future positions: $M_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases}$
The mask is applied to the scores _before_ softmax, not after.
$A = \text{softmax}\left(\frac{QK^T + M}{\sqrt{d_k}}\right)$
The $-\infty$ values become 0 after softmax, and the remaining values in each row sum to 1. Note that you *add* the mask not multiply.
## Autoregressive Generation
During generation however, tokens are produced one at a time. At step $t$:
**Shapes:** $Q_t \in \mathbb{R}^{(1 \times d_k)}, \quad K_{1:t} \in \mathbb{R}^{(t \times d_k)}, \quad V_{1:t} \in \mathbb{R}^{(t \times d_v)}$ **Computation:** $Q_t K_{1:t}^T \in \mathbb{R}^{(1 \times t)}$ $\text{softmax}\left(\frac{Q_t K_{1:t}^T}{\sqrt{d_k}}\right) \cdot V_{1:t} \in \mathbb{R}^{(1 \times d_v)}$
This is a **vector-matrix multiplication**. No mask is needed because future tokens do not exist yet.
## KV Caching
At each generation step $t$, the attention operation requires:
| Component | What's Needed | Size |
|-----------|--------------|------|
| $Q$ | Only $Q_t$ (current token) | $(1, d_k)$ |
| $K$ | All previous: $K_1, K_2, \ldots, K_t$ | $(t, d_k)$ |
| $V$ | All previous: $V_1, V_2, \ldots, V_t$ | $(t, d_v)$ |
The current token's query $Q_t$ compares against all previous keys $K_{1:t}$ to compute attention weights, then retrieves a weighted combination of all previous values $V_{1:t}$.
Rather than recomputing $K$ and $V$ for all previous tokens at each step, we cache them:
- At step $t$, compute $K_t$ and $V_t$ for the new token and append to the cache.
- Reuse cached $K_{1:t-1}$ and $V_{1:t-1}$ from previous steps.
- $Q_t$ is not cached because it is only used at step $t$ and never referenced again.
Without caching, all K and V projections would need to be recomputed at every step, resulting in $O(n^2)$ computation for generating each token at position $n$ and overall $O(n^3)$. With caching, generating each token becomes $O(n)$, and the overall complexity is therefore $O(n^2)$.