Bayesian analysis with Stan

Extracting draws and generating predictions

Statistics
Bayesian Modeling
Stan
R
Author

Marc-Aurèle Rivière

Published

November 24, 2023

Abstract
Made in reply to this question from the R4DS slack.

Setup

renv::install(
  c("rstan", "cmdstanr"), 
  repos = list(Stan = "https://mc-stan.org/r-packages/", CRAN = "https://cloud.r-project.org/")
)
Installing CmdStan
cmdstanr::check_cmdstan_toolchain(fix = TRUE, quiet = TRUE)

cpp_opts <- list(
  stan_threads = TRUE
  , STAN_CPP_OPTIMS = TRUE
  , PRECOMPILED_HEADERS = TRUE
  , CXXFLAGS_OPTIM = "-march=native -mtune=native"
  , CXXFLAGS_OPTIM_TBB = "-mtune=native -march=native"
  , CXXFLAGS_OPTIM_SUNDIALS = "-mtune=native -march=native"
)

cmdstanr::install_cmdstan(cpp_options = cpp_opts, quiet = TRUE)
cmdstanr::set_cmdstan_path("/home/mar/Dev/SDK/.cmdstan/cmdstan-2.33.1")

cmdstanr::register_knitr_engine(override = FALSE)

1 Data

How does body mass affect flipper length for the different species?

  • Dependent variable: flipper length (mm).
  • Independent variable: body mass (kg, converted from g)
(penguins_data <- penguins
  |> mutate(
    body_mass_kg = body_mass_g/1000,
    body_mass_kg_cntr = body_mass_kg - mean(body_mass_kg, na.rm = TRUE)
  ) 
  |> filter(!is.na(flipper_length_mm), !is.na(species), !is.na(body_mass_kg_cntr))
  |> mutate(ID = row_number())
  |> select(ID, species, body_mass_kg_cntr, flipper_length_mm)
)
penguins_stan_data <- with(
  penguins_data,
  list(
    N = nrow(penguins_data),
    bodymass = body_mass_kg_cntr,
    species = as.integer(species),
    flipperlength = flipper_length_mm
  )
)

2 RStan

2.1 Model fitting

Compiling

penguins_rstan_model
data {
  int<lower=0> N;
  vector[N] bodymass;
  array[N] int<lower=1, upper=3> species;
  vector[N] flipperlength;
}

parameters {
  vector[3] alpha;
  real beta_bodymass;
  real sigma;
}

model {
  alpha ~ normal(200, 100);
  beta_bodymass ~ normal(0, 100);
  sigma ~ gamma(3, 2);

  flipperlength ~ normal(alpha[species] + beta_bodymass * bodymass, sigma);
}

Sampling

penguins_rstan_fit <- sampling(
  penguins_rstan_model, penguins_stan_data,
  chains = 4, cores = 4, warmup = 1500, iter = 4000,
  refresh = 0, seed = 256
)
Inference for Stan model: anon_model.
4 chains, each with iter=4000; warmup=1500; thin=1; 
post-warmup draws per chain=2500, total post-warmup draws=10000.

                mean se_mean   sd   2.5%    50%  97.5% n_eff Rhat
alpha[1]      194.16    0.01 0.53 193.11 194.17 195.18  5869    1
alpha[2]      199.76    0.01 0.71 198.38 199.76 201.12  7472    1
alpha[3]      209.84    0.01 0.74 208.43 209.83 211.30  5165    1
beta_bodymass   8.40    0.01 0.63   7.14   8.40   9.63  4403    1
sigma           5.35    0.00 0.20   4.97   5.34   5.76  9122    1

Samples were drawn using NUTS(diag_e) at Sun Nov 26 21:12:43 2023.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

2.2 Model diagnostics

2.2.1 Posterior draws

Using rstan:

as.data.frame(penguins_rstan_fit) |> select(matches("alpha"), beta_bodymass, sigma)
Alternative
rstan::extract(penguins_rstan_fit, c("alpha", "beta_bodymass", "sigma")) |> as.data.frame()

Using posterior:

(penguins_rstan_draws <- penguins_rstan_fit |> as_draws_df() |> subset_draws(c("alpha", "beta_bodymass", "sigma")))

Using tidybayes:

Wide format:

spread_draws(penguins_rstan_fit, `alpha\\[\\d+\\]`, beta_bodymass, sigma, ndraws = 500, regex = TRUE)

Semi-long format:

spread_draws(penguins_rstan_fit, alpha[species], beta_bodymass, sigma, ndraws = 500)

Long format:

gather_draws(penguins_rstan_fit, alpha[species], beta_bodymass, sigma, ndraws = 500)
gather_rvars(penguins_rstan_fit, alpha[species], beta_bodymass, sigma, ndraws = 500)

2.2.2 Diagnostic plots

Just a sample of useful diagnostic plots

neff_ratio(penguins_rstan_fit) |> mcmc_neff_hist()

rhat(penguins_rstan_fit) |> mcmc_rhat_hist()

mcmc_acf(penguins_rstan_draws)

wrap_plots(
  mcmc_hist(penguins_rstan_draws, facet_args = list(nrow = ncol(penguins_rstan_draws))),
  mcmc_trace(penguins_rstan_draws, facet_args = list(nrow = ncol(penguins_rstan_draws))),
  widths = c(1, 1.5)
)

2.3 Model predictions

The low-level Stan interfaces (e.g. rstan & cmdstanr) don’t have the required machinery to generate predictions (e.g. something like posterior_epred). Making predictions requires running data through the model, which is defined in our Stan code. R doesn’t have access to that.

There are generally 4 ways to go about making predictions from an rstan / cmdstanr model:

  1. Modifying the model’s Stan code to give it the ability to generate predictions (the recommended way)
  2. Creating a new model Stan model just to generate predictions
  3. Re-creating the model in R and using it to generate predictions based on the Stan model’s posterior draws
  4. Hijacking brms machinery and letting it do the work for us

2.3.1 Adding predictive abilities to the model

Here, the idea is to change our original model to include the ability to make predictions for the provided data.

See this page of the Stan manual.

penguins_rstan_model2
data {
  int<lower=0> N;
  vector[N] bodymass;
  array[N] int<lower=1, upper=3> species;
  vector[N] flipperlength;
}

parameters {
  vector[3] alpha;
  real beta_bodymass;
  real sigma;
}

model {
  alpha ~ normal(200, 100);
  beta_bodymass ~ normal(0, 100);
  sigma ~ gamma(3, 2);

  flipperlength ~ normal(alpha[species] + beta_bodymass * bodymass, sigma);
}

generated quantities {
  vector[N] linpred = alpha[species] + beta_bodymass * bodymass;
  vector[N] epred = linpred; // No inverse link function to apply here
  array[N] real prediction = normal_rng(epred, sigma);
}

Of course, that code could be extended to make predictions based on new data, instead of the data used for fitting the model (e.g. both a set of training and testing data)

Sampling the model:

penguins_rstan_fit2 <- sampling(
  penguins_rstan_model2, penguins_stan_data,
  chains = 4, cores = 4, warmup = 1500, iter = 4000,
  refresh = 0, seed = 256
)

Combining the predictions to the data:

left_join(penguins_data, spread_rvars(penguins_rstan_fit2, epred[ID], prediction[ID], ndraws = 500, seed = 256), join_by(ID))

Prediction plots

Generating PPC:

make_ppc_density_plot
make_ppc_density_plot <- function(dat, pred_name = ".prediction") {
  return(
    dat
    |> select(ID, .draw, {{ pred_name }})
    |> pivot_wider(names_from = ID, values_from = {{ pred_name }})
    |> select(-.draw)
    |> data.matrix()
    |> ppc_dens_overlay(y = penguins_data$flipper_length_mm, yrep = _)
    + xlim(162, 243)
  )
}
(penguins_data
  |> left_join(spread_draws(penguins_rstan_fit2, epred[ID], prediction[ID], ndraws = 100, seed = 256), join_by(ID))
  |> make_ppc_density_plot("prediction")
)

Generating the prediction curves:

make_prediction_plot
make_prediction_plot <- function(dat, epred_name = ".epred") {
  return(
    ggplot(dat, aes(x = body_mass_kg_cntr, y = flipper_length_mm, color = species))
    + geom_line(aes(y = .data[[epred_name]], group = paste(species, .draw)), alpha = .1) 
    + geom_point(data = penguins_data)
    + scale_color_brewer(palette = "Dark2")
  )
}
(penguins_data
  |> left_join(spread_draws(penguins_rstan_fit2, epred[ID], ndraws = 100, seed = 256), join_by(ID))
  |> make_prediction_plot("epred")
)

2.3.2 Using a new prediction-only Stan model

Here, we make a second Stan model with the sole purpose of generating predictions based on a set of predictor values (i.e. newdata) and a set of posterior samples from the original model.

penguins_predict_model
data {
  // New data to make predictions for
  int N_newdata;
  array[N_newdata] int<lower=1, upper=3> species;
  vector[N_newdata] bodymass;
  
  // Posterior samples from the original model
  int N_draws;
  matrix[N_draws, 3] alphas;
  vector[N_draws] beta_bodymass;
  vector[N_draws] sigma;
}

parameters {}

model {}

generated quantities {
  matrix[N_draws, N_newdata] linpred;
  matrix[N_draws, N_newdata] epred;
  matrix[N_draws, N_newdata] prediction;
  
  for(n in 1:N_newdata) {
    for(i in 1:N_draws) {
      linpred[i, n] = alphas[i, species[n]] + beta_bodymass[i] * bodymass[n];
      epred[i, n] = linpred[i, n]; // No inverse link function to apply here
      prediction[i, n] = normal_rng(epred[i, n], sigma[i]);
    }  
  }
}

Generating the data for the prediction model (i.e. new data & posterior samples from the original model):

penguins_rstan_draws_samp <- penguins_rstan_draws |> as.data.frame() |> slice_sample(n = 100)

penguins_predict_data <- list(
  N_newdata = nrow(penguins_data),
  species = as.integer(penguins_data$species),
  bodymass = penguins_data$body_mass_kg_cntr,
  N_draws = 100,
  alphas = select(penguins_rstan_draws_samp, matches("alpha")) |> data.matrix(),
  beta_bodymass = pull(penguins_rstan_draws_samp, beta_bodymass),
  sigma = pull(penguins_rstan_draws_samp, sigma)
)


Sampling the model:

penguins_predict_fit <- sampling(
  penguins_predict_model, 
  penguins_predict_data,
  chains = 1, iter = 1, 
  refresh = 0, seed = 256, 
  algorithm = "Fixed_param"
)

Combining the predictions to the data, and then condensing the results into rvar for legibility:

penguins_predict_draws <- left_join(
  penguins_data,
  as.data.frame(penguins_predict_fit)
  |> select(matches("^epred|^prediction"))
  |> pivot_longer(
    cols = everything(), names_pattern = "(epred|prediction)\\[(\\d+),(\\d+)\\]", names_to = c(".value", ".draw", "ID"), 
    names_transform = list(ID = as.integer)
  ),
  join_by(ID)
)

(penguins_predict_draws |> summarize(across(contains("pred"), rvar), .by = c(ID, species, body_mass_kg_cntr, flipper_length_mm)))

Prediction plots

Generating PPC:

make_ppc_density_plot(penguins_predict_draws, "prediction")

Generating the prediction curves:

make_prediction_plot(penguins_predict_draws, "epred")

2.3.3 Generating predictions in R

Here, we write an R function to generate predictions based on the posterior samples of our initial model’s coefficients. Basically, we’re doing the same thing as the previous solution, but in R instead of Stan.

Let’s start by writing a function to generate predictions from posterior samples:

penguins_rstan_draws_samp <- spread_draws(penguins_rstan_fit, `alpha\\[\\d+\\]`, beta_bodymass, sigma, ndraws = 100, regex = TRUE)

generate_penguins_r_preds <- function(...) {
  args <- list(...)

  return(
    penguins_rstan_draws_samp
    |> mutate(
      .linpred = `alpha[1]` * args$Adelie + `alpha[2]` * args$Chinstrap + `alpha[3]` * args$Gentoo + 
        beta_bodymass * args$body_mass_kg_cntr,
      .epred = .linpred, # No inverse link function to apply here
      .prediction = rnorm(n(), .epred, sigma)
    )
    |> select(.draw, .epred, .prediction)
  )
}

Then, let’s get the data (i.e. predictors) we need to feed that function into the proper format (often called a design matrix):

(penguins_design_matrix <- penguins_data 
  |> pivot_wider(names_from = species, values_from = species, values_fn = \(x) as.integer(!is.na(x)), values_fill = 0)
)

Instead of our original data (here, penguins_data), we could have provided a new dataset (e.g. one generated using modelr::data_grid)

Now, let’s apply the prediction function to this design matrix, and then condense the results into rvar for legibility:

penguins_r_preds <- (penguins_design_matrix 
  |> mutate(preds = pmap(pick(-ID), generate_penguins_r_preds)) # Generating the predictions rowwise
  |> unnest(cols = preds)
  |> pivot_longer(
    cols = levels(penguins_data$species), names_to = "species", 
    values_transform = \(x) na_if(x, 0), values_drop_na = TRUE
  )
  |> select(ID, species, body_mass_kg_cntr, flipper_length_mm, .draw, .epred, .prediction)
)

(penguins_r_preds |> summarize(across(contains("pred"), rvar), .by = c(ID, species, body_mass_kg_cntr, flipper_length_mm)))

Prediction plots

Generating PPC:

make_ppc_density_plot(penguins_r_preds)

Generating the prediction curves:

make_prediction_plot(penguins_r_preds)

2.3.4 Hijacking brms

The last solution would be to hijack brms (i.e. inject our rstan model into an empty brmsfit object) in order to use brms prediction/plotting tools. However, this require our Stan model to use the same naming scheme for parameters as the Stan code brms generates, in order for brms to properly link our rstan model’s parameters to the variables in our data (penguins_data).

Which means we need to refit our model entirely here, with the correct names, since changing an already-fitted Stan-model parameters’ is a mess.

First, let’s define the model’s formula and priors:

penguins_formula <- bf(flipper_length_mm ~ 0 + body_mass_kg_cntr + species, family = gaussian(), center = FALSE)

penguins_priors <- c(
  prior(normal(200, 100), class = "b")
  , prior(normal(0, 100), coef = "body_mass_kg_cntr")
  , prior(gamma(3, 2), class = "sigma")
)

Now, have brms generate the Stan code, which we’ll use as a reference to rewrite our own:

penguins_brms_stan_code <- make_stancode(formula = penguins_formula, data = penguins_data, prior = penguins_priors)
brms-generated Stan code (not run)
data {
  int<lower=1> N;  // total number of observations
  vector[N] Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
}

parameters {
  vector[K] b;  // regression coefficients
  real<lower=0> sigma;  // dispersion parameter
}

transformed parameters {
  real lprior = 0;  // prior contributions to the log posterior
  lprior += normal_lpdf(b[1] | 0, 100);
  lprior += normal_lpdf(b[2] | 200, 100);
  lprior += normal_lpdf(b[3] | 200, 100);
  lprior += normal_lpdf(b[4] | 200, 100);
  lprior += gamma_lpdf(sigma | 3, 2);
}

model {
  // likelihood including constants
  target += normal_id_glm_lpdf(Y | X, 0, b, sigma);
  // priors including constants
  target += lprior;
}

And here’s our model rewritten using brms’ variable naming scheme:

penguins_brms_stan_model
data {
  int<lower=1> N;  // total number of observations
  vector[N] Y;     // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
}

parameters {
  vector[K] b;          // regression coefficients
  real<lower=0> sigma;  // dispersion parameter
}

model {
  b[1] ~ normal(0, 100); // prior for body_mass_kg_cntr's coefficient
  
  for (k in 2:K) 
    b[k] ~ normal(200, 100); // priors for the intercepts (speciesAdelie, ...)
  
  sigma ~ gamma(3, 2);
  
  Y ~ normal(X * b, sigma);
}

Now, let’s have brms generate the data for that Stan model:

penguins_brms_stan_data <- make_standata(formula = penguins_formula, data = penguins_data) |> discard_at("prior_only")


Then, let’s fit our brms-like Stan model with rstan:

penguins_brms_rstan_fit <- sampling(
  penguins_brms_stan_model, penguins_brms_stan_data,
  chains = 4, cores = 4, warmup = 1500, iter = 4000,
  refresh = 0, seed = 256
)

And finally, we inject the resulting stanfit object into an empty brms shell:

penguins_fake_brms_mod <- brm(
  penguins_formula,
  penguins_data,
  empty = TRUE
)

penguins_fake_brms_mod$fit <- penguins_brms_rstan_fit
penguins_fake_brms_mod <- rename_pars(penguins_fake_brms_mod) # This would not work with our original parameter names
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: flipper_length_mm ~ 0 + body_mass_kg_cntr + species 
   Data: penguins_data (Number of observations: 342) 
  Draws: 4 chains, each with iter = 4000; warmup = 1500; thin = 1;
         total post-warmup draws = 10000

Population-Level Effects: 
                  Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
body_mass_kg_cntr     8.40      0.64     7.15     9.62 1.00     4674     6306
speciesAdelie       194.16      0.54   193.09   195.22 1.00     5806     6582
speciesChinstrap    199.76      0.72   198.34   201.19 1.00     6872     6683
speciesGentoo       209.84      0.74   208.41   211.30 1.00     5179     6787

Family Specific Parameters: 
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     5.35      0.20     4.97     5.76 1.00    10127     7527

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

And now, we can use the very convenient add_*_draws/rvars family of functions:

add_epred_rvars(penguins_data, penguins_fake_brms_mod, ndraws = 500)

Prediction plots

Generating PPC:

pp_check(penguins_fake_brms_mod, type = "dens_overlay", ndraws = 100)

Generating the prediction curves:

penguins_data |> 
  add_epred_draws(penguins_fake_brms_mod, ndraws = 100) |> 
  make_prediction_plot()

But at this point, why not simply fit the model with brms in the first place ?

3 BRMS

3.1 Model fitting

(With brms default priors)

penguins_mod_brms <- brm(
  penguins_formula, penguins_data,
  chains = 4, cores = 4, warmup = 1500, iter = 4000,
  refresh = 0, seed = 256, backend = "cmdstanr", silent = 2
)

3.2 Model predictions

add_epred_rvars(penguins_data, penguins_mod_brms, ndraws = 500)

Prediction plots

Generating PPC:

pp_check(penguins_mod_brms, type = "dens_overlay", ndraws = 100)

Generating the prediction curves:

add_epred_draws(penguins_data, penguins_mod_brms, ndraws = 100) |> 
  make_prediction_plot()