EM Algorithm

Max Turgeon

2019-10-23

Introduction

In Chapter 4 of Modern Statistics for Modern Biology (Holmes and Huber 2019Holmes, Susan, and Wolfgang Huber. 2019. Modern Statistics for Modern Biology.), the authors introduce mixture models as a general class of statistical models that can deal with heterogeneity. They start by discussing finite mixture models, they present the associated likelihood, and then they introduce the EM algorithm has a way to obtain the maximum likelihood estimates.

In this short note, I want to clarify few things about the EM algorithm.

Finite mixture models

We are going to use simulations to discuss the different concepts. We will use a similar data-generating mechanism to the one in Equation 4.1: \[f(x)=\frac{1}{2}\phi_1(x)+\frac{1}{2}\phi_2(x),\] where - \(\phi_1(x)\) is the density of a normal random variable with mean \(\mu_1=-1.5\) and standard deviation 1; - \(\phi_2(x)\) is the density of a normal random variable with mean \(\mu_2=1.5\) and standard deviation 1.

Following the code in the book, we can plot the density:

library(tidyverse)
means <- c(-1.5, 1.5)
sds <- c(1, 1)

tibble(
  x = seq(-5, 5, length.out = 1000),
  f = 0.5 * dnorm(x, mean = means[1], sd = sds[1]) +
      0.5 * dnorm(x, mean = means[2], sd = sds[2])) %>% 
  ggplot(aes(x = x, y = f)) +
  geom_line(color = "red", size = 1.5) + ylab("mixture density") +
  theme_minimal()

We can also generate data from this mixture distribution as follows:

set.seed(1234)
n <- 100
component <- sample(c(1, 2), size = n,
                    replace = TRUE)

y <- rnorm(n, mean = means[component],
           sd = sds[component])

ggplot(tibble(value = y), 
       aes(x = value)) +
  geom_histogram() + 
  theme_minimal()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

Maximum Likelihood Estimation

Maximum Likelihood Estimation says that you can obtain estimates for the parameters in your model by maximising the likelihood. Contrary to what is alluded to in the book, nothing prevents us from directly maximising the likelihood of a mixture model.

Given data points \(x_1, \ldots, x_n\), we write the likelihood of a general mixture of two normal distributions as follows:

\[ L(\pi, \mu_1, \mu_2, \sigma_1, \sigma_2) = \prod_{i=1}^n \left[\pi\phi(x_i\mid\mu_1,\sigma_1)+(1 -\pi)\phi(x_i\mid\mu_2,\sigma_2)\right].\]

We can use the optim function to maximise the likelihood:

# First define the mixture log-likelihood
#' @params theta A vector containing the parameters, in the following order: 
#'   pi, mu1, mu2, sigma1, sigma2.
#' @params data A vector containing the data.
mixture_likelihood <- function(theta, data = y) {
  pi <- theta[1]
  mus <- theta[2:3]
  sds <- theta[4:5]
  sum(log(pi*dnorm(data, mean = mus[1], sd = sds[1]) + 
            (1 - pi)*dnorm(data, mean = mus[2], sd = sds[2])))
}

# Initial values
theta0 <- c(0.3, 0, 1, 0.5, 0.5)

# Optimise
results <- optim(
  theta0,
  mixture_likelihood,
  control = list(fnscale = -1, # Turns minimisation into maximisation
                 maxit = 1000) # Increase otherwise it does not converge
)

results
## $par
## [1]  0.3656915 -1.6535093  1.4580282  0.8658091  1.0611332
## 
## $value
## [1] -192.8535
## 
## $counts
## function gradient 
##      516       NA 
## 
## $convergence
## [1] 0
## 
## $message
## NULL

As we can see, although we only have point estimates and not confidence intervals, it seems that our estimate for the mixing proportion \(\pi\) is not very accurate. Moreover, we had to increase the number of maximum iterations in order to achieve convergence. To better visualise the difficulty of directly maximising the mixture likelihood, it may help to plot it. We are going to fix \(\pi = 0.5\), \(\sigma_1=1\), and \(\sigma_2=1\), and we are going to look at the likelihood as a function of the two mean parameters:

# Plot densities
X <- seq(-3, 3, length.out = 100)
Y <- seq(-3, 3, length.out = 100)

# For all combinations of X, Y, compute likelihood
plot_mixture <- outer(X, Y, function(x, y) {
  apply(cbind(x, y), 1, function(mus) mixture_likelihood(c(0.5, mus, 1, 1)))
})

# Set the breaks in the palette
breaks <- seq(-1200, -130, length.out = 13)

image(X, Y, plot_mixture, 
      xlab = expression(mu[1]),
      ylab = expression(mu[2]),
      oldstyle = TRUE, # Redder is larger
      breaks = breaks,
      main = "Mixture Likelihood")

The main issue is that we actually have two maxima. And this is makes sense: we label the first component with a 1, and the second with a 2, but these labels are interchangeable. In Bayesian statistics, this issue is known as label switching.

Complete-data likelihood

The first step towards the EM algorithm is to realise that if, for every observation, we knew from which subgroup it came from, we would not have this identifiability issue, and indeed the likelihood would be simpler.

Let \(y_i\) be equal to 1 or 2 according to whether the data point \(x_1\) comes from population 1 or 2. The complete-data likelihood is given by \[ L(\mu_1, \mu_2, \sigma_1, \sigma_2) = \prod_{i:y_i = 1} \phi(x_i\mid\mu_1,\sigma_1)\prod_{i:y_i = 2} \phi(x_i\mid\mu_2,\sigma_2).\]

As above, we can use optim to maximise this likelihood directly:

# First define the mixture log-likelihood
#' @params theta A vector containing the parameters, in the following order: 
#'   mu1, mu2, sigma1, sigma2.
#' @params data A list containing the observed variable and the labels.
complete_likelihood <- function(theta, data = list(y, component)) {
  mus <- theta[1:2]
  sds <- theta[3:4]
  
  y <- data[[1]]
  component <- data[[2]]
  sum(dnorm(y[component == 1], mean = mus[1], sd = sds[1], log = TRUE)) + 
    sum(dnorm(y[component == 2], mean = mus[2], sd = sds[2], log = TRUE))
}

# Initial values
theta0 <- c(0, 1, 0.5, 0.5)

# Optimise
results <- optim(
  theta0,
  complete_likelihood,
  control = list(fnscale = -1) # Turns minimisation into maximisation
)

results
## $par
## [1] -1.4785366  1.6220438  0.9491896  0.9411954
## 
## $value
## [1] -136.1799
## 
## $counts
## function gradient 
##      133       NA 
## 
## $convergence
## [1] 0
## 
## $message
## NULL

Our estimates are actually closer to the true population values than the one obtained from the mixture likelihood. Moreover, the algorithm also converged more quickly. As above, we will plot the likelihood as a function of both mean parameters:

# For all combinations of X, Y, compute likelihood
plot_complete <- outer(X, Y, function(x, y) {
  apply(cbind(x, y), 1, function(mus) complete_likelihood(c(mus, 1, 1)))
})

image(X, Y, plot_complete, 
      xlab = expression(mu[1]),
      ylab = expression(mu[2]),
      oldstyle = TRUE, # Redder is larger
      breaks = breaks,
      main = "Complete-Data Likelihood")

Let’s both them next to one another for easier comparison:

As we can see, there are two main differences between the likelihoods:

These two properties together imply that it is generally easier to maximise the complete-data likelihood than the mixture likelihood. In this toy example, both converged very quickly; however, for larger problems (e.g. more components, more parameters, larger sample size), maximising the mixture likelihood directly can become noticeably slower.

EM Algorithm

The goal of the EM algorithm is to maximise a series of likelihoods where each of them is a “realisation” of the complete-data likelihood. More specifically, we first posit a joint model for the observed data and the unobserved labels that we call the complete-data model. At each iteration, we go through two steps:

  1. Expectation Step: Given the current estimates of the parameters \(\hat{\theta}\), we take the expected value of the labels given the data.
  2. Maximisation Step: We plug back these expected values into the complete-data likelihood and maximise over the parameters \(\hat{\theta}\).

What we get is an iterative algorithm for finding the maximum likelihood estimates that has better convergence properties than simply maximising the mixture likelihood directly.

The EM algorithm is used beyond mixture models:

Another biological example of using the EM algorithm is haplotype frequency estimation. In this setting, the observed data would be the alleles at multiple loci. For any homozygous locus, there is no ambiguity; however, for heterozygous loci, we cannot determine which is the paternal allele and which is the maternal allele. The EM algorithm can be used to estimate the frequency (or proportion) of each allele through an iterative procedure and under Hardy-Weinberg equilibrium (Excoffier and Slatkin 1995Excoffier, Laurent, and Montgomery Slatkin. 1995. “Maximum-Likelihood Estimation of Molecular Haplotype Frequencies in a Diploid Population.” Molecular Biology and Evolution.).

Example

Let’s look at an example of the EM algorithm for a mixture of two normal distributions.

########################
# initial parameter estimates
theta0 <- c(0.3, 0, 1, 0.5, 0.5)

# E step: calculates conditional probabilities for latent variables
E_step <- function(theta, data = y) {
  comp1 <- theta[1] * dnorm(data, mean = theta[2], sd = theta[4])
  comp2 <- (1 - theta[1]) * dnorm(data, mean = theta[3], sd = theta[5])
  
  return(cbind(comp1/(comp1 + comp2),
               comp2/(comp1 + comp2)))
}

# M step: calculates the parameter estimates which maximise Q
M_step <- function(gammas, data = y) {
  pi <- mean(gammas[,1])
  
  mus <- c(weighted.mean(data, gammas[,1]),
           weighted.mean(data, gammas[,2]))
  
  sds <- sqrt(c(cov.wt(as.matrix(data), wt = gammas[,1])$cov,
                cov.wt(as.matrix(data), wt = gammas[,2])$cov))
  
  return(c(pi, mus, sds))
}

# run EM
epsilon <- 10e-8

intermediate_values <- matrix(NA, ncol = 5, nrow = 101)
intermediate_values[1,] <- theta0

theta <- theta0
for (iter in 1:100) {
  old_theta <- theta
  mat <- E_step(theta)
  theta <- M_step(mat)
  
  intermediate_values[iter + 1,] <- theta
  
  # Check if significant change in likelihood
  if (abs(mixture_likelihood(theta) - mixture_likelihood(old_theta)) < epsilon) break
}

intermediate_values[iter + 1,]
## [1]  0.3654576 -1.6506847  1.4546061  0.8820147  1.0736684

We can check that this is similar to what we would get using the mixtools package:

library(mixtools)

fit_results <- normalmixEM(y)
## number of iterations= 67
fit_results$lambda
## [1] 0.3656863 0.6343137
fit_results$mu
## [1] -1.653997  1.457641
fit_results$sigma
## [1] 0.8655441 1.0612356

The slight discrepancy between the results is most likely due to the use of a different stopping criterion.

Finally, let’s visualise how the EM algorithm converges to the maximum likelihood estimates. We will do this by plotting the mixture density at every iteration.

# We want to plot intermediate densities
densities_em <- purrr::map_df(seq_len(nrow(intermediate_values)), 
                              function(i) {
  theta <- intermediate_values[i,]
  pi <- theta[1]
  means <- theta[2:3]
  sds <- theta[4:5]
  tibble(
  index = i,
  x = seq(-3.5, 4.5, length.out = 1000),
  f = pi * dnorm(x, mean = means[1], sd = sds[1]) +
    (1 - pi) * dnorm(x, mean = means[2], sd = sds[2]))
  }) %>% 
  filter(!is.na(f))

# Add the original data to the data.frame
# So that we can plot it
densities_em$y <- y
library(gganimate)
animation <- ggplot(densities_em, aes(x = x, y = f)) +
  geom_histogram(aes(x = y, y = stat(density)),
                 fill = 'grey80') +
  geom_line(size = 1.5) + ylab("mixture density") + 
  geom_vline(xintercept = means, linetype = 'dashed') +
  theme_minimal() + 
  transition_time(index) +
  labs(title = "Iteration: {frame_time}")
animate(animation, end_pause = 10)

What we see near the end, where we can’t really see the density change from one iteration to the next, is typical of the EM algorithm: as it approaches the maximum, it slows down considerably. For this reason, many improvements of the EM algorithm have been suggested (see for example (Jamshidian and Jennrich 1997Jamshidian, Mortaza, and Robert I. Jennrich. 1997. “Acceleration of the EM Algorithm by Using Quasi‐Newton Methods.” JRSS B.)).