library(keras3)
library(tensorflow)
## 
## Attaching package: 'tensorflow'
## The following objects are masked from 'package:keras3':
## 
##     set_random_seed, shape
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.2     ✔ tibble    3.3.0
## ✔ lubridate 1.9.4     ✔ tidyr     1.3.1
## ✔ purrr     1.0.4
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(caret)
## Loading required package: lattice
## 
## Attaching package: 'caret'
## 
## The following object is masked from 'package:purrr':
## 
##     lift
## 
## The following object is masked from 'package:tensorflow':
## 
##     train
library(tidyr)

NHANES <- readRDS('/Users/darrensummerlee/Library/CloudStorage/Dropbox/NHANES paper/data set/BP_stats_no_enhance.rds')
n <- nrow(NHANES)
MIMS <- NHANES$MIMS
age <- NHANES$age
gender <- NHANES$gender
CHD <- NHANES$CHD
BMI <- NHANES$BMI
bpS <- NHANES$BPS_avg

MIMS_scaled <- scale(MIMS)
age_scaled <- scale(age)
BMI_scaled <- scale(BMI)
gender_onehot <- model.matrix(~ gender - 1)
CHD_onehot <- model.matrix(~ CHD - 1)

scalars <- cbind(age_scaled, BMI_scaled, gender_onehot, CHD_onehot)
colnames(scalars)[1:2] <- c("Age", "BMI")
set.seed(1)

train_idx <- sample(1:n, size =  0.8 * n, replace = FALSE)
test_idx <- setdiff(1:n, train_idx)

scalar_train <- scalars[train_idx, ]
scalar_test <- scalars[test_idx, ]

bpS_train <- bpS[train_idx]
bpS_test <-bpS[test_idx]

MIMS_train <- MIMS_scaled[train_idx, ]
MIMS_test <- MIMS_scaled[test_idx, ]
MIMS_train <- array(MIMS_train, dim = c(nrow(MIMS_train), 1440, 1))
MIMS_test <- array(MIMS_test, dim = c(nrow(MIMS_test), 1440, 1))
set.seed(1)
func_input <- layer_input(shape = c(1440, 1), name = "func_input")
func_branch <- func_input %>%
  layer_gru(units = 32, activation = "tanh")

scalar_input_layer <- layer_input(shape = ncol(scalar_train), name = "scalar_input")
scalar_branch <- scalar_input_layer %>%
  layer_dense(units = 32, activation = "relu")

combined <- layer_concatenate(list(func_branch, scalar_branch)) %>%
  layer_dense(units = 1)

input <- layer_input(shape = c(10))
output <- input %>%
  layer_dense(units = 16, activation = "relu") %>%
  layer_dense(units = 1)

model <- keras_model(inputs = list(func_input, scalar_input_layer), outputs = combined)

model %>% compile(
  optimizer = "adam",
  loss = "mse",
  metrics = list("mae")
)
set.seed(1)
history <- model %>% fit(
  x = list(func_input = MIMS_train, scalar_input = scalar_train),
  y = bpS_train,
  epochs = 10,
  batch_size = 32,
  validation_split = 0.2,
  callbacks = list(
    callback_early_stopping(
      monitor = "val_loss",
      patience = 3, 
      restore_best_weights = TRUE
    )
  ),
  verbose = 2
)
## Epoch 1/10
## 158/158 - 32s - 201ms/step - loss: 14024.3418 - mae: 116.8675 - val_loss: 12204.8652 - val_mae: 109.0161
## Epoch 2/10
## 158/158 - 31s - 197ms/step - loss: 10877.1006 - mae: 102.5447 - val_loss: 8996.3320 - val_mae: 93.1568
## Epoch 3/10
## 158/158 - 31s - 193ms/step - loss: 7379.4907 - mae: 83.7094 - val_loss: 5449.1182 - val_mae: 71.7261
## Epoch 4/10
## 158/158 - 32s - 200ms/step - loss: 4027.9019 - mae: 60.5061 - val_loss: 2537.1602 - val_mae: 47.4276
## Epoch 5/10
## 158/158 - 34s - 213ms/step - loss: 1705.7799 - mae: 37.0831 - val_loss: 928.8557 - val_mae: 25.9777
## Epoch 6/10
## 158/158 - 34s - 218ms/step - loss: 640.4885 - mae: 20.0205 - val_loss: 376.8244 - val_mae: 14.4872
## Epoch 7/10
## 158/158 - 32s - 200ms/step - loss: 330.2272 - mae: 13.3580 - val_loss: 267.8032 - val_mae: 12.1342
## Epoch 8/10
## 158/158 - 32s - 200ms/step - loss: 275.8527 - mae: 12.3064 - val_loss: 258.1783 - val_mae: 12.0922
## Epoch 9/10
## 158/158 - 33s - 210ms/step - loss: 270.1825 - mae: 12.2882 - val_loss: 257.8130 - val_mae: 12.1358
## Epoch 10/10
## 158/158 - 32s - 202ms/step - loss: 269.1870 - mae: 12.2690 - val_loss: 257.4179 - val_mae: 12.1301
preds <- model %>% predict(list(func_input = MIMS_test, scalar_input = scalar_test))
## 50/50 - 3s - 70ms/step
rmse <- sqrt(mean((preds - bpS_test)^2))

ss_res <- sum((bpS_test - preds)^2)
ss_tot <- sum((bpS_test - mean(bpS_test))^2)
r_squared <- 1 - (ss_res / ss_tot)
r_2 <- summary(lm(bpS_test ~ preds))$r.squared

#cov rate
residuals <- bpS_test - preds
resid_sd <- sd(residuals)
lower_bound <- preds - 1.96 * resid_sd
upper_bound <- preds + 1.96 * resid_sd
covered <- bpS_test >= lower_bound & bpS_test <= upper_bound
coverage_rate <- mean(covered)

cat("95% Coverage Rate:", round(coverage_rate * 100, 2), "%\n")
## 95% Coverage Rate: 95.13 %
cat("Manual R^2:", round(r_squared, 4), "\n")
## Manual R^2: 0.2054
cat("Auto R^2:", round(r_2, 4), "\n")
## Auto R^2: 0.2064
cat("Test RMSE:", rmse, "\n")
## Test RMSE: 15.91577
plot(history)

results <- data.frame(actual = bpS_test, predicted = as.vector(preds))

ggplot(results, aes(x = actual, y = predicted)) +
  geom_point(alpha = 0.5) +
  geom_abline(color = "red", linetype = "dashed") + 
  labs(title = "Actual vs. Predicted Values",
       x = "Actual Blood Pressure",
       y = "Predicted Blood Pressure") +
  coord_fixed() +
  theme_minimal()

results$residuals <- results$predicted - results$actual

ggplot(results, aes(x = predicted, y = residuals)) +
  geom_point(alpha = 0.5) +
  geom_hline(yintercept = 0, color = "red", linetype = "dashed") +
  labs(title = "Residuals vs. Predicted Values",
       x = "Predicted Blood Pressure",
       y = "Residuals (Predicted - Actual)") +
  theme_minimal()

zones <- data.frame(
  xmin = c(50, 100, 120, 130, 140),
  xmax = c(100, 120, 130, 140, Inf),
  fill = c("Low", "Normal", "Elevated", "ISH-S1", "S2"),
  color = c("lightblue", "green3", "yellow", "orange", "red")
)

zone_rects <- expand.grid(x = 1:nrow(zones), y = 1:nrow(zones))

max_risk_index <- pmax(zone_rects$x, zone_rects$y)
zone_rects <- cbind(
  zone_rects,
  xmin = zones$xmin[zone_rects$x],
  xmax = zones$xmax[zone_rects$x],
  ymin = zones$xmin[zone_rects$y],
  ymax = zones$xmax[zone_rects$y],
  fill = zones$color[max_risk_index]
)

ggplot() +
  geom_rect(data = zone_rects, aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax, fill = fill), alpha = 0.4) +
  scale_fill_identity() +
  geom_point(data = results, aes(x = predicted, y = actual), color = "blue", size = 0.5) +
  geom_abline(intercept = -20, slope = 1, linetype = "dashed", color = "black", linewidth = 0.4) +
  geom_abline(intercept = 20, slope = 1, linetype = "dashed", color = "black", linewidth = 0.4) +
  xlim(50, 150) +
  ylim(50, 250) +
  coord_cartesian(xlim = c(50, 200), ylim = c(50, 200), expand = FALSE) +
  scale_x_continuous(breaks = c(50, 80, 100, 120, 140, 150)) +
  scale_y_continuous(breaks = c(50, 80, 100, 120, 140, 150)) +
  labs(x = "Predicted BP", y = "True BP") +
  coord_fixed() +
  theme_minimal()
## Scale for x is already present.
## Adding another scale for x, which will replace the existing scale.
## Scale for y is already present.
## Adding another scale for y, which will replace the existing scale.
## Coordinate system already present. Adding new coordinate system, which will
## replace the existing one.

set.seed(1)

scalar_variable_names <- colnames(scalar_test)
variable_importance <- data.frame(
  Variable = character(),
  Importance_RMSE_Increase = numeric(),
  stringsAsFactors = FALSE
)

for (i in 1:ncol(scalar_test)) {
  scalar_test_shuffled <- scalar_test
  scalar_test_shuffled[, i] <- sample(scalar_test_shuffled[, i])
  
  preds_shuffled <- model %>% predict(list(MIMS_test, scalar_test_shuffled))
  rmse_shuffled <- sqrt(mean((preds_shuffled - bpS_test)^2))
  importance <- rmse_shuffled - rmse
  
  variable_importance <- rbind(
    variable_importance, 
    data.frame(
      Variable = scalar_variable_names[i], 
      Importance_RMSE_Increase = importance
    )
  )
}
## 50/50 - 3s - 61ms/step
## 50/50 - 3s - 58ms/step
## 50/50 - 3s - 58ms/step
## 50/50 - 3s - 60ms/step
## 50/50 - 3s - 57ms/step
## 50/50 - 3s - 59ms/step
## 50/50 - 3s - 57ms/step
## 50/50 - 3s - 60ms/step
MIMS_test_shuffled <- MIMS_test
for (t in 1:dim(MIMS_test)[2]) {
  MIMS_test_shuffled[, t, 1] <- sample(MIMS_test_shuffled[, t, 1])
}

preds_mims_shuffled <- model %>% predict(list(MIMS_test_shuffled, scalar_test))
## 50/50 - 3s - 58ms/step
rmse_mims_shuffled <- sqrt(mean((preds_mims_shuffled - bpS_test)^2))
mims_importance <- rmse_mims_shuffled - rmse

variable_importance <- rbind(
  variable_importance, 
  data.frame(
    Variable = "MIMS", 
    Importance_RMSE_Increase = mims_importance
  )
)

options(scipen = 999)
print("Variable Importance (Increase in RMSE):")
## [1] "Variable Importance (Increase in RMSE):"
print(variable_importance[order(-variable_importance$Importance_RMSE_Increase), ])
##        Variable Importance_RMSE_Increase
## 3    genderMale          12.613344661966
## 4  genderFemale          10.705532627400
## 1           Age           3.761359252388
## 5         CHDNo           2.241904219418
## 6        CHDYes           1.751069767620
## 2           BMI           0.149921321009
## 7    CHDRefused           0.000000000000
## 8 CHDDon't know           0.000000000000
## 9          MIMS          -0.000006617201

```