# Variational Inference When "learning to represent" an input $x$ we assume a latent variable $z$ and try to explain $x$ using all possible $z$ $ p(x)=\int_{z} p(x, z) d z=\int_{z} p(x \mid z) p(z) d z=\mathbb{E}_{z \sim p(z)}\left[p_{\theta}(x \mid z)\right] $ Hence, [[Latenent Variable Models]] can be viewed as a generation process - First, we generate a new $z$ from $p(z)$ by sampling - Then, we generate a new $x$ by sampling from the $p(x \mid z)$ given the sampled $z$ Generally, $p_{\theta}(x, z)$ is intractable. How do we find the optimal $\theta$? $\log \prod_{x \in D} p(x)=\sum_{x} \log p(x)=\sum_{x} \log \sum_{z} p_{\theta}(x, z)$ Like in Boltzmann machines, the $\sum_{z}$ ... is a nasty computation. E.g., for a 3 -dimensional binary $z$, we iterate over $[0,0,0],[0,0,1],[0,1,1], \ldots$. For 20 dimensions $2^{20} \approx 1 M$ latents and generations. Per image $x$! For continuous $z$ even harder, we cannot even enumerate. How do we make it tractable? ## Naive Monte Carlo We want to optimize per data point $x$: $\log \sum_{z} p_{\theta}(x, z)$. This sum is equivalent to expected value times the number of summands $ \log \sum_{z} p_{\theta}(x, z)=\log |Z| \mathbb{E}\left[p_{\theta}(x, z)\right] $ Do we need all the summands to compute the expected value (average)? No, if we sample randomly $z$ (uniformly) and average them to gives us an estimate. Basically replace whole sum with a weighted smaller sum: $ \log |Z| \mathbb{E}\left[p_{\theta}(x, z)\right] \approx \log \frac{|Z|}{K} \mathbb{E}_{\mathbf{z} \sim \text { Uniform }}^{(K)}\left[p_{\boldsymbol{\theta}}(\boldsymbol{x}, \mathbf{z})\right]=\log \frac{|Z|}{K} \sum_{k=1}^{K} p_{\boldsymbol{\theta}}\left(\boldsymbol{x}, \boldsymbol{z}_{k}\right) $ This doesnt scale. Most $z_{k}$ would be in "very low density regions i.e. unimportant $p_{\theta}\left(x, z_{k}\right)$. In technical terms, this is a 'high variance' estimator. ## Importance sampling Monte Carlo It would be better if select few good summands in the $\operatorname{sum} \sum_{k=1}^{K} p_{\theta}\left(x, z_{k}\right)$. If, theoretically, we had a nice distribution around the mass of relevant $z_{k}$ we could use that distribution to sample $z_{k}$ and get a better sample average with fewer $k$ such that $ \begin{aligned} &\log \sum_{z} p_{\theta}(x, z)=\log \sum_{z} q_{\varphi}(z) \frac{1}{q_{\varphi}\left(z_{k}\right)} p_{\theta}\left(x, z_{k}\right)\\ &=\log \mathbb{E}_{z \sim q_{\varphi}(z)}\left[\frac{p_{\theta}(x, z)}{q_{\varphi}(z)}\right] \approx \log \frac{1}{K} \sum_{k=1}^{K} \frac{p_{\theta}\left(x, z_{k}\right)}{q_{\varphi}\left(z_{k}\right)}, \text { where } z_{k} \text { are sampled from } q_{\varphi}\left(z_{k}\right) \end{aligned} $ Note the dual use of $q_{\varphi}\left(z_{k}\right)$: - In the nominator $q_{\varphi}\left(z_{k}\right)$ is the density function we use as sampling mechanism. By sampling from it (e.g., Gaussian samples if it is Gaussian) this quantity is used and disappears by the sum - In the denominator $q_{\varphi}\left(z_{k}\right)$ is simply a function. We feed it $z_{k}$ and returns how important $z_{k}$ is for our probability space Scales much better and with much lower variance, but we don't know what is a good $q_{\varphi}\left(z_{k}\right)$. ### Learning the importance sampling distribution Importance sampling is promising but how to determine $q_{\varphi}\left(\mathbf{z}_{k}\right)$ ? Learn $q_{\varphi}\left(z_{k}\right)$ from data! Our learning objective is to maximize the log probability $ \log \mathbb{E}_{\mathbf{z} \sim q_{\varphi}(z)}\left[\frac{p_{\theta}(x, z)}{q_{\varphi}(z)}\right] \approx \log \frac{1}{K} \sum_{k=1}^{K} \frac{p_{\theta}\left(x, z_{k}\right)}{q_{\varphi}\left(z_{k}\right)} $ The $\log \mathbb{E}$ stands for logarithm of an unknown integral and not very convenient for derivations and computations. Would be much nicer if we could swap the log $\mathbb{E}$ to $\mathbb{E}$ log. - Then we would simply need the expectation of the logarithm of a function - Especially convenient if $p_{\theta}(x, z)$ belongs to the exponential family. ### Lower bound on the maximum likelihood We can use [[Jensen's Inequality]] to obtain a lower bound! $ \log \mathbb{E}_{\mathbf{z} \sim q_{\varphi}(z)}\left[\frac{p_{\theta}(x, z)}{q_{\varphi}(z)}\right] \geq \mathbb{E}_{z \sim q_{\varphi}(z)}\left[\log \frac{p_{\theta}(x, z)}{q_{\varphi}(z)}\right] $ We replaced the original MLE objective with a quantity that is always smaller, implying (1) By improving $\mathbb{E}_{z \sim q_{\varphi}(z)}\left[\log \frac{p_{\theta}(x, z)}{q_{\varphi}(z)}\right]$ we always improve $\log \mathbb{E}_{z \sim q_{\varphi}(z)}\left[\frac{p_{\theta}(x, z)}{q_{\varphi}(z)}\right]$ because we have a 'Lower bound' (2) $\mathbb{E}_{z \sim q_{\varphi}(z)}\left[\log \frac{p_{\theta}(x, z)}{q_{\varphi}(z)}\right]$ is a tractable \& comfortable quantity which enables easy optimization. An expectation -> Monte Carlo sampling is possible The log can couple nicely with $p_{\theta}$ if chosen properly ## Making posterior tractable with variation inference We can also view variational inference from the lens of intractability. The problematic quantity in our latent model is the posterior The reason is the intractable normalization integration $ p(z \mid x)=\frac{p(x, z)}{p(x)}=\frac{p(x, z)}{\int p(x, z) d z} $ Variational inference approximates the true posterior $p(z \mid x)$ with $q_{\varphi}(z \mid x)$ $ \begin{array}{l} \mathrm{KL}(q(z) \| p(z \mid x))=\int q_{\varphi}(z) \log \frac{q_{\varphi}(z)}{p(z \mid x)} d z \\ =-\int q_{\varphi}(z) \log \frac{p(x, z)}{p(x) q_{\varphi}(z)} d z=-\int q_{\varphi}(z) \log \frac{p(x, z)}{q_{\varphi}(z)}+\int q_{\varphi}(z) \log p(x) d z \\ =-\mathbb{E}_{q_{\varphi}(z)}\left[\log \frac{p(x, z)}{q_{\varphi}(z)}\right]+\log p(x) \end{array} $ ## Evidence Lower Bound This first term is called the evidence lower bound or ELBO. $ \begin{array}{l} \log p(x)=\mathbb{E}_{q(z)}\left[\log \frac{p(x, z)}{q_{\varphi}(z)}\right]+\operatorname{KL}\left(q_{\varphi}(z) \| p(z \mid x)\right) \\ =\mathrm{ELBO}+\operatorname{KL}\left(q_{\varphi}(z) \| p(z \mid x)\right) \end{array} $ Why is it called 'evidence'? The KL term is always positive. If we drop it, we bound the log evidence $\log p(x)$ from below $ \log p(x) \geq \mathbb{E}_{q_{\varphi}(z)}\left[\log \frac{p_{\theta}(x, z)}{q_{\varphi}(z)}\right] $ Higher ELBO -> smaller difference to true $p_{\theta}(z \mid x)$ -> better latent representation Higher ELBO -> gap to log-likelihood tightens -> better density model ![[variation-inference.jpg]] ### ELBO balancing reconstruction and the prior We can expand the ELBO as $ \begin{array}{l} \mathbb{E}_{q(\mathbf{z})}\left[\log \frac{p(\boldsymbol{x}, \mathbf{z})}{q_{\varphi}(\mathbf{z})}\right]=\mathbb{E}_{q(\mathbf{z})}\left[\log \frac{p_{\boldsymbol{\theta}}(\boldsymbol{x} \mid \mathbf{z}) p(\mathbf{z})}{q_{\varphi}(\mathbf{z})}\right] \\ =\mathbb{E}_{q_{\varphi}(z)}\left[\log p_{\theta}(x \mid z)\right]-\operatorname{KL}\left[q_{\varphi}(z) \| p(z)\right] \end{array} $ - The first term encourages the reconstructions that the maximize likelihood - The second term minimizes the distance of the variational distribution from the prior. ### ELBO and entropy regularization We can also expand the ELBO as $ \begin{array}{l} \mathbb{E}_{q_{\varphi}(z)}\left[\log \frac{p(x, z)}{q(z)}\right]=\mathbb{E}_{q(z)}\left[\log \frac{p_{\theta}(x \mid z) p(z)}{q_{\varphi}(z)}\right] \\ =\mathbb{E}_{q(z)}[\log p(x, z)]-\mathbb{E}_{q(z)}\left[\log q_{\varphi}(z)\right] \\ =\mathbb{E}_{q(z)}[\log p(x, z)]+H\left(q_{\varphi}(z)\right) \end{array} $ where $\mathrm{H}(\cdot)$ is the entropy. - Maximizing the joint likelihood -> Something like the Boltzmann energy - While maintaining enough entropy ('uncertainty') in the distribution of latents - Avoiding latents to collapse to pathological, point estimates ( $z$ as single values) ## Variation inference underestimates variance If you noticed, for the second way to derive the ELBO we minimized $ \mathrm{KL}\left(q_{\varphi}(z) \| p(z \mid x)\right)=\int q_{\varphi}(z) \log \frac{q_{\varphi}(z)}{p(z \mid x)} d z $ We want to sample from $q_{\varphi}(z)$ in expectations later on, as $p(z \mid x)$ is intractable The model wants to approximate $p(z \mid x)$ -> can't really know where $p(z \mid x)$ is low. The model prefers to hedge and 'bias' $q_{\varphi}(z)$ towards 0 for regions it can't be certain. - Better pick one mode (randomly) than miss a 'zero' density region of $p(z \mid x)$ and skyrocket the $\frac{q_{\varphi}(z)}{p(Z \mid x)}$ ### Overestimating variance We would need to use the forward $\mathrm{KL}$ $ \mathrm{KL}\left(p(z \mid x) \| q_{\varphi}(z)\right)=\int p(z \mid x) \log \frac{p_{\theta}(z \mid x)}{q_{\varphi}(z)} d z $ The model would prefer placing some density everywhere That way it avoids $\frac{p(z \mid x)}{g_{\varphi}(x)}$ skyrocketing if it misses areas where $p(z \mid x)$. But forward KL is not always easy to compute. --- ## References