Mixing and sampling efficiency of MCMC algorithms

Random-walk Metropolis sampler

rand_walk_metropolis <- function(eval_logp, init, n_iter, prop_sd, seed = NA) {
  
  if (!is.na(seed)) { set.seed(seed) }
  
  n_dim <- length(init)
  samples <- matrix(NA, nrow = n_dim, ncol = n_iter + 1L)
  logp_samples <- vector("numeric", n_iter + 1L)
  n_accepted <- 0L
  x <- init
  logp <- eval_logp(x)
  samples[, 1] <- init
  logp_samples[1] <- logp
  
  for (i in 1:n_iter) {
    x_prop <- x + prop_sd * rnorm(n_dim)
    logp_prop <- eval_logp(x_prop)
    accepted <- logp_prop - logp > log(runif(1))
    if (accepted) {
      x <- x_prop
      logp <- logp_prop
    }
    samples[, i + 1] <- x
    logp_samples[i + 1] <- logp
    n_accepted <- n_accepted + accepted
  }
  
  rownames(samples) <- paste0("param_", 1:n_dim)
  return(list(
    samples = samples, 
    logp = logp_samples,
    accept_rate = n_accepted / n_iter
  ))
}

Sample from 16-dimensional i.i.d. Gaussian

d <- 16
iid_target_logp <- function(x) { - 0.5 * sum(x^2) }

# Pick theoretically optimal (for sufficiently large `d`) proposal sd
prop_sd <- 2.38 / sqrt(d) 

init <- rep(0, d)
n_iter <- 10^5
rwm_output <- rand_walk_metropolis(
  iid_target_logp, init, n_iter, prop_sd, seed = 1918
)

Examine the traceplot, auto-correlation, and ESS along the first coordinate

coord_samples <- rwm_output$samples[1, ]
trace_plot(coord_samples[1:1000])

auto_cor_plot(coord_samples, lag.max = 100)

ess_est <- list(
  ar_proc = coda::effectiveSize(coord_samples)[[1]],
  batch_means = mcmcse::ess(coord_samples, method = "bm"),
  init_pos_seq = n_iter / mcmc::initseq(coord_samples)$var.pos
)
print(ess_est) 
## $ar_proc
## [1] 2026.527
## 
## $batch_means
## [1] 1885.564
## 
## $init_pos_seq
## [1] 1937.285
mcmc_iter_per_ess <- lapply(ess_est, function (ess) n_iter / ess)
print(mcmc_iter_per_ess)
## $ar_proc
## [1] 49.34551
## 
## $batch_means
## [1] 53.03453
## 
## $init_pos_seq
## [1] 51.61863

Check that the dist of MCMC samples coincides with the target

Compound symmetric Gaussian target

Consider a Gaussian target with covariance \(\Sigma_{ij} = \rho\) for all \(i \neq j\).

d <- 16
rho <- 0.9
Sigma <- (1 - rho) * diag(d) + rho * array(1., dim = c(d, d))
# Analytically invert `Sigma`
Prec <- 1 / (1 - rho) * (
  diag(d) - rho / (1 - rho + rho * d) * array(1., dim = c(d, d))
) 

cs_target_logp <- function(x) { - 0.5 * t(x) %*% Prec %*% x }

# Approximate, based on the compound symmetry structure, the optimal proposal sd
smaller_target_sd <- sqrt(min(eigen(Sigma)$values)) # only two unique eigenvalues
cs_prop_sd <- smaller_target_sd * 2.38 / sqrt(d) 

init <- rep(0, d)
n_iter <- 10^5
rwm_output <- rand_walk_metropolis(
  cs_target_logp, init, n_iter, cs_prop_sd, seed=1
)

As before, examine the traceplot, auto-correlation, and ESS

coord_samples <- rwm_output$samples[1, ]
trace_plot(coord_samples, thin = 100L)

auto_cor_plot(coord_samples, lag.max = 10000)

ess_est <- list(
  ar_proc = coda::effectiveSize(coord_samples)[[1]],
  batch_means = mcmcse::ess(coord_samples, method = "bm"),
  init_pos_seq = n_iter / mcmc::initseq(coord_samples)$var.pos
)
print(ess_est) 
## $ar_proc
## [1] 73.92964
## 
## $batch_means
## [1] 47.39285
## 
## $init_pos_seq
## [1] 24.41295
mcmc_iter_per_ess <- lapply(ess_est, function (ess) n_iter / ess)
print(mcmc_iter_per_ess)
## $ar_proc
## [1] 1352.637
## 
## $batch_means
## [1] 2110.023
## 
## $init_pos_seq
## [1] 4096.186

How does the dist of MCMC samples compare to the actual target?

Don’t easily trust the output when the mixing is this bad

rwm_output <- rand_walk_metropolis(
  cs_target_logp, init, n_iter, cs_prop_sd, seed=615
)
coord_samples <- rwm_output$samples[1, ]
trace_plot(coord_samples, thin = 100L)

auto_cor_plot(coord_samples, lag.max = 10000)

ess_est <- list(
  ar_proc = coda::effectiveSize(coord_samples)[[1]],
  batch_means = mcmcse::ess(coord_samples, method = "bm"),
  init_pos_seq = n_iter / mcmc::initseq(coord_samples)$var.pos
)
print(ess_est) 
## $ar_proc
## [1] 38.80634
## 
## $batch_means
## [1] 36.57983
## 
## $init_pos_seq
## [1] 5.265385
hist_against_truth(coord_samples, true_density = dnorm)

Convergence of MCMC algorithms

Start away from stationarity on the compound symmetric target

set.seed(410)
init <- rnorm(d)
init[c(1, 2)] <- 3.5

n_iter <- 10^4
rwm_output <- rand_walk_metropolis(
  cs_target_logp, init, n_iter, cs_prop_sd, seed=21205
)

Examine the chain’s trajectory in the first two dimensions

index <- c(1, 2)
target_density <- function(xy) {
  mvtnorm::dmvnorm(xy, sigma = Sigma[index, index])
}
n_iter_to_plot <- 1000L
plot_trajectory(rwm_output$samples, index, n_iter_to_plot, target_density)

Assess convergence via the log target density

trace_plot(
  rwm_output$logp[1:500],
  ylab = "Log density"
)

Visually check if the convergence in log density coincides with that in parameter values

n_iter_to_plot <- 200L
plot_trajectory(rwm_output$samples, index, n_iter_to_plot, target_density)

Also check that, by zooming in a bit, the log density actually stabilized

trace_plot(
  rwm_output$logp[200:1200],
  ylab = "Log density"
)

Running multiple chains for convergence/mixing diagnostics

Let’s try sampling from a target of unknown structure

n_iter <- 10^3
init <- .1
prop_sd <- .1
rwm_output <- rand_walk_metropolis(
  unknown_target, init, n_iter, prop_sd, seed=1876
)
samples_1 <- rwm_output$samples
trace_plot(samples_1, ylim=param_range)

Traceplot seems fine? But let’s run another chain with different initialization.

init <- -.1
rwm_output <- rand_walk_metropolis(
  unknown_target, init, n_iter, prop_sd, seed=1876
)
samples_2 <- rwm_output$samples
trace_plot(samples_2, ylim=param_range)

The issue here is multi-modality in the target:

ESS would look fine if the chain is mixing well locally, oblivious to other modes

post_burnin_samples <- samples_1[100:n_iter]
coda::effectiveSize(post_burnin_samples)[[1]]
## [1] 113.5144

But \(\hat{R}\) can detect the issue

combined_samples <- rbind(samples_1, samples_2)
c(
  posterior::rhat_basic(combined_samples), # 2013 ver.
  posterior::rhat(combined_samples) # 2021 ver.
) # >= 1.01 indicates a likely convergence/mixing issue
## [1] 10.198620  2.124564

Let’s re-run with larger \(\sigma_\mathrm{prop}\) to help jump between the modes

n_iter <- 4000
prop_sd <- 1

init <- 1
rwm_output <- rand_walk_metropolis(
  bimodal_target_logp, init, n_iter, prop_sd, seed=111
)
samples_1 <- rwm_output$samples

init <- -1
rwm_output <- rand_walk_metropolis(
  bimodal_target_logp, init, n_iter, prop_sd, seed=222
)
samples_2 <- rwm_output$samples

Better, but 1000 iterations not quite enough

quarter_samples_1 <- samples_1[1:ceiling(n_iter / 4)]
quarter_samples_2 <- samples_2[1:ceiling(n_iter / 4)]

trace_plot(quarter_samples_1, col = tableau10[1])
trace_plot(quarter_samples_2, col = tableau10[2], add = TRUE)

combined_quarter_samples <- rbind(quarter_samples_1, quarter_samples_2)
c(posterior::rhat_basic(combined_quarter_samples),
  posterior::rhat(combined_quarter_samples))
## [1] 1.044920 1.066895

\(\hat{R}\) reaches the 1.01 threshold at 4000 iterations (for this run)

trace_plot(samples_1, col = tableau10[1])
trace_plot(samples_2, col = tableau10[2], add = TRUE)

combined_samples <- rbind(samples_1, samples_2)
c(posterior::rhat_basic(combined_samples),
  posterior::rhat(combined_samples))
## [1] 1.004153 1.004762
hist_against_truth(samples_1, true_density)

Let’s try one another, quadrimodal example

mu <- c(-1.5, -.5, .5, 1.5)
sigma <- c(.1, .04, .1, .1)
w <- c(.05, .1, .25, .6)

quadrimodal_target_logp <- function(x) {
  log_dnorm_mixture(x, mu, sigma, w)
}
param_range <- c(-2, 2)
prop_sd <- .1
n_iter <- 10^3
n_chain <- 4L
samples_list <- lapply(
  1:n_chain,
  function(chain_index) {
    init <- mu[chain_index]
    rwm_output <- rand_walk_metropolis(
      quadrimodal_target_logp, init, n_iter, prop_sd, seed=chain_index
    )
    rwm_output$samples
  }
)
for (chain_index in 1:n_chain) {
  trace_plot(
    samples_list[[chain_index]], 
    col = tableau10[chain_index], 
    add = (chain_index > 1),
    ylim = param_range
  )
}

Note that simply averaging over the chains yields an incorrect posterior approximation.

true_density <- function(x) {
  sapply(
    1:length(x), 
    function (i) { exp(quadrimodal_target_logp(x[i])) }
  )
}
hist_against_truth(
  do.call(cbind, samples_list), 
  true_density, xlim = param_range, breaks = 201
)

combined_samples <- do.call(rbind, samples_list)
c(posterior::rhat_basic(combined_samples),
  posterior::rhat(combined_samples))
## [1] 1.580313 1.311234

\(\hat{R}\) is also about mixing, not just convergence

Sample from the compound symmetric target, initializing at stationarity

set.seed(955)
init_list <- MASS::mvrnorm(n_chain, mu = rep(0, d), Sigma = Sigma)

n_iter <- 10^4
samples_list <- lapply(
  1:n_chain,
  function(chain_index) {
    init <- init_list[chain_index, ]
    rwm_output <- rand_walk_metropolis(
      cs_target_logp, init, n_iter, cs_prop_sd, seed=chain_index
    )
    rwm_output$samples
  }
)
param_index <- 1L
for (chain_index in 1:n_chain) {
  trace_plot(
    samples_list[[chain_index]][param_index, ], 
    col = tableau10[chain_index], 
    add = (chain_index > 1),
    ylim = c(-3, 3), thin = 10L
  )
}

combined_samples <- do.call(rbind, samples_list)
c(posterior::rhat_basic(combined_samples),
  posterior::rhat(combined_samples))
## [1] 1.334540 1.332496