basics of MCMC

Author

Caughlin

What is MCMC?

Markov Chain Monte Carlo (MCMC) is a method for sampling from probability distributions when direct sampling is difficult.

  • Monte Carlo → Uses randomness to approximate a distribution.

  • Markov Chain → Each new sample depends only on the current state (not the full history).

  • Goal → Generate samples that, over time, follow a target distribution.

  • In Bayesian statistics, that target distribution is usually the posterior distribution.

Key Components of MCMC (Metropolis Algorithm)

  1. Current state
    Where we are now in parameter space.

  2. Proposal step
    Suggest a nearby candidate state.

  3. Acceptance rule
    Decide whether to move to the proposed state based on a probability ratio.

  4. Repeat many times
    The sequence of states forms a Markov chain.

  5. Over time, the chain spends more time in high-probability regions.

In the following example, we will consider a squirrel that is visiting patches with different numbers of acorns, from 1 to 10. These patches represent possible parameter states. The squirrel position represents the current state of the Markov chain.

num_weeks <- 500
positions <- rep(0, num_weeks)
current   <- 10
for (i in 1:num_weeks) {
  # record current position
  positions[i] <- current
 
  #flip coin to generate proposal
  proposal <- current + sample(c(-1,1),size=1)

  #make sure squirrel loops around patches rather than in a straight line
  if (proposal < 1) proposal <- 10
  if (proposal > 10) proposal <- 1
 #move to a different patch?
  move_ratio <- proposal / current
  current   <- ifelse(runif(1) < move_ratio, proposal, current)
}

The key elements here are:

 proposal <- current + sample(c(-1,1),size=1)

This is a Markov chain, because it only depends on the current patch. We are drawing a value of either 1 or -1, ensuring the squirrel can only move one patch in either direction.

move_ratio <- proposal / current
current <- ifelse(runif(1) < move_ratio, proposal, current)

This is the core of MCMC.

  • The squirrel compares acorns in the new patch vs. current patch.

  • If the proposal has more acorns, ratio > 1 → always move

  • If it has fewer acorns, ratio < 1 → sometimes move.

Why Accept Worse Patches?

This is the key insight of MCMC.

If the squirrel only moved uphill, it would get stuck at patch 10.

By occasionally accepting worse patches:

  • The squirrel can explore the whole space.

  • The chain doesn’t get trapped.

  • Long-run visitation frequency matches the target distribution.

Let’s look at the histogram of squirrel positions to see how this plays out:

hist(positions,main="Time spent in each patch",xlab="Number of acorns in a patch")

Now, let’s look at a more complicated example. We will use the Metropolis-Hastings algorithm to estimate parameters for a linear regression. We will begin by simulating data:

set.seed(1)

trueA = 5
trueB = 0
trueSd = 1
sampleSize = 31

x = runif(sampleSize)
y = rnorm(n = sampleSize, mean = trueA * x + trueB, sd = trueSd)

plot(x, y, main="Test Data")

The function below calculates the log likelihood for our three parameters in a linear regression. Note that we can use “sum” because we have set “log=true” (this is the logarithmic product rule again).

likelihood = function(param){
  a  = param[1]
  b  = param[2]
  sd = param[3]
  
  if (sd <= 0) return(-Inf)   # IMPORTANT
  
  mu = a*x + b
  sum(dnorm(y, mean = mu, sd = sd, log = TRUE))
}

Next, we will look at the priors. Are these priors informative or non-informative?

prior = function(param){
  a  = param[1]
  b  = param[2]
  sd = param[3]
  
  if (sd <= 0) return(-Inf)   # IMPORTANT
  
  aprior  = dunif(a, min=0, max=10, log=TRUE)
  bprior  = dnorm(b, mean=0, sd=5, log=TRUE)
  sdprior = dgamma(sd, shape=0.5, rate=0.2, log=TRUE)
  
  aprior + bprior + sdprior
}

Now we can define our posterior, in terms of prior and likelihood. Remember, because we are working with logs, we can add these together (rather than multiply).

posterior = function(param){
  likelihood(param) + prior(param)
}

This code describes our proposal function. Where should we move next, given our current state? In effect, this is adding a small amount of noise to each of our three parameters. They are centered at the previous values (mean = param). The sd values indicate how much noise we add to each parameter. If these values are too small, the chain will move very slowly (poor mixing). If these values are too big, proposals will often be rejected (low acceptance).

proposalfunction = function(param){
  rnorm(3, mean = param, sd = c(0.1, 0.5, 0.3))
}

Now we will develop code to run the algorithm. This function is going to take a starting value (an initial guess for the parameters), and run the MCMC algorithm over a certain number of iterations. We are storing the output in a matrix where each row = one MCMC step, and the first row is the starting value. Similar to the squirrel problem, we are comparing each proposal to a random value.

run_metropolis_MCMC = function(startvalue, iterations){
  chain = matrix(NA_real_, nrow = iterations + 1, ncol = 3)
  chain[1,] = startvalue
  
  for (i in 1:iterations){
    current  = chain[i,]
    proposal = proposalfunction(current)
    
    log_alpha = posterior(proposal) - posterior(current)
    
    if (log(runif(1)) < log_alpha){
      chain[i+1,] = proposal
    } else {
      chain[i+1,] = current
    }
  }
  chain
}

We will now run our algorithm. We are going to choose some starting values, and run for 10000 iterations. We will discard the first 5000 iterations as “burn-in.” We assume these are iterations where the algorithm is jumping around at places with low probability density.

startvalue = c(4, 0, 1)
chain = run_metropolis_MCMC(startvalue, 10000)

burnIn = 5000
post = chain[-(1:burnIn), , drop=FALSE]

We can now explore the output. Below are histograms of the posterior samples. Since we know the true value (we made it up above), we can evaluate how well our algorithm works.

par(mfrow = c(2,3))
hist(post[,1], nclass=30, main="Posterior of a", xlab="True value = red line")
abline(v = mean(post[,1]))
abline(v = trueA, col="red")

hist(post[,2], nclass=30, main="Posterior of b", xlab="True value = red line")
abline(v = mean(post[,2]))
abline(v = trueB, col="red")

hist(post[,3], nclass=30, main="Posterior of sd", xlab="True value = red line")
abline(v = mean(post[,3]))
abline(v = trueSd, col="red")

We can also explore the behavior of the chains directly:

plot(post[,1], type="l", xlab="Iteration", main="Chain values of a")
abline(h = trueA, col="red")

plot(post[,2], type="l", xlab="Iteration", main="Chain values of b")
abline(h = trueB, col="red")

plot(post[,3], type="l", xlab="Iteration", main="Chain values of sd")
abline(h = trueSd, col="red")

The algorithms brms uses are a lot more complex than this simple example. To get a flavor for the wide range of MCMC algorithms out there, check out this interactive website: https://chi-feng.github.io/mcmc-demo/app.html