library(keras3)
library(tensorflow)
##
## Attaching package: 'tensorflow'
## The following objects are masked from 'package:keras3':
##
## set_random_seed, shape
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.2 ✔ tibble 3.3.0
## ✔ lubridate 1.9.4 ✔ tidyr 1.3.1
## ✔ purrr 1.0.4
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
##
## The following object is masked from 'package:purrr':
##
## lift
##
## The following object is masked from 'package:tensorflow':
##
## train
library(tidyr)
NHANES <- readRDS('/Users/darrensummerlee/Library/CloudStorage/Dropbox/NHANES paper/data set/BP_stats_no_enhance.rds')
set.seed(1)
n <- nrow(NHANES)
train_idx <- sample(1:n, size = 0.8 * n, replace = FALSE)
test_idx <- setdiff(1:n, train_idx)
MIMS <- NHANES$MIMS
age <- NHANES$age
gender <- NHANES$gender
CHD <- NHANES$CHD
BMI <- NHANES$BMI
bpS <- NHANES$BPS_avg
MIMS_train <- MIMS[train_idx, ]
MIMS_test <- MIMS[test_idx, ]
age_train <- age[train_idx]
age_test <- age[test_idx]
BMI_train <- BMI[train_idx]
BMI_test <- BMI[test_idx]
gender_train <- gender[train_idx]
gender_test <- gender[test_idx]
CHD_train <- CHD[train_idx]
CHD_test <- CHD[test_idx]
bpS_train <- bpS[train_idx]
bpS_test <- bpS[test_idx]
MIMS_mean <- colMeans(MIMS_train)
MIMS_sd <- apply(MIMS_train, 2, sd)
MIMS_train_scaled <- scale(MIMS_train)
MIMS_test_scaled <- scale(MIMS_test, center = MIMS_mean, scale = MIMS_sd)
MIMS_train <- array(MIMS_train, dim = c(nrow(MIMS_train), 1440, 1))
MIMS_test <- array(MIMS_test, dim = c(nrow(MIMS_test), 1440, 1))
age_mean <- mean(age_train)
age_sd <- sd(age_train)
age_train_scaled <- scale(age_train)
age_test_scaled <- scale(age_test, center = age_mean, scale = age_sd)
BMI_mean <- mean(BMI_train)
BMI_sd <- sd(BMI_train)
BMI_train_scaled <- scale(BMI_train)
BMI_test_scaled <- scale(BMI_test, center = BMI_mean, scale = BMI_sd)
#one-hot encoding
gender_train_oh <- model.matrix(~ gender_train - 1)
gender_test_oh <- model.matrix(~ gender_test - 1)
CHD_train_oh <- model.matrix(~ CHD_train - 1)
CHD_test_oh <- model.matrix(~ CHD_test - 1)
#combine data
scalar_train <- cbind(age_train_scaled, BMI_train_scaled, gender_train_oh, CHD_train_oh)
scalar_test <- cbind(age_test_scaled, BMI_test_scaled, gender_test_oh, CHD_test_oh)
colnames(scalar_train)[1:2] <- c("Age", "BMI")
colnames(scalar_test)[1:2] <- c("Age", "BMI")
set.seed(1)
func_input <- layer_input(shape = c(1440, 1), name = "func_input")
func_branch <- func_input %>%
layer_simple_rnn(units = 32, activation = "tanh")
scalar_input <- layer_input(shape = ncol(scalar_train), name = "scalar_input")
scalar_branch <- scalar_input %>%
layer_dense(units = 32, activation = "relu")
combined <- layer_concatenate(list(func_branch, scalar_branch)) %>%
layer_dense(units = 1)
model <- keras_model(inputs = list(func_input, scalar_input), outputs = combined)
model %>% compile(
optimizer = "adam",
loss = "mse",
metrics = list("mae")
)
set.seed(1)
history <- model %>% fit(
x = list(func_input = MIMS_train, scalar_input = scalar_train),
y = bpS_train,
epochs = 10,
batch_size = 32,
validation_split = 0.2,
callbacks = list(
callback_early_stopping(
monitor = "val_loss",
patience = 2,
restore_best_weights = TRUE
)
)
)
## Epoch 1/10
## 158/158 - 13s - 85ms/step - loss: 14205.1787 - mae: 117.6825 - val_loss: 12687.3379 - val_mae: 111.1721
## Epoch 2/10
## 158/158 - 12s - 77ms/step - loss: 11362.9922 - mae: 104.8019 - val_loss: 9438.6758 - val_mae: 95.3914
## Epoch 3/10
## 158/158 - 12s - 78ms/step - loss: 7724.5640 - mae: 85.6043 - val_loss: 5682.6689 - val_mae: 73.1922
## Epoch 4/10
## 158/158 - 12s - 78ms/step - loss: 4199.5864 - mae: 61.7835 - val_loss: 2652.6790 - val_mae: 48.5761
## Epoch 5/10
## 158/158 - 12s - 78ms/step - loss: 1803.7251 - mae: 38.3123 - val_loss: 994.4409 - val_mae: 27.1116
## Epoch 6/10
## 158/158 - 12s - 78ms/step - loss: 690.7731 - mae: 20.8962 - val_loss: 403.8442 - val_mae: 14.9787
## Epoch 7/10
## 158/158 - 12s - 78ms/step - loss: 348.1240 - mae: 13.6341 - val_loss: 271.5634 - val_mae: 12.1123
## Epoch 8/10
## 158/158 - 12s - 77ms/step - loss: 279.7185 - mae: 12.2656 - val_loss: 257.3118 - val_mae: 11.9738
## Epoch 9/10
## 158/158 - 12s - 77ms/step - loss: 270.6432 - mae: 12.2241 - val_loss: 257.3408 - val_mae: 12.0648
## Epoch 10/10
## 158/158 - 12s - 77ms/step - loss: 269.5128 - mae: 12.2380 - val_loss: 257.7947 - val_mae: 12.1145
set.seed(1)
preds <- model %>% predict(list(func_input = MIMS_test, scalar_input = scalar_test))
## 50/50 - 1s - 28ms/step
rmse <- sqrt(mean((preds - bpS_test)^2))
ss_res <- sum((bpS_test - preds)^2)
ss_tot <- sum((bpS_test - mean(bpS_test))^2)
manual_r2 <- 1 - (ss_res / ss_tot)
auto_r2 <- summary(lm(bpS_test ~ preds))$r.squared
residuals <- bpS_test - preds
resid_sd <- sd(residuals)
lower_bound <- preds - 1.96 * resid_sd
upper_bound <- preds + 1.96 * resid_sd
covered <- bpS_test >= lower_bound & bpS_test <= upper_bound
coverage_rate <- mean(covered)
cat("Test RMSE:", rmse, "\n")
## Test RMSE: 15.93641
cat("Manual R^2:", round(manual_r2, 4), "\n")
## Manual R^2: 0.2033
cat("Auto R^2:", round(auto_r2, 4), "\n")
## Auto R^2: 0.2064
cat("95% Coverage Rate:", round(coverage_rate * 100, 2), "%\n")
## 95% Coverage Rate: 95.07 %
val_loss <- history$metrics$val_loss
best_epoch <- which.min(val_loss)
best_val_loss <- min(val_loss)
cat("Best epoch was:", best_epoch, "\n")
## Best epoch was: 8
cat("Best validation loss was:", best_val_loss, "\n")
## Best validation loss was: 257.3118
plot(history) +
geom_vline(xintercept = best_epoch, color = "red", linetype = "dashed", size = 1) +
labs(title = "Training History with Best Epoch Highlighted")
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

results <- data.frame(actual = bpS_test, predicted = as.vector(preds))
#2d color
zones <- data.frame(
xmin = c(50, 100, 120, 130, 140),
xmax = c(100, 120, 130, 140, Inf),
fill = c("Low", "Normal", "Elevated", "ISH-S1", "S2"),
color = c("lightblue", "green3", "yellow", "orange", "red")
)
zone_rects <- expand.grid(x = 1:nrow(zones), y = 1:nrow(zones))
max_risk_index <- pmax(zone_rects$x, zone_rects$y)
zone_rects <- cbind(
zone_rects,
xmin = zones$xmin[zone_rects$x],
xmax = zones$xmax[zone_rects$x],
ymin = zones$xmin[zone_rects$y],
ymax = zones$xmax[zone_rects$y],
fill = zones$color[max_risk_index]
)
ggplot() +
geom_rect(data = zone_rects, aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax, fill = fill), alpha = 0.4) +
scale_fill_identity() +
geom_point(data = results, aes(x = predicted, y = actual), color = "blue", size = 0.5) +
geom_abline(intercept = -20, slope = 1, linetype = "dashed", color = "black", linewidth = 0.4) +
geom_abline(intercept = 20, slope = 1, linetype = "dashed", color = "black", linewidth = 0.4) +
xlim(50, 150) +
ylim(50, 250) +
coord_cartesian(xlim = c(50, 200), ylim = c(50, 200), expand = FALSE) +
scale_x_continuous(breaks = c(50, 80, 100, 120, 140, 150)) +
scale_y_continuous(breaks = c(50, 80, 100, 120, 140, 150)) +
labs(x = "Predicted BP", y = "True BP") +
coord_fixed() +
theme_minimal()
## Scale for x is already present.
## Adding another scale for x, which will replace the existing scale.
## Scale for y is already present.
## Adding another scale for y, which will replace the existing scale.
## Coordinate system already present. Adding new coordinate system, which will
## replace the existing one.

set.seed(1)
#create structure
scalar_variable_names <- colnames(scalar_test)
variable_importance <- data.frame(
Variable = character(),
Importance_RMSE_Increase = numeric(),
stringsAsFactors = FALSE
)
#test scalar variables
for (i in 1:ncol(scalar_test)) {
scalar_test_shuffled <- scalar_test
scalar_test_shuffled[, i] <- sample(scalar_test_shuffled[, i])
preds_shuffled <- model %>% predict(list(MIMS_test, scalar_test_shuffled))
rmse_shuffled <- sqrt(mean((preds_shuffled - bpS_test)^2))
importance <- rmse_shuffled - rmse
variable_importance <- rbind(
variable_importance,
data.frame(
Variable = scalar_variable_names[i],
Importance_RMSE_Increase = importance
)
)
}
## 50/50 - 1s - 30ms/step
## 50/50 - 1s - 25ms/step
## 50/50 - 1s - 24ms/step
## 50/50 - 1s - 25ms/step
## 50/50 - 1s - 24ms/step
## 50/50 - 1s - 25ms/step
## 50/50 - 1s - 24ms/step
## 50/50 - 1s - 25ms/step
#test functional variable
MIMS_test_shuffled <- MIMS_test
for (t in 1:dim(MIMS_test)[2]) {
MIMS_test_shuffled[, t, 1] <- sample(MIMS_test_shuffled[, t, 1])
}
preds_mims_shuffled <- model %>% predict(list(MIMS_test_shuffled, scalar_test))
## 50/50 - 1s - 24ms/step
rmse_mims_shuffled <- sqrt(mean((preds_mims_shuffled - bpS_test)^2))
mims_importance <- rmse_mims_shuffled - rmse
#combine results
variable_importance <- rbind(
variable_importance,
data.frame(
Variable = "MIMS",
Importance_RMSE_Increase = mims_importance
)
)
options(scipen = 999)
print("Variable Importance (Increase in RMSE):")
## [1] "Variable Importance (Increase in RMSE):"
print(variable_importance[order(-variable_importance$Importance_RMSE_Increase), ])
## Variable Importance_RMSE_Increase
## 3 gender_testMale 12.1481203522
## 4 gender_testFemale 10.2430390554
## 1 Age 3.7070151085
## 5 CHD_testNo 2.1000676700
## 6 CHD_testYes 1.6431470791
## 2 BMI 0.1423151669
## 9 MIMS 0.0001693863
## 7 CHD_testRefused 0.0000000000
## 8 CHD_testDon't know 0.0000000000