library(randomForest)
library(fastshap)
library(dplyr)
library(caret)
library(pROC)
library(PRROC)
library(prg)
library(tidyr)
library(patchwork)
library(cowplot)
library(ggplot2)
knitr::opts_chunk$set(
message = FALSE,
warning = FALSE,
fig.align = "center",
out.width = "150%"
)Two AI-READI cohorts. Study 1 is the internal development set; Study 2 is the independent external validation set, applied without model retraining.
df_test <- read.csv("~/Downloads/read-ai/df_test_f_v2.csv")
df_external <- read.csv("~/Downloads/read-ai/df_external_f_v2.csv")
cat(sprintf(
"Study 1 (internal): n = %d IR = %d Non-IR = %d prevalence = %.3f\n",
nrow(df_test),
sum(df_test$IR_label == "IR"),
sum(df_test$IR_label == "Non-IR"),
mean(df_test$IR_label == "IR")
))## Study 1 (internal): n = 97 IR = 32 Non-IR = 65 prevalence = 0.330
cat(sprintf(
"Study 2 (external): n = %d IR = %d Non-IR = %d prevalence = %.3f\n",
nrow(df_external),
sum(df_external$IR_label == "IR"),
sum(df_external$IR_label == "Non-IR"),
mean(df_external$IR_label == "IR")
))## Study 2 (external): n = 61 IR = 19 Non-IR = 42 prevalence = 0.311
Three modality-specific feature sets. mtry is fixed a
priori as floor(sqrt(p)) — the Breiman default for RF
classification — throughout all models and all stages of the
pipeline.
feature_sets <- list(
CGM_model = c(
"bmi", "mean_fasting_glucose", "sd_fasting_glucose",
"excursions_per_day", "rec_time_50pct_median",
"overall_max_glucose", "age", "whr"
),
Smartwatch_model = c(
"stress_overall_daily_mean_stress", "sd_hr_day_night_diff",
"age", "bmi", "whr"
),
Baseline_model = c("bmi", "age", "whr")
)
model_names <- c("Model CGM", "Model smartwatch", "Model baseline")
for (nm in names(feature_sets)) {
p <- length(feature_sets[[nm]])
cat(sprintf(" %-20s p = %d mtry = %d\n", nm, p, floor(sqrt(p))))
}## CGM_model p = 8 mtry = 2
## Smartwatch_model p = 5 mtry = 2
## Baseline_model p = 3 mtry = 1
standardize <- function(df, model_name) {
df <- as.data.frame(df)
df$Model <- model_name
df
}
# AUPRG: Area Under the Precision-Recall Gain curve.
# Flach & Kull, NeurIPS 2015. Random classifier -> 0; perfect -> 1.
# Requires: install.packages("prg")
compute_auprg <- function(labels, scores) {
tryCatch({
prg::calc_auprg(prg::create_prg_curve(labels, scores))
}, error = function(e) NA_real_)
}Single function used by both the bootstrap loop and the multi-seed
external evaluation. Pipeline:
medianImpute -> upSample -> randomForest. No
cross-validation is performed here because hyperparameters are fixed a
priori (mtry = floor(sqrt(p)), ntree = 300,
nodesize = 1), leaving CV with no selection role to
fulfil.
fit_rf_seed <- function(features, df_train, seed) {
set.seed(seed)
p <- length(features)
mtry <- floor(sqrt(p))
X_tr <- df_train %>%
dplyr::select(all_of(features)) %>%
mutate(across(where(is.numeric), as.double))
y_tr <- factor(make.names(trimws(df_train$IR_label)))
imp <- caret::preProcess(X_tr, method = "medianImpute")
X_tr_imp <- predict(imp, X_tr)
up_data <- caret::upSample(x = X_tr_imp, y = y_tr, yname = ".outcome")
rf <- randomForest::randomForest(
x = up_data[, names(X_tr_imp), drop = FALSE],
y = up_data$.outcome,
ntree = 300,
mtry = mtry,
nodesize = 1
)
list(rf = rf, imp = imp, features = features)
}
predict_rf <- function(fit_obj, df_new) {
X_new <- df_new %>%
dplyr::select(all_of(fit_obj$features)) %>%
mutate(across(where(is.numeric), as.double))
X_imp <- predict(fit_obj$imp, X_new)
predict(fit_obj$rf, X_imp, type = "prob")[, "IR"]
}100 stratified 70/30 splits on Study 1 (n = 97). Each iteration trains three models on ~68 patients, evaluates on ~29, and records AUC-ROC, AUPRC, AUPRG, and classification metrics. The 2.5–97.5 percentiles across 100 iterations form the internal 95% CIs. SHAP values on each holdout are stored for internal feature importance figures.
all_results <- list()
all_shap_val <- list()
set.seed(123)
n_iter <- 100
for (i in seq_len(n_iter)) {
y <- factor(trimws(df_test$IR_label))
levels(y) <- make.names(levels(y))
train_idx <- createDataPartition(y, p = 0.7, list = FALSE)
test_set <- df_test[-train_idx, ]
train_set <- df_test[ train_idx, ]
y_test <- factor(make.names(trimws(test_set$IR_label)))
iter_seed <- i * 7L
CGM_fit <- fit_rf_seed(feature_sets$CGM_model, train_set, iter_seed)
SW_fit <- fit_rf_seed(feature_sets$Smartwatch_model, train_set, iter_seed)
Base_fit <- fit_rf_seed(feature_sets$Baseline_model, train_set, iter_seed)
fits <- list(CGM_fit, SW_fit, Base_fit)
iter_rows <- lapply(seq_along(fits), function(m) {
fit <- fits[[m]]
prob <- predict_rf(fit, test_set)
pred <- factor(ifelse(prob >= 0.5, "IR", "Non.IR"), levels = c("Non.IR","IR"))
if (length(unique(as.integer(y_test))) < 2) return(NULL)
cm <- confusionMatrix(pred, y_test)
sens <- as.numeric(cm$byClass["Sensitivity"])
spec <- as.numeric(cm$byClass["Specificity"])
prec <- as.numeric(cm$byClass["Pos Pred Value"])
standardize(data.frame(
Accuracy = as.numeric(cm$overall["Accuracy"]),
Sensitivity = sens,
Specificity = spec,
Balanced_Accuracy = (sens + spec) / 2,
Precision = prec,
F1 = ifelse(is.na(prec)|prec+sens==0, NA_real_, 2*prec*sens/(prec+sens)),
AUC = as.numeric(auc(roc(y_test, prob,
levels = c("Non.IR","IR"), quiet = TRUE))),
PRAUC = pr.curve(scores.class0 = prob,
weights.class0 = ifelse(y_test == "IR", 1, 0),
curve = FALSE)$auc.integral,
AUPRG = compute_auprg(as.integer(y_test == "IR"), prob),
Dataset = "Internal",
iteration = i
), model_names[m])
})
all_results[[i]] <- dplyr::bind_rows(iter_rows)
for (m in seq_along(fits)) {
fit <- fits[[m]]
X_test_raw <- test_set %>%
dplyr::select(all_of(fit$features)) %>%
mutate(across(where(is.numeric), as.double))
X_test_imp <- predict(fit$imp, X_test_raw)
shap_s <- tryCatch(
fastshap::explain(
object = fit$rf,
X = as.data.frame(X_test_imp),
pred_wrapper = function(object, newdata)
predict(object, newdata, type = "prob")[, "IR"],
nsim = 50,
adjust = TRUE
),
error = function(e) NULL
)
if (!is.null(shap_s))
all_shap_val[[length(all_shap_val) + 1]] <- list(
iteration = i,
model_name = model_names[m],
shap = shap_s,
X_df = as.data.frame(X_test_imp)
)
}
}
bootstrap_results <- dplyr::bind_rows(all_results)final_summary <- bootstrap_results %>%
group_by(Model, Dataset) %>%
summarise(
across(
c(Accuracy, Sensitivity, Specificity, Balanced_Accuracy,
Precision, F1, AUC, PRAUC, AUPRG),
list(
mean = ~mean(.x, na.rm = TRUE),
lower = ~quantile(.x, 0.025, na.rm = TRUE),
upper = ~quantile(.x, 0.975, na.rm = TRUE)
),
.names = "{.col}_{.fn}"
),
.groups = "drop"
)
final_summary %>%
dplyr::select(
Model,
AUC_mean, AUC_lower, AUC_upper,
PRAUC_mean, PRAUC_lower, PRAUC_upper,
AUPRG_mean, AUPRG_lower, AUPRG_upper
) %>%
dplyr::rename(
AUC_ROC = AUC_mean, AUC_lo = AUC_lower, AUC_hi = AUC_upper,
AUPRC = PRAUC_mean, AUPRC_lo = PRAUC_lower, AUPRC_hi = PRAUC_upper,
AUPRG = AUPRG_mean, AUPRG_lo = AUPRG_lower, AUPRG_hi = AUPRG_upper
) %>%
mutate(across(where(is.numeric), ~round(.x, 3)))write.csv(bootstrap_results, "bootstrap_results.csv", row.names = FALSE)
write.csv(final_summary, "bootstrap_summary.csv", row.names = FALSE)evaluate_multiseed <- function(features, df_train, df_ext,
N_SEEDS = 20, n_boot = 2000) {
y_ext <- factor(make.names(trimws(df_ext$IR_label)))
all_probs <- matrix(NA_real_, nrow = nrow(df_ext), ncol = N_SEEDS)
all_aucs <- all_auprcs <- all_auprgs <- numeric(N_SEEDS)
for (s in seq_len(N_SEEDS)) {
fit <- fit_rf_seed(features, df_train, seed = s)
prob <- predict_rf(fit, df_ext)
all_probs[, s] <- prob
roc_s <- roc(y_ext, prob, levels = c("Non.IR","IR"), quiet = TRUE)
all_aucs[s] <- as.numeric(auc(roc_s))
all_auprcs[s] <- pr.curve(scores.class0 = prob,
weights.class0 = ifelse(y_ext == "IR", 1, 0),
curve = FALSE)$auc.integral
all_auprgs[s] <- compute_auprg(as.integer(y_ext == "IR"), prob)
}
ens <- rowMeans(all_probs)
roc_ens <- roc(y_ext, ens, levels = c("Non.IR","IR"), quiet = TRUE)
ens_auc <- as.numeric(auc(roc_ens))
dl <- tryCatch(
as.numeric(ci.auc(roc_ens, method = "delong", conf.level = 0.95)),
error = function(e) c(NA, NA, NA)
)
ens_auprc <- pr.curve(scores.class0 = ens,
weights.class0 = ifelse(y_ext == "IR", 1, 0),
curve = FALSE)$auc.integral
ens_auprg <- compute_auprg(as.integer(y_ext == "IR"), ens)
set.seed(42)
n_ext <- length(y_ext)
prauc_boot <- auprg_boot <- numeric(n_boot)
for (i in seq_len(n_boot)) {
idx <- sample(n_ext, n_ext, replace = TRUE)
yb <- y_ext[idx]; pb <- ens[idx]
if (length(unique(yb)) < 2) {
prauc_boot[i] <- auprg_boot[i] <- NA; next
}
prauc_boot[i] <- tryCatch(
pr.curve(scores.class0 = pb, weights.class0 = ifelse(yb=="IR",1,0),
curve = FALSE)$auc.integral,
error = function(e) NA_real_
)
auprg_boot[i] <- tryCatch(
compute_auprg(as.integer(yb == "IR"), pb),
error = function(e) NA_real_
)
}
ens_pred <- factor(ifelse(ens >= 0.5, "IR", "Non.IR"), levels = c("Non.IR","IR"))
list(
auc_mean = round(mean(all_aucs), 3),
auc_sd = round(sd(all_aucs), 3),
auprc_mean = round(mean(all_auprcs), 3),
auprc_sd = round(sd(all_auprcs), 3),
auprg_mean = round(mean(all_auprgs), 3),
auprg_sd = round(sd(all_auprgs), 3),
all_aucs = all_aucs,
all_auprcs = all_auprcs,
all_auprgs = all_auprgs,
ens_auc = round(ens_auc, 3),
ens_auc_lo = round(dl[1], 3),
ens_auc_hi = round(dl[3], 3),
ens_auprc = round(ens_auprc, 3),
ens_auprc_lo = round(quantile(prauc_boot, 0.025, na.rm = TRUE), 3),
ens_auprc_hi = round(quantile(prauc_boot, 0.975, na.rm = TRUE), 3),
ens_auprg = round(ens_auprg, 3),
ens_auprg_lo = round(quantile(auprg_boot, 0.025, na.rm = TRUE), 3),
ens_auprg_hi = round(quantile(auprg_boot, 0.975, na.rm = TRUE), 3),
ensemble_prob = ens,
y_ext = y_ext,
cm = confusionMatrix(ens_pred, y_ext),
N_SEEDS = N_SEEDS
)
}Each model is fitted 20 times (seeds 1–20) on all 97 Study 1 patients and applied directly to the 61 Study 2 patients. The ensemble prediction is the mean probability across seeds. No retraining or hyperparameter adjustment is performed on the external cohort.
ms_CGM <- evaluate_multiseed(feature_sets$CGM_model, df_test, df_external)
ms_SW <- evaluate_multiseed(feature_sets$Smartwatch_model, df_test, df_external)
ms_Base <- evaluate_multiseed(feature_sets$Baseline_model, df_test, df_external)
ms_list <- list(
"CGM model" = ms_CGM,
"Smartwatch model" = ms_SW,
"Baseline model" = ms_Base
)
ms_summary <- dplyr::bind_rows(lapply(names(ms_list), function(nm) {
ms <- ms_list[[nm]]
data.frame(
Model = nm,
AUC_mean = ms$auc_mean, AUC_sd = ms$auc_sd,
Ens_AUC = ms$ens_auc, Ens_AUC_lo = ms$ens_auc_lo, Ens_AUC_hi = ms$ens_auc_hi,
AUPRC_mean = ms$auprc_mean, AUPRC_sd = ms$auprc_sd,
Ens_AUPRC = ms$ens_auprc, Ens_AUPRC_lo = ms$ens_auprc_lo, Ens_AUPRC_hi = ms$ens_auprc_hi,
AUPRG_mean = ms$auprg_mean, AUPRG_sd = ms$auprg_sd,
Ens_AUPRG = ms$ens_auprg, Ens_AUPRG_lo = ms$ens_auprg_lo, Ens_AUPRG_hi = ms$ens_auprg_hi,
stringsAsFactors = FALSE
)
}))
ms_summary %>% mutate(across(where(is.numeric), ~round(.x, 3)))ext_probs <- list(
"CGM model" = ms_CGM$ensemble_prob,
"Smartwatch model" = ms_SW$ensemble_prob,
"Baseline model" = ms_Base$ensemble_prob
)
y_ext_fac <- ms_CGM$y_ext
y_bin <- as.integer(y_ext_fac == "IR")
ir_prev <- mean(y_bin)CIs from 2,000-iteration bootstrap on ensemble probabilities.
compute_ext_full <- function(y_true, prob, model_name, n_boot = 2000, seed = 42) {
n <- length(y_true)
pred <- factor(ifelse(prob >= 0.5, "IR", "Non.IR"), levels = c("Non.IR","IR"))
cm <- confusionMatrix(pred, y_true)
sens <- as.numeric(cm$byClass["Sensitivity"])
spec <- as.numeric(cm$byClass["Specificity"])
prec <- as.numeric(cm$byClass["Pos Pred Value"])
f1 <- ifelse(is.na(prec)|prec+sens==0, NA_real_, 2*prec*sens/(prec+sens))
acc <- as.numeric(cm$overall["Accuracy"])
auprc <- pr.curve(scores.class0 = prob,
weights.class0 = ifelse(y_true == "IR", 1, 0),
curve = FALSE)$auc.integral
auprg <- compute_auprg(as.integer(y_true == "IR"), prob)
set.seed(seed)
mat <- matrix(NA_real_, nrow = n_boot, ncol = 8,
dimnames = list(NULL,
c("sens","spec","prec","f1","acc","balacc","prauc","auprg")))
for (i in seq_len(n_boot)) {
idx <- sample(n, n, replace = TRUE)
yb <- y_true[idx]; pb_cls <- pred[idx]; pb_pr <- prob[idx]
if (length(unique(yb)) < 2) next
cm_b <- confusionMatrix(pb_cls, yb)
s_b <- as.numeric(cm_b$byClass["Sensitivity"])
sp_b <- as.numeric(cm_b$byClass["Specificity"])
pr_b <- as.numeric(cm_b$byClass["Pos Pred Value"])
f1_b <- ifelse(is.na(pr_b)|pr_b+s_b==0, NA_real_, 2*pr_b*s_b/(pr_b+s_b))
prauc_b <- tryCatch(
pr.curve(scores.class0 = pb_pr, weights.class0 = ifelse(yb=="IR",1,0),
curve = FALSE)$auc.integral, error = function(e) NA_real_)
auprg_b <- tryCatch(
compute_auprg(as.integer(yb == "IR"), pb_pr), error = function(e) NA_real_)
mat[i, ] <- c(s_b, sp_b, pr_b, f1_b,
as.numeric(cm_b$overall["Accuracy"]),
(s_b + sp_b) / 2, prauc_b, auprg_b)
}
ci <- apply(mat, 2, quantile, probs = c(0.025, 0.975), na.rm = TRUE)
data.frame(
Model = model_name,
Sensitivity = round(sens, 3), Sensitivity_lo = round(ci["2.5%", "sens"], 3), Sensitivity_hi = round(ci["97.5%", "sens"], 3),
Specificity = round(spec, 3), Specificity_lo = round(ci["2.5%", "spec"], 3), Specificity_hi = round(ci["97.5%", "spec"], 3),
Balanced_Accuracy = round((sens+spec)/2, 3),
BalAcc_lo = round(ci["2.5%", "balacc"], 3), BalAcc_hi = round(ci["97.5%", "balacc"], 3),
Precision = round(prec, 3), Precision_lo = round(ci["2.5%", "prec"], 3), Precision_hi = round(ci["97.5%", "prec"], 3),
F1 = round(f1, 3), F1_lo = round(ci["2.5%", "f1"], 3), F1_hi = round(ci["97.5%", "f1"], 3),
Accuracy = round(acc, 3), Accuracy_lo = round(ci["2.5%", "acc"], 3), Accuracy_hi = round(ci["97.5%", "acc"], 3),
AUPRC = round(auprc,3), AUPRC_lo = round(ci["2.5%", "prauc"], 3), AUPRC_hi = round(ci["97.5%", "prauc"], 3),
AUPRG = round(auprg,3), AUPRG_lo = round(ci["2.5%", "auprg"], 3), AUPRG_hi = round(ci["97.5%", "auprg"], 3),
stringsAsFactors = FALSE
)
}
ext_full <- dplyr::bind_rows(lapply(names(ext_probs), function(nm)
compute_ext_full(y_ext_fac, ext_probs[[nm]], nm)
)) %>%
left_join(ms_summary, by = "Model")
ext_full %>%
dplyr::select(Model, AUC_mean, AUC_sd, Ens_AUC_lo, Ens_AUC_hi,
AUPRC_mean, AUPRC_sd, AUPRG_mean, AUPRG_sd,
Sensitivity, Specificity, Balanced_Accuracy, F1) %>%
mutate(across(where(is.numeric), ~round(.x, 3)))write.csv(ext_full, "external_results.csv", row.names = FALSE)
write.csv(ms_summary, "external_summary.csv", row.names = FALSE)Distribution of AUC-ROC, AUPRC, and AUPRG across 20 seeds for each model on the external cohort. Crossbars show mean ± 1 SD; dashed lines show random-classifier baselines.
seed_df <- dplyr::bind_rows(lapply(names(ms_list), function(nm) {
data.frame(
Model = nm,
AUC = ms_list[[nm]]$all_aucs,
AUPRC = ms_list[[nm]]$all_auprcs,
AUPRG = ms_list[[nm]]$all_auprgs
)
})) %>%
mutate(Model = factor(Model, levels = names(ms_list)))
model_colors <- c(
"CGM model" = "#1f77b4",
"Smartwatch model" = "#2ca02c",
"Baseline model" = "#d62728"
)
summ_seed <- seed_df %>%
group_by(Model) %>%
summarise(
m_auc = mean(AUC), sd_auc = sd(AUC), mx_auc = max(AUC),
m_auprc= mean(AUPRC), sd_auprc= sd(AUPRC), mx_auprc= max(AUPRC),
m_auprg= mean(AUPRG), sd_auprg= sd(AUPRG), mx_auprg= max(AUPRG),
.groups = "drop"
)
make_seed_panel <- function(df, summ, y_var, y_lab, y_ref = NULL, title = "") {
mc <- paste0("m_", tolower(y_var))
sc <- paste0("sd_", tolower(y_var))
xc <- paste0("mx_", tolower(y_var))
p <- ggplot(df, aes(x = Model, y = .data[[y_var]], colour = Model)) +
geom_jitter(width = 0.12, alpha = 0.65, size = 2, show.legend = FALSE) +
geom_crossbar(data = summ,
aes(x = Model, y = .data[[mc]],
ymin = .data[[mc]] - .data[[sc]],
ymax = .data[[mc]] + .data[[sc]]),
width = 0.35, linewidth = 0.8, fill = NA, show.legend = FALSE) +
geom_text(data = summ,
aes(x = Model, y = .data[[xc]] + 0.022,
label = sprintf("%.3f\n\u00b1%.3f", .data[[mc]], .data[[sc]])),
size = 2.9, vjust = 0, show.legend = FALSE) +
scale_colour_manual(values = model_colors) +
scale_x_discrete(labels = c("CGM model"="CGM",
"Smartwatch model"="Smartwatch",
"Baseline model"="Baseline")) +
coord_cartesian(ylim = c(0.30, 1.10)) +
labs(title = title, x = NULL, y = y_lab) +
theme_bw(base_size = 11) +
theme(plot.title = element_text(face = "bold", size = 11),
panel.grid.minor = element_blank(),
panel.grid.major.x= element_blank())
if (!is.null(y_ref))
p <- p + geom_hline(yintercept = y_ref, linetype = "dashed",
colour = "grey55", linewidth = 0.7)
p
}
p_auc <- make_seed_panel(seed_df, summ_seed, "AUC", "AUC-ROC", 0.5, "A AUC-ROC")
p_auprc <- make_seed_panel(seed_df, summ_seed, "AUPRC", "AUPRC", ir_prev,"B AUPRC") +
annotate("text", x = 0.65, y = ir_prev + 0.055,
label = sprintf("Chance (%.2f)", ir_prev), size = 3, colour = "grey55")
p_auprg <- make_seed_panel(seed_df, summ_seed, "AUPRG", "AUPRG", 0, "C AUPRG") +
annotate("text", x = 0.65, y = 0.055, label = "Chance = 0",
size = 3, colour = "grey55")
p_auc | p_auprc | p_auprgNet benefit is computed as
NB(t) = TP/n - FP/n x t/(1-t). Evaluated on ensemble
probabilities from the 20-seed fit.
n_ext <- length(y_bin)
thresholds <- seq(0.05, 0.50, by = 0.01)
nb_fn <- function(y, prob, t) {
pred <- as.integer(prob >= t)
tp <- sum(pred == 1 & y == 1)
fp <- sum(pred == 1 & y == 0)
tp / n_ext - fp / n_ext * t / (1 - t)
}
dca_df <- do.call(rbind, lapply(thresholds, function(t) {
data.frame(
threshold = t,
model = c("CGM","SW","Baseline","All","None"),
net_benefit = c(
nb_fn(y_bin, ext_probs[["CGM model"]], t),
nb_fn(y_bin, ext_probs[["Smartwatch model"]], t),
nb_fn(y_bin, ext_probs[["Baseline model"]], t),
ir_prev - (1 - ir_prev) * t / (1 - t),
0
),
stringsAsFactors = FALSE
)
}))
p_dca <- ggplot(dca_df, aes(x = threshold, y = net_benefit,
colour = model, linetype = model)) +
geom_line(linewidth = 0.9) +
scale_colour_manual(
values = c("CGM"="#1f77b4","SW"="#2ca02c","Baseline"="#d62728",
"All"="#888888","None"="#bbbbbb"),
labels = c("CGM"="CGM model","SW"="Smartwatch model",
"Baseline"="Baseline","All"="Treat all","None"="Treat none")
) +
scale_linetype_manual(
values = c("CGM"="solid","SW"="dashed","Baseline"="dotdash",
"All"="dotted","None"="dotted"),
labels = c("CGM"="CGM model","SW"="Smartwatch model",
"Baseline"="Baseline","All"="Treat all","None"="Treat none")
) +
geom_hline(yintercept = 0, colour = "black", linewidth = 0.5) +
geom_vline(xintercept = ir_prev, colour = "black",
linetype = "dashed", linewidth = 0.6, alpha = 0.35) +
annotate("text", x = ir_prev + 0.01, y = 0.38,
label = sprintf("Prevalence\n(%.2f)", ir_prev),
size = 3, colour = "#555555", hjust = 0) +
coord_cartesian(xlim = c(0.05,0.50), ylim = c(-0.06,0.42)) +
labs(x = "Threshold probability", y = "Net benefit",
colour = NULL, linetype = NULL) +
theme_bw(base_size = 12) +
theme(legend.position = "right", panel.grid.minor = element_blank())
print(p_dca)##
## Net benefit at prevalence threshold (t = 0.31):
for (m in c("CGM","SW","Baseline","All","None")) {
sub <- dca_df[dca_df$model == m, ]
nb <- sub$net_benefit[which.min(abs(sub$threshold - ir_prev))]
cat(sprintf(" %-12s NB = %.4f\n", m, nb))
}## CGM NB = 0.1314
## SW NB = 0.1224
## Baseline NB = 0.1003
## All NB = 0.0021
## None NB = 0.0000
feature_name_map <- c(
"bmi" = "BMI",
"mean_fasting_glucose" = "Mean Fasting Glucose",
"sd_fasting_glucose" = "SD Fasting Glucose",
"excursions_per_day" = "Glucose Excursions / Day",
"rec_time_50pct_median" = "Recovery Time (50th pct)",
"overall_max_glucose" = "Overall Max Glucose",
"age" = "Age",
"whr" = "WHR",
"stress_overall_daily_mean_stress" = "Mean Stress (HRV-based)",
"sd_hr_day_night_diff" = "HR Day-Night Variability"
)
rename_features <- function(x)
ifelse(x %in% names(feature_name_map), feature_name_map[x], x)aggregate_shap <- function(shap_list, model_filter) {
entries <- Filter(function(x) x$model_name == model_filter, shap_list)
mat <- do.call(rbind, lapply(entries,
function(e) colMeans(abs(as.data.frame(e$shap)))))
data.frame(feature=colnames(mat), mean_shap=colMeans(mat),
sd_shap=apply(mat,2,sd), row.names=NULL) %>%
arrange(desc(mean_shap)) %>%
mutate(pct_share = mean_shap / sum(mean_shap) * 100,
feature = rename_features(feature))
}
collect_shap_long <- function(shap_list, model_filter) {
entries <- Filter(function(x) x$model_name == model_filter, shap_list)
bind_rows(lapply(entries, function(e) {
sd <- as.data.frame(e$shap) %>%
mutate(obs=row_number(), iteration=e$iteration) %>%
pivot_longer(-c(obs,iteration), names_to="feature", values_to="shap_value") %>%
mutate(feature=rename_features(feature))
fd <- as.data.frame(e$X_df) %>%
mutate(obs=row_number()) %>%
pivot_longer(-obs, names_to="feature", values_to="feature_value") %>%
mutate(feature=rename_features(feature))
left_join(sd, fd, by=c("obs","feature"))
}))
}
rank_stability_df <- function(shap_list, model_filter) {
entries <- Filter(function(x) x$model_name == model_filter, shap_list)
mat <- do.call(rbind, lapply(entries,
function(e) colMeans(abs(as.data.frame(e$shap)))))
rank_mat <- t(apply(mat, 1, rank))
data.frame(feature=colnames(rank_mat), mean_rank=colMeans(rank_mat),
iqr_rank=apply(rank_mat,2,IQR), row.names=NULL) %>%
arrange(mean_rank) %>%
mutate(feature=rename_features(feature))
}
panel_beeswarm <- function(shap_list, model_filter,
shared_x_range=NULL, is_external=FALSE) {
agg <- aggregate_shap(shap_list, model_filter)
long <- collect_shap_long(shap_list, model_filter)
lmap <- agg %>%
mutate(label=paste0(feature," (",round(pct_share,1),"%)")) %>%
dplyr::select(feature,label)
long <- long %>%
left_join(lmap, by="feature") %>%
group_by(feature) %>%
mutate(feat_scaled=
(feature_value - min(feature_value,na.rm=TRUE)) /
(max(feature_value,na.rm=TRUE) - min(feature_value,na.rm=TRUE) + 1e-9)
) %>%
ungroup() %>%
mutate(label=factor(label,
levels=lmap$label[match(rev(agg$feature),lmap$feature)]))
p <- ggplot(long, aes(x=shap_value, y=label, colour=feat_scaled)) +
geom_jitter(height=0.18, alpha=ifelse(is_external,0.9,0.5), size=1.5, stroke=0) +
scale_colour_gradient(low="#3B82C4", high="#E84545",
name="Feature\nvalue", breaks=c(0,1), labels=c("Low","High")) +
geom_vline(xintercept=0, linetype="dashed", colour="grey50", linewidth=0.4) +
labs(title=model_filter, x="SHAP value", y=NULL) +
theme_bw(base_size=11) +
theme(plot.title=element_text(face="bold",hjust=0.5,size=11),
legend.position="none", panel.grid.minor=element_blank(),
axis.text.y=element_text(size=9))
if (!is.null(shared_x_range)) p <- p + coord_cartesian(xlim=shared_x_range)
p
}
panel_bar <- function(shap_list, model_filter, shared_x_max=NULL) {
agg <- aggregate_shap(shap_list, model_filter) %>%
mutate(label=paste0(feature," (",round(pct_share,1),"%)"),
label=factor(label,
levels=rev(paste0(feature," (",round(pct_share,1),"%)"))),
ymin_eb=pmax(mean_shap-sd_shap,0), ymax_eb=mean_shap+sd_shap)
p <- ggplot(agg, aes(x=mean_shap, y=label)) +
geom_col(fill="#4C72B0", alpha=0.85) +
geom_errorbarh(aes(xmin=ymin_eb,xmax=ymax_eb), height=0.3, colour="grey30") +
labs(title=model_filter, x="Mean |SHAP|", y=NULL) +
theme_bw(base_size=11) +
theme(plot.title=element_text(face="bold",hjust=0.5,size=11),
panel.grid.minor=element_blank(), axis.text.y=element_text(size=9))
if (!is.null(shared_x_max)) p <- p + coord_cartesian(xlim=c(0,shared_x_max))
p
}
panel_rank <- function(shap_list, model_filter, shared_x_range=NULL) {
rs <- rank_stability_df(shap_list, model_filter) %>%
mutate(label=factor(feature, levels=rev(feature)))
p <- ggplot(rs, aes(x=mean_rank, y=label)) +
geom_point(size=3, colour="#2ca02c") +
geom_errorbarh(aes(xmin=mean_rank-iqr_rank/2, xmax=mean_rank+iqr_rank/2),
height=0.3, colour="grey40") +
scale_x_reverse() +
labs(title=model_filter, x="Mean rank (lower = more important)", y=NULL) +
theme_bw(base_size=11) +
theme(plot.title=element_text(face="bold",hjust=0.5,size=11),
panel.grid.minor=element_blank(), axis.text.y=element_text(size=9))
if (!is.null(shared_x_range)) p <- p + coord_cartesian(xlim=shared_x_range)
p
}
make_shap_figure <- function(shap_list, dataset_label, model_names_vec,
plot_type = c("beeswarm","bar","rank"),
panel_labels = LETTERS[seq_along(model_names_vec)]) {
plot_type <- match.arg(plot_type)
if (plot_type=="beeswarm") {
all_vals <- unlist(lapply(model_names_vec, function(mn)
unlist(lapply(Filter(function(x) x$model_name==mn, shap_list),
function(e) as.numeric(as.matrix(e$shap))))))
srng <- c(-max(abs(all_vals),na.rm=TRUE)*1.05,
max(abs(all_vals),na.rm=TRUE)*1.05)
}
if (plot_type=="bar") {
am <- unlist(lapply(model_names_vec, function(mn) {
a <- aggregate_shap(shap_list,mn); a$mean_shap+a$sd_shap }))
sxmax <- max(am,na.rm=TRUE)*1.05
}
if (plot_type=="rank") {
nf <- max(unlist(lapply(model_names_vec, function(mn)
nrow(rank_stability_df(shap_list,mn)))))
srng <- c(nf+0.5,0.5)
}
panels <- lapply(seq_along(model_names_vec), function(j) {
mn <- model_names_vec[rev(seq_along(model_names_vec))[j]]
p <- switch(plot_type,
beeswarm = panel_beeswarm(shap_list,mn,shared_x_range=srng),
bar = panel_bar(shap_list,mn,shared_x_max=sxmax),
rank = panel_rank(shap_list,mn,shared_x_range=srng))
p + labs(tag=panel_labels[j]) +
theme(plot.tag=element_text(face="bold",size=13))
})
subtitle <- switch(plot_type,
beeswarm = "Each point = one observation x one iteration. Labels = mean |SHAP| share (%).",
bar = "Error bars = +/-1 SD across iterations. Labels = mean |SHAP| share (%).",
rank = "Error bars = IQR across iterations. Leftward = more important.")
if (plot_type=="beeswarm") {
dummy <- ggplot(data.frame(x=0,y=0,v=c(0,1)),aes(x,y,colour=v)) +
geom_point() +
scale_colour_gradient(low="#3B82C4",high="#E84545",
name="Feature\nvalue",breaks=c(0,1),labels=c("Low","High")) +
theme(legend.position="right",legend.title=element_text(size=9),
legend.text=element_text(size=8),legend.key.height=unit(1.2,"cm"))
leg <- cowplot::get_legend(dummy)
combined <- (Reduce(`|`,panels)|patchwork::wrap_elements(leg)) +
plot_layout(widths=c(rep(1,length(panels)),0.15))
} else {
combined <- Reduce(`|`,panels)
}
combined + plot_annotation(
title = paste0(switch(plot_type,
beeswarm="SHAP Beeswarm",
bar="Mean |SHAP| Feature Importance",
rank="Feature Rank Stability"),
" — ",dataset_label),
caption = subtitle,
theme = theme(
plot.title =element_text(face="bold",hjust=0.5,size=13),
plot.caption=element_text(size=10,colour="grey40",hjust=0.5)))
}Aggregated across all 100 bootstrap iterations.
for (pt in c("beeswarm","bar","rank")) {
fig <- make_shap_figure(all_shap_val, "Internal Validation", model_names, pt)
print(fig)
ggsave(sprintf("shap_%s_internal.pdf",pt), fig,
width=15, height=5.5, device=cairo_pdf)
}SHAP values are computed for each of the 20 seed models and averaged element-wise, ensuring consistency with the ensemble predictions used for all reported performance metrics.
compute_ensemble_shap <- function(features, df_train, df_ext,
N_SEEDS = 20, nsim = 50) {
p <- length(features)
mtry <- floor(sqrt(p))
X_ext_raw <- df_ext %>%
dplyr::select(all_of(features)) %>%
mutate(across(where(is.numeric), as.double))
shap_matrices <- vector("list", N_SEEDS)
X_ext_ref <- NULL
for (s in seq_len(N_SEEDS)) {
set.seed(s)
X_tr <- df_train %>%
dplyr::select(all_of(features)) %>%
mutate(across(where(is.numeric), as.double))
y_tr <- factor(make.names(trimws(df_train$IR_label)))
imp <- caret::preProcess(X_tr, method="medianImpute")
X_tr_imp <- predict(imp, X_tr)
X_ext_imp <- predict(imp, X_ext_raw)
up_data <- caret::upSample(x=X_tr_imp, y=y_tr, yname=".outcome")
rf <- randomForest::randomForest(
x=up_data[,names(X_tr_imp),drop=FALSE], y=up_data$.outcome,
ntree=300, mtry=mtry, nodesize=1)
shap_matrices[[s]] <- as.matrix(fastshap::explain(
object=rf, X=as.data.frame(X_ext_imp),
pred_wrapper=function(object,newdata) predict(object,newdata,type="prob")[,"IR"],
nsim=nsim, adjust=TRUE))
if (s==1) X_ext_ref <- as.data.frame(X_ext_imp)
}
list(shap=as.data.frame(Reduce("+",shap_matrices)/N_SEEDS), X_df=X_ext_ref)
}
es_CGM <- compute_ensemble_shap(feature_sets$CGM_model, df_test, df_external)
es_SW <- compute_ensemble_shap(feature_sets$Smartwatch_model, df_test, df_external)
es_Base <- compute_ensemble_shap(feature_sets$Baseline_model, df_test, df_external)
all_shap_ext <- list(
list(iteration=1, model_name="Model CGM", shap=es_CGM$shap, X_df=es_CGM$X_df),
list(iteration=1, model_name="Model smartwatch", shap=es_SW$shap, X_df=es_SW$X_df),
list(iteration=1, model_name="Model baseline", shap=es_Base$shap, X_df=es_Base$X_df)
)
make_shap_ext_figure <- function(shap_list, model_names_vec,
panel_labels=LETTERS[seq_along(model_names_vec)]) {
all_vals <- unlist(lapply(model_names_vec, function(mn)
unlist(lapply(Filter(function(x) x$model_name==mn, shap_list),
function(e) as.numeric(as.matrix(e$shap))))))
srng <- c(-max(abs(all_vals),na.rm=TRUE)*1.05,
max(abs(all_vals),na.rm=TRUE)*1.05)
panels <- lapply(seq_along(model_names_vec), function(j) {
panel_beeswarm(shap_list,model_names_vec[j],
shared_x_range=srng,is_external=TRUE) +
labs(tag=panel_labels[j]) +
theme(plot.tag=element_text(face="bold",size=13))
})
dummy <- ggplot(data.frame(x=0,y=0,v=c(0,1)),aes(x,y,colour=v)) +
geom_point() +
scale_colour_gradient(low="#3B82C4",high="#E84545",
name="Feature\nvalue",breaks=c(0,1),labels=c("Low","High")) +
theme(legend.position="right",legend.title=element_text(size=9),
legend.text=element_text(size=8),legend.key.height=unit(1.2,"cm"))
leg <- cowplot::get_legend(dummy)
(Reduce(`|`,panels)|patchwork::wrap_elements(leg)) +
plot_layout(widths=c(rep(1,length(panels)),0.15)) +
plot_annotation(
title = "SHAP Beeswarm — External Test Set",
caption = paste0("Each point = one external patient. ",
"SHAP values averaged across 20 independently fitted models. ",
"Labels = mean |SHAP| share (%)."),
theme = theme(
plot.title =element_text(face="bold",hjust=0.5,size=13),
plot.caption=element_text(size=10,colour="grey40",hjust=0.5)))
}
fig_ext <- make_shap_ext_figure(all_shap_ext, model_names)
print(fig_ext)(p_dca + labs(title="A Decision curve analysis")) |
(p_auc + labs(title="B AUC-ROC seed stability")) |
(p_auprc + labs(title="C AUPRC seed stability")) |
(p_auprg + labs(title="D AUPRG seed stability"))## R version 4.5.2 (2025-10-31)
## Platform: aarch64-apple-darwin20
## Running under: macOS Sequoia 15.5
##
## Matrix products: default
## BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.1
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## time zone: Europe/Zurich
## tzcode source: internal
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] cowplot_1.2.0 patchwork_1.3.2 tidyr_1.3.2
## [4] prg_0.5.1 PRROC_1.4 rlang_1.2.0
## [7] pROC_1.19.0.1 caret_7.0-1 lattice_0.22-7
## [10] ggplot2_4.0.1 dplyr_1.2.0 fastshap_0.1.1
## [13] randomForest_4.7-1.2
##
## loaded via a namespace (and not attached):
## [1] gtable_0.3.6 xfun_0.55 bslib_0.9.0
## [4] recipes_1.3.1 vctrs_0.7.2 tools_4.5.2
## [7] generics_0.1.4 stats4_4.5.2 parallel_4.5.2
## [10] proxy_0.4-29 tibble_3.3.0 ModelMetrics_1.2.2.2
## [13] pkgconfig_2.0.3 Matrix_1.7-4 data.table_1.18.0
## [16] RColorBrewer_1.1-3 S7_0.2.1 lifecycle_1.0.5
## [19] compiler_4.5.2 farver_2.1.2 stringr_1.6.0
## [22] codetools_0.2-20 htmltools_0.5.9 class_7.3-23
## [25] sass_0.4.10 yaml_2.3.12 prodlim_2026.03.11
## [28] pillar_1.11.1 jquerylib_0.1.4 MASS_7.3-65
## [31] cachem_1.1.0 gower_1.0.2 iterators_1.0.14
## [34] rpart_4.1.24 foreach_1.5.2 nlme_3.1-168
## [37] parallelly_1.46.1 lava_1.8.2 tidyselect_1.2.1
## [40] digest_0.6.39 stringi_1.8.7 future_1.69.0
## [43] reshape2_1.4.5 purrr_1.2.0 listenv_0.10.1
## [46] labeling_0.4.3 splines_4.5.2 fastmap_1.2.0
## [49] grid_4.5.2 cli_3.6.6 magrittr_2.0.4
## [52] survival_3.8-3 e1071_1.7-17 future.apply_1.20.2
## [55] withr_3.0.2 scales_1.4.0 lubridate_1.9.4
## [58] timechange_0.3.0 rmarkdown_2.30 globals_0.19.1
## [61] otel_0.2.0 nnet_7.3-20 timeDate_4052.112
## [64] evaluate_1.0.5 knitr_1.51 hardhat_1.4.2
## [67] Rcpp_1.1.1 glue_1.8.0 ipred_0.9-15
## [70] rstudioapi_0.17.1 jsonlite_2.0.0 R6_2.6.1
## [73] plyr_1.8.9