Bayesian analysis with Stan
Extracting draws and generating predictions
Setup
Installing CmdStan
cmdstanr::check_cmdstan_toolchain(fix = TRUE, quiet = TRUE)
cpp_opts <- list(
stan_threads = TRUE
, STAN_CPP_OPTIMS = TRUE
, PRECOMPILED_HEADERS = TRUE
, CXXFLAGS_OPTIM = "-march=native -mtune=native"
, CXXFLAGS_OPTIM_TBB = "-mtune=native -march=native"
, CXXFLAGS_OPTIM_SUNDIALS = "-mtune=native -march=native"
)
cmdstanr::install_cmdstan(cpp_options = cpp_opts, quiet = TRUE)cmdstanr::set_cmdstan_path("/home/mar/Dev/SDK/.cmdstan/cmdstan-2.33.1")
cmdstanr::register_knitr_engine(override = FALSE)1 Data
How does body mass affect flipper length for the different species?
- Dependent variable: flipper length (mm).
- Independent variable: body mass (kg, converted from g)
(penguins_data <- penguins
|> mutate(
body_mass_kg = body_mass_g/1000,
body_mass_kg_cntr = body_mass_kg - mean(body_mass_kg, na.rm = TRUE)
)
|> filter(!is.na(flipper_length_mm), !is.na(species), !is.na(body_mass_kg_cntr))
|> mutate(ID = row_number())
|> select(ID, species, body_mass_kg_cntr, flipper_length_mm)
)penguins_stan_data <- with(
penguins_data,
list(
N = nrow(penguins_data),
bodymass = body_mass_kg_cntr,
species = as.integer(species),
flipperlength = flipper_length_mm
)
)2 RStan
2.1 Model fitting
Compiling
penguins_rstan_model
data {
int<lower=0> N;
vector[N] bodymass;
array[N] int<lower=1, upper=3> species;
vector[N] flipperlength;
}
parameters {
vector[3] alpha;
real beta_bodymass;
real sigma;
}
model {
alpha ~ normal(200, 100);
beta_bodymass ~ normal(0, 100);
sigma ~ gamma(3, 2);
flipperlength ~ normal(alpha[species] + beta_bodymass * bodymass, sigma);
}Sampling
penguins_rstan_fit <- sampling(
penguins_rstan_model, penguins_stan_data,
chains = 4, cores = 4, warmup = 1500, iter = 4000,
refresh = 0, seed = 256
)Inference for Stan model: anon_model.
4 chains, each with iter=4000; warmup=1500; thin=1;
post-warmup draws per chain=2500, total post-warmup draws=10000.
mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
alpha[1] 194.16 0.01 0.53 193.11 194.17 195.18 5869 1
alpha[2] 199.76 0.01 0.71 198.38 199.76 201.12 7472 1
alpha[3] 209.84 0.01 0.74 208.43 209.83 211.30 5165 1
beta_bodymass 8.40 0.01 0.63 7.14 8.40 9.63 4403 1
sigma 5.35 0.00 0.20 4.97 5.34 5.76 9122 1
Samples were drawn using NUTS(diag_e) at Sun Nov 26 21:12:43 2023.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
2.2 Model diagnostics
2.2.1 Posterior draws
Using rstan:
as.data.frame(penguins_rstan_fit) |> select(matches("alpha"), beta_bodymass, sigma)Alternative
rstan::extract(penguins_rstan_fit, c("alpha", "beta_bodymass", "sigma")) |> as.data.frame()Using posterior:
(penguins_rstan_draws <- penguins_rstan_fit |> as_draws_df() |> subset_draws(c("alpha", "beta_bodymass", "sigma")))Using tidybayes:
Wide format:
spread_draws(penguins_rstan_fit, `alpha\\[\\d+\\]`, beta_bodymass, sigma, ndraws = 500, regex = TRUE)Semi-long format:
spread_draws(penguins_rstan_fit, alpha[species], beta_bodymass, sigma, ndraws = 500)Long format:
gather_draws(penguins_rstan_fit, alpha[species], beta_bodymass, sigma, ndraws = 500)rvar format, which is faster & less memory intensive:
gather_rvars(penguins_rstan_fit, alpha[species], beta_bodymass, sigma, ndraws = 500)2.2.2 Diagnostic plots
Just a sample of useful diagnostic plots
neff_ratio(penguins_rstan_fit) |> mcmc_neff_hist()rhat(penguins_rstan_fit) |> mcmc_rhat_hist()mcmc_acf(penguins_rstan_draws)wrap_plots(
mcmc_hist(penguins_rstan_draws, facet_args = list(nrow = ncol(penguins_rstan_draws))),
mcmc_trace(penguins_rstan_draws, facet_args = list(nrow = ncol(penguins_rstan_draws))),
widths = c(1, 1.5)
)2.3 Model predictions
The low-level Stan interfaces (e.g. rstan & cmdstanr) don’t have the required machinery to generate predictions (e.g. something like posterior_epred). Making predictions requires running data through the model, which is defined in our Stan code. R doesn’t have access to that.
There are generally 4 ways to go about making predictions from an rstan / cmdstanr model:
- Modifying the model’s Stan code to give it the ability to generate predictions (the recommended way)
- Creating a new model Stan model just to generate predictions
- Re-creating the model in R and using it to generate predictions based on the Stan model’s posterior draws
- Hijacking
brmsmachinery and letting it do the work for us
2.3.1 Adding predictive abilities to the model
Here, the idea is to change our original model to include the ability to make predictions for the provided data.
See this page of the Stan manual.
penguins_rstan_model2
data {
int<lower=0> N;
vector[N] bodymass;
array[N] int<lower=1, upper=3> species;
vector[N] flipperlength;
}
parameters {
vector[3] alpha;
real beta_bodymass;
real sigma;
}
model {
alpha ~ normal(200, 100);
beta_bodymass ~ normal(0, 100);
sigma ~ gamma(3, 2);
flipperlength ~ normal(alpha[species] + beta_bodymass * bodymass, sigma);
}
generated quantities {
vector[N] linpred = alpha[species] + beta_bodymass * bodymass;
vector[N] epred = linpred; // No inverse link function to apply here
array[N] real prediction = normal_rng(epred, sigma);
}Of course, that code could be extended to make predictions based on new data, instead of the data used for fitting the model (e.g. both a set of training and testing data)
Sampling the model:
penguins_rstan_fit2 <- sampling(
penguins_rstan_model2, penguins_stan_data,
chains = 4, cores = 4, warmup = 1500, iter = 4000,
refresh = 0, seed = 256
)Combining the predictions to the data:
left_join(penguins_data, spread_rvars(penguins_rstan_fit2, epred[ID], prediction[ID], ndraws = 500, seed = 256), join_by(ID))Prediction plots
Generating PPC:
make_ppc_density_plot
make_ppc_density_plot <- function(dat, pred_name = ".prediction") {
return(
dat
|> select(ID, .draw, {{ pred_name }})
|> pivot_wider(names_from = ID, values_from = {{ pred_name }})
|> select(-.draw)
|> data.matrix()
|> ppc_dens_overlay(y = penguins_data$flipper_length_mm, yrep = _)
+ xlim(162, 243)
)
}(penguins_data
|> left_join(spread_draws(penguins_rstan_fit2, epred[ID], prediction[ID], ndraws = 100, seed = 256), join_by(ID))
|> make_ppc_density_plot("prediction")
)Generating the prediction curves:
make_prediction_plot
make_prediction_plot <- function(dat, epred_name = ".epred") {
return(
ggplot(dat, aes(x = body_mass_kg_cntr, y = flipper_length_mm, color = species))
+ geom_line(aes(y = .data[[epred_name]], group = paste(species, .draw)), alpha = .1)
+ geom_point(data = penguins_data)
+ scale_color_brewer(palette = "Dark2")
)
}(penguins_data
|> left_join(spread_draws(penguins_rstan_fit2, epred[ID], ndraws = 100, seed = 256), join_by(ID))
|> make_prediction_plot("epred")
)2.3.2 Using a new prediction-only Stan model
Here, we make a second Stan model with the sole purpose of generating predictions based on a set of predictor values (i.e. newdata) and a set of posterior samples from the original model.
penguins_predict_model
data {
// New data to make predictions for
int N_newdata;
array[N_newdata] int<lower=1, upper=3> species;
vector[N_newdata] bodymass;
// Posterior samples from the original model
int N_draws;
matrix[N_draws, 3] alphas;
vector[N_draws] beta_bodymass;
vector[N_draws] sigma;
}
parameters {}
model {}
generated quantities {
matrix[N_draws, N_newdata] linpred;
matrix[N_draws, N_newdata] epred;
matrix[N_draws, N_newdata] prediction;
for(n in 1:N_newdata) {
for(i in 1:N_draws) {
linpred[i, n] = alphas[i, species[n]] + beta_bodymass[i] * bodymass[n];
epred[i, n] = linpred[i, n]; // No inverse link function to apply here
prediction[i, n] = normal_rng(epred[i, n], sigma[i]);
}
}
}Generating the data for the prediction model (i.e. new data & posterior samples from the original model):
penguins_rstan_draws_samp <- penguins_rstan_draws |> as.data.frame() |> slice_sample(n = 100)
penguins_predict_data <- list(
N_newdata = nrow(penguins_data),
species = as.integer(penguins_data$species),
bodymass = penguins_data$body_mass_kg_cntr,
N_draws = 100,
alphas = select(penguins_rstan_draws_samp, matches("alpha")) |> data.matrix(),
beta_bodymass = pull(penguins_rstan_draws_samp, beta_bodymass),
sigma = pull(penguins_rstan_draws_samp, sigma)
)
Sampling the model:
penguins_predict_fit <- sampling(
penguins_predict_model,
penguins_predict_data,
chains = 1, iter = 1,
refresh = 0, seed = 256,
algorithm = "Fixed_param"
)Combining the predictions to the data, and then condensing the results into rvar for legibility:
penguins_predict_draws <- left_join(
penguins_data,
as.data.frame(penguins_predict_fit)
|> select(matches("^epred|^prediction"))
|> pivot_longer(
cols = everything(), names_pattern = "(epred|prediction)\\[(\\d+),(\\d+)\\]", names_to = c(".value", ".draw", "ID"),
names_transform = list(ID = as.integer)
),
join_by(ID)
)
(penguins_predict_draws |> summarize(across(contains("pred"), rvar), .by = c(ID, species, body_mass_kg_cntr, flipper_length_mm)))Prediction plots
Generating PPC:
make_ppc_density_plot(penguins_predict_draws, "prediction")Generating the prediction curves:
make_prediction_plot(penguins_predict_draws, "epred")2.3.3 Generating predictions in R
Here, we write an R function to generate predictions based on the posterior samples of our initial model’s coefficients. Basically, we’re doing the same thing as the previous solution, but in R instead of Stan.
Let’s start by writing a function to generate predictions from posterior samples:
penguins_rstan_draws_samp <- spread_draws(penguins_rstan_fit, `alpha\\[\\d+\\]`, beta_bodymass, sigma, ndraws = 100, regex = TRUE)
generate_penguins_r_preds <- function(...) {
args <- list(...)
return(
penguins_rstan_draws_samp
|> mutate(
.linpred = `alpha[1]` * args$Adelie + `alpha[2]` * args$Chinstrap + `alpha[3]` * args$Gentoo +
beta_bodymass * args$body_mass_kg_cntr,
.epred = .linpred, # No inverse link function to apply here
.prediction = rnorm(n(), .epred, sigma)
)
|> select(.draw, .epred, .prediction)
)
}Then, let’s get the data (i.e. predictors) we need to feed that function into the proper format (often called a design matrix):
(penguins_design_matrix <- penguins_data
|> pivot_wider(names_from = species, values_from = species, values_fn = \(x) as.integer(!is.na(x)), values_fill = 0)
)Instead of our original data (here, penguins_data), we could have provided a new dataset (e.g. one generated using modelr::data_grid)
Now, let’s apply the prediction function to this design matrix, and then condense the results into rvar for legibility:
penguins_r_preds <- (penguins_design_matrix
|> mutate(preds = pmap(pick(-ID), generate_penguins_r_preds)) # Generating the predictions rowwise
|> unnest(cols = preds)
|> pivot_longer(
cols = levels(penguins_data$species), names_to = "species",
values_transform = \(x) na_if(x, 0), values_drop_na = TRUE
)
|> select(ID, species, body_mass_kg_cntr, flipper_length_mm, .draw, .epred, .prediction)
)
(penguins_r_preds |> summarize(across(contains("pred"), rvar), .by = c(ID, species, body_mass_kg_cntr, flipper_length_mm)))Prediction plots
Generating PPC:
make_ppc_density_plot(penguins_r_preds)Generating the prediction curves:
make_prediction_plot(penguins_r_preds)2.3.4 Hijacking brms
The last solution would be to hijack brms (i.e. inject our rstan model into an empty brmsfit object) in order to use brms prediction/plotting tools. However, this require our Stan model to use the same naming scheme for parameters as the Stan code brms generates, in order for brms to properly link our rstan model’s parameters to the variables in our data (penguins_data).
Which means we need to refit our model entirely here, with the correct names, since changing an already-fitted Stan-model parameters’ is a mess.
First, let’s define the model’s formula and priors:
Now, have brms generate the Stan code, which we’ll use as a reference to rewrite our own:
penguins_brms_stan_code <- make_stancode(formula = penguins_formula, data = penguins_data, prior = penguins_priors)brms-generated Stan code (not run)
data {
int<lower=1> N; // total number of observations
vector[N] Y; // response variable
int<lower=1> K; // number of population-level effects
matrix[N, K] X; // population-level design matrix
}
parameters {
vector[K] b; // regression coefficients
real<lower=0> sigma; // dispersion parameter
}
transformed parameters {
real lprior = 0; // prior contributions to the log posterior
lprior += normal_lpdf(b[1] | 0, 100);
lprior += normal_lpdf(b[2] | 200, 100);
lprior += normal_lpdf(b[3] | 200, 100);
lprior += normal_lpdf(b[4] | 200, 100);
lprior += gamma_lpdf(sigma | 3, 2);
}
model {
// likelihood including constants
target += normal_id_glm_lpdf(Y | X, 0, b, sigma);
// priors including constants
target += lprior;
}And here’s our model rewritten using brms’ variable naming scheme:
penguins_brms_stan_model
data {
int<lower=1> N; // total number of observations
vector[N] Y; // response variable
int<lower=1> K; // number of population-level effects
matrix[N, K] X; // population-level design matrix
}
parameters {
vector[K] b; // regression coefficients
real<lower=0> sigma; // dispersion parameter
}
model {
b[1] ~ normal(0, 100); // prior for body_mass_kg_cntr's coefficient
for (k in 2:K)
b[k] ~ normal(200, 100); // priors for the intercepts (speciesAdelie, ...)
sigma ~ gamma(3, 2);
Y ~ normal(X * b, sigma);
}Now, let’s have brms generate the data for that Stan model:
penguins_brms_stan_data <- make_standata(formula = penguins_formula, data = penguins_data) |> discard_at("prior_only")
Then, let’s fit our brms-like Stan model with rstan:
penguins_brms_rstan_fit <- sampling(
penguins_brms_stan_model, penguins_brms_stan_data,
chains = 4, cores = 4, warmup = 1500, iter = 4000,
refresh = 0, seed = 256
)And finally, we inject the resulting stanfit object into an empty brms shell:
penguins_fake_brms_mod <- brm(
penguins_formula,
penguins_data,
empty = TRUE
)
penguins_fake_brms_mod$fit <- penguins_brms_rstan_fit
penguins_fake_brms_mod <- rename_pars(penguins_fake_brms_mod) # This would not work with our original parameter names Family: gaussian
Links: mu = identity; sigma = identity
Formula: flipper_length_mm ~ 0 + body_mass_kg_cntr + species
Data: penguins_data (Number of observations: 342)
Draws: 4 chains, each with iter = 4000; warmup = 1500; thin = 1;
total post-warmup draws = 10000
Population-Level Effects:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
body_mass_kg_cntr 8.40 0.64 7.15 9.62 1.00 4674 6306
speciesAdelie 194.16 0.54 193.09 195.22 1.00 5806 6582
speciesChinstrap 199.76 0.72 198.34 201.19 1.00 6872 6683
speciesGentoo 209.84 0.74 208.41 211.30 1.00 5179 6787
Family Specific Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 5.35 0.20 4.97 5.76 1.00 10127 7527
Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
And now, we can use the very convenient add_*_draws/rvars family of functions:
add_epred_rvars(penguins_data, penguins_fake_brms_mod, ndraws = 500)Prediction plots
Generating PPC:
pp_check(penguins_fake_brms_mod, type = "dens_overlay", ndraws = 100)Generating the prediction curves:
penguins_data |>
add_epred_draws(penguins_fake_brms_mod, ndraws = 100) |>
make_prediction_plot()But at this point, why not simply fit the model with brms in the first place ?
3 BRMS
3.1 Model fitting
(With brms default priors)
penguins_mod_brms <- brm(
penguins_formula, penguins_data,
chains = 4, cores = 4, warmup = 1500, iter = 4000,
refresh = 0, seed = 256, backend = "cmdstanr", silent = 2
)3.2 Model predictions
add_epred_rvars(penguins_data, penguins_mod_brms, ndraws = 500)Prediction plots
Generating PPC:
pp_check(penguins_mod_brms, type = "dens_overlay", ndraws = 100)Generating the prediction curves:
add_epred_draws(penguins_data, penguins_mod_brms, ndraws = 100) |>
make_prediction_plot()