# Layer Normalization The layer normalization tries to overcome some of the limitations of batch normalization, especially when the batch size is too small to extract robust statistics about each feature. In layer normalization, mean and variance are computed independently for each element of the batch by aggregating over the features dimension. Effectively, layer normalization treats each sample as its own little world for normalization, while batch normalization treats each batch as the world. **Why "layer" normalization?** It normalizes across all features within a layer for each sample individually. **Isn't this noisy?** Yes, computing statistics from one sample can be noisier than using batch statistics, but: - In practice, layers often have hundreds/thousands of features, making single-sample statistics more reliable - The noise can act as regularization - It's independent of batch size and works well for sequential models For each sample $s$: - Mean: $\mu_s = \frac{1}{M} \sum_{i=1}^{M} X_{si}$ (average across this sample's M features) - Variance: $\sigma_s^2 = \frac{1}{M} \sum_{i=1}^{M} (X_{si} - \mu_s)^2$ - Normalize: $\hat{X}{si} = \frac{X{si} - \mu_s}{\sqrt{\sigma_s^2 + \epsilon}}$ - Scale & shift: $Y_{si} = \gamma_i \hat{X}_{si} + \beta_i$ Each sample normalizes independently, but scale/shift parameters are learned per feature. ## Gradients Let's start with an intermediate gradient $ \begin{aligned} \frac{\partial L}{\partial \gamma \hat{X}} \Rightarrow\left[\frac{\partial L}{\partial \gamma \hat{X}}\right]_{r i}=\frac{\partial L}{\partial \gamma_{i} \hat{X}_{r i}} &=\sum_{s j} \frac{\partial L}{\partial Y_{s j}} \cdot \frac{\partial Y_{s j}}{\partial \gamma_{i} \hat{X}_{r i}} \\ &=\sum_{s j} \frac{\partial L}{\partial Y_{s j}} \cdot \frac{\partial\left(\gamma_{i} \hat{X}_{s j}+\beta_{j}\right)}{\partial \gamma_{i} X_{r i}} \\ &=\sum_{s j} \frac{\partial L}{\partial Y_{s j}} \end{aligned} $ Now the gradients of learnable parameters is given by, $ \begin{align} \frac{\partial L}{\partial \beta} \Rightarrow \left[ \frac{\partial L}{\partial \beta} \right]_i &= \frac{\partial L}{\partial \beta_i} = \sum_{sj}^{}\frac{\partial L}{\partial Y_{sj}} \frac{\partial Y_{sj}}{\partial \beta_i} \\ &= \sum_{sj}^{}\frac{\partial L}{\partial Y_{sj}} \frac{\partial }{\partial \beta_i}(\gamma_j \hat{X}_{sj} + \beta_j) \\ &= \sum_{sj}^{}\frac{\partial L}{\partial Y_{sj}} \gamma_i \hat{X}_{sj} \\ \end{align} $ $ \begin{align} \frac{\partial L}{\partial \gamma} \Rightarrow \left[ \frac{\partial L}{\partial \gamma} \right]_i = \frac{\partial L}{\partial \gamma_i} &= \sum_{sj}^{}\frac{\partial L}{\partial Y_{sj}} \frac{\partial Y_{sj}}{\partial \gamma_i} \\ &= \sum_{sj}^{}\frac{\partial L}{\partial Y_{sj}} \frac{\partial }{\partial \gamma_i} (\gamma_j\hat{X}_{sj}+\beta_j) \\ &= \sum_{sj}^{}\frac{\partial L}{\partial Y_{sj}} \hat{X}_{si} \\ \end{align} $ Other intermediate gradients necessary to calculate input gradients are $ \begin{aligned} \frac{\partial L}{\partial \hat{X}} \Rightarrow\left[\frac{\partial L}{\partial \hat{X}}\right]_{s i} &=\sum_{s j} \frac{\partial L}{\partial \gamma_{j} \hat{X}_{s i}} \cdot \frac{\partial \gamma_{i} \hat{X}_{sj}}{\partial \hat{X}_{s i}} \\ &=\sum_{s j} \frac{\partial L_{j}}{\partial \gamma_{j}\hat{X}_{sj}} \cdot \gamma_{i j} \\ &=\sum_{s j} \frac{\partial L}{\partial Y_{s j}} \gamma_{i} \end{aligned} $ $ \begin{align} \frac{\partial L}{\partial \sigma^2} \Rightarrow \left[\frac{\partial L}{\partial \sigma^2} \right]_s &= \sum_{sj}^{}\frac{\partial L}{\partial \hat{X}_{sj}} \frac{\partial \hat{X}_{sj}}{\partial \sigma^2_s} \\ &= \sum_{sj}^{}\frac{\partial L}{\partial \hat{X}_{sj}} \frac{\partial }{\partial \sigma_s^2} \left(\frac{X_{sj} - \mu_j}{\sqrt{\sigma_s^2 + \epsilon}} \right) \\ &= \sum_{sj}^{}\frac{\partial L}{\partial \hat{X}_{sj}} (X_{sj} - \mu_s) \frac{\partial }{\partial \sigma_s^2} (\sigma_s^2 + \epsilon)^{-\frac{1}{2}} \\ &= \sum_{sj}^{}\frac{\partial L}{\partial \hat{X}_{sj}} (X_{sj} - \mu_s). -\frac{1}{2}. (\sigma_s^2 + \epsilon)^{-\frac{3}{2}} \\ \end{align} $ $ \begin{align} \frac{\partial L}{\partial \mu} \Rightarrow \left[ \frac{\partial L}{\partial \mu}\right]_s &= \sum_{sj}^{}\frac{\partial L}{\partial \hat{X}_{sj}}\frac{\partial \hat{X}_{sj}}{\partial \mu_s} + \frac{\partial L}{\partial \sigma_s^2}\frac{\partial \sigma_s^2}{\partial \mu_s} \\ &= \sum_{sj}^{}\frac{\partial L}{\partial \hat{X}_{sj}} \frac{\partial }{\partial \mu_s} \frac{(X_{sj} - \mu_s)}{\sqrt{\sigma_s^2 + \epsilon}} + \frac{\partial L}{\partial \sigma_s^2} \frac{\partial }{\partial \mu_s}\frac{1}{M}\sum_{i=1}^{M}(X_{sj} - \mu_s)^2 \\ &= \sum_{sj}^{}\frac{\partial L}{\partial \hat{X}_{sj}} \frac{-1}{\sqrt{\sigma_s^2}+\epsilon} + \frac{\partial L}{\partial \sigma_s^2}\frac{1}{M}\sum_{i=1}^{M}-2(X_{sm}-\mu_s) \\ &= \sum_{sj}^{}\frac{\partial L}{\partial \hat{X}_{sj}} \frac{-1}{\sqrt{\sigma_s^2+\epsilon}} + \frac{\partial L}{\partial \sigma_s^2} -2 \left(\mu_s - M \frac{\mu_s}{M} \right) \\ &= \sum_{sj}^{}\frac{\partial L}{\partial \hat{X}_{sj}}\frac{-1}{\sqrt{\sigma_s^2+\epsilon}} \\ \end{align} $ So the final input gradient is given as, $ \begin{align} \frac{\partial L}{\partial X}\Rightarrow \left[ \frac{\partial L}{\partial X}\right]_{ri}&= \sum_{sj}^{}\frac{\partial L}{\partial \hat{X}_{sj}} \frac{\partial \hat{X}_{sj}}{\partial X_{ri}} + \frac{\partial L}{\partial \mu_s}.\frac{\partial \mu_s}{\partial X_{ri}} + \frac{\partial L}{\partial \sigma_s^2}\frac{\partial \sigma_s^2}{\partial X_{ri}} \\ &= \sum_{sj}^{}\frac{\partial L}{\partial \hat{X}_{sj}} \frac{1}{\sqrt{\sigma_s^2+\epsilon}} + \frac{\partial L}{\partial \mu_s} \frac{\partial }{\partial X_{ri}}\frac{1}{M}\sum_{k=1}^{M}X_{sk}+\frac{\partial }{\partial \sigma_s^2}\frac{1}{M}(X_{ri} - \mu_s) \\ &= \sum_{sj}^{}\frac{\partial L}{\partial \hat{X}_{sj}} \frac{1}{\sqrt{\sigma_s^2+\epsilon}} + \frac{1}{M} \frac{\partial L}{\partial \mu_s} + \frac{2}{M} \frac{\partial L}{\partial \sigma_s^2}(X_{ri} - \mu_s) \\ \end{align} $ ## References 1. Why they help self attention: https://arxiv.org/abs/2305.02582