# Load required libraries
library(readr)
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom 1.0.7 ✔ recipes 1.1.0
## ✔ dials 1.3.0 ✔ rsample 1.2.1
## ✔ dplyr 1.1.4 ✔ tibble 3.2.1
## ✔ ggplot2 3.5.1 ✔ tidyr 1.3.1
## ✔ infer 1.0.7 ✔ tune 1.2.1
## ✔ modeldata 1.4.0 ✔ workflows 1.1.4
## ✔ parsnip 1.2.1 ✔ workflowsets 1.1.0
## ✔ purrr 1.0.2 ✔ yardstick 1.3.1
## Warning: package 'dplyr' was built under R version 4.4.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Learn how to get started at https://www.tidymodels.org/start/
library(xgboost)
##
## Attaching package: 'xgboost'
## The following object is masked from 'package:dplyr':
##
## slice
library(glmnet)
## Loading required package: Matrix
##
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
##
## expand, pack, unpack
## Loaded glmnet 4.1-8
library(dplyr)
library(ggplot2)
library(vip)
##
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
##
## vi
library(broom)
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
library(yardstick)
library(purrr)
library(tidyr)
library(hardhat)
library(magrittr) # Ensure the pipe operator %>% is available
##
## Attaching package: 'magrittr'
## The following object is masked from 'package:tidyr':
##
## extract
## The following object is masked from 'package:purrr':
##
## set_names
library(rpart)
##
## Attaching package: 'rpart'
## The following object is masked from 'package:dials':
##
## prune
library(rpart.plot) # For visualizing individual decision tree structure
library(vip)
# Loading training and test datasets
train_data <- read_csv("C:/Users/JOYCE/OneDrive/Desktop/FHS/FHS-Cleaned-V2/train.csv")
## Rows: 6115 Columns: 27
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (10): Gender, Education Level, Smoker, Cigarettes/per day, Diabetic, PRE...
## dbl (17): TOTCHOL, AGE, SYSBP, DIABP, BMI, BPMEDS, HEARTRTE, GLUCOSE, PERIOD...
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
test_data <- read_csv("C:/Users/JOYCE/OneDrive/Desktop/FHS/FHS-Cleaned-V2/test.csv")
## Rows: 2039 Columns: 27
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (10): Gender, Education Level, Smoker, Cigarettes/per day, Diabetic, PRE...
## dbl (17): TOTCHOL, AGE, SYSBP, DIABP, BMI, BPMEDS, HEARTRTE, GLUCOSE, PERIOD...
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
# Define a simple decision tree with limited depth for visualization purposes
simple_tree <- rpart(CVD ~ ., data = train_data, method = "class", control = rpart.control(maxdepth = 5, cp = 0.01))
# Plot the decision tree
rpart.plot(simple_tree, main="Decision Tree for Cardiovascular Disease Prediction (Depth = 5)",
type = 3, extra = 4, under = TRUE, cex = 0.7, box.palette = "RdBu", shadow.col = "gray", nn = TRUE)
# Convert categorical columns to factors in `train_data` and `test_data`
categorical_cols <- c("Gender", "Education Level", "Smoker", "Cigarettes/per day",
"Diabetic", "PREVCHD - Coronary Disease", "PREVAP - Angina Pectoris",
"PREVMI - Myocardial Infarction", "PREVSTRK - Stroke History",
"PREVHYP - Hypertension")
train_data <- train_data %>%
mutate(CVD = as.factor(CVD), across(all_of(categorical_cols), as.factor))
test_data <- test_data %>%
mutate(CVD = as.factor(CVD), across(all_of(categorical_cols), as.factor))
# Convert categorical variables to dummy variables and normalize for xgboost
train_data_numeric <- train_data %>%
select(-CVD) %>%
recipe(~ .) %>%
step_dummy(all_nominal(), -all_outcomes()) %>%
step_normalize(all_predictors()) %>% # Add normalization step
prep() %>%
bake(new_data = NULL) %>%
as.matrix()
test_data_numeric <- test_data %>%
select(-CVD) %>%
recipe(~ .) %>%
step_dummy(all_nominal(), -all_outcomes()) %>%
step_normalize(all_predictors()) %>% # Add normalization step
prep() %>%
bake(new_data = NULL) %>%
as.matrix()
# Create the xgboost DMatrix
train_matrix <- xgb.DMatrix(data = train_data_numeric, label = as.numeric(train_data$CVD) - 1)
test_matrix <- xgb.DMatrix(data = test_data_numeric, label = as.numeric(test_data$CVD) - 1)
# Logistic Regression with L1 Regularization
set.seed(123)
log_reg_spec <- logistic_reg(
penalty = tune(),
mixture = 1 # Set mixture to 1 for pure L1 regularization (lasso)
) %>%
set_engine("glmnet") %>%
set_mode("classification")
log_reg_wf <- workflow() %>%
add_model(log_reg_spec) %>%
add_formula(CVD ~ .)
# Cross-validation setup and tuning for logistic regression
folds <- vfold_cv(train_data, v = 5)
log_reg_params <- extract_parameter_set_dials(log_reg_wf) %>%
update(
penalty = penalty(range = c(-6, 1)) # Range for penalty strength in logistic regression
)
log_reg_res <- tune_grid(
log_reg_wf,
resamples = folds,
param_info = log_reg_params,
grid = 10,
control = control_grid()
)
# Extract best parameters and fit the final logistic regression model
best_log_reg <- select_best(log_reg_res, metric = "roc_auc")
final_log_reg_wf <- finalize_workflow(log_reg_wf, best_log_reg)
log_reg_fit <- fit(final_log_reg_wf, data = train_data)
# Decision Tree (xgboost) Model with Adjusted Parameters for Complexity Control
decision_tree_params <- list(
booster = "gbtree",
eta = 0.3, # Lower learning rate for better generalization
max_depth = 3, # Reduced depth to limit tree complexity
min_child_weight = 5, # Increase to avoid overly specific splits
alpha = 0.1, # L1 regularization term
lambda = 0.1, # L2 regularization term
objective = "binary:logistic",
nrounds = 100 # Increased rounds to compensate for lower eta
)
decision_tree_model <- xgboost(
params = decision_tree_params,
data = train_matrix,
nrounds = decision_tree_params$nrounds,
verbose = 0
)
## [22:58:53] WARNING: src/learner.cc:767:
## Parameters: { "nrounds" } are not used.
# Random Forest (xgboost) Model with L1 Regularization
random_forest_params <- list(
booster = "gbtree",
eta = 0.1,
max_depth = 6,
alpha = 0.1,
lambda = 0,
subsample = 0.8,
colsample_bytree = 0.8,
objective = "binary:logistic",
nrounds = 500
)
random_forest_model <- xgboost(
params = random_forest_params,
data = train_matrix,
nrounds = random_forest_params$nrounds,
verbose = 0
)
## [22:58:54] WARNING: src/learner.cc:767:
## Parameters: { "nrounds" } are not used.
# Predictions for Training and Test Sets
log_reg_train_preds <- predict(log_reg_fit, train_data, type = "prob") %>% pull(.pred_1)
log_reg_test_preds <- predict(log_reg_fit, test_data, type = "prob") %>% pull(.pred_1)
decision_tree_train_preds <- predict(decision_tree_model, newdata = train_matrix)
decision_tree_test_preds <- predict(decision_tree_model, newdata = test_matrix)
random_forest_train_preds <- predict(random_forest_model, newdata = train_matrix)
random_forest_test_preds <- predict(random_forest_model, newdata = test_matrix)
# Convert predictions to binary classes for Training and Test Sets
log_reg_train_classes <- factor(ifelse(log_reg_train_preds > 0.5, "1", "0"))
log_reg_test_classes <- factor(ifelse(log_reg_test_preds > 0.5, "1", "0"))
decision_tree_train_classes <- factor(ifelse(decision_tree_train_preds > 0.5, "1", "0"))
decision_tree_test_classes <- factor(ifelse(decision_tree_test_preds > 0.5, "1", "0"))
random_forest_train_classes <- factor(ifelse(random_forest_train_preds > 0.5, "1", "0"))
random_forest_test_classes <- factor(ifelse(random_forest_test_preds > 0.5, "1", "0"))
# Performance metrics for each model
individual_model_results <- bind_rows(
tibble(Model = "Logistic Regression",
Accuracy = accuracy_vec(test_data$CVD, log_reg_test_classes),
Precision = precision_vec(test_data$CVD, log_reg_test_classes),
Recall = recall_vec(test_data$CVD, log_reg_test_classes),
Specificity = spec_vec(test_data$CVD, log_reg_test_classes),
F1 = f_meas_vec(test_data$CVD, log_reg_test_classes)),
tibble(Model = "Decision Tree (xgboost)",
Accuracy = accuracy_vec(test_data$CVD, decision_tree_test_classes),
Precision = precision_vec(test_data$CVD, decision_tree_test_classes),
Recall = recall_vec(test_data$CVD, decision_tree_test_classes),
Specificity = spec_vec(test_data$CVD, decision_tree_test_classes),
F1 = f_meas_vec(test_data$CVD, decision_tree_test_classes)),
tibble(Model = "Random Forest (xgboost)",
Accuracy = accuracy_vec(test_data$CVD, random_forest_test_classes),
Precision = precision_vec(test_data$CVD, random_forest_test_classes),
Recall = recall_vec(test_data$CVD, random_forest_test_classes),
Specificity = spec_vec(test_data$CVD, random_forest_test_classes),
F1 = f_meas_vec(test_data$CVD, random_forest_test_classes))
)
# Convert ensemble predictions to factors for metrics calculations
ensemble_lr_dt_classes <- factor(ifelse((log_reg_test_preds + decision_tree_test_preds) / 2 > 0.5, "1", "0"))
ensemble_lr_rf_classes <- factor(ifelse((log_reg_test_preds + random_forest_test_preds) / 2 > 0.5, "1", "0"))
ensemble_dt_rf_classes <- factor(ifelse((decision_tree_test_preds + random_forest_test_preds) / 2 > 0.5, "1", "0"))
ensemble_three_classes <- factor(ifelse((log_reg_test_preds + decision_tree_test_preds + random_forest_test_preds) / 3 > 0.5, "1", "0"))
# Ensemble performance metrics
ensemble_results <- bind_rows(
tibble(Model = "Logistic + Decision Tree (xgboost) Ensemble",
Accuracy = accuracy_vec(test_data$CVD, ensemble_lr_dt_classes),
Precision = precision_vec(test_data$CVD, ensemble_lr_dt_classes),
Recall = recall_vec(test_data$CVD, ensemble_lr_dt_classes),
Specificity = spec_vec(test_data$CVD, ensemble_lr_dt_classes),
F1 = f_meas_vec(test_data$CVD, ensemble_lr_dt_classes)),
tibble(Model = "Logistic + Random Forest (xgboost) Ensemble",
Accuracy = accuracy_vec(test_data$CVD, ensemble_lr_rf_classes),
Precision = precision_vec(test_data$CVD, ensemble_lr_rf_classes),
Recall = recall_vec(test_data$CVD, ensemble_lr_rf_classes),
Specificity = spec_vec(test_data$CVD, ensemble_lr_rf_classes),
F1 = f_meas_vec(test_data$CVD, ensemble_lr_rf_classes)),
tibble(Model = "Decision Tree + Random Forest (xgboost) Ensemble",
Accuracy = accuracy_vec(test_data$CVD, ensemble_dt_rf_classes),
Precision = precision_vec(test_data$CVD, ensemble_dt_rf_classes),
Recall = recall_vec(test_data$CVD, ensemble_dt_rf_classes),
Specificity = spec_vec(test_data$CVD, ensemble_dt_rf_classes),
F1 = f_meas_vec(test_data$CVD, ensemble_dt_rf_classes)),
tibble(Model = "Three-Model Ensemble",
Accuracy = accuracy_vec(test_data$CVD, ensemble_three_classes),
Precision = precision_vec(test_data$CVD, ensemble_three_classes),
Recall = recall_vec(test_data$CVD, ensemble_three_classes),
Specificity = spec_vec(test_data$CVD, ensemble_three_classes),
F1 = f_meas_vec(test_data$CVD, ensemble_three_classes))
)
# Print performance metrics
print("Individual Model Performance Metrics:")
## [1] "Individual Model Performance Metrics:"
print(individual_model_results)
## # A tibble: 3 × 6
## Model Accuracy Precision Recall Specificity F1
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 Logistic Regression 0.972 0.977 0.986 0.926 0.981
## 2 Decision Tree (xgboost) 0.973 0.982 0.982 0.943 0.982
## 3 Random Forest (xgboost) 0.970 0.981 0.98 0.939 0.980
print("Ensemble Model Performance Metrics:")
## [1] "Ensemble Model Performance Metrics:"
print(ensemble_results)
## # A tibble: 4 × 6
## Model Accuracy Precision Recall Specificity F1
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 Logistic + Decision Tree (xgboost… 0.971 0.978 0.984 0.928 0.981
## 2 Logistic + Random Forest (xgboost… 0.970 0.979 0.981 0.933 0.980
## 3 Decision Tree + Random Forest (xg… 0.973 0.982 0.982 0.943 0.982
## 4 Three-Model Ensemble 0.973 0.981 0.983 0.939 0.982
# Define function to generate and plot confusion matrices for Training, Test, and Ensemble Sets with purple shades
generate_confusion_matrix <- function(truth, estimate, model_name, data_split) {
conf_matrix <- conf_mat(tibble(truth = truth, estimate = estimate), truth, estimate)
autoplot(conf_matrix, type = "heatmap") +
labs(title = paste("Confusion Matrix -", model_name, "(", data_split, ")")) +
scale_fill_gradient(low = "lavender", high = "purple") +
theme_minimal() +
theme(
text = element_text(size = 12, family = "serif"),
plot.title = element_text(hjust = 0.5)
)
}
# Generate and print confusion matrices for Training and Test Sets for each model and ensemble with 4-second delay
confusion_matrices <- list(
# Logistic Regression
generate_confusion_matrix(train_data$CVD, log_reg_train_classes, "Logistic Regression", "Training Set"),
generate_confusion_matrix(test_data$CVD, log_reg_test_classes, "Logistic Regression", "Test Set"),
# Decision Tree
generate_confusion_matrix(train_data$CVD, decision_tree_train_classes, "Decision Tree (xgboost)", "Training Set"),
generate_confusion_matrix(test_data$CVD, decision_tree_test_classes, "Decision Tree (xgboost)", "Test Set"),
# Random Forest
generate_confusion_matrix(train_data$CVD, random_forest_train_classes, "Random Forest (xgboost)", "Training Set"),
generate_confusion_matrix(test_data$CVD, random_forest_test_classes, "Random Forest (xgboost)", "Test Set"),
# Ensembles on Test Set
generate_confusion_matrix(test_data$CVD, ensemble_lr_dt_classes, "Logistic + Decision Tree Ensemble", "Test Set"),
generate_confusion_matrix(test_data$CVD, ensemble_lr_rf_classes, "Logistic + Random Forest Ensemble", "Test Set"),
generate_confusion_matrix(test_data$CVD, ensemble_dt_rf_classes, "Decision Tree + Random Forest Ensemble", "Test Set"),
generate_confusion_matrix(test_data$CVD, ensemble_three_classes, "Three-Model Ensemble", "Test Set")
)
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
for (conf_mat_plot in confusion_matrices) {
print(conf_mat_plot)
Sys.sleep(2) # 2-second delay after each plot
}
# Custom ggplot2 plot for Logistic Regression feature importance with Coefficient Value
log_reg_coefs <- tidy(extract_fit_parsnip(log_reg_fit)) %>%
arrange(desc(estimate)) %>%
mutate(term = recode(term,
"ANGINA" = "Angina Pectoris",
"MI_FCHD" = "Myocardial Infarction Fatal Coronary Heart Disease",
"HOSPMI" = "Hospitalized Myocardial Infarction",
"DIABP" = "Diastolic Blood Pressure",
"SYSBP" = "Systolic Blood Pressure",
"BMI" = "Body Mass Index",
"TOTCHOL" = "Total Cholesterol",
"HEARTRTE" = "Heart Rate",
"ANYCHD" = "Any Fatal Coronary Heart Disease",
"PREVMI" = "Prevalent Myocardial Infarction",
"PREVAP" = "Prevalent Angina",
"PERIOD" = "Examination Cycle",
"HYPERTEN" = "Hypertensive",
"PREVSTRK - Stroke History_No.Stroke.History" = "No Stroke History",
"Diabetic_Non.Diabetic" = "Non-diabetic",
"BPMEDS" = "Anti-hypertensive Meds",
"Gender_Male" = "Gender (Male)",
"STROKE" = "Stroke",
"DEATH" = "Death",
"AGE" = "Age",
"GLUCOSE" = "Glucose"
))
ggplot(log_reg_coefs, aes(x = reorder(term, estimate), y = estimate)) +
geom_col(fill = "purple") +
coord_flip() +
labs(title = "Feature Importance - Logistic Regression (L1)", x = "Feature", y = "Coefficient Value") +
theme_minimal() +
theme(
text = element_text(size = 12, family = "serif"),
plot.title = element_text(hjust = 0.5)
)
# Logistic Regression Top 5 Features
log_reg_importance <- tidy(extract_fit_parsnip(log_reg_fit)) %>%
mutate(Importance = abs(estimate)) %>%
arrange(desc(Importance)) %>%
head(5) %>% # Select top 5 features based on absolute coefficient values
select(term, Importance)
print("Top 5 Features in Logistic Regression:")
## [1] "Top 5 Features in Logistic Regression:"
print(log_reg_importance)
## # A tibble: 5 × 2
## term Importance
## <chr> <dbl>
## 1 STROKE 16.9
## 2 (Intercept) 10.6
## 3 ANYCHD 10.5
## 4 MI_FCHD 7.37
## 5 HOSPMI 4.48
# Feature importance for Decision Tree (xgboost) and Random Forest (xgboost) Models with Axis Labels and Centered Titles
decision_tree_importance <- xgb.importance(model = decision_tree_model)
random_forest_importance <- xgb.importance(model = random_forest_model)
# Decision Tree Top 5 Features using vip::vi()
tree_importance <- vip::vi(decision_tree_model) %>%
arrange(desc(Importance)) %>%
head(5) # Select top 5 features
print("Top 5 Features in Decision Tree:")
## [1] "Top 5 Features in Decision Tree:"
print(tree_importance)
## # A tibble: 5 × 2
## Variable Importance
## <chr> <dbl>
## 1 MI_FCHD 0.324
## 2 STROKE 0.296
## 3 ANYCHD 0.265
## 4 ANGINA 0.0398
## 5 BMI 0.0140
vip::vi()
# Custom ggplot2 plot for Decision Tree (xgboost) feature importance with Gain
decision_tree_importance_df <- as.data.frame(decision_tree_importance) %>%
mutate(Feature = recode(Feature,
"ANGINA" = "Angina Pectoris",
"MI_FCHD" = "Myocardial Infarction Fatal Coronary Heart Disease",
"HOSPMI" = "Hospitalized Myocardial Infarction",
"DIABP" = "Diastolic Blood Pressure",
"SYSBP" = "Systolic Blood Pressure",
"BMI" = "Body Mass Index",
"TOTCHOL" = "Total Cholesterol",
"HEARTRTE" = "Heart Rate",
"ANYCHD" = "Any Fatal Coronary Heart Disease",
"PREVMI" = "Prevalent Myocardial Infarction",
"PREVAP" = "Prevalent Angina",
"PERIOD" = "Examination Cycle",
"HYPERTEN" = "Hypertensive",
"PREVSTRK - Stroke History_No.Stroke.History" = "No Stroke History",
"Diabetic_Non.Diabetic" = "Non-diabetic",
"BPMEDS" = "Anti-hypertensive Meds",
"Gender_Male" = "Gender (Male)",
"STROKE" = "Stroke",
"DEATH" = "Death",
"AGE" = "Age",
"GLUCOSE" = "Glucose"
))
ggplot(decision_tree_importance_df, aes(x = reorder(Feature, Gain), y = Gain)) +
geom_col(fill = "purple") +
coord_flip() +
labs(title = "Feature Importance - Decision Tree (xgboost)", x = "Feature", y = "Gain") +
theme_minimal() +
theme(
text = element_text(size = 12, family = "serif"),
plot.title = element_text(hjust = 0.5)
)
# Custom ggplot2 plot for Random Forest (xgboost) feature importance with Gain
random_forest_importance_df <- as.data.frame(random_forest_importance) %>%
mutate(Feature = recode(Feature,
"ANGINA" = "Angina Pectoris",
"MI_FCHD" = "Myocardial Infarction Fatal Coronary Heart Disease",
"HOSPMI" = "Hospitalized Myocardial Infarction",
"DIABP" = "Diastolic Blood Pressure",
"SYSBP" = "Systolic Blood Pressure",
"BMI" = "Body Mass Index",
"TOTCHOL" = "Total Cholesterol",
"HEARTRTE" = "Heart Rate",
"ANYCHD" = "Any Fatal Coronary Heart Disease",
"PREVMI" = "Prevalent Myocardial Infarction",
"PREVAP" = "Prevalent Angina",
"PERIOD" = "Examination Cycle",
"HYPERTEN" = "Hypertensive",
"PREVSTRK - Stroke History_No.Stroke.History" = "No Stroke History",
"Diabetic_Non.Diabetic" = "Non-diabetic",
"BPMEDS" = "Anti-hypertensive Meds",
"Gender_Male" = "Gender (Male)",
"STROKE" = "Stroke",
"DEATH" = "Death",
"AGE" = "Age",
"GLUCOSE" = "Glucose"
))
ggplot(random_forest_importance_df, aes(x = reorder(Feature, Gain), y = Gain)) +
geom_col(fill = "purple") +
coord_flip() +
labs(title = "Feature Importance - Random Forest (xgboost)", x = "Feature", y = "Gain") +
theme_minimal() +
theme(
text = element_text(size = 12, family = "serif"),
plot.title = element_text(hjust = 0.5)
)
# Random Forest Top 5 Features # fitted_model <-
extract_fit_parsnip(random_forest_model) # rf_importance <-
vip::vi(rf_fitted_model$fit) %>% # arrange(desc(Importance)) %>% #
head(5) # Select top 5 features # print(“Top 5 Features in Random
Forest:”) # print(rf_importance)
# Display Feature Names for Each Model
# Logistic Regression Feature Names
log_reg_features <- log_reg_coefs %>%
select(term) %>%
rename("Logistic Regression Features" = term)
print("Logistic Regression Features:")
## [1] "Logistic Regression Features:"
print.data.frame(log_reg_features, row.names = FALSE)
## Logistic Regression Features
## Stroke
## Any Fatal Coronary Heart Disease
## Myocardial Infarction Fatal Coronary Heart Disease
## `PREVMI - Myocardial Infarction`Prevalent myocardial infarction
## `Education Level`High School Diploma/GED
## GenderMale
## `PREVAP - Angina Pectoris`Prevalent disease angina pectoris
## Anti-hypertensive Meds
## Hypertensive
## `PREVHYP - Hypertension`Prevalent hypertension
## DiabeticNon Diabetic
## Age
## Glucose
## Total Cholesterol
## Body Mass Index
## `Cigarettes/per day`Not current smoker
## `PREVSTRK - Stroke History`No stroke history
## Systolic Blood Pressure
## Heart Rate
## Diastolic Blood Pressure
## SmokerSmoker
## `Education Level`College degree (BA, BS) or higher
## Examination Cycle
## Death
## `Education Level`Some College/Vocational School
## `PREVCHD - Coronary Disease`Prevalent coronary disease
## Angina Pectoris
## Hospitalized Myocardial Infarction
## (Intercept)
# Decision Tree Feature Names
decision_tree_features <- decision_tree_importance_df %>%
select(Feature) %>%
rename("Decision Tree Features" = Feature)
print("Decision Tree (xgboost) Features:")
## [1] "Decision Tree (xgboost) Features:"
print.data.frame(decision_tree_features, row.names = FALSE)
## Decision Tree Features
## Myocardial Infarction Fatal Coronary Heart Disease
## Stroke
## Any Fatal Coronary Heart Disease
## Angina Pectoris
## Body Mass Index
## Glucose
## Diastolic Blood Pressure
## Total Cholesterol
## Age
## Heart Rate
## Death
## Education Level_High.School.Diploma.GED
## Systolic Blood Pressure
## Hospitalized Myocardial Infarction
## Education Level_Some.College.Vocational.School
## PREVMI - Myocardial Infarction_Prevalent.myocardial.infarction
## PREVCHD - Coronary Disease_Prevalent.coronary.disease
## Gender (Male)
## Examination Cycle
## PREVHYP - Hypertension_Prevalent.hypertension
## Education Level_College.degree..BA..BS..or.higher
## Hypertensive
## Smoker_Smoker
## PREVAP - Angina Pectoris_Prevalent.disease.angina.pectoris
## Anti-hypertensive Meds
# Random Forest Feature Names
random_forest_features <- random_forest_importance_df %>%
select(Feature) %>%
rename("Random Forest Features" = Feature)
print("Random Forest (xgboost) Features:")
## [1] "Random Forest (xgboost) Features:"
print.data.frame(random_forest_features, row.names = FALSE)
## Random Forest Features
## Stroke
## Any Fatal Coronary Heart Disease
## Myocardial Infarction Fatal Coronary Heart Disease
## Angina Pectoris
## Hospitalized Myocardial Infarction
## Body Mass Index
## Total Cholesterol
## Glucose
## Systolic Blood Pressure
## Age
## Heart Rate
## Diastolic Blood Pressure
## Death
## PREVMI - Myocardial Infarction_Prevalent.myocardial.infarction
## Education Level_Some.College.Vocational.School
## Education Level_High.School.Diploma.GED
## PREVSTRK - Stroke History_No.stroke.history
## Gender (Male)
## Hypertensive
## PREVCHD - Coronary Disease_Prevalent.coronary.disease
## Examination Cycle
## PREVHYP - Hypertension_Prevalent.hypertension
## PREVAP - Angina Pectoris_Prevalent.disease.angina.pectoris
## Smoker_Smoker
## Education Level_College.degree..BA..BS..or.higher
## Anti-hypertensive Meds
## Non-diabetic
## Cigarettes/per day_Not.current.smoker
# Individual Decision Tree Model for Visualization (Not Part of Ensemble)
# Setting up a simpler decision tree for structure
individual_tree_spec <- decision_tree(
cost_complexity = tune(),
tree_depth = 3
) %>%
set_engine("rpart", model = TRUE) %>%
set_mode("classification")
individual_tree_wf <- workflow() %>%
add_model(individual_tree_spec) %>%
add_formula(CVD ~ .)
# Cross-validation and tuning for the individual decision tree model
individual_tree_res <- tune_grid(
individual_tree_wf,
resamples = folds,
grid = 10
)
# Selecting the best decision tree based on ROC AUC
best_individual_tree <- select_best(individual_tree_res, metric = "roc_auc")
final_individual_tree_wf <- finalize_workflow(individual_tree_wf, best_individual_tree)
individual_tree_fit <- fit(final_individual_tree_wf, data = train_data)
# Add ROC Curve plots for each model and ensemble on both Training and Test sets
# Function to generate and plot ROC curves with customized styling
generate_roc_plot <- function(truth, predicted_probs, model_name, data_split) {
roc_obj <- roc(truth, predicted_probs)
ggroc(roc_obj, color = "purple", size = 1.2) +
geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "gray") +
labs(title = paste("ROC Curve -", model_name, "(", data_split, ")"),
x = "False Positive Rate",
y = "True Positive Rate") +
theme_minimal() +
theme(
text = element_text(size = 12, family = "Times New Roman"),
plot.title = element_text(hjust = 0.5)
)
}
# List of ROC plots for both Training and Test Set predictions
roc_plots <- list(
# Logistic Regression
generate_roc_plot(train_data$CVD, log_reg_train_preds, "Logistic Regression", "Training Set"),
generate_roc_plot(test_data$CVD, log_reg_test_preds, "Logistic Regression", "Test Set"),
# Decision Tree
generate_roc_plot(train_data$CVD, decision_tree_train_preds, "Decision Tree (xgboost)", "Training Set"),
generate_roc_plot(test_data$CVD, decision_tree_test_preds, "Decision Tree (xgboost)", "Test Set"),
# Random Forest
generate_roc_plot(train_data$CVD, random_forest_train_preds, "Random Forest (xgboost)", "Training Set"),
generate_roc_plot(test_data$CVD, random_forest_test_preds, "Random Forest (xgboost)", "Test Set"),
# Ensembles on Test Set
generate_roc_plot(test_data$CVD, (log_reg_test_preds + decision_tree_test_preds) / 2, "Logistic + Decision Tree Ensemble", "Test Set"),
generate_roc_plot(test_data$CVD, (log_reg_test_preds + random_forest_test_preds) / 2, "Logistic + Random Forest Ensemble", "Test Set"),
generate_roc_plot(test_data$CVD, (decision_tree_test_preds + random_forest_test_preds) / 2, "Decision Tree + Random Forest Ensemble", "Test Set"),
generate_roc_plot(test_data$CVD, (log_reg_test_preds + decision_tree_test_preds + random_forest_test_preds) / 3, "Three-Model Ensemble", "Test Set")
)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
# Display each ROC plot with a 4-second delay
for (roc_plot in roc_plots) {
print(roc_plot)
Sys.sleep(2) # 2-second delay after each plot
}
## Warning in grid.Call(C_stringMetric, as.graphicsAnnot(x$label)): font family
## not found in Windows font database
## Warning in grid.Call(C_stringMetric, as.graphicsAnnot(x$label)): font family
## not found in Windows font database
## Warning in grid.Call(C_stringMetric, as.graphicsAnnot(x$label)): font family
## not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font family not found in Windows font database
##############################################################################
# Feature Engineering code - added no value to the models. Actually reduced the
# models' performance.
##############################################################################