Data 622 Homework 3

Nikoleta Emanouilidi

library(dplyr)
library(rmdformats)
library(prettydoc)
library(hrbrthemes)
library(tint)
library(tufte)
library(readr)
library(e1071)
library(gridExtra)
library(kableExtra)
library(tidyverse)
library(ROSE)   # For handling class imbalance
library(ggplot2)
library(rpart)
library(rpart.plot)
library(caret)
library(pROC)
library(randomForest)
library(corrplot)

Data

The dataset used in this study was obtained from Kaggle.It includes a combination of socio-demographic details and health-related factors that are associated with the risk of stroke.

# Load the dataset
url<-"https://raw.githubusercontent.com/NikoletaEm/datasps/refs/heads/main/healthcare-dataset-stroke-data.csv"
stroke<- read_csv(url)
# View dataset structure
str(stroke)
## spc_tbl_ [5,110 × 12] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
##  $ id               : num [1:5110] 9046 51676 31112 60182 1665 ...
##  $ gender           : chr [1:5110] "Male" "Female" "Male" "Female" ...
##  $ age              : num [1:5110] 67 61 80 49 79 81 74 69 59 78 ...
##  $ hypertension     : num [1:5110] 0 0 0 0 1 0 1 0 0 0 ...
##  $ heart_disease    : num [1:5110] 1 0 1 0 0 0 1 0 0 0 ...
##  $ ever_married     : chr [1:5110] "Yes" "Yes" "Yes" "Yes" ...
##  $ work_type        : chr [1:5110] "Private" "Self-employed" "Private" "Private" ...
##  $ Residence_type   : chr [1:5110] "Urban" "Rural" "Rural" "Urban" ...
##  $ avg_glucose_level: num [1:5110] 229 202 106 171 174 ...
##  $ bmi              : chr [1:5110] "36.6" "N/A" "32.5" "34.4" ...
##  $ smoking_status   : chr [1:5110] "formerly smoked" "never smoked" "never smoked" "smokes" ...
##  $ stroke           : num [1:5110] 1 1 1 1 1 1 1 1 1 1 ...
##  - attr(*, "spec")=
##   .. cols(
##   ..   id = col_double(),
##   ..   gender = col_character(),
##   ..   age = col_double(),
##   ..   hypertension = col_double(),
##   ..   heart_disease = col_double(),
##   ..   ever_married = col_character(),
##   ..   work_type = col_character(),
##   ..   Residence_type = col_character(),
##   ..   avg_glucose_level = col_double(),
##   ..   bmi = col_character(),
##   ..   smoking_status = col_character(),
##   ..   stroke = col_double()
##   .. )
##  - attr(*, "problems")=<externalptr>
summary(stroke)
##        id           gender               age         hypertension    
##  Min.   :   67   Length:5110        Min.   : 0.08   Min.   :0.00000  
##  1st Qu.:17741   Class :character   1st Qu.:25.00   1st Qu.:0.00000  
##  Median :36932   Mode  :character   Median :45.00   Median :0.00000  
##  Mean   :36518                      Mean   :43.23   Mean   :0.09746  
##  3rd Qu.:54682                      3rd Qu.:61.00   3rd Qu.:0.00000  
##  Max.   :72940                      Max.   :82.00   Max.   :1.00000  
##  heart_disease     ever_married        work_type         Residence_type    
##  Min.   :0.00000   Length:5110        Length:5110        Length:5110       
##  1st Qu.:0.00000   Class :character   Class :character   Class :character  
##  Median :0.00000   Mode  :character   Mode  :character   Mode  :character  
##  Mean   :0.05401                                                           
##  3rd Qu.:0.00000                                                           
##  Max.   :1.00000                                                           
##  avg_glucose_level     bmi            smoking_status         stroke       
##  Min.   : 55.12    Length:5110        Length:5110        Min.   :0.00000  
##  1st Qu.: 77.25    Class :character   Class :character   1st Qu.:0.00000  
##  Median : 91.89    Mode  :character   Mode  :character   Median :0.00000  
##  Mean   :106.15                                          Mean   :0.04873  
##  3rd Qu.:114.09                                          3rd Qu.:0.00000  
##  Max.   :271.74                                          Max.   :1.00000

We can spot some “N/A”s that need handling

EDA

# Convert 'bmi' to numeric and handle "N/A" values
stroke$bmi <- as.numeric(ifelse(stroke$bmi == "N/A", NA, stroke$bmi))

# Impute missing BMI values with the mean
stroke$bmi[is.na(stroke$bmi)] <- mean(stroke$bmi, na.rm = TRUE)

# Convert categorical variables to factors
stroke$gender <- as.factor(stroke$gender)
stroke$ever_married <- as.factor(stroke$ever_married)
stroke$work_type <- as.factor(stroke$work_type)
stroke$Residence_type <- as.factor(stroke$Residence_type)
stroke$smoking_status <- as.factor(stroke$smoking_status)

# Drop the 'id' column as it is unnecessary for analysis
stroke <- stroke %>% dplyr::select(-id)

# Check the structure of the cleaned data
str(stroke)
## tibble [5,110 × 11] (S3: tbl_df/tbl/data.frame)
##  $ gender           : Factor w/ 3 levels "Female","Male",..: 2 1 2 1 1 2 2 1 1 1 ...
##  $ age              : num [1:5110] 67 61 80 49 79 81 74 69 59 78 ...
##  $ hypertension     : num [1:5110] 0 0 0 0 1 0 1 0 0 0 ...
##  $ heart_disease    : num [1:5110] 1 0 1 0 0 0 1 0 0 0 ...
##  $ ever_married     : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 2 2 1 2 2 ...
##  $ work_type        : Factor w/ 5 levels "children","Govt_job",..: 4 5 4 4 5 4 4 4 4 4 ...
##  $ Residence_type   : Factor w/ 2 levels "Rural","Urban": 2 1 1 2 1 2 1 2 1 2 ...
##  $ avg_glucose_level: num [1:5110] 229 202 106 171 174 ...
##  $ bmi              : num [1:5110] 36.6 28.9 32.5 34.4 24 ...
##  $ smoking_status   : Factor w/ 4 levels "formerly smoked",..: 1 2 2 3 2 1 2 2 4 4 ...
##  $ stroke           : num [1:5110] 1 1 1 1 1 1 1 1 1 1 ...
dim(stroke) 
## [1] 5110   11
# Check class imbalance
cls <- table(stroke$stroke)

knitr::kable(cls, caption = "Class Imbalance")
Class Imbalance
Var1 Freq
0 4861
1 249
# Converting numeric binary variables (0 and 1) to "No" and "Yes" factors will make my dataset more readable and interpretable

binary_vars <- c("hypertension", "heart_disease", "stroke")

stroke[binary_vars] <- lapply(stroke[binary_vars], function(x) {
  factor(ifelse(x == 1, "Yes", "No"), levels = c("No", "Yes"))
})

str(stroke)
## tibble [5,110 × 11] (S3: tbl_df/tbl/data.frame)
##  $ gender           : Factor w/ 3 levels "Female","Male",..: 2 1 2 1 1 2 2 1 1 1 ...
##  $ age              : num [1:5110] 67 61 80 49 79 81 74 69 59 78 ...
##  $ hypertension     : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 1 2 1 1 1 ...
##  $ heart_disease    : Factor w/ 2 levels "No","Yes": 2 1 2 1 1 1 2 1 1 1 ...
##  $ ever_married     : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 2 2 1 2 2 ...
##  $ work_type        : Factor w/ 5 levels "children","Govt_job",..: 4 5 4 4 5 4 4 4 4 4 ...
##  $ Residence_type   : Factor w/ 2 levels "Rural","Urban": 2 1 1 2 1 2 1 2 1 2 ...
##  $ avg_glucose_level: num [1:5110] 229 202 106 171 174 ...
##  $ bmi              : num [1:5110] 36.6 28.9 32.5 34.4 24 ...
##  $ smoking_status   : Factor w/ 4 levels "formerly smoked",..: 1 2 2 3 2 1 2 2 4 4 ...
##  $ stroke           : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 2 2 2 2 2 ...
# Set up a plotting area with multiple rows and columns
par(mfrow = c(2, 2))  

# Plot histogram for age
hist(stroke$age, main = "Age Distribution", xlab = "Age", col = "blue", border = "black")

# Plot histogram for average glucose levels
hist(stroke$avg_glucose_level, main = "Average Glucose Level Distribution", xlab = "Avg Glucose Level", col = "green", border = "black")

# Plot histogram for BMI
hist(stroke$bmi, main = "BMI Distribution", xlab = "BMI", col = "orange", border = "black")

par(mfrow = c(1, 1))

# Set up a new plotting area for boxplots
par(mfrow = c(2, 2))  # Set the plotting area to 2x2 grid for boxplots

# Boxplot for age (horizontal)
boxplot(stroke$age, main = "Age Boxplot", xlab = "Age", horizontal = TRUE, col = "blue")

# Boxplot for average glucose level (horizontal)
boxplot(stroke$avg_glucose_level, main = "Average Glucose Level Boxplot", xlab = "Avg Glucose Level", horizontal = TRUE, col = "green")

# Boxplot for BMI (horizontal)
boxplot(stroke$bmi, main = "BMI Boxplot", xlab = "BMI", horizontal = TRUE, col = "orange")

# Reset back to a single plotting area
par(mfrow = c(1, 1))

A few comments about the plots above:

  • Age Distribution:The histogram shows that the ages are approximately uniformly distributed between 0 and 80 years, with a slight dip after 80 and the boxplot indicates a balanced spread of ages, with no significant outliers.

  • Average Glucose Level Distribution: The histogram reveals a right-skewed distribution, with most valuesconcentrated between 70 and 150 mg/dL and the boxplot shows several outliers in the higherrange, suggesting that extreme glucose levels may need special attention in further analysis.

  • BMI Distribution: The histogram depicts a positively skewed distribution, with most individuals having a BMI between 20 and 50 and the boxplot highlights the presence of multiple outliers.

theme_set(theme_minimal())

# Plot distribution for gender
ggplot(stroke, aes(x = gender)) +
  geom_bar(fill = "skyblue") +
  labs(title = "Distribution of Gender", x = "Gender", y = "Count")

# Plot distribution for ever_married
ggplot(stroke, aes(x = ever_married)) +
  geom_bar(fill = "orange") +
  labs(title = "Distribution of Marital Status", x = "Ever Married", y = "Count")

# Plot distribution for work_type
ggplot(stroke, aes(x = work_type)) +
  geom_bar(fill = "green") +
  labs(title = "Distribution of Work Type", x = "Work Type", y = "Count")

# Plot distribution for Residence_type
ggplot(stroke, aes(x = Residence_type)) +
  geom_bar(fill = "purple") +
  labs(title = "Distribution of Residence Type", x = "Residence Type", y = "Count")

# Plot distribution for smoking_status
ggplot(stroke, aes(x = smoking_status)) +
  geom_bar(fill = "pink") +
  labs(title = "Distribution of Smoking Status", x = "Smoking Status", y = "Count")

# Set up the plot layout
par(mfrow = c(3, 2),    # 3 rows, 2 columns
    mar = c(5, 6, 4, 3), # Increased left margin
    las = 1)            # Rotate x-axis labels vertically

# Plot distribution for gender
plot(stroke$gender, 
     main = "Distribution of Gender", 
     xlab = "Gender", 
     ylab = "Count", 
     col = "skyblue")

# Plot distribution for ever_married
plot(stroke$ever_married, 
     main = "Distribution of Marital Status",
     xlab = "Ever Married",
     ylab = "Count",
     col = "orange")

# Plot distribution for work_type
plot(stroke$work_type, 
     main = "Distribution of Work Type",
     xlab = "Work Type",
     ylab = "Count",
     col = "green")

# Plot distribution for Residence_type
plot(stroke$Residence_type,
     main = "Distribution of Residence Type",
     xlab = "Residence Type",
     ylab = "Count",
     col = "purple")

# Plot distribution for smoking_status
plot(stroke$smoking_status,
     main = "Distribution of Smoking Status",
     xlab = "Smoking Status",
     ylab = "Count",
     col = "pink")

# Optional: Reset the plot layout after finishing
par(mfrow = c(1, 1))

A few observations for the bar plots above:

  • Gender: The dataset has a higher proportion of females than males, with very few entries labeled as “Other.”

  • Marital Status:The majority of individuals are married, as evidenced by the larger bar for “Yes.”

  • Work Type:Most individuals are employed in the private sector. The “children” and “government job” categories represent a smaller proportion of the population.

  • Residence Type: The distribution between rural and urban residence types is nearly equal.

  • Smoking Status:The smoking habits are fairly distributed across the categories.

Now I will be examining the distribution of binary health conditions: hypertension and heart disease

par(mfrow = c(1, 2))  

# Hypertension distribution
plot(stroke$hypertension,
     main = "Distribution of Hypertension",
     xlab = "Hypertension",
     ylab = "Count",
     col = "coral")

# Heart disease distribution
plot(stroke$heart_disease,
     main = "Distribution of Heart Disease",
     xlab = "Heart Disease",
     ylab = "Count",
     col = "darkred")

par(mfrow = c(1, 1))  

Now we explore relationships between variables using a correlation plot. This helps identify multicollinearity or related patterns

numeric_data <- stroke %>% 
  dplyr::select_if(is.numeric)

corr_matrix <- cor(numeric_data)

# Plot the correlation matrix
corrplot(corr_matrix, method = "color", type = "upper", 
         tl.col = "black", tl.srt = 70, 
         addCoef.col = "black", number.cex = 0.7, 
         col = colorRampPalette(c("blue", "white", "red"))(200))

The strongest relationship observed is between age and bmi (r = 0.33), which is moderate at best. Age and avg_glucose_level have a low positive correlation (r = 0.24).Bmi and avg_glucose_level are also weakly correlated (r = 0.17). All correlations are below 0.8, meaning we don’t need to worry about multicollinearity!

# In a stroke prediction context, having patients with an age below 1 year (newborns or infants) is extremely rare or even not applicable. 

# Filter for age values less than 1
invalid_ages <- stroke %>% filter(age < 1)

print(invalid_ages)
## # A tibble: 43 × 11
##    gender   age hypertension heart_disease ever_married work_type Residence_type
##    <fct>  <dbl> <fct>        <fct>         <fct>        <fct>     <fct>         
##  1 Female  0.64 No           No            No           children  Urban         
##  2 Female  0.88 No           No            No           children  Rural         
##  3 Female  0.32 No           No            No           children  Rural         
##  4 Male    0.88 No           No            No           children  Rural         
##  5 Male    0.24 No           No            No           children  Rural         
##  6 Female  0.32 No           No            No           children  Rural         
##  7 Female  0.72 No           No            No           children  Urban         
##  8 Male    0.8  No           No            No           children  Rural         
##  9 Male    0.4  No           No            No           children  Urban         
## 10 Female  0.08 No           No            No           children  Urban         
## # ℹ 33 more rows
## # ℹ 4 more variables: avg_glucose_level <dbl>, bmi <dbl>, smoking_status <fct>,
## #   stroke <fct>
stroke <- stroke %>%
  mutate(age = ceiling(age * 10) / 10)
## Handling Imbalanced data 
set.seed(123)
balanced_data <- ROSE(stroke ~ ., data = stroke)$data
table(balanced_data$stroke)
## 
##   No  Yes 
## 2571 2539
# Split data into train and test sets (70-30)
set.seed(123)
split <- createDataPartition(balanced_data$stroke, p = 0.7, list = FALSE)
train <- balanced_data[split, ]
test <- balanced_data[-split, ]
# Feature Selection
# Build a preliminary random forest to get feature importance
set.seed(123)
rf_temp <- randomForest(as.factor(stroke) ~ ., data = train, importance = TRUE, ntree = 200)

importance_df <- importance(rf_temp)
importance_df <- data.frame(Feature = rownames(importance_df), importance_df)
importance_df <- importance_df[order(importance_df$MeanDecreaseGini, decreasing = TRUE), ]

ggplot(importance_df, aes(x = reorder(Feature, MeanDecreaseGini), y = MeanDecreaseGini)) +
  geom_col(fill = "steelblue") +
  coord_flip() +
  labs(title = "Feature Importance Ranking (Random Forest)",
       x = "Feature", y = "Mean Decrease in Gini") +
  theme_minimal()

As shown in the plot, age, average glucose level, and BMI emerged as the most influential predictors of stroke. These variables likely capture underlying health risks more effectively than others. In contrast, variables such as gender, residence type, and heart disease showed comparatively lower importance in this model.

Decision Trees

# ==== BASIC  TREE ====
# Baseline tree using default settings
set.seed(123)
baseline_tree <- rpart(stroke ~ ., data = train, method = "class")


rpart.plot(baseline_tree, main = "Baseline Decision Tree")

baseline_preds <- predict(baseline_tree, test, type = "class")
confusion_matrix_bt<-confusionMatrix(baseline_preds, as.factor(test$stroke))
print(confusion_matrix_bt)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  568 134
##        Yes 203 627
##                                           
##                Accuracy : 0.78            
##                  95% CI : (0.7584, 0.8005)
##     No Information Rate : 0.5033          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.5603          
##                                           
##  Mcnemar's Test P-Value : 0.0002121       
##                                           
##             Sensitivity : 0.7367          
##             Specificity : 0.8239          
##          Pos Pred Value : 0.8091          
##          Neg Pred Value : 0.7554          
##              Prevalence : 0.5033          
##          Detection Rate : 0.3708          
##    Detection Prevalence : 0.4582          
##       Balanced Accuracy : 0.7803          
##                                           
##        'Positive' Class : No              
## 
baseline_probs <- predict(baseline_tree, test)[,2]
roc_baseline <- roc(test$stroke, baseline_probs)
## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
auc_baseline <- auc(roc_baseline)
cat("Baseline Decision Tree AUC:", auc_baseline, "\n")
## Baseline Decision Tree AUC: 0.8124507
# ==== RANDOM FOREST ====

# Using only top 3 features based on variable importance
top_features <- c("age", "avg_glucose_level", "bmi", "stroke")
train_top <- train[, top_features]
test_top <- test[, top_features]

set.seed(123)
rf_top <- randomForest(as.factor(stroke) ~ ., data = train_top, ntree = 500, importance = TRUE)


rf_preds <- predict(rf_top, test_top, type = "response")
confusion_matrix_rf <- confusionMatrix(rf_preds, as.factor(test_top$stroke))
print(confusion_matrix_rf)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  543 153
##        Yes 228 608
##                                           
##                Accuracy : 0.7513          
##                  95% CI : (0.7289, 0.7728)
##     No Information Rate : 0.5033          
##     P-Value [Acc > NIR] : < 2e-16         
##                                           
##                   Kappa : 0.5029          
##                                           
##  Mcnemar's Test P-Value : 0.00015         
##                                           
##             Sensitivity : 0.7043          
##             Specificity : 0.7989          
##          Pos Pred Value : 0.7802          
##          Neg Pred Value : 0.7273          
##              Prevalence : 0.5033          
##          Detection Rate : 0.3544          
##    Detection Prevalence : 0.4543          
##       Balanced Accuracy : 0.7516          
##                                           
##        'Positive' Class : No              
## 
# AUC
rf_probs <- predict(rf_top, test_top, type = "prob")[, "Yes"]
roc_rf <- roc(test_top$stroke, rf_probs)
## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
auc_rf <- auc(roc_rf)
cat("Random Forest AUC (Top Features):", auc_rf, "\n")
## Random Forest AUC (Top Features): 0.8311168
# ==== IMPROVED PRUNED TREE ====
ctrl <- trainControl(
  method = "cv", 
  number = 10, 
  classProbs = TRUE, 
  summaryFunction = twoClassSummary,
  verboseIter = TRUE
)

#  defining  grid for cp (complexity parameter)
grid <- expand.grid(cp = seq(0.001, 0.03, by = 0.001))


tree_model <- train(
  stroke ~ ., 
  data = train, 
  method = "rpart",
  trControl = ctrl,
  tuneGrid = grid,
  metric = "ROC"  
)
## + Fold01: cp=0.001 
## - Fold01: cp=0.001 
## + Fold02: cp=0.001 
## - Fold02: cp=0.001 
## + Fold03: cp=0.001 
## - Fold03: cp=0.001 
## + Fold04: cp=0.001 
## - Fold04: cp=0.001 
## + Fold05: cp=0.001 
## - Fold05: cp=0.001 
## + Fold06: cp=0.001 
## - Fold06: cp=0.001 
## + Fold07: cp=0.001 
## - Fold07: cp=0.001 
## + Fold08: cp=0.001 
## - Fold08: cp=0.001 
## + Fold09: cp=0.001 
## - Fold09: cp=0.001 
## + Fold10: cp=0.001 
## - Fold10: cp=0.001 
## Aggregating results
## Selecting tuning parameters
## Fitting cp = 0.001 on full training set
rpart.plot(tree_model$finalModel, main = "Tuned & Pruned Decision Tree")
## Warning: labs do not fit even at cex 0.15, there may be some overplotting

tree_preds <- predict(tree_model, test)
tree_probs <- predict(tree_model, test, type = "prob")[, "Yes"]


conf_matrix <- confusionMatrix(tree_preds, test$stroke, positive = "Yes")
print(conf_matrix)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  578 158
##        Yes 193 603
##                                          
##                Accuracy : 0.7709         
##                  95% CI : (0.749, 0.7917)
##     No Information Rate : 0.5033         
##     P-Value [Acc > NIR] : < 2e-16        
##                                          
##                   Kappa : 0.5419         
##                                          
##  Mcnemar's Test P-Value : 0.06956        
##                                          
##             Sensitivity : 0.7924         
##             Specificity : 0.7497         
##          Pos Pred Value : 0.7575         
##          Neg Pred Value : 0.7853         
##              Prevalence : 0.4967         
##          Detection Rate : 0.3936         
##    Detection Prevalence : 0.5196         
##       Balanced Accuracy : 0.7710         
##                                          
##        'Positive' Class : Yes            
## 
roc_tree <- roc(test$stroke, tree_probs)
## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
auc_tree <- auc(roc_tree)
cat("Improved Pruned Tree AUC:", auc_tree, "\n")
## Improved Pruned Tree AUC: 0.828014
# Starting with empty results table
tree_results <- data.frame(
  Model = character(),
  Accuracy = numeric(),
  Sensitivity = numeric(),
  Specificity = numeric(),
  AUC = numeric(),
  stringsAsFactors = FALSE
)

# Clean helper function: returns a single-row data frame
extract_metrics_tree <- function(cm, auc_value) {
  data.frame(
    Accuracy = round(cm$overall["Accuracy"], 4),
    Sensitivity = round(cm$byClass["Sensitivity"], 4),
    Specificity = round(cm$byClass["Specificity"], 4),
    AUC = round(auc_value, 4)
  )
}

# Append Baseline Tree results
tree_results <- rbind(
  tree_results,
  cbind(Model = "Baseline Tree", extract_metrics_tree(confusion_matrix_bt, auc_baseline))
)

# Append Pruned Tree results
tree_results <- rbind(
  tree_results,
  cbind(Model = "Random Forest", extract_metrics_tree(confusion_matrix_rf , auc_rf))
)

# Append Improved Tree results
tree_results <- rbind(
  tree_results,
  cbind(Model = "Improved Tree", extract_metrics_tree(conf_matrix, auc_tree))
)

rownames(tree_results) <- NULL
print(tree_results)
##           Model Accuracy Sensitivity Specificity    AUC
## 1 Baseline Tree   0.7800      0.7367      0.8239 0.8125
## 2 Random Forest   0.7513      0.7043      0.7989 0.8311
## 3 Improved Tree   0.7709      0.7924      0.7497 0.8280

SVM

##  Scaling is critical because SVMs are sensitive to the scale of features
scale_features <- c("bmi", "avg_glucose_level")

# Scaling only numeric variables 
preproc <- preProcess(train[, scale_features], method = c("center", "scale"))
train_scaled <- train
test_scaled <- test
train_scaled[, scale_features] <- predict(preproc, train[, scale_features])
test_scaled[, scale_features] <- predict(preproc, test[, scale_features])
# ==== SVM RADIAL ====
# I will train a Support Vector Machine (SVM) using the radial basis function (RBF) kernel

set.seed(123)

svm_radial <- svm(stroke ~ ., data = train_scaled,
                  kernel = "radial", probability = TRUE)

pred_radial <- predict(svm_radial, test_scaled)


confusion_matrix <- confusionMatrix(pred_radial, test_scaled$stroke)
print(confusion_matrix)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  557 110
##        Yes 214 651
##                                           
##                Accuracy : 0.7885          
##                  95% CI : (0.7672, 0.8087)
##     No Information Rate : 0.5033          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.5774          
##                                           
##  Mcnemar's Test P-Value : 1.051e-08       
##                                           
##             Sensitivity : 0.7224          
##             Specificity : 0.8555          
##          Pos Pred Value : 0.8351          
##          Neg Pred Value : 0.7526          
##              Prevalence : 0.5033          
##          Detection Rate : 0.3636          
##    Detection Prevalence : 0.4354          
##       Balanced Accuracy : 0.7889          
##                                           
##        'Positive' Class : No              
## 
pred_with_probs <- predict(svm_radial, test_scaled, probability = TRUE)
prob_matrix <- attr(pred_with_probs, "probabilities")

# Checking what columns exist in the prob matrix
print(colnames(prob_matrix))  # Just to verify
## [1] "No"  "Yes"
if (!"Yes" %in% colnames(prob_matrix)) {
  stop("Class 'Yes' not found in prediction output. Check if 'yes' is present in test data.")
}
prob_radial <- prob_matrix[, "Yes"]

# ROC & AUC
roc_radial <- roc(test_scaled$stroke, prob_radial)
## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
auc_radial <- auc(roc_radial)
cat("AUC - Radial Kernel:", auc_radial, "\n")
## AUC - Radial Kernel: 0.8579144
plot(roc_radial, main = "ROC Curve - Radial Kernel", col = "blue", lwd = 2)
legend("bottomright", legend = paste("AUC =", round(auc_radial, 3)), 
       col = "blue", lwd = 2, bty = "n")

# ==== SVM RADIAL TOP FEATURES ====

top_features <- c("age", "avg_glucose_level", "bmi", "stroke")
train_top <- train_scaled[, top_features]
test_top <- test_scaled[, top_features]


# Setting up the cross-validation control
ctrl <- trainControl(method = "cv", number = 3,
                     classProbs = TRUE,
                     summaryFunction = twoClassSummary,
                     savePredictions = "final")

# Create tuning grid for radial kernel
tune_grid <- expand.grid(
  sigma = 2^seq(-15, -5, 2),  # gamma-related
  C = 2^seq(-1, 5, 2)
)

set.seed(123)
svm_top_tuned <- train(
  stroke ~ .,
  data = train_top,
  method = "svmRadial",
  metric = "ROC",
  trControl = ctrl,
  tuneLength = 5
)


print(svm_top_tuned)
## Support Vector Machines with Radial Basis Function Kernel 
## 
## 3578 samples
##    3 predictor
##    2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (3 fold) 
## Summary of sample sizes: 2386, 2385, 2385 
## Resampling results across tuning parameters:
## 
##   C     ROC        Sens       Spec     
##   0.25  0.8250242  0.7072222  0.8166741
##   0.50  0.8262765  0.7066667  0.8200496
##   1.00  0.8270642  0.7077778  0.8149877
##   2.00  0.8271718  0.7050000  0.8144266
##   4.00  0.8281791  0.7038889  0.8116160
## 
## Tuning parameter 'sigma' was held constant at a value of 0.5875391
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were sigma = 0.5875391 and C = 4.
plot(svm_top_tuned)

# Predict class
preds_top <- predict(svm_top_tuned, test_top)
confusion_matrix_top <- confusionMatrix(preds_top, test_top$stroke)
print(confusion_matrix_top)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  559 142
##        Yes 212 619
##                                          
##                Accuracy : 0.7689         
##                  95% CI : (0.747, 0.7898)
##     No Information Rate : 0.5033         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.5381         
##                                          
##  Mcnemar's Test P-Value : 0.0002451      
##                                          
##             Sensitivity : 0.7250         
##             Specificity : 0.8134         
##          Pos Pred Value : 0.7974         
##          Neg Pred Value : 0.7449         
##              Prevalence : 0.5033         
##          Detection Rate : 0.3649         
##    Detection Prevalence : 0.4576         
##       Balanced Accuracy : 0.7692         
##                                          
##        'Positive' Class : No             
## 
# Predict probabilities for AUC
probs_top <- predict(svm_top_tuned, test_top, type = "prob")[, "Yes"]
roc_top <- roc(test_top$stroke, probs_top)
## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
auc_top <- auc(roc_top)
cat("Tuned SVM with Selected Features AUC:", auc_top, "\n")
## Tuned SVM with Selected Features AUC: 0.8417605
plot(roc_top, 
     main = "ROC Curve - Radial Kernel with Top Features",
     col = "purple", 
     lwd = 2)


legend("bottomright", legend = paste("AUC =", round(auc_top, 3)), 
       col = "purple", lwd = 2, bty = "n")

# ==== SVM LINEAR ====
# I need to define a smaller tuning grid for linear kernel 
tune_grid_linear <- expand.grid(C = 2^seq(-3, 3, 1))  # Reasonable range for cost for linear kernel

# Train the SVM with Linear Kernel 
set.seed(123)
svm_tuned_linear <- tune(
  svm,
  stroke ~ .,
  data = train_scaled,
  kernel = "linear",
  ranges = tune_grid_linear,
  probability = TRUE
)

print(svm_tuned_linear)
## 
## Parameter tuning of 'svm':
## 
## - sampling method: 10-fold cross validation 
## 
## - best parameters:
##      C
##  0.125
## 
## - best performance: 0.2339413
# Getting the best model
svm_best_linear <- svm_tuned_linear$best.model

# predicting on the test set
pred_linear <- predict(svm_best_linear, test_scaled)

confusion_matrix_linear <- confusionMatrix(pred_linear, test_scaled$stroke)
print(confusion_matrix_linear)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  560 133
##        Yes 211 628
##                                           
##                Accuracy : 0.7755          
##                  95% CI : (0.7537, 0.7961)
##     No Information Rate : 0.5033          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.5512          
##                                           
##  Mcnemar's Test P-Value : 3.302e-05       
##                                           
##             Sensitivity : 0.7263          
##             Specificity : 0.8252          
##          Pos Pred Value : 0.8081          
##          Neg Pred Value : 0.7485          
##              Prevalence : 0.5033          
##          Detection Rate : 0.3655          
##    Detection Prevalence : 0.4523          
##       Balanced Accuracy : 0.7758          
##                                           
##        'Positive' Class : No              
## 
# Predicting the probabilities for ROC/AUC
pred_with_probs_linear <- predict(svm_best_linear, test_scaled, probability = TRUE)
prob_matrix_linear <- attr(pred_with_probs_linear, "probabilities")

# Extract probabilities for "Yes" class (stroke = Yes)
prob_linear <- prob_matrix_linear[, "Yes"]

# making sure no NA values
complete_cases_linear <- complete.cases(test_scaled$stroke, prob_linear)


# Compute ROC and AUC
roc_linear <- roc(test_scaled$stroke[complete_cases_linear], prob_linear[complete_cases_linear])
## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
# AUC for linear SVM
auc_linear <- auc(roc_linear)
cat("AUC - Linear Kernel:", auc_linear, "\n")
## AUC - Linear Kernel: 0.8448795
plot(roc_linear, main = "ROC Curve - Linear Kernel", col = "red", lwd = 2)
legend("bottomright", legend = paste("AUC =", round(auc_linear, 3)), 
       col = "red", lwd = 2, bty = "n")

# Creating an empty data frame to store metrics
svm_results <- data.frame(
  Kernel = character(),
  Accuracy = numeric(),
  Sensitivity = numeric(),
  Specificity = numeric(),
  AUC = numeric(),
  stringsAsFactors = FALSE
)

# Function to extract metrics from confusion matrix
extract_metrics <- function(cm, auc_value) {
  accuracy <- cm$overall["Accuracy"]
  sensitivity <- cm$byClass["Sensitivity"]
  specificity <- cm$byClass["Specificity"]
  
  return(data.frame(
    Accuracy = round(accuracy, 4),
    Sensitivity = round(sensitivity, 4),
    Specificity = round(specificity, 4),
    AUC = round(auc_value, 4)
  ))
}

svm_results <- rbind(
  svm_results,
  cbind(Kernel = "SVM Radial", extract_metrics(confusion_matrix, auc_radial)),
  cbind(Kernel = "SVM Radial Top Features", extract_metrics(confusion_matrix_top, auc_top)),
  cbind(Kernel = "SVM Linear", extract_metrics(confusion_matrix_linear, auc_linear))
)

rownames(svm_results) <- NULL
# View the final comparison table
print(svm_results)
##                    Kernel Accuracy Sensitivity Specificity    AUC
## 1              SVM Radial   0.7885      0.7224      0.8555 0.8579
## 2 SVM Radial Top Features   0.7689      0.7250      0.8134 0.8418
## 3              SVM Linear   0.7755      0.7263      0.8252 0.8449

Conclusion

# Standardize column names before combining
svm_results$Model <- svm_results$Kernel
svm_results$Kernel <- NULL  # Drop the original 'Kernel' column

# Reordering  columns to match tree_results
svm_results <- svm_results[, c("Model", "Accuracy", "Sensitivity", "Specificity", "AUC")]

combined_results <- rbind(tree_results, svm_results)

rownames(combined_results) <- NULL


kable(combined_results, caption = "Model Comparison: Tree-based vs SVM Models", digits = 4) %>%
  kable_styling(bootstrap_options = c("striped", "hover", "condensed"), full_width = F) %>%
  column_spec(1, bold = TRUE) %>%
  column_spec(2:5, color = "black", background = "white") %>%
  add_header_above(c(" " = 1, "Model Metrics" = 4)) %>%
  row_spec(0, bold = TRUE)
Model Comparison: Tree-based vs SVM Models
Model Metrics
Model Accuracy Sensitivity Specificity AUC
Baseline Tree 0.7800 0.7367 0.8239 0.8125
Random Forest 0.7513 0.7043 0.7989 0.8311
Improved Tree 0.7709 0.7924 0.7497 0.8280
SVM Radial 0.7885 0.7224 0.8555 0.8579
SVM Radial Top Features 0.7689 0.7250 0.8134 0.8418
SVM Linear 0.7755 0.7263 0.8252 0.8449