In this post we introduce an alternative view on Expectation Maximization using KL-divergence by Jianlin from https://kexue.fm. Unlike the common view on EM using Jensen's inequality, the derivation of EM using KL-divergence is shorter and more intuitive.
Jianlin's post (written in Chinese)
From our previous post about EM, we know that EM's objective is maximizing the expected log likelihood of the training dataset. In here, we want to expand the concept to show that such optimization is a special case of a more general form of optimization: minimizing the KL divergence. Given two probability distribution $p(x)$ and $q(x)$, KL measure how close $q(x)$ is to $p(x)$
\begin{align} {KL} (p(x) \parallel q(x)) &= \sum_x p(x) \log \frac{p(x)}{q(x)} \tag{KL divergence} \end{align}
To derive EM using the concept of KL divergence, let's say we want to use $p_{\theta}(x)$ to approximate the underlying distribution of the training dataset $p_{true}(x)$. Given that $p_{true}(x)$ is known (as constant), minimizing the KL divergence is equivalent to maximizing the expected log likelihood of the training dataset.
\begin{align} \theta_{best} &= \operatorname*{argmin}_{\theta} {KL} (p_{true}(x) \parallel p_{\theta}(x)) \\ &= \operatorname*{argmin}_{\theta} \sum_x p_{true}(x) \log \frac{p_{true}(x)}{p_{\theta}(x)} \\ &= \operatorname*{argmin}_{\theta} \sum_x p_{true}(x) \log p_{true}(x) - \sum_x p_{true}(x) \log p_{\theta}(x) \tag{first term is constant} \\ &= \operatorname*{argmin}_{\theta} - \sum_x p_{true}(x) \log p_{\theta}(x) \tag{Cross Entropy} \\ &= \operatorname*{argmax}_{\theta} \mathbb{E}_x (\log p_{\theta}(x)) \tag{0} \end{align}
where (0) is the objective of the general EM algorithm.
Incorporating the hidden variable Z that are used to explain X, we can rewrite our objective using the joint probability of X and Z
\begin{align} & {KL} (p_{true}(x, z) \parallel p_{\theta}(x, z)) \\ &= \sum_{x, z} p_{true}(x, z) \log \frac{p_{true}(x, z)}{p_{\theta}(x, z)} \\ &= \sum_{x} p_{true}(x) \sum_{z} p_{true}(z | x) \log \frac{p_{true}(z|x)p_{true}(x)}{p_{\theta}(z|x)p_{\theta}(x)} \\ &= \mathbb{E}_x \Big[ \sum_{z} p_{true}(z | x) \log \frac{p_{true}(z|x)p_{true}(x)}{p_{\theta}(z|x)p_{\theta}(x)} \Big] \\ &= \mathbb{E}_x \Big[ \sum_{z} p_{true}(z | x) \log \frac{p_{true}(z|x)}{p_{\theta}(z|x)p_{\theta}(x)} + \sum_{z} p_{true}(z | x) \log p_{true}(x) \Big] \\ &= \mathbb{E}_x \Big[ \sum_{z} p_{true}(z | x) \log \frac{p_{true}(z|x)}{p_{\theta}(z|x)p_{\theta}(x)} \Big] + \mathbb{E} \Big[ \log p_{true}(x) \Big] \\ &= \mathbb{E}_x \Big[ \sum_{z} p_{true}(z | x) \log \frac{p_{true}(z|x)}{p_{\theta}(z|x)p_{\theta}(x)} \Big] + C \tag{1} \end{align}
To minimize (1), there are two terms we can adjust: $p_{true}(z | x)$ and $p_{\theta}(z|x)p_{\theta}(x)$. Similar to EM, we can optimize one part assuming the other part is constant and we can do this interchangeably. So, if we assume $p_{\theta}(z|x)p_{\theta}(x)$ is known, $p_{true}(z | x)$ can be adjust to minimize (1) as follow
\begin{align} p_{true}(z | x) &= \operatorname*{argmin}_{p_{true}(z | x)} \mathbb{E}_x \Big[ \sum_{z} p_{true}(z | x) \log \frac{p_{true}(z|x)}{p_{\theta}(z|x)p_{\theta}(x)} \Big] \\ &= \operatorname*{argmin}_{p_{true}(z | x)} \mathbb{E}_x \Big[ \sum_{z} p_{true}(z | x) \log \frac{p_{true}(z|x)}{p_{\theta}(z|x)} - \sum_{z} p_{true}(z | x) \log p_{\theta}(x) \Big] \\ &= \operatorname*{argmin}_{p_{true}(z | x)} \mathbb{E}_x \Big[ {KL} (p_{true}(z | x) \parallel p_{\theta}(z | x)) - C \Big] \\ &= \operatorname*{argmin}_{p_{true}(z | x)} \mathbb{E}_x \Big[ {KL} (p_{true}(z | x) \parallel p_{\theta}(z | x)) \Big] \tag{2} \\ &= p_{\theta}(z | x) \tag{E-step} \end{align}
from (2) to (E-step), we need to use the fact that ${KL}(p_{\theta}(z \vert x) \parallel p_{\theta}(z \vert x)) = 0$. For M step, we assume $p_{true}(z \vert x)$ is constant and is equal to $p_{\theta}(z \vert x)$
\begin{align} \theta_{new} &= \operatorname*{argmin}_{\theta} \mathbb{E}_x \Big[ \sum_{z} p_{true}(z | x) \log \frac{p_{true}(z|x)}{p_{\theta}(z|x)p_{\theta}(x)} \Big] \\ &= \operatorname*{argmin}_{\theta} \mathbb{E}_x \Big[\sum_{z} p_{true}(z | x) \log p_{true}(z | x) - \sum_{z} p_{true}(z | x) \log {p_{\theta}(z|x)p_{\theta}(x)} \Big] \\ &= \operatorname*{argmax}_{\theta} \mathbb{E}_x \Big[\sum_{z} p_{true}(z | x) \log {p_{\theta}(z|x)p_{\theta}(x)} \Big] \\ &= \operatorname*{argmax}_{\theta} \mathbb{E}_x \Big[\sum_{z} p_{\theta}(z | x) \log p_{\theta}(x, z) \Big] \tag{M-step} \end{align}
The above results are identical to the defintion of EM algorithm.