# Self-Attention Mechanism ## Motivation for Attention Mechanism The concept of attention tries to incorporate this question into network architecture: How relevant is the `ith` element in the sequence relative to other elements in the same sequence? Key ideas: - don't try to learn one global representation for the source sequence - rather learn context-sensitive token representations for each token - when generating a target token, dynamically combine the most relevant source representations (weighted sum with weights representing some notion of similarity between tokens) ## Self-Attention The attention mechanism applied inside [[Transformers]] is referred to as scaled-dot-product-attention. Each sequence element provides a key, value, and query vector, where each of these vectors are learned linear projection from the token representation. For each element, we perform attention computation where based on its query, we check the similarity of the all other sequence elements’ keys, and return a different, averaged value vector for each element, mixing other token information progressively. Input is set of queries $Q \in \mathbb{R}^{T \times d_{k}},$ keys $K \in \mathbb{R}^{T \times d_{k}}$ and values $V \in \mathbb{R}^{T \times d_{v}}$ where $T$ is the sequence length, and $d_{k}$ and $d_{v}$ are the dimensionality for queries/keys and values respectively. The attention tensor is given as: $ \text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V $ Large dot products cause softmax to saturate (outputting ~1 for the max value and ~0 for others), leading to vanishing gradients. Dividing by √d_k prevents the dot products from growing too large, keeping softmax in its sensitive range where gradients remain meaningful. Why √d_k specifically>? This [[High-Dimensional Dot Product Normalization]] ensures stable training across different attention head dimensions by preserving the variance despite the dimensionality. ```python def scaled_dot_product(q, k, v, mask=None): d_k = q.size()[-1] attn_logits = torch.matmul(q, k.transpose(-2, -1)) attn_logits = attn_logits / math.sqrt(d_k) if mask is not None: attn_logits = attn_logits.masked_fill(mask == 0, -9e15) attention = F.softmax(attn_logits, dim=-1) values = torch.matmul(attention, v) return values, attention ``` To account for the fact that an element in sequence can have multiple interpretation or different relation to neighbors, we can combine several attention mechanisms with **Multi-Head Attention**. We generate and use multiple sets of Q,K,V vectors. To get the final output attention vector, we multiply the concatenated individual attention vectors with $W^o$, which is then fed to the fully connected layer. Multi-headed attention improves the attention layer in the following ways: 1. Expands ability to focus on different positions. 2. Gives attention layer multiple "representation subspaces". ## Scaling Self-Attention Self-attention is expensive. $ \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V $ - The vanilla self-attention component with $O\left(n^2\right)$ time and memory complexity where $n$ is the input sequence length. - The term $QK^T$ is the main culprit. - Q, K, V has shape (batch_size, max_length, model_dim) - $QK^T$ has shape (batch_size, max_length, max_length) - So for sequence length of 16K, even batch size of 1 would be 64k $\times$ 64K which is 16GB memory in 32-bit float. - For deep transformers with large number of layers it quickly explodes just for forward pass. - For training i.e. backward pass, we require ~5x this memory i.e. parameters, activations, gradients, first order and second order moments (generally AdamW is used). There are several approaches and variants to help scale self-attention in modern architectures: ### [[Multi-Head Latent Attention (MHLA)]] Project key and value tensors to lower-dimensional space, breaking independence from the sequence length. Introduced in [DeepSeek-V2, 2024](https://arxiv.org/abs/2405.04434), not only saves memory but can outperform MHA! ### Group-query Attention (GQA) [Grouped-query Attention (GQA), 2023](https://arxiv.org/abs/2305.13245) showed that each head doesn't necessarily need its own keys and values, and sharing the same key and value heads across multiple heads saves a lot of memory without noticeably affecting modeling performance. Used heavily in modern architectures like Llama. ### Local Attention - Key idea: Divide input space into group of neighbors, apply self-attention separately and then combine the outputs — called as block-sparse attention. - To make sure farther elements can interact with each other, alternate block attentions layers with full attention layers. - Complexity reduces down to $O(n \sqrt{n})$ - Similar interleaved sparse and full self-attention is used in GPT-3 - Also introduced in the work [Liu et al. (ICLR 2018)](https://arxiv.org/abs/1801.10198) - Longformer by [Beltagy et al. (2020)](https://arxiv.org/abs/2004.05150) introduces a couple of additional ideas - Use a sliding windows $w$ acoss different layers similar to CNNs - each token attends to $\frac{w}{2}$ tokens, with complexity $O(n \times w)$ - Additionally, sliding windows can be "dilated" by adding gaps of size $d$ in windows, increasing coverage without increasing parameters. - Global attention: Allow certain tokens to attend across all tokens, not just $\frac{w}{2}$ to emulate property like CLS token in [[BERT]] - Global attention can be thought of as "memory" tokens and is generalized in Extended Transformer Construction (ETC) [Ainslie et al.](https://arxiv.org/abs/2004.08483) - LongT5 by [Guo et al. 2022](https://arxiv.org/abs/2112.07916) adopts similar idea by introducing Transient Global Attention (TGlobal) - Instead of arbitrary global tokens like ETC, create "transient" global tokens for fixed blocks of the tokens by summing tokens in the blocks. - Allow these global tokens to attend to full input and rest to attend locally. ### Compressed Attention #### Memory-compressed attention - Introduced by [Liu et al. (ICLR 2018)](https://arxiv.org/abs/1801.10198) for long sequence generation - ![[memory compressed attention.png]] - Key idea: Use [[Convolution]] operation on top of key and value matrices to reduce the size of the attention matrix - Introduces kernels $\theta_k$ and $\theta_v$ to compute self-attention as $ \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q (\theta_k \circledast K)^T}{\sqrt{d_k}}\right) (\theta_v \circledast V) $ - Reduces attention matrix size from $n \times d$ to $n \times (n/s)$ where $s$ is the stride and kernel size (3 in the original paper). #### Low-rank approximation of attention matrix - Introduced by [Wang et al. (2020)](https://arxiv.org/abs/2006.04768) in the Linformer paper - Key idea: Self-attention matrix can be approximated by a low-rank matrix. - Enables linear time $O(n)$ and space complexity! - Provides proof that low-rank matrix exists! - Introduce two linear projection matrices $E_i, F_i \in \mathbb{R}^{n \times k}$ to project original $n \times d$ key and value matrices to $k \times d$. When $k=O\left(d / \epsilon^2\right)$ approximates self attention with $\epsilon$ error. $ \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q (E_i K)^T}{\sqrt{d_k}}\right) F_i V $ - The projection matrices can be shared across layers and heads! ### Kernelized Attention - We can consider the exponential of dot product $QK^T$ to be a [[Kernel Methods|kernel function]] i.e. computing similarity. - Instead of doing full dot product, can we can find a mapping of Q and K that approximates the similarity function $k(x,y) = \exp(QK^T)$ of full self-attention? - Linear transformers [Katharopoulos et al. 2020](https://arxiv.org/pdf/2006.16236.pdf) - Self-attention can be rewritten as $V=\frac{\phi(Q)\left(\phi(K)^T V\right)}{\phi(Q) \phi(K)^T}$ - They choose $\phi(x)=\operatorname{elu}(x)+1$ - Reduces complexity to linear time for causual attention. - Performer [Choromanski et al. 2021](https://arxiv.org/abs/2009.14794) use an unbiased approximation kernel to approximate $QK^T$. ### Conditional Attention - What if not all tokens needs the same amount of computation with full self-attention? Can we learn to "route" inputs between light and heavy computation path? Similar in idea to [[Mixture of Experts]]. - CoLT5 [Ainslie et al. April 2023](https://arxiv.org/pdf/2303.09752.pdf) introduces token level conditional computation. - Light branch has lower hidden dimension, has fewer heads and applies only local attention. Heavy branch performs full attention. - How to "route" or find important tokens? - Multiply tokens with learned embedding to get scores $s$, select top-k highest tokens. - Apply conditional feedforward $X_i=X_i+\operatorname{FFd}_{\text {Light }}\left(X_i\right)+\tilde{s}_i \cdot \operatorname{FFd}_{\text {Heavy }}\left(X_i\right)$ - Apply conditional attention $X_i=X_i+\mathrm{A}_{\text {Light }}\left(X_i, X\right)+\tilde{s}_i^q \cdot \mathrm{A}_{\text {Heavy }}\left(X_i, \tilde{s}^{k v} X\right)$ - Results in upto 75% training speedup, 100% inference speedup than LongT5 with performace improvements. - Can handle upto 64k tokens. ### Recurrence based Transformers - Can we use a recurrence and memory states with Transformers? - RMT (Bulatov et al., 2022) uses global memory tokens, similar to ETC as part of the input and outputs memory tokens. - Shown to scale to 1M+ tokens! - Compatible with existing small input sized Transformers. - Incorporates recurrence: - Input is segmented to N segments. The first segment is appended with memory tokens and fed to the Transformer. - The second segment is now appended with the output memory tokens of the first segment and fed again to the Transformer. - This process is repeated until the full sequence is processed. - Basically [[Recurrent Neural Networks (RNN)]]! - Quadratic complexity can be reduced to linear and can handle arbitrary input length. ### Hardware Optimization Flash Attention [Dao et al. 2022](https://arxiv.org/pdf/2205.14135.pdf) - Use a single fused kernel operation for self-attention that takes into account GPU computational architecture. - Computes exact attention, not approximation. - Two techniques: - Incrementally perform softmax by splitting input into blocks (tiling). - Store softmax denominator on-chip to quickly recompute attention for backward pass. - Upto 7.6x faster on GPT-2 and uses less memory - Allows longer sequence length (16K) - Block-sparse version allows even longer (64K)