devtools:: install_github("elianachristou/fquantdr")
## Skipping install of 'fquantdr' from a github remote, the SHA1 (4e6298a0) has not changed since last install.
## Use `force = TRUE` to force installation
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ forcats 1.0.0 ✔ readr 2.1.5
## ✔ ggplot2 3.5.2 ✔ stringr 1.5.1
## ✔ lubridate 1.9.4 ✔ tibble 3.3.0
## ✔ purrr 1.0.4 ✔ tidyr 1.3.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(ggplot2)
library(fquantdr)
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
##
## The following object is masked from 'package:purrr':
##
## lift
library(fda)
## Loading required package: splines
## Loading required package: fds
## Loading required package: rainbow
## Loading required package: MASS
##
## Attaching package: 'MASS'
##
## The following object is masked from 'package:dplyr':
##
## select
##
## Loading required package: pcaPP
## Loading required package: RCurl
##
## Attaching package: 'RCurl'
##
## The following object is masked from 'package:tidyr':
##
## complete
##
## Loading required package: deSolve
##
## Attaching package: 'fda'
##
## The following object is masked from 'package:lattice':
##
## melanoma
##
## The following object is masked from 'package:graphics':
##
## matplot
##
## The following object is masked from 'package:datasets':
##
## gait
library(vip)
##
## Attaching package: 'vip'
##
## The following object is masked from 'package:utils':
##
## vi
NHANES <- readRDS('/Users/darrensummerlee/Library/CloudStorage/Dropbox/NHANES paper/data set/BP_stats_no_enhance.rds')
model_data <- data.frame(
bpS = NHANES$BPS_avg,
Gender = as.factor(NHANES$gender),
BMI = NHANES$BMI,
CHD = as.factor(NHANES$CHD),
Age = NHANES$age
)
MIMS_array <- array(NHANES$MIMS, dim = c(nrow(NHANES$MIMS), ncol(NHANES$MIMS),1))
rmse_metric <- function(truth, estimate) {
sqrt(mean((truth - estimate)^2))
}
set.seed(1)
folds <- createFolds(model_data$bpS, k = 10, list = TRUE)
cv_results <- data.frame(
Fold = 1:10, RMSE = numeric(10), R_squared = numeric(10), Coverage = numeric(10)
)
importance_list <- list()
custom_predict <- function(object, newdata) { predict(object, newdata) }
for (i in 1:10) {
cat("Processing Fold:", i, "\n")
#train/test split
test_idx <- folds[[i]]
train_idx <- setdiff(1:nrow(model_data), test_idx)
# Training data for FSIR
MIMS_train <- MIMS_array[train_idx, , , drop = FALSE]
Y_train <- model_data$bpS[train_idx]
# Test data fo bpS and MIMS
MIMS_test <- MIMS_array[test_idx, , , drop = FALSE]
Y_test <- model_data$bpS[test_idx]
# Fit FSIR, extract sufpreds
nt <- dim(MIMS_array)[2]
model_fsir <- mfsir(MIMS_train, Y_train, H = 5, nbasis = 20)
FSIR1_train <- model_fsir$sufpred[, 1]
FSIR2_train <- model_fsir$sufpred[, 2]
# Functional representation of test data
nt <- dim(MIMS_test)[2]
MIMS_test_mat <- t(MIMS_test[, , 1])
basis_obj <- create.bspline.basis(rangeval = c(1, nt), nbasis = 20)
MIMS_test_fd <- Data2fd(argvals = 1:nt, y = MIMS_test_mat, basisobj = basis_obj)
test_coef <- as.matrix(t(MIMS_test_fd$coefs))
sufpred_test <- test_coef %*% model_fsir$phi
FSIR1_test <- sufpred_test[, 1]
FSIR2_test <- sufpred_test[, 2]
#combine train/test data for lm
fold_train_data <- data.frame(model_data[train_idx, ], FSIR1 = FSIR1_train, FSIR2 = FSIR2_train)
fold_test_data <- data.frame(model_data[test_idx, ], FSIR1 = FSIR1_test, FSIR2 = FSIR2_test)
# Linear model on training set
model_lm <- lm(bpS ~ FSIR1 + FSIR2 + Gender + BMI + CHD + Age,
data = data.frame(model_data[train_idx, ],
FSIR1 = FSIR1_train, FSIR2 = FSIR2_train))
# Predict lm on test set
pred_obj <- predict(model_lm,
newdata = data.frame(model_data[test_idx, ], FSIR1 = FSIR1_test, FSIR2 = FSIR2_test),
interval = "prediction", level = 0.95)
preds <- pred_obj[, "fit"]
#RMSE
residuals <- fold_test_data$bpS - preds
cv_results$RMSE[i] <- sqrt(mean(residuals^2))
#R^2
cv_results$R_squared[i] <- summary(model_lm)$r.squared
#Cov Rate
within_interval <- Y_test >= pred_obj[, "lwr"] & Y_test <= pred_obj[, "upr"]
cv_results$Coverage[i] <- mean(within_interval)
#Variable Importance
vi_fold <- vip(
model_lm,
method = "permute",
train = fold_test_data,
target = "bpS",
metric = rmse_metric,
pred_wrapper = custom_predict,
nsim = 10,
smaller_is_better = TRUE
)
vi_fold$data$Fold <- i
importance_list[[i]] <- vi_fold$data
}
## Processing Fold: 1
## Processing Fold: 2
## Processing Fold: 3
## Processing Fold: 4
## Processing Fold: 5
## Processing Fold: 6
## Processing Fold: 7
## Processing Fold: 8
## Processing Fold: 9
## Processing Fold: 10
cat("10-fold FSIR Model:\n")
## 10-fold FSIR Model:
print(cv_results)
## Fold RMSE R_squared Coverage
## 1 1 18.63613 0.2067116 0.9252218
## 2 2 19.74847 0.2083922 0.9100127
## 3 3 19.35363 0.2082860 0.9278481
## 4 4 18.49449 0.2143038 0.9280303
## 5 5 18.75374 0.2139740 0.9177215
## 6 6 20.04222 0.2112004 0.9050633
## 7 7 19.60423 0.2136966 0.9015152
## 8 8 18.32588 0.2090780 0.9367089
## 9 9 18.40637 0.2119296 0.9342604
## 10 10 22.58268 0.2137966 0.8833967
cat("RMSE:", round(mean(cv_results$RMSE), 2), "\n")
## RMSE: 19.39
cat("R-squared:", round(mean(cv_results$R_squared), 3), "\n")
## R-squared: 0.211
cat("95% Coverage Rate:", round(mean(cv_results$Coverage) * 100, 3), "%\n")
## 95% Coverage Rate: 91.698 %
importance_df <- do.call(rbind, importance_list)
avg_importance <- importance_df %>%
group_by(Variable) %>%
summarise(
Avg_Importance = mean(Importance),
Std_Dev = sd(Importance)
)
print(avg_importance)
## # A tibble: 6 × 3
## Variable Avg_Importance Std_Dev
## <chr> <dbl> <dbl>
## 1 Age 2.21 0.211
## 2 BMI 0.129 0.0469
## 3 CHD 0.0229 0.0127
## 4 FSIR1 0.172 0.158
## 5 FSIR2 0.0163 0.0746
## 6 Gender 0.0743 0.0658
ggplot(avg_importance, aes(x = reorder(Variable, Avg_Importance), y = Avg_Importance)) +
geom_col(fill = "turquoise") +
geom_errorbar(
aes(ymin = Avg_Importance - Std_Dev, ymax = Avg_Importance + Std_Dev),
width = 0.2
) +
coord_flip() +
labs(
title = "Average Permutation Importance from 10-Fold CV",
subtitle = "Error bars show standard deviation across folds",
x = "Variable",
y = "Average Increase in RMSE"
) +
theme_minimal()

results <- data.frame(actual = Y_test, predicted = as.vector(preds))
#2d color
zones <- data.frame(
xmin = c(50, 100, 120, 130, 140),
xmax = c(100, 120, 130, 140, Inf),
fill = c("Low", "Normal", "Elevated", "ISH-S1", "S2"),
color = c("lightblue", "green3", "yellow", "orange", "red")
)
zone_rects <- expand.grid(x = 1:nrow(zones), y = 1:nrow(zones))
max_risk_index <- pmax(zone_rects$x, zone_rects$y)
zone_rects <- cbind(
zone_rects,
xmin = zones$xmin[zone_rects$x],
xmax = zones$xmax[zone_rects$x],
ymin = zones$xmin[zone_rects$y],
ymax = zones$xmax[zone_rects$y],
fill = zones$color[max_risk_index]
)
ggplot() +
geom_rect(data = zone_rects, aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax, fill = fill), alpha = 0.4) +
scale_fill_identity() +
geom_point(data = results, aes(x = predicted, y = actual), color = "blue", size = 0.5) +
geom_abline(intercept = -20, slope = 1, linetype = "dashed", color = "black", linewidth = 0.4) +
geom_abline(intercept = 20, slope = 1, linetype = "dashed", color = "black", linewidth = 0.4) +
xlim(50, 150) +
ylim(50, 250) +
coord_cartesian(xlim = c(50, 200), ylim = c(50, 200), expand = FALSE) +
scale_x_continuous(breaks = c(50, 80, 100, 120, 140, 150)) +
scale_y_continuous(breaks = c(50, 80, 100, 120, 140, 150)) +
labs(x = "Predicted BP", y = "True BP") +
coord_fixed() +
theme_minimal()
## Scale for x is already present.
## Adding another scale for x, which will replace the existing scale.
## Scale for y is already present.
## Adding another scale for y, which will replace the existing scale.
## Coordinate system already present. Adding new coordinate system, which will
## replace the existing one.
