library(keras3)
library(tensorflow)
## 
## Attaching package: 'tensorflow'
## The following objects are masked from 'package:keras3':
## 
##     set_random_seed, shape
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.2     ✔ tibble    3.3.0
## ✔ lubridate 1.9.4     ✔ tidyr     1.3.1
## ✔ purrr     1.0.4
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(caret)
## Loading required package: lattice
## 
## Attaching package: 'caret'
## 
## The following object is masked from 'package:purrr':
## 
##     lift
## 
## The following object is masked from 'package:tensorflow':
## 
##     train
library(tidyr)

NHANES <- readRDS('/Users/darrensummerlee/Library/CloudStorage/Dropbox/NHANES paper/data set/BP_stats_no_enhance.rds')
set.seed(1)
n <- nrow(NHANES)
train_idx <- sample(1:n, size =  0.8 * n, replace = FALSE)
test_idx <- setdiff(1:n, train_idx)

MIMS <- NHANES$MIMS
age <- NHANES$age
gender <- NHANES$gender
CHD <- NHANES$CHD
BMI <- NHANES$BMI
bpS <- NHANES$BPS_avg

MIMS_train <- MIMS[train_idx, ]
MIMS_test <- MIMS[test_idx, ]

age_train <- age[train_idx]
age_test <- age[test_idx]

BMI_train <- BMI[train_idx]
BMI_test <- BMI[test_idx]

gender_train <- gender[train_idx]
gender_test <- gender[test_idx]

CHD_train <- CHD[train_idx]
CHD_test <- CHD[test_idx]

bpS_train <- bpS[train_idx]
bpS_test <- bpS[test_idx]
MIMS_mean <- colMeans(MIMS_train)
MIMS_sd <- apply(MIMS_train, 2, sd)
MIMS_train_scaled <- scale(MIMS_train)
MIMS_test_scaled <- scale(MIMS_test, center = MIMS_mean, scale = MIMS_sd)
MIMS_train <- array(MIMS_train, dim = c(nrow(MIMS_train), 1440, 1))
MIMS_test <- array(MIMS_test, dim = c(nrow(MIMS_test), 1440, 1))

age_mean <- mean(age_train)
age_sd <- sd(age_train)
age_train_scaled <- scale(age_train)
age_test_scaled <- scale(age_test, center = age_mean, scale = age_sd)

BMI_mean <- mean(BMI_train)
BMI_sd <- sd(BMI_train)
BMI_train_scaled <- scale(BMI_train)
BMI_test_scaled <- scale(BMI_test, center = BMI_mean, scale = BMI_sd)

#one-hot encoding
gender_train_oh <- model.matrix(~ gender_train - 1)
gender_test_oh <- model.matrix(~ gender_test - 1)

CHD_train_oh <- model.matrix(~ CHD_train - 1)
CHD_test_oh <- model.matrix(~ CHD_test - 1)

#combine data
scalar_train <- cbind(age_train_scaled, BMI_train_scaled, gender_train_oh, CHD_train_oh)
scalar_test <- cbind(age_test_scaled, BMI_test_scaled, gender_test_oh, CHD_test_oh)
colnames(scalar_train)[1:2] <- c("Age", "BMI")
colnames(scalar_test)[1:2] <- c("Age", "BMI")
set.seed(1)
func_input <- layer_input(shape = c(1440, 1), name = "func_input")
func_branch <- func_input %>%
  layer_simple_rnn(units = 32, activation = "tanh")

scalar_input <- layer_input(shape = ncol(scalar_train), name = "scalar_input")
scalar_branch <- scalar_input %>%
  layer_dense(units = 32, activation = "relu")

combined <- layer_concatenate(list(func_branch, scalar_branch)) %>%
  layer_dense(units = 1)

model <- keras_model(inputs = list(func_input, scalar_input), outputs = combined)

model %>% compile(
  optimizer = "adam",
  loss = "mse",
  metrics = list("mae")
)
set.seed(1)
history <- model %>% fit(
  x = list(func_input = MIMS_train, scalar_input = scalar_train),
  y = bpS_train,
  epochs = 10,
  batch_size = 32,
  validation_split = 0.2,
  callbacks = list(
    callback_early_stopping(
      monitor = "val_loss",
      patience = 2,
      restore_best_weights = TRUE
    )
  )
)
## Epoch 1/10
## 158/158 - 13s - 85ms/step - loss: 14205.1787 - mae: 117.6825 - val_loss: 12687.3379 - val_mae: 111.1721
## Epoch 2/10
## 158/158 - 12s - 77ms/step - loss: 11362.9922 - mae: 104.8019 - val_loss: 9438.6758 - val_mae: 95.3914
## Epoch 3/10
## 158/158 - 12s - 78ms/step - loss: 7724.5640 - mae: 85.6043 - val_loss: 5682.6689 - val_mae: 73.1922
## Epoch 4/10
## 158/158 - 12s - 78ms/step - loss: 4199.5864 - mae: 61.7835 - val_loss: 2652.6790 - val_mae: 48.5761
## Epoch 5/10
## 158/158 - 12s - 78ms/step - loss: 1803.7251 - mae: 38.3123 - val_loss: 994.4409 - val_mae: 27.1116
## Epoch 6/10
## 158/158 - 12s - 78ms/step - loss: 690.7731 - mae: 20.8962 - val_loss: 403.8442 - val_mae: 14.9787
## Epoch 7/10
## 158/158 - 12s - 78ms/step - loss: 348.1240 - mae: 13.6341 - val_loss: 271.5634 - val_mae: 12.1123
## Epoch 8/10
## 158/158 - 12s - 77ms/step - loss: 279.7185 - mae: 12.2656 - val_loss: 257.3118 - val_mae: 11.9738
## Epoch 9/10
## 158/158 - 12s - 77ms/step - loss: 270.6432 - mae: 12.2241 - val_loss: 257.3408 - val_mae: 12.0648
## Epoch 10/10
## 158/158 - 12s - 77ms/step - loss: 269.5128 - mae: 12.2380 - val_loss: 257.7947 - val_mae: 12.1145
set.seed(1)
preds <- model %>% predict(list(func_input = MIMS_test, scalar_input = scalar_test))
## 50/50 - 1s - 28ms/step
rmse <- sqrt(mean((preds - bpS_test)^2))

ss_res <- sum((bpS_test - preds)^2)
ss_tot <- sum((bpS_test - mean(bpS_test))^2)
manual_r2 <- 1 - (ss_res / ss_tot)
auto_r2 <- summary(lm(bpS_test ~ preds))$r.squared

residuals <- bpS_test - preds
resid_sd <- sd(residuals)
lower_bound <- preds - 1.96 * resid_sd
upper_bound <- preds + 1.96 * resid_sd
covered <- bpS_test >= lower_bound & bpS_test <= upper_bound
coverage_rate <- mean(covered)

cat("Test RMSE:", rmse, "\n")
## Test RMSE: 15.93641
cat("Manual R^2:", round(manual_r2, 4), "\n")
## Manual R^2: 0.2033
cat("Auto R^2:", round(auto_r2, 4), "\n")
## Auto R^2: 0.2064
cat("95% Coverage Rate:", round(coverage_rate * 100, 2), "%\n")
## 95% Coverage Rate: 95.07 %
val_loss <- history$metrics$val_loss
best_epoch <- which.min(val_loss)
best_val_loss <- min(val_loss)

cat("Best epoch was:", best_epoch, "\n")
## Best epoch was: 8
cat("Best validation loss was:", best_val_loss, "\n")
## Best validation loss was: 257.3118
plot(history) +
  geom_vline(xintercept = best_epoch, color = "red", linetype = "dashed", size = 1) +
  labs(title = "Training History with Best Epoch Highlighted")
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

results <- data.frame(actual = bpS_test, predicted = as.vector(preds))
#2d color
zones <- data.frame(
  xmin = c(50, 100, 120, 130, 140),
  xmax = c(100, 120, 130, 140, Inf),
  fill = c("Low", "Normal", "Elevated", "ISH-S1", "S2"),
  color = c("lightblue", "green3", "yellow", "orange", "red")
)

zone_rects <- expand.grid(x = 1:nrow(zones), y = 1:nrow(zones))

max_risk_index <- pmax(zone_rects$x, zone_rects$y)
zone_rects <- cbind(
  zone_rects,
  xmin = zones$xmin[zone_rects$x],
  xmax = zones$xmax[zone_rects$x],
  ymin = zones$xmin[zone_rects$y],
  ymax = zones$xmax[zone_rects$y],
  fill = zones$color[max_risk_index]
)

ggplot() +
  geom_rect(data = zone_rects, aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax, fill = fill), alpha = 0.4) +
  scale_fill_identity() +
  geom_point(data = results, aes(x = predicted, y = actual), color = "blue", size = 0.5) +
  geom_abline(intercept = -20, slope = 1, linetype = "dashed", color = "black", linewidth = 0.4) +
  geom_abline(intercept = 20, slope = 1, linetype = "dashed", color = "black", linewidth = 0.4) +
  xlim(50, 150) +
  ylim(50, 250) +
  coord_cartesian(xlim = c(50, 200), ylim = c(50, 200), expand = FALSE) +
  scale_x_continuous(breaks = c(50, 80, 100, 120, 140, 150)) +
  scale_y_continuous(breaks = c(50, 80, 100, 120, 140, 150)) +
  labs(x = "Predicted BP", y = "True BP") +
  coord_fixed() +
  theme_minimal()
## Scale for x is already present.
## Adding another scale for x, which will replace the existing scale.
## Scale for y is already present.
## Adding another scale for y, which will replace the existing scale.
## Coordinate system already present. Adding new coordinate system, which will
## replace the existing one.

set.seed(1)
#create structure
scalar_variable_names <- colnames(scalar_test)
variable_importance <- data.frame(
  Variable = character(),
  Importance_RMSE_Increase = numeric(),
  stringsAsFactors = FALSE
)

#test scalar variables
for (i in 1:ncol(scalar_test)) {
  scalar_test_shuffled <- scalar_test
  scalar_test_shuffled[, i] <- sample(scalar_test_shuffled[, i])
  
  preds_shuffled <- model %>% predict(list(MIMS_test, scalar_test_shuffled))
  rmse_shuffled <- sqrt(mean((preds_shuffled - bpS_test)^2))
  importance <- rmse_shuffled - rmse

  variable_importance <- rbind(
    variable_importance, 
    data.frame(
      Variable = scalar_variable_names[i], 
      Importance_RMSE_Increase = importance
    )
  )
}
## 50/50 - 1s - 30ms/step
## 50/50 - 1s - 25ms/step
## 50/50 - 1s - 24ms/step
## 50/50 - 1s - 25ms/step
## 50/50 - 1s - 24ms/step
## 50/50 - 1s - 25ms/step
## 50/50 - 1s - 24ms/step
## 50/50 - 1s - 25ms/step
#test functional variable
MIMS_test_shuffled <- MIMS_test
for (t in 1:dim(MIMS_test)[2]) {
  MIMS_test_shuffled[, t, 1] <- sample(MIMS_test_shuffled[, t, 1])
}
preds_mims_shuffled <- model %>% predict(list(MIMS_test_shuffled, scalar_test))
## 50/50 - 1s - 24ms/step
rmse_mims_shuffled <- sqrt(mean((preds_mims_shuffled - bpS_test)^2))
mims_importance <- rmse_mims_shuffled - rmse

#combine results
variable_importance <- rbind(
  variable_importance, 
  data.frame(
    Variable = "MIMS", 
    Importance_RMSE_Increase = mims_importance
  )
)

options(scipen = 999)
print("Variable Importance (Increase in RMSE):")
## [1] "Variable Importance (Increase in RMSE):"
print(variable_importance[order(-variable_importance$Importance_RMSE_Increase), ])
##             Variable Importance_RMSE_Increase
## 3    gender_testMale            12.1481203522
## 4  gender_testFemale            10.2430390554
## 1                Age             3.7070151085
## 5         CHD_testNo             2.1000676700
## 6        CHD_testYes             1.6431470791
## 2                BMI             0.1423151669
## 9               MIMS             0.0001693863
## 7    CHD_testRefused             0.0000000000
## 8 CHD_testDon't know             0.0000000000