library(tidyverse)
library(brms)
library(lme4)
library(rstanarm)
library(bridgesampling)

data("PlantGrowth")

d <- PlantGrowth %>% 
  mutate(category = case_when(
    group == "ctrl" ~ "control",
    TRUE ~ "test"
  ))

Dataset

This is a toy dataset that comes with R. I added some modification so that we can play around with the mixed effect model. Below is a visual overview of this dataset. Note that we have 2 categoires and 3 group. And we want to treat the cateogry as the predictor and the group as random effect.

d %>% 
  ggplot(aes(x = category, y = weight)) + 
  geom_jitter(width = .3, aes(color = group)) + 
  stat_summary(fun.data = "mean_cl_boot") + 
  theme_classic()

Frequentist model

First we want to see what a good old frequentist model would say. So it looks like there’s nothing much going on her (if you are into significance testing, then it’s so much greater than 0.05 threshold.) It’s alright. We are not hoping this is going to work anyway.

fm <- lmer(weight ~ category + (1|group), data = d)
summary(fm)
## Linear mixed model fit by REML ['lmerMod']
## Formula: weight ~ category + (1 | group)
##    Data: d
## 
## REML criterion at convergence: 60.6
## 
## Scaled residuals: 
##      Min       1Q   Median       3Q      Max 
## -1.79013 -0.66249  0.00241  0.43235  2.12405 
## 
## Random effects:
##  Groups   Name        Variance Std.Dev.
##  group    (Intercept) 0.3353   0.5790  
##  Residual             0.3886   0.6234  
## Number of obs: 30, groups:  group, 3
## 
## Fixed effects:
##              Estimate Std. Error t value
## (Intercept)    5.0320     0.6117   8.227
## categorytest   0.0615     0.7491   0.082
## 
## Correlation of Fixed Effects:
##             (Intr)
## categorytst -0.816

Bayesian models

brms vs rstanarm

Now let’s turn to Bayesian Models. So there are two packages that are widely used in running bayesian models. brms and rstanarm. In the past I’ve been defaulted into using brms because that’s what everyone else in the lab are using. According to the wise strangers on the internet, these two packages are pretty similar to each other. Both are interfaces to work with Stan, both have very lmer like syntax.

But the first key difference is that rstanarm models in pre-compiled (I don’t really know what it means) whereas brms models are not. As a result, rstanarm will be slightly faster than brms, but the cost is having less flexibility. However, I have yet to ran into use cases where the flexibility seems to be an issue.

Another difference that is not very obvious at first is the way these two packages pick prior for the models. And this is going to be the key topic in this rpub today – does this difference really matter?

(spoiler alert: yes it does.)

models to compare

Here we are going to look at 5 different cases:

  • m1_brms_default: brms model ran on default prior.
  • m2_rstanarm_default: rstanarm model ran on default prior
  • m3_brms_rprior: brms model ran on the rstanarm default prior for predictor
  • m4_nl_rprior: brms` model ran on “lazy normal priors”
  • m5_nl_rprior: rstanarm` model ran on lazy normal priors”

We will first run the model

get_prior_string_from_rstanarm <- function(rstanarm_model){
  paste0(prior_summary(rstanarm_model)$prior$dist,
         "(",prior_summary(rstanarm_model)$prior$location,",",
         prior_summary(rstanarm_model)$prior$adjusted_scale,")")

}
# `brms` model ran on default prior.
start_time = proc.time()
m1_brms_default <- brm(weight ~ category + (1|group), 
                     data = d,
                    save_pars = save_pars(all = TRUE), 
                    file = "m1_bm_default_b.rds"
                  )
end_time = proc.time()
m1_brms_default_t = end_time - start_time


# `rstanarm` model ran on default prior 
start_time = proc.time()
m2_rstanarm_default <- rstanarm::stan_glmer(weight ~ category + (1|group), data = d, diagnostic_file = file.path(tempdir(), "df.csv"))
end_time = proc.time()
m2_rstanarm_default_t = end_time - start_time
saveRDS(m2_rstanarm_default, "m2_rstanarm_default.rds")

# `brms` model ran on the `rstanarm` default prior for predictor 
rstanarm_prior <- c(set_prior(get_prior_string_from_rstanarm(bm_full_default),class = "b", coef = "categorytest"))
start_time = proc.time()
m3_brms_rprior <- brm(weight ~ category + (1|group), data = d,
                         prior = rstanarm_prior,
                  save_pars = save_pars(all = TRUE),
                  file = "m3_brms_rprior.rds"
                  )
end_time = proc.time()
m3_brms_rprior_t = end_time - start_time

#`brms` model ran on "lazy normal priors" 
start_time = proc.time()
nl_prior <- c(set_prior("normal(0, 100000)",class = "b", coef = "categorytest"))
m4_nl_rprior <- brm(weight ~ category + (1|group), data = d,
                         prior = nl_prior,
                  save_pars = save_pars(all = TRUE),
                  file = "m4_nl_rprior.rds"
                  )
end_time = proc.time()
m4_nl_rprior_t = end_time - start_time

#`rstanarm` model ran on lazy normal priors" 
start_time = proc.time()
m5_nl_rprior <-  rstanarm::stan_glmer(weight ~ category + (1|group), data = d, prior = normal(location = c(0), scale = c(100000)), diagnostic_file = file.path(tempdir(), "df.csv"))
end_time = proc.time()
m5_nl_rprior_t = end_time-start_time
saveRDS(m5_nl_rprior, "m5_nl_rprior.rds")

Comparison

coefficient

Do the estimates of the coefficients differ? Interestingly, it seems to be quite stable across the five models.

get_tidy_model_fit <- function(model, name, prob){
  if(length(class(model)) == 1){
    tibble(
      name = name,
      prob = prob, 
      estimate = (summary(model)$fixed %>%   
                    rownames_to_column("predictor") %>% 
                    filter(predictor == name))$Estimate, 
      ci_lower = (posterior_interval(model, prob = prob)
                  %>%   as.data.frame() %>% 
                    rownames_to_column("predictor") %>% 
                    filter(predictor == paste0("b_",name)))[,2] ,
      ci_upper = (posterior_interval(model, prob = prob) %>%  as.data.frame() %>%  
                    rownames_to_column("predictor") %>% 
                    filter(predictor == paste0("b_",name)))[,3] 
    ) 
  }else{
    tibble(
      name = name,
      prob = prob, 
      estimate = (summary(model) %>% as.data.frame() %>% rownames_to_column("predictor") %>% filter(grepl(name, predictor) & !grepl("\\[", predictor)))$mean,
      ci_lower = (posterior_interval(model, prob = .95) %>% as.data.frame() %>% rownames_to_column("predictor") %>% filter(grepl(name, predictor) & !grepl("\\[", predictor)))[,2], 
      ci_upper = (posterior_interval(model, prob = .95) %>% as.data.frame() %>% rownames_to_column("predictor") %>% filter(grepl(name, predictor) & !grepl("\\[", predictor)))[,3]
    ) 
  }
  
}
b_ef <- bind_rows(
  get_tidy_model_fit(m1_brms_default, "Intercept", .95) %>% mutate(model_name = "m1_brms_default"),
  get_tidy_model_fit(m1_brms_default, "categorytest", .95) %>% mutate(model_name = "m1_brms_default"),
  get_tidy_model_fit(m2_rstanarm_default, "(Intercept)", .95) %>% mutate(model_name = "m2_rstanarm_default"),
  get_tidy_model_fit(m2_rstanarm_default, "categorytest", .95)%>% mutate(model_name = "m2_rstanarm_default"),
  get_tidy_model_fit(m3_brms_rprior, "Intercept", .95) %>% mutate(model_name = "m3_brms_rprior"),
  get_tidy_model_fit(m3_brms_rprior, "categorytest", .95) %>% mutate(model_name = "m3_brms_rprior"),
  get_tidy_model_fit(m4_nl_rprior, "Intercept", .95) %>% 
  mutate(model_name = "m4_nl_rprior"),
  get_tidy_model_fit(m4_nl_rprior, "categorytest", .95) %>%  mutate(model_name = "m4_nl_rprior"),
  get_tidy_model_fit(m5_nl_rprior, "(Intercept)", .95) %>%  mutate(model_name = "m5_nl_rprior"),
  get_tidy_model_fit(m5_nl_rprior, "categorytest", .95) %>%  mutate(model_name = "m5_nl_rprior")
)
saveRDS(b_ef, "b_ef.rds")
b_ef <- readRDS("b_ef.rds")
b_ef %>% 
  mutate(name = case_when(
    name == "(Intercept)" ~ "Intercept",
    TRUE ~ name
  )) %>% 
  ggplot(aes(x = model_name, 
             y = estimate, 
             ymin = ci_lower, 
             ymax = ci_upper)) + 
  geom_pointrange() + 
  facet_wrap(~name) + 
  coord_flip() + 
  theme_classic()

bayes factor

It’s reassuring to see that the estimates are pretty similar across the five models w.r.t the prior. But the difference on the bayes factor is quite intriguing.

A reminder on the interpretation standard we normally use: when BF > 3 we treat it as moderate evidence supporting H1, and when BF < 1/3 we treat it moderate evidence as supporting H0.

So what we have here seems to suggest the brms model ran on the default prior is almost giving us results opposite to what we are seeing in the frequentist model, whereas the others are more consistent.

# running two null model 
rstanarm_null <-  rstanarm::stan_glmer(weight ~ 1 + (1|group), data = d,  diagnostic_file = file.path(tempdir(), "df.csv"))

brms_null <- brm(weight ~ 1 + (1|group), data = d,
                  save_pars = save_pars(all = TRUE))

saveRDS(rstanarm_null, "rstanarm_null.rds")
saveRDS(brms_null, "brms_null.rds")
bfs_df <- tibble(
  model_name = c("m1_brms_default", 
                 "m2_rstanarm_default", 
                 "m3_brms_rprior", 
                 "m4_nl_rprior", 
                 "m5_nl_rprior"), 
  bf_df = c(
    bayes_factor(m1_brms_default, brms_null)$bf,
    bayes_factor(bridge_sampler(m2_rstanarm_default),
                 bridge_sampler(rstanarm_null))$bf, 
    bayes_factor(m3_brms_rprior, brms_null)$bf, 
    bayes_factor(m4_nl_rprior, brms_null)$bf, 
    bayes_factor(bridge_sampler(m5_nl_rprior), 
                 bridge_sampler(rstanarm_null))$bf
  )
)

saveRDS(bfs_df, "bfs_df.rds")
bfs_df = readRDS("bfs_df.rds")
bfs_df %>% 
  ggplot(aes(x = model_name, 
             y = bf_df)) + 
  geom_hline(yintercept = 3, color = "red") + 
  geom_hline(yintercept = 1/3, color = "blue") + 
  geom_point() + 
  coord_flip() + 
  theme_classic()

timing

finally, I’m interested in the performance of the model. How long does it take for each to run?

The spread of rstanarm is really interesting!

bfs_tdf <- tibble(
  model_name = c("m1_brms_default", 
                 "m2_rstanarm_default", 
                 "m3_brms_rprior", 
                 "m4_nl_rprior", 
                 "m5_nl_rprior"), 
  time = c(
   m1_brms_default_t[[3]],
   m2_rstanarm_default_t[[3]], 
   m3_brms_rprior_t[[3]], 
   m4_nl_rprior_t[[3]], 
   m5_nl_rprior_t[[3]]
  ), 
  package_type = c(
    "brms", 
    "rstanarm",
    "brms", 
    "brms", 
    "rstanarm"
  )
)

saveRDS(bfs_tdf, "bfs_tdf.RDS")
bfs_tdf <- readRDS("bfs_tdf.RDS")

bfs_tdf %>% 
  ggplot(aes(x = model_name, 
             y = time, 
             color = package_type)) + 
  geom_point() + 
  coord_flip() + 
  theme_classic()