Georg Grab

The Variance Collapse Problem in the Expectation-Maximization Algorithm of Gaussian Mixtures.

In my opinion, the problem of variance collapse when fitting a gaussian mixture model to data tends to be downplayed a little. But see for yourself.

What is the pixelated mess that's flashing across the screen above, you ask? That's the Expectation-Maximization algorithm running right here in your browser [1], applied to a dataset generated by a gaussian mixture model. A pretty easy dataset, if I may add: the three point clusters are clearly separated. The gaussians we want to fit are randomly initialized, then the algorithm runs to convergence, and the final result is shown for a few seconds until repeating. You may see that more often than not, some of the distributions collapse onto a single datapoint while the other distributions explode and fit the rest of the points. This phenomenon will be the topic of this article, but we're getting ahead of ourselves. Let's briefly review what the Expectation-Maximization (EM) algorithm is about, what Gaussian mixtures are, and why we would need any of that.

EM Algorithm with Gaussian Mixtures

The EM algorithm goes back to a paper by Dempster, Laird, and Rubin from 1977 [2]. It is an iterative method to find the maximum likelihood estimate of the parameters of a statistical model given some data. In the following, we assume that we have \(N\) data points \(\{x_i\}_{i=1}^N\) generated by our presumed model. For example, the simulation above assumes that the statistical model is a mixture of \(k=3\) gaussians, and hence the latent variables we would like to fit are \(\mu_k\) and \(\Sigma_k\), the mean and covariance of each distribution. It should be noted that the method does not care at all about what kind of statistical model we use. It can be something simple like the mixture of gaussians assumed here, but we could also assume more complex models like neural networks. That said, the algorithm works broadly as follows.
First, we initialize the latent variables, for us \(\mu_k\) and \(\Sigma_k\), as well as a prior \(\pi_k\). The simulation above sets a random mean vector and a fixed, diagonal covariance matrix, though most libraries you come across will use something more sophisticated, for reasons we will get into. The prior is usually initialized to be uniform, i.e., \(\pi_k=1/k\).
Next, the iteration begins. We start by computing the likelihood of every datapoint for each of the clusters, $$ p(X=x | C=k) = \mathcal{N}(x; \mu_k, \Sigma_k). $$ This quantity tells us the probability that datapoint \(x\) was generated by distribution \(k\), but what we need for the EM algorithm is some notion of "cluster membership", in other words, what's the probability that distribution \(k\) that generated \(x\). So we additionally compute the posterior by applying Bayes' rule: $$ p(C=k | X=x) = \dfrac{p(X=x | C=k) \pi_k}{\sum_{k'} p(X=x | C=k') \pi_{k'}}. $$ Intuitively the posterior can be seen as a "soft clustering": for a given data point \(x\), we obtain a probability vector (the components of which of course sum to one) of length \(k\) that represents the responsibility each distribution has for that data point. Computing this posterior concludes the E-step of the EM algorithm.
In the M-step, we update the latent parameters of our models given the posterior computed in the E-step. First, the new prior is computed: $$ \pi_k = \dfrac{N_k}{N}, \text{ where } N_k=\sum_{i=1}^N p(C=k | X=x_i). $$ This is straightforward: for each distribution, we sum the "responsibilities" of each data point to arrive at a number quantifying its overall responsibility over the data. Next, we update \(\mu_k\) and \(\Sigma_k\): $$ \mu_k = \dfrac{1}{N_k}\sum_{i=1}^N p(C=k | X=x_i) x_i, $$ $$ \Sigma_k = \dfrac{1}{N_k} p(C=k | X=x_i)(x_i - \mu_k)(x_i - \mu_k)^T. $$ For both of these latent variables, we use the posterior to assign a certain weight to every data point before doing the usual mean and covariance calculation. This concludes the M-step [3]. We now simply repeat until convergence or another arbitrary stopping criterion.

Collapsing Variances

So, why now do we see so many of the clusters in the simulation above collapsing onto a single point? Quite simply, for a gaussian mixture model, this gives the algorithm an "easy way out" of doing the work of fitting reasonable distributions to the data as measured by the likelihood function. If distribution \(k\) collapses onto the point \(x_i\), the likelihood \(p(X=x_i|C=k) = 1\), as is the posterior, and the posterior probability of that cluster given any other data point is 0. This is a local optimum that the algorithm cannot recover from.
Although this is a well known problem of the EM algorithm and gaussian mixture models, I didn't appreciate that this happens basically all the time when randomly initializing the latent variables. Contemporary implementations get around this by initializing the latent variables cleverly, most of the time by running some variant of k-means clustering before the EM algorithm.

EM Simulation

Below is a variant of the title simulation that you can step through, and where additionally the current log likelihood of the model, and all the latent variables of the distribution are logged. Draw some points by clicking and holding somewhere, adjust the parameters of the distribution you're generating them with using the scroll wheel. The code for this simulation is on Github.
Input covariance:
Draw Posterior
Animate