# Load required libraries
library(readr)
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom        1.0.7     ✔ recipes      1.1.0
## ✔ dials        1.3.0     ✔ rsample      1.2.1
## ✔ dplyr        1.1.4     ✔ tibble       3.2.1
## ✔ ggplot2      3.5.1     ✔ tidyr        1.3.1
## ✔ infer        1.0.7     ✔ tune         1.2.1
## ✔ modeldata    1.4.0     ✔ workflows    1.1.4
## ✔ parsnip      1.2.1     ✔ workflowsets 1.1.0
## ✔ purrr        1.0.2     ✔ yardstick    1.3.1
## Warning: package 'dplyr' was built under R version 4.4.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard()  masks scales::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
## • Learn how to get started at https://www.tidymodels.org/start/
library(xgboost)
## 
## Attaching package: 'xgboost'
## The following object is masked from 'package:dplyr':
## 
##     slice
library(glmnet)
## Loading required package: Matrix
## 
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
## 
##     expand, pack, unpack
## Loaded glmnet 4.1-8
library(dplyr)
library(ggplot2)
library(vip)
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi
library(broom)
library(pROC)
## Type 'citation("pROC")' for a citation.
## 
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
library(yardstick)
library(purrr)
library(tidyr)
library(hardhat)
library(magrittr) # Ensure the pipe operator %>% is available
## 
## Attaching package: 'magrittr'
## The following object is masked from 'package:tidyr':
## 
##     extract
## The following object is masked from 'package:purrr':
## 
##     set_names
library(rpart)
## 
## Attaching package: 'rpart'
## The following object is masked from 'package:dials':
## 
##     prune
library(rpart.plot) # For visualizing individual decision tree structure
library(vip)
# Loading training and test datasets
train_data <- read_csv("C:/Users/JOYCE/OneDrive/Desktop/FHS/FHS-Cleaned-V2/train.csv")
## Rows: 6115 Columns: 27
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (10): Gender, Education Level, Smoker, Cigarettes/per day, Diabetic, PRE...
## dbl (17): TOTCHOL, AGE, SYSBP, DIABP, BMI, BPMEDS, HEARTRTE, GLUCOSE, PERIOD...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
test_data <- read_csv("C:/Users/JOYCE/OneDrive/Desktop/FHS/FHS-Cleaned-V2/test.csv")
## Rows: 2039 Columns: 27
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (10): Gender, Education Level, Smoker, Cigarettes/per day, Diabetic, PRE...
## dbl (17): TOTCHOL, AGE, SYSBP, DIABP, BMI, BPMEDS, HEARTRTE, GLUCOSE, PERIOD...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
# Define a simple decision tree with limited depth for visualization purposes
simple_tree <- rpart(CVD ~ ., data = train_data, method = "class", control = rpart.control(maxdepth = 5, cp = 0.01))
# Plot the decision tree
rpart.plot(simple_tree, main="Decision Tree for Cardiovascular Disease Prediction (Depth = 5)", 
           type = 3, extra = 4, under = TRUE, cex = 0.7, box.palette = "RdBu", shadow.col = "gray", nn = TRUE)

# Convert categorical columns to factors in `train_data` and `test_data`
categorical_cols <- c("Gender", "Education Level", "Smoker", "Cigarettes/per day", 
                      "Diabetic", "PREVCHD - Coronary Disease", "PREVAP - Angina Pectoris", 
                      "PREVMI - Myocardial Infarction", "PREVSTRK - Stroke History", 
                      "PREVHYP - Hypertension")
train_data <- train_data %>%
  mutate(CVD = as.factor(CVD), across(all_of(categorical_cols), as.factor))

test_data <- test_data %>%
  mutate(CVD = as.factor(CVD), across(all_of(categorical_cols), as.factor))
# Convert categorical variables to dummy variables and normalize for xgboost
train_data_numeric <- train_data %>%
  select(-CVD) %>%
  recipe(~ .) %>%
  step_dummy(all_nominal(), -all_outcomes()) %>%
  step_normalize(all_predictors()) %>%  # Add normalization step
  prep() %>%
  bake(new_data = NULL) %>%
  as.matrix()

test_data_numeric <- test_data %>%
  select(-CVD) %>%
  recipe(~ .) %>%
  step_dummy(all_nominal(), -all_outcomes()) %>%
  step_normalize(all_predictors()) %>%  # Add normalization step
  prep() %>%
  bake(new_data = NULL) %>%
  as.matrix()
# Create the xgboost DMatrix
train_matrix <- xgb.DMatrix(data = train_data_numeric, label = as.numeric(train_data$CVD) - 1)
test_matrix <- xgb.DMatrix(data = test_data_numeric, label = as.numeric(test_data$CVD) - 1)
# Logistic Regression with L1 Regularization
set.seed(123)

log_reg_spec <- logistic_reg(
  penalty = tune(), 
  mixture = 1  # Set mixture to 1 for pure L1 regularization (lasso)
) %>%
  set_engine("glmnet") %>%
  set_mode("classification")

log_reg_wf <- workflow() %>%
  add_model(log_reg_spec) %>%
  add_formula(CVD ~ .)
# Cross-validation setup and tuning for logistic regression
folds <- vfold_cv(train_data, v = 5)

log_reg_params <- extract_parameter_set_dials(log_reg_wf) %>%
  update(
    penalty = penalty(range = c(-6, 1))  # Range for penalty strength in logistic regression
  )

log_reg_res <- tune_grid(
  log_reg_wf,
  resamples = folds,
  param_info = log_reg_params,
  grid = 10,
  control = control_grid()
)
# Extract best parameters and fit the final logistic regression model
best_log_reg <- select_best(log_reg_res, metric = "roc_auc")
final_log_reg_wf <- finalize_workflow(log_reg_wf, best_log_reg)
log_reg_fit <- fit(final_log_reg_wf, data = train_data)
# Decision Tree (xgboost) Model with Adjusted Parameters for Complexity Control
decision_tree_params <- list(
  booster = "gbtree",
  eta = 0.3,               # Lower learning rate for better generalization
  max_depth = 3,           # Reduced depth to limit tree complexity
  min_child_weight = 5,    # Increase to avoid overly specific splits
  alpha = 0.1,             # L1 regularization term
  lambda = 0.1,            # L2 regularization term
  objective = "binary:logistic",
  nrounds = 100            # Increased rounds to compensate for lower eta
)

decision_tree_model <- xgboost(
  params = decision_tree_params,
  data = train_matrix,
  nrounds = decision_tree_params$nrounds,
  verbose = 0
)
## [22:58:53] WARNING: src/learner.cc:767: 
## Parameters: { "nrounds" } are not used.
# Random Forest (xgboost) Model with L1 Regularization
random_forest_params <- list(
  booster = "gbtree",
  eta = 0.1,              
  max_depth = 6,          
  alpha = 0.1,            
  lambda = 0,             
  subsample = 0.8,        
  colsample_bytree = 0.8, 
  objective = "binary:logistic",
  nrounds = 500           
)

random_forest_model <- xgboost(
  params = random_forest_params,
  data = train_matrix,
  nrounds = random_forest_params$nrounds,
  verbose = 0
)
## [22:58:54] WARNING: src/learner.cc:767: 
## Parameters: { "nrounds" } are not used.
# Predictions for Training and Test Sets
log_reg_train_preds <- predict(log_reg_fit, train_data, type = "prob") %>% pull(.pred_1)
log_reg_test_preds <- predict(log_reg_fit, test_data, type = "prob") %>% pull(.pred_1)

decision_tree_train_preds <- predict(decision_tree_model, newdata = train_matrix)
decision_tree_test_preds <- predict(decision_tree_model, newdata = test_matrix)

random_forest_train_preds <- predict(random_forest_model, newdata = train_matrix)
random_forest_test_preds <- predict(random_forest_model, newdata = test_matrix)
# Convert predictions to binary classes for Training and Test Sets
log_reg_train_classes <- factor(ifelse(log_reg_train_preds > 0.5, "1", "0"))
log_reg_test_classes <- factor(ifelse(log_reg_test_preds > 0.5, "1", "0"))

decision_tree_train_classes <- factor(ifelse(decision_tree_train_preds > 0.5, "1", "0"))
decision_tree_test_classes <- factor(ifelse(decision_tree_test_preds > 0.5, "1", "0"))

random_forest_train_classes <- factor(ifelse(random_forest_train_preds > 0.5, "1", "0"))
random_forest_test_classes <- factor(ifelse(random_forest_test_preds > 0.5, "1", "0"))
# Performance metrics for each model
individual_model_results <- bind_rows(
  tibble(Model = "Logistic Regression", 
         Accuracy = accuracy_vec(test_data$CVD, log_reg_test_classes),
         Precision = precision_vec(test_data$CVD, log_reg_test_classes),
         Recall = recall_vec(test_data$CVD, log_reg_test_classes),
         Specificity = spec_vec(test_data$CVD, log_reg_test_classes),
         F1 = f_meas_vec(test_data$CVD, log_reg_test_classes)),
  tibble(Model = "Decision Tree (xgboost)", 
         Accuracy = accuracy_vec(test_data$CVD, decision_tree_test_classes),
         Precision = precision_vec(test_data$CVD, decision_tree_test_classes),
         Recall = recall_vec(test_data$CVD, decision_tree_test_classes),
         Specificity = spec_vec(test_data$CVD, decision_tree_test_classes),
         F1 = f_meas_vec(test_data$CVD, decision_tree_test_classes)),
  tibble(Model = "Random Forest (xgboost)", 
         Accuracy = accuracy_vec(test_data$CVD, random_forest_test_classes),
         Precision = precision_vec(test_data$CVD, random_forest_test_classes),
         Recall = recall_vec(test_data$CVD, random_forest_test_classes),
         Specificity = spec_vec(test_data$CVD, random_forest_test_classes),
         F1 = f_meas_vec(test_data$CVD, random_forest_test_classes))
)
# Convert ensemble predictions to factors for metrics calculations
ensemble_lr_dt_classes <- factor(ifelse((log_reg_test_preds + decision_tree_test_preds) / 2 > 0.5, "1", "0"))
ensemble_lr_rf_classes <- factor(ifelse((log_reg_test_preds + random_forest_test_preds) / 2 > 0.5, "1", "0"))
ensemble_dt_rf_classes <- factor(ifelse((decision_tree_test_preds + random_forest_test_preds) / 2 > 0.5, "1", "0"))
ensemble_three_classes <- factor(ifelse((log_reg_test_preds + decision_tree_test_preds + random_forest_test_preds) / 3 > 0.5, "1", "0"))
# Ensemble performance metrics
ensemble_results <- bind_rows(
  tibble(Model = "Logistic + Decision Tree (xgboost) Ensemble", 
         Accuracy = accuracy_vec(test_data$CVD, ensemble_lr_dt_classes),
         Precision = precision_vec(test_data$CVD, ensemble_lr_dt_classes),
         Recall = recall_vec(test_data$CVD, ensemble_lr_dt_classes),
         Specificity = spec_vec(test_data$CVD, ensemble_lr_dt_classes),
         F1 = f_meas_vec(test_data$CVD, ensemble_lr_dt_classes)),
  tibble(Model = "Logistic + Random Forest (xgboost) Ensemble", 
         Accuracy = accuracy_vec(test_data$CVD, ensemble_lr_rf_classes),
         Precision = precision_vec(test_data$CVD, ensemble_lr_rf_classes),
         Recall = recall_vec(test_data$CVD, ensemble_lr_rf_classes),
         Specificity = spec_vec(test_data$CVD, ensemble_lr_rf_classes),
         F1 = f_meas_vec(test_data$CVD, ensemble_lr_rf_classes)),
  tibble(Model = "Decision Tree + Random Forest (xgboost) Ensemble", 
         Accuracy = accuracy_vec(test_data$CVD, ensemble_dt_rf_classes),
         Precision = precision_vec(test_data$CVD, ensemble_dt_rf_classes),
         Recall = recall_vec(test_data$CVD, ensemble_dt_rf_classes),
         Specificity = spec_vec(test_data$CVD, ensemble_dt_rf_classes),
         F1 = f_meas_vec(test_data$CVD, ensemble_dt_rf_classes)),
  tibble(Model = "Three-Model Ensemble", 
         Accuracy = accuracy_vec(test_data$CVD, ensemble_three_classes),
         Precision = precision_vec(test_data$CVD, ensemble_three_classes),
         Recall = recall_vec(test_data$CVD, ensemble_three_classes),
         Specificity = spec_vec(test_data$CVD, ensemble_three_classes),
         F1 = f_meas_vec(test_data$CVD, ensemble_three_classes))
)
# Print performance metrics
print("Individual Model Performance Metrics:")
## [1] "Individual Model Performance Metrics:"
print(individual_model_results)
## # A tibble: 3 × 6
##   Model                   Accuracy Precision Recall Specificity    F1
##   <chr>                      <dbl>     <dbl>  <dbl>       <dbl> <dbl>
## 1 Logistic Regression        0.972     0.977  0.986       0.926 0.981
## 2 Decision Tree (xgboost)    0.973     0.982  0.982       0.943 0.982
## 3 Random Forest (xgboost)    0.970     0.981  0.98        0.939 0.980
print("Ensemble Model Performance Metrics:")
## [1] "Ensemble Model Performance Metrics:"
print(ensemble_results)
## # A tibble: 4 × 6
##   Model                              Accuracy Precision Recall Specificity    F1
##   <chr>                                 <dbl>     <dbl>  <dbl>       <dbl> <dbl>
## 1 Logistic + Decision Tree (xgboost…    0.971     0.978  0.984       0.928 0.981
## 2 Logistic + Random Forest (xgboost…    0.970     0.979  0.981       0.933 0.980
## 3 Decision Tree + Random Forest (xg…    0.973     0.982  0.982       0.943 0.982
## 4 Three-Model Ensemble                  0.973     0.981  0.983       0.939 0.982
# Define function to generate and plot confusion matrices for Training, Test, and Ensemble Sets with purple shades
generate_confusion_matrix <- function(truth, estimate, model_name, data_split) {
  conf_matrix <- conf_mat(tibble(truth = truth, estimate = estimate), truth, estimate)
  autoplot(conf_matrix, type = "heatmap") +
    labs(title = paste("Confusion Matrix -", model_name, "(", data_split, ")")) +
    scale_fill_gradient(low = "lavender", high = "purple") +
    theme_minimal() +
    theme(
      text = element_text(size = 12, family = "serif"),
      plot.title = element_text(hjust = 0.5)
    )
}
# Generate and print confusion matrices for Training and Test Sets for each model and ensemble with 4-second delay
confusion_matrices <- list(
  # Logistic Regression
  generate_confusion_matrix(train_data$CVD, log_reg_train_classes, "Logistic Regression", "Training Set"),
  generate_confusion_matrix(test_data$CVD, log_reg_test_classes, "Logistic Regression", "Test Set"),
  
  # Decision Tree
  generate_confusion_matrix(train_data$CVD, decision_tree_train_classes, "Decision Tree (xgboost)", "Training Set"),
  generate_confusion_matrix(test_data$CVD, decision_tree_test_classes, "Decision Tree (xgboost)", "Test Set"),
  
  # Random Forest
  generate_confusion_matrix(train_data$CVD, random_forest_train_classes, "Random Forest (xgboost)", "Training Set"),
  generate_confusion_matrix(test_data$CVD, random_forest_test_classes, "Random Forest (xgboost)", "Test Set"),
  
  # Ensembles on Test Set
  generate_confusion_matrix(test_data$CVD, ensemble_lr_dt_classes, "Logistic + Decision Tree Ensemble", "Test Set"),
  generate_confusion_matrix(test_data$CVD, ensemble_lr_rf_classes, "Logistic + Random Forest Ensemble", "Test Set"),
  generate_confusion_matrix(test_data$CVD, ensemble_dt_rf_classes, "Decision Tree + Random Forest Ensemble", "Test Set"),
  generate_confusion_matrix(test_data$CVD, ensemble_three_classes, "Three-Model Ensemble", "Test Set")
)
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
for (conf_mat_plot in confusion_matrices) {
  print(conf_mat_plot)
  Sys.sleep(2)  # 2-second delay after each plot
}

# Custom ggplot2 plot for Logistic Regression feature importance with Coefficient Value
log_reg_coefs <- tidy(extract_fit_parsnip(log_reg_fit)) %>%
  arrange(desc(estimate)) %>%
  mutate(term = recode(term,
                       "ANGINA" = "Angina Pectoris",
                       "MI_FCHD" = "Myocardial Infarction Fatal Coronary Heart Disease",
                       "HOSPMI" = "Hospitalized Myocardial Infarction",
                       "DIABP" = "Diastolic Blood Pressure",
                       "SYSBP" = "Systolic Blood Pressure",
                       "BMI" = "Body Mass Index",
                       "TOTCHOL" = "Total Cholesterol",
                       "HEARTRTE" = "Heart Rate",
                       "ANYCHD" = "Any Fatal Coronary Heart Disease",
                       "PREVMI" = "Prevalent Myocardial Infarction",
                       "PREVAP" = "Prevalent Angina",
                       "PERIOD" = "Examination Cycle",
                       "HYPERTEN" = "Hypertensive",
                       "PREVSTRK - Stroke History_No.Stroke.History" = "No Stroke History",
                       "Diabetic_Non.Diabetic" = "Non-diabetic",
                       "BPMEDS" = "Anti-hypertensive Meds",
                       "Gender_Male" = "Gender (Male)",
                       "STROKE" = "Stroke",
                       "DEATH" = "Death",
                       "AGE" = "Age",
                       "GLUCOSE" = "Glucose"
  ))
ggplot(log_reg_coefs, aes(x = reorder(term, estimate), y = estimate)) +
  geom_col(fill = "purple") +
  coord_flip() +
  labs(title = "Feature Importance - Logistic Regression (L1)", x = "Feature", y = "Coefficient Value") +
  theme_minimal() +
  theme(
    text = element_text(size = 12, family = "serif"),
    plot.title = element_text(hjust = 0.5)
  )

# Logistic Regression Top 5 Features
log_reg_importance <- tidy(extract_fit_parsnip(log_reg_fit)) %>%
  mutate(Importance = abs(estimate)) %>%
  arrange(desc(Importance)) %>%
  head(5) %>%  # Select top 5 features based on absolute coefficient values
  select(term, Importance)
print("Top 5 Features in Logistic Regression:")
## [1] "Top 5 Features in Logistic Regression:"
print(log_reg_importance)
## # A tibble: 5 × 2
##   term        Importance
##   <chr>            <dbl>
## 1 STROKE           16.9 
## 2 (Intercept)      10.6 
## 3 ANYCHD           10.5 
## 4 MI_FCHD           7.37
## 5 HOSPMI            4.48
# Feature importance for Decision Tree (xgboost) and Random Forest (xgboost) Models with Axis Labels and Centered Titles
decision_tree_importance <- xgb.importance(model = decision_tree_model)
random_forest_importance <- xgb.importance(model = random_forest_model)
# Decision Tree Top 5 Features using vip::vi()
tree_importance <- vip::vi(decision_tree_model) %>%
  arrange(desc(Importance)) %>%
  head(5)  # Select top 5 features
print("Top 5 Features in Decision Tree:")
## [1] "Top 5 Features in Decision Tree:"
print(tree_importance)  
## # A tibble: 5 × 2
##   Variable Importance
##   <chr>         <dbl>
## 1 MI_FCHD      0.324 
## 2 STROKE       0.296 
## 3 ANYCHD       0.265 
## 4 ANGINA       0.0398
## 5 BMI          0.0140

Random Forest Top 5 Features using vip::vi()

rf_importance <- vip::vi(random_forest_model$fit) %>%

arrange(desc(Importance)) %>%

head(5) # Select top 5 features

print(“Top 5 Features in Random Forest:”)

print(rf_importance)

# Custom ggplot2 plot for Decision Tree (xgboost) feature importance with Gain
decision_tree_importance_df <- as.data.frame(decision_tree_importance) %>%
  mutate(Feature = recode(Feature,
                          "ANGINA" = "Angina Pectoris",
                          "MI_FCHD" = "Myocardial Infarction Fatal Coronary Heart Disease",
                          "HOSPMI" = "Hospitalized Myocardial Infarction",
                          "DIABP" = "Diastolic Blood Pressure",
                          "SYSBP" = "Systolic Blood Pressure",
                          "BMI" = "Body Mass Index",
                          "TOTCHOL" = "Total Cholesterol",
                          "HEARTRTE" = "Heart Rate",
                          "ANYCHD" = "Any Fatal Coronary Heart Disease",
                          "PREVMI" = "Prevalent Myocardial Infarction",
                          "PREVAP" = "Prevalent Angina",
                          "PERIOD" = "Examination Cycle",
                          "HYPERTEN" = "Hypertensive",
                          "PREVSTRK - Stroke History_No.Stroke.History" = "No Stroke History",
                          "Diabetic_Non.Diabetic" = "Non-diabetic",
                          "BPMEDS" = "Anti-hypertensive Meds",
                          "Gender_Male" = "Gender (Male)",
                          "STROKE" = "Stroke",
                          "DEATH" = "Death",
                          "AGE" = "Age",
                          "GLUCOSE" = "Glucose"
  ))
ggplot(decision_tree_importance_df, aes(x = reorder(Feature, Gain), y = Gain)) +
  geom_col(fill = "purple") +
  coord_flip() +
  labs(title = "Feature Importance - Decision Tree (xgboost)", x = "Feature", y = "Gain") +
  theme_minimal() +
  theme(
    text = element_text(size = 12, family = "serif"),
    plot.title = element_text(hjust = 0.5)
  )

# Custom ggplot2 plot for Random Forest (xgboost) feature importance with Gain
random_forest_importance_df <- as.data.frame(random_forest_importance) %>%
  mutate(Feature = recode(Feature,
                          "ANGINA" = "Angina Pectoris",
                          "MI_FCHD" = "Myocardial Infarction Fatal Coronary Heart Disease",
                          "HOSPMI" = "Hospitalized Myocardial Infarction",
                          "DIABP" = "Diastolic Blood Pressure",
                          "SYSBP" = "Systolic Blood Pressure",
                          "BMI" = "Body Mass Index",
                          "TOTCHOL" = "Total Cholesterol",
                          "HEARTRTE" = "Heart Rate",
                          "ANYCHD" = "Any Fatal Coronary Heart Disease",
                          "PREVMI" = "Prevalent Myocardial Infarction",
                          "PREVAP" = "Prevalent Angina",
                          "PERIOD" = "Examination Cycle",
                          "HYPERTEN" = "Hypertensive",
                          "PREVSTRK - Stroke History_No.Stroke.History" = "No Stroke History",
                          "Diabetic_Non.Diabetic" = "Non-diabetic",
                          "BPMEDS" = "Anti-hypertensive Meds",
                          "Gender_Male" = "Gender (Male)",
                          "STROKE" = "Stroke",
                          "DEATH" = "Death",
                          "AGE" = "Age",
                          "GLUCOSE" = "Glucose"
  ))
ggplot(random_forest_importance_df, aes(x = reorder(Feature, Gain), y = Gain)) +
  geom_col(fill = "purple") +
  coord_flip() +
  labs(title = "Feature Importance - Random Forest (xgboost)", x = "Feature", y = "Gain") +
  theme_minimal() +
  theme(
    text = element_text(size = 12, family = "serif"),
    plot.title = element_text(hjust = 0.5)
  )

# Random Forest Top 5 Features # fitted_model <- extract_fit_parsnip(random_forest_model) # rf_importance <- vip::vi(rf_fitted_model$fit) %>% # arrange(desc(Importance)) %>% # head(5) # Select top 5 features # print(“Top 5 Features in Random Forest:”) # print(rf_importance)

# Display Feature Names for Each Model

# Logistic Regression Feature Names
log_reg_features <- log_reg_coefs %>%
  select(term) %>%
  rename("Logistic Regression Features" = term)
print("Logistic Regression Features:")
## [1] "Logistic Regression Features:"
print.data.frame(log_reg_features, row.names = FALSE)
##                                     Logistic Regression Features
##                                                           Stroke
##                                 Any Fatal Coronary Heart Disease
##               Myocardial Infarction Fatal Coronary Heart Disease
##  `PREVMI - Myocardial Infarction`Prevalent myocardial infarction
##                         `Education Level`High School Diploma/GED
##                                                       GenderMale
##      `PREVAP - Angina Pectoris`Prevalent disease angina pectoris
##                                           Anti-hypertensive Meds
##                                                     Hypertensive
##                   `PREVHYP - Hypertension`Prevalent hypertension
##                                             DiabeticNon Diabetic
##                                                              Age
##                                                          Glucose
##                                                Total Cholesterol
##                                                  Body Mass Index
##                           `Cigarettes/per day`Not current smoker
##                     `PREVSTRK - Stroke History`No stroke history
##                                          Systolic Blood Pressure
##                                                       Heart Rate
##                                         Diastolic Blood Pressure
##                                                     SmokerSmoker
##               `Education Level`College degree (BA, BS) or higher
##                                                Examination Cycle
##                                                            Death
##                  `Education Level`Some College/Vocational School
##           `PREVCHD - Coronary Disease`Prevalent coronary disease
##                                                  Angina Pectoris
##                               Hospitalized Myocardial Infarction
##                                                      (Intercept)
# Decision Tree Feature Names
decision_tree_features <- decision_tree_importance_df %>%
  select(Feature) %>%
  rename("Decision Tree Features" = Feature)
print("Decision Tree (xgboost) Features:")
## [1] "Decision Tree (xgboost) Features:"
print.data.frame(decision_tree_features, row.names = FALSE)
##                                          Decision Tree Features
##              Myocardial Infarction Fatal Coronary Heart Disease
##                                                          Stroke
##                                Any Fatal Coronary Heart Disease
##                                                 Angina Pectoris
##                                                 Body Mass Index
##                                                         Glucose
##                                        Diastolic Blood Pressure
##                                               Total Cholesterol
##                                                             Age
##                                                      Heart Rate
##                                                           Death
##                         Education Level_High.School.Diploma.GED
##                                         Systolic Blood Pressure
##                              Hospitalized Myocardial Infarction
##                  Education Level_Some.College.Vocational.School
##  PREVMI - Myocardial Infarction_Prevalent.myocardial.infarction
##           PREVCHD - Coronary Disease_Prevalent.coronary.disease
##                                                   Gender (Male)
##                                               Examination Cycle
##                   PREVHYP - Hypertension_Prevalent.hypertension
##               Education Level_College.degree..BA..BS..or.higher
##                                                    Hypertensive
##                                                   Smoker_Smoker
##      PREVAP - Angina Pectoris_Prevalent.disease.angina.pectoris
##                                          Anti-hypertensive Meds
# Random Forest Feature Names
random_forest_features <- random_forest_importance_df %>%
  select(Feature) %>%
  rename("Random Forest Features" = Feature)
print("Random Forest (xgboost) Features:")
## [1] "Random Forest (xgboost) Features:"
print.data.frame(random_forest_features, row.names = FALSE)
##                                          Random Forest Features
##                                                          Stroke
##                                Any Fatal Coronary Heart Disease
##              Myocardial Infarction Fatal Coronary Heart Disease
##                                                 Angina Pectoris
##                              Hospitalized Myocardial Infarction
##                                                 Body Mass Index
##                                               Total Cholesterol
##                                                         Glucose
##                                         Systolic Blood Pressure
##                                                             Age
##                                                      Heart Rate
##                                        Diastolic Blood Pressure
##                                                           Death
##  PREVMI - Myocardial Infarction_Prevalent.myocardial.infarction
##                  Education Level_Some.College.Vocational.School
##                         Education Level_High.School.Diploma.GED
##                     PREVSTRK - Stroke History_No.stroke.history
##                                                   Gender (Male)
##                                                    Hypertensive
##           PREVCHD - Coronary Disease_Prevalent.coronary.disease
##                                               Examination Cycle
##                   PREVHYP - Hypertension_Prevalent.hypertension
##      PREVAP - Angina Pectoris_Prevalent.disease.angina.pectoris
##                                                   Smoker_Smoker
##               Education Level_College.degree..BA..BS..or.higher
##                                          Anti-hypertensive Meds
##                                                    Non-diabetic
##                           Cigarettes/per day_Not.current.smoker
# Individual Decision Tree Model for Visualization (Not Part of Ensemble)
# Setting up a simpler decision tree for structure
individual_tree_spec <- decision_tree(
  cost_complexity = tune(),
  tree_depth = 3
) %>%
  set_engine("rpart", model = TRUE) %>%
  set_mode("classification")

individual_tree_wf <- workflow() %>%
  add_model(individual_tree_spec) %>%
  add_formula(CVD ~ .)

# Cross-validation and tuning for the individual decision tree model
individual_tree_res <- tune_grid(
  individual_tree_wf,
  resamples = folds,
  grid = 10
)
# Selecting the best decision tree based on ROC AUC
best_individual_tree <- select_best(individual_tree_res, metric = "roc_auc")
final_individual_tree_wf <- finalize_workflow(individual_tree_wf, best_individual_tree)
individual_tree_fit <- fit(final_individual_tree_wf, data = train_data)
# Add ROC Curve plots for each model and ensemble on both Training and Test sets

# Function to generate and plot ROC curves with customized styling
generate_roc_plot <- function(truth, predicted_probs, model_name, data_split) {
  roc_obj <- roc(truth, predicted_probs)
  ggroc(roc_obj, color = "purple", size = 1.2) +
    geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "gray") +
    labs(title = paste("ROC Curve -", model_name, "(", data_split, ")"), 
         x = "False Positive Rate", 
         y = "True Positive Rate") +
    theme_minimal() +
    theme(
      text = element_text(size = 12, family = "Times New Roman"),
      plot.title = element_text(hjust = 0.5)
    )
}
# List of ROC plots for both Training and Test Set predictions
roc_plots <- list(
  # Logistic Regression
  generate_roc_plot(train_data$CVD, log_reg_train_preds, "Logistic Regression", "Training Set"),
  generate_roc_plot(test_data$CVD, log_reg_test_preds, "Logistic Regression", "Test Set"),
  
  # Decision Tree
  generate_roc_plot(train_data$CVD, decision_tree_train_preds, "Decision Tree (xgboost)", "Training Set"),
  generate_roc_plot(test_data$CVD, decision_tree_test_preds, "Decision Tree (xgboost)", "Test Set"),
  
  # Random Forest
  generate_roc_plot(train_data$CVD, random_forest_train_preds, "Random Forest (xgboost)", "Training Set"),
  generate_roc_plot(test_data$CVD, random_forest_test_preds, "Random Forest (xgboost)", "Test Set"),
  
  # Ensembles on Test Set
  generate_roc_plot(test_data$CVD, (log_reg_test_preds + decision_tree_test_preds) / 2, "Logistic + Decision Tree Ensemble", "Test Set"),
  generate_roc_plot(test_data$CVD, (log_reg_test_preds + random_forest_test_preds) / 2, "Logistic + Random Forest Ensemble", "Test Set"),
  generate_roc_plot(test_data$CVD, (decision_tree_test_preds + random_forest_test_preds) / 2, "Decision Tree + Random Forest Ensemble", "Test Set"),
  generate_roc_plot(test_data$CVD, (log_reg_test_preds + decision_tree_test_preds + random_forest_test_preds) / 3, "Three-Model Ensemble", "Test Set")
)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
# Display each ROC plot with a 4-second delay
for (roc_plot in roc_plots) {
  print(roc_plot)
  Sys.sleep(2) # 2-second delay after each plot
}
## Warning in grid.Call(C_stringMetric, as.graphicsAnnot(x$label)): font family
## not found in Windows font database
## Warning in grid.Call(C_stringMetric, as.graphicsAnnot(x$label)): font family
## not found in Windows font database
## Warning in grid.Call(C_stringMetric, as.graphicsAnnot(x$label)): font family
## not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database

##############################################################################
# Feature Engineering code - added no value to the models. Actually reduced the
# models' performance.
##############################################################################