MCMC for ‘Big Data’ with Stan
Faster sampling with CmdStan using within-chain parallelization
This is an extension (and a translation in R) of PyMC-Labs’ benchmarking of MCMC for Big Data. The Stan code was updated to use within-chain parallelization and compiler optimization for faster CPU sampling.
Please visit this page for a more up-to-date version of this post.
You can check the source code by clicking on the </> Code button at the top-right.
1 Setup
library(here) # File path management
library(pipebind) # Piping goodies
library(data.table) # Data wrangling (fast)
library(dplyr) # Data wrangling
library(tidyr) # Data wrangling (misc)
library(purrr) # Manipulating lists
library(stringr) # Manipulating strings
library(lubridate) # Manipulating dates
library(cmdstanr) # R interface with Stan
library(posterior) # Wrangling Stan model ouputs
library(ggplot2) # Plots
library(patchwork) # Combining plots
library(ggridges) # Ridgeline plots
library(bayesplot) # Plots for Stan models
options(
mc.cores = max(1L, parallel::detectCores(logical = TRUE)),
scipen = 999L,
digits = 4L,
ggplot2.discrete.colour = \() scale_color_viridis_d(),
ggplot2.discrete.fill = \() scale_fill_viridis_d()
)
nrows_print <- 10
data.table::setDTthreads(getOption("mc.cores"))1.1 Stan setup
Installing CmdStan
cmdstanr::check_cmdstan_toolchain(fix = TRUE, quiet = TRUE)
cpp_opts <- list(
stan_threads = TRUE
, STAN_CPP_OPTIMS = TRUE
, STAN_NO_RANGE_CHECKS = TRUE # WARN: remove this if you haven't tested the model
, PRECOMPILED_HEADERS = TRUE
, CXXFLAGS_OPTIM = "-march=native -mtune=native"
, CXXFLAGS_OPTIM_TBB = "-mtune=native -march=native"
, CXXFLAGS_OPTIM_SUNDIALS = "-mtune=native -march=native"
)
cmdstanr::install_cmdstan(cpp_options = cpp_opts, quiet = TRUE)Loading CmdStan (if already installed)
highest_cmdstan_version <- fs::dir_ls(config$cmdstan_path) |> fs::path_file() |>
keep(\(e) str_detect(e, "cmdstan-")) |>
bind(x, str_split(x, '-', simplify = TRUE)[,2]) |>
reduce(\(x, y) ifelse(utils::compareVersion(x, y) == 1, x, y))
set_cmdstan_path(glue::glue("{config$cmdstan_path}cmdstan-{highest_cmdstan_version}"))Setting up knitr’s engine for CmdStan
## Inspired by: https://mpopov.com/blog/2020/07/30/replacing-the-knitr-engine-for-stan/
## Note: We could haved use cmdstanr::register_knitr_engine(),
## but it wouldn't include compiler optimizations & multi-threading by default
knitr::knit_engines$set(
cmdstan = function(options) {
output_var <- options$output.var
if (!is.character(output_var) || length(output_var) != 1L) {
stop(
"The chunk option output.var must be a character string ",
"providing a name for the returned `CmdStanModel` object."
)
}
if (options$eval) {
if (options$cache) {
cache_path <- options$cache.path
if (length(cache_path) == 0L || is.na(cache_path) || cache_path == "NA")
cache_path <- ""
dir <- paste0(cache_path, options$label)
} else {
dir <- tempdir()
}
file <- write_stan_file(options$code, dir = dir, force_overwrite = TRUE)
mod <- cmdstan_model(
file,
cpp_opts <- list(
stan_threads = TRUE
, STAN_CPP_OPTIMS = TRUE
, STAN_NO_RANGE_CHECKS = TRUE # The model was already tested
, PRECOMPILED_HEADERS = TRUE
# , CXXFLAGS_OPTIM = "-march=native -mtune=native"
, CXXFLAGS_OPTIM_TBB = "-mtune=native -march=native"
, CXXFLAGS_OPTIM_SUNDIALS = "-mtune=native -march=native"
),
stanc_options = list("Oexperimental")
)
assign(output_var, mod, envir = knitr::knit_global())
}
options$engine <- "stan"
code <- paste(options$code, collapse = "\n")
knitr::engine_output(options, code, '')
}
)─ Session info ───────────────────────────────────────────────────────────────
setting value
version R version 4.2.1 (2022-06-23)
os Ubuntu 20.04.4 LTS
system x86_64, linux-gnu
ui X11
language (EN)
collate C.UTF-8
ctype C.UTF-8
tz Europe/Paris
date 2022-09-24
pandoc 2.19.2 @ /usr/lib/rstudio-server/bin/quarto/bin/tools/ (via rmarkdown)
Quarto 1.1.251
Stan (CmdStan) 2.30.1
─ Packages ───────────────────────────────────────────────────────────────────
! package * version date (UTC) lib source
bayesplot * 1.9.0 2022-03-10 [1] CRAN (R 4.2.0)
cmdstanr * 0.5.3 2022-08-03 [1] local
data.table * 1.14.2 2021-09-27 [1] CRAN (R 4.2.0)
dplyr * 1.0.10 2022-09-01 [1] CRAN (R 4.2.1)
ggplot2 * 3.3.6 2022-05-03 [1] CRAN (R 4.2.0)
ggridges * 0.5.3 2021-01-08 [1] CRAN (R 4.2.0)
P here * 1.0.1 2020-12-13 [?] CRAN (R 4.2.0)
lubridate * 1.8.0 2021-10-07 [1] CRAN (R 4.2.0)
patchwork * 1.1.2 2022-08-19 [1] CRAN (R 4.2.1)
pipebind * 0.1.1 2022-08-10 [1] CRAN (R 4.2.0)
posterior * 1.3.1 2022-09-06 [1] CRAN (R 4.2.1)
purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.2.0)
P stringr * 1.4.1 2022-08-20 [?] CRAN (R 4.2.1)
tidyr * 1.2.1 2022-09-08 [1] CRAN (R 4.2.1)
[1] /home/mar/Dev/Projects/R/Bayes/renv/library/R-4.2/x86_64-pc-linux-gnu
[2] /home/mar/.cache/R/renv/library/Bayes-9578a481/R-4.2/x86_64-pc-linux-gnu
[3] /usr/lib/R/library
[4] /usr/local/lib/R/site-library
[5] /usr/lib/R/site-library
P ── Loaded and on-disk path mismatch.
──────────────────────────────────────────────────────────────────────────────
2 Data
2.1 Matches data
Loading the raw matches data:
matches_data_raw <- purrr::map_dfr(
fs::dir_ls(matches_data_path, regexp = "atp_matches_(.*).csv"),
\(f) readr::read_csv(f, num_threads = 32, show_col_types = FALSE) |>
select(tourney_date, tourney_level, round, winner_id, winner_name, loser_id, loser_name, score)
) |> mutate(tourney_date = lubridate::ymd(tourney_date))Filtering matches based on the original posts’ data processing:
round_numbers = list(
"R128" = 1,
"RR" = 1,
"R64" = 2,
"R32" = 3,
"R16" = 4,
"QF" = 5,
"SF" = 6,
"F" = 7
)
(matches_data_clean <- matches_data_raw
|> filter(
tourney_date %between% c("1968-01-01", "2021-06-20"),
str_detect(score, "RET|W/O|DEF|nbsp|Def.", negate = TRUE),
str_length(score) > 4,
tourney_level != "D",
round %in% names(round_numbers)
)
|> mutate(
round_number = sapply(round, \(r) round_numbers[[r]]),
label = 1
)
|> arrange(tourney_date, round_number)
|> select(-round, -tourney_level)
)data.frame [160,399 x 8]
| [ omitted 160,384 entries ] |
2.2 Player data
Loading the raw player data:
(player_data_raw <- readr::read_csv(player_data_path)
|> mutate(player_name = str_c(name_first, name_last, sep = " "))
|> select(player_id, player_name)
)data.frame [55,649 x 2]
| [ omitted 55,634 entries ] |
Filtering player_data to only keep the players actually present in our data, and updating their IDs:
2.3 Matches + Player data
Allocating the new player IDs (player_idx) to the winner_id and loser_id from matches_data:
(matches_data <- left_join(matches_data_clean, player_data, by = c("winner_id" = "player_id"))
|> rename(winner_idx = player_idx)
|> relocate(winner_idx, .after = winner_id)
|> left_join(player_data, by = c("loser_id" = "player_id"))
|> rename(loser_idx = player_idx)
|> relocate(loser_idx, .after = loser_id)
|> drop_na(winner_idx, loser_idx)
|> select(-matches("player_name"))
)data.frame [160,399 x 10]
| [ omitted 160,384 entries ] |
3 Model
3.1 Stan code
Updated Stan code with within-chain parallelization
functions {
array[] int sequence(int start, int end) {
array[end - start + 1] int seq;
for (n in 1 : num_elements(seq)) {
seq[n] = n + start - 1;
}
return seq;
}
// Compute partial sums of the log-likelihood
real partial_log_lik_lpmf(array[] int seq, int start, int end,
data array[] int labels,
data array[] int winner_ids,
data array[] int loser_ids,
vector player_skills) {
real ptarget = 0;
int N = end - start + 1;
vector[N] mu = rep_vector(0.0, N);
for (n in 1 : N) {
int nn = n + start - 1;
mu[n] += player_skills[winner_ids[nn]] - player_skills[loser_ids[nn]];
}
ptarget += bernoulli_logit_lpmf(labels[start : end] | mu);
return ptarget;
}
}
data {
int n_players;
int n_matches;
array[n_matches] int<lower=1, upper=n_players> winner_ids; // Winner of game n
array[n_matches] int<lower=1, upper=n_players> loser_ids; // Loser of game n
array[n_matches] int<lower=0, upper=1> labels; // Always 1 in this model
int grainsize;
}
transformed data {
array[n_matches] int seq = sequence(1, n_matches);
}
parameters {
real<lower=0> player_sd; // Scale of ability variation (hierarchical prior)
vector[n_players] player_skills; // Ability of player k
}
model {
player_sd ~ std_normal();
player_skills ~ normal(0, player_sd);
target += reduce_sum(
partial_log_lik_lpmf, seq, grainsize,
labels, winner_ids, loser_ids, player_skills
);
}3.2 Stan data
List of 6
$ n_matches : int 160399
$ n_players : int 4830
$ winner_ids: int [1:160399] 3655 3440 253 103 3600 3128 99 24 3349 3436 ...
$ loser_ids : int [1:160399] 3129 2909 3656 3657 3658 3659 3660 3325 3661 3662 ...
$ labels : num [1:160399] 1 1 1 1 1 1 1 1 1 1 ...
$ grainsize : num 2673
3.3 Model fit
tennis_mod_fit <- tennis_mod_exe$sample(
data = tennis_data_stan, seed = 256,
iter_warmup = 1000, iter_sampling = 1000, refresh = 0,
chains = 4, parallel_chains = 4, threads_per_chain = 7
)Sampling takes ~2.67 minutes on my CPU (Ryzen 5950X, 16 Cores/32 Threads), on WSL2 (Ubuntu 20.04)
data.table [4 x 2]
4 Model diagnostics
bayesplot::mcmc_neff_hist(bayesplot::neff_ratio(tennis_mod_fit))bayesplot::mcmc_rhat_hist(bayesplot::rhat(tennis_mod_fit))Plotting random subsets of the traces:
5 Posterior Predictions
5.1 Posterior data
Getting our Posterior Predictions into long format and joining the result with player_data:
(player_skills <- as.data.table(tennis_mod_fit$draws(variables = "player_skills") |>
bind(x, subset_draws(x, "player_skills", regex = T, draw = sample.int(ndraws(x), size = 500))))
[, .(player_skills = list(value)), by = variable
][, `:=`(player_idx = as.integer(str_extract(variable, "\\d{1,4}")), variable = NULL)
][, `:=`(skill_mean = sapply(player_skills, mean), skill_sd = sapply(player_skills, sd))
][as.data.table(player_data), on = "player_idx", nomatch = NULL
][order(-skill_mean), .(player_name, player_id, player_idx, skill_mean, skill_sd, player_skills)]
)data.table [4,830 x 6]
| [ omitted 4,815 entries ] |
5.2 Posterior plots
Plot code
ridgeline_plot <- function(dat, var) {
dat[, .(player_skills = unlist(player_skills)), by = setdiff(names(dat), 'player_skills')
][, player_name := factor(player_name, levels = unique(player_name))] -> dat
ggplot(dat, aes_string(y = var)) +
ggridges::geom_density_ridges(
aes_string(x = "player_skills", fill = var),
alpha = 0.5, scale = 2, color = "grey30"
) +
labs(x = "Player Skills", y = "") +
scale_y_discrete(
position = "right",
limits = \(x) rev(x),
labels = \(x) str_replace_all(x, "\\s", "\n")
) +
theme(legend.position = "none", axis.line.y = element_blank())
}Plotting the player_skills posteriors of the top 10 players:
Plotting the player_skills posteriors of the bottom 10 players: