library(dplyr)
library(rmdformats)
library(prettydoc)
library(hrbrthemes)
library(tint)
library(tufte)
library(readr)
library(e1071)
library(gridExtra)
library(kableExtra)
library(tidyverse)
library(ROSE) # For handling class imbalance
library(ggplot2)
library(rpart)
library(rpart.plot)
library(caret)
library(pROC)
library(randomForest)
library(corrplot)Data
The dataset used in this study was obtained from Kaggle.It includes a combination of socio-demographic details and health-related factors that are associated with the risk of stroke.
# Load the dataset
url<-"https://raw.githubusercontent.com/NikoletaEm/datasps/refs/heads/main/healthcare-dataset-stroke-data.csv"
stroke<- read_csv(url)## spc_tbl_ [5,110 × 12] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
## $ id : num [1:5110] 9046 51676 31112 60182 1665 ...
## $ gender : chr [1:5110] "Male" "Female" "Male" "Female" ...
## $ age : num [1:5110] 67 61 80 49 79 81 74 69 59 78 ...
## $ hypertension : num [1:5110] 0 0 0 0 1 0 1 0 0 0 ...
## $ heart_disease : num [1:5110] 1 0 1 0 0 0 1 0 0 0 ...
## $ ever_married : chr [1:5110] "Yes" "Yes" "Yes" "Yes" ...
## $ work_type : chr [1:5110] "Private" "Self-employed" "Private" "Private" ...
## $ Residence_type : chr [1:5110] "Urban" "Rural" "Rural" "Urban" ...
## $ avg_glucose_level: num [1:5110] 229 202 106 171 174 ...
## $ bmi : chr [1:5110] "36.6" "N/A" "32.5" "34.4" ...
## $ smoking_status : chr [1:5110] "formerly smoked" "never smoked" "never smoked" "smokes" ...
## $ stroke : num [1:5110] 1 1 1 1 1 1 1 1 1 1 ...
## - attr(*, "spec")=
## .. cols(
## .. id = col_double(),
## .. gender = col_character(),
## .. age = col_double(),
## .. hypertension = col_double(),
## .. heart_disease = col_double(),
## .. ever_married = col_character(),
## .. work_type = col_character(),
## .. Residence_type = col_character(),
## .. avg_glucose_level = col_double(),
## .. bmi = col_character(),
## .. smoking_status = col_character(),
## .. stroke = col_double()
## .. )
## - attr(*, "problems")=<externalptr>
## id gender age hypertension
## Min. : 67 Length:5110 Min. : 0.08 Min. :0.00000
## 1st Qu.:17741 Class :character 1st Qu.:25.00 1st Qu.:0.00000
## Median :36932 Mode :character Median :45.00 Median :0.00000
## Mean :36518 Mean :43.23 Mean :0.09746
## 3rd Qu.:54682 3rd Qu.:61.00 3rd Qu.:0.00000
## Max. :72940 Max. :82.00 Max. :1.00000
## heart_disease ever_married work_type Residence_type
## Min. :0.00000 Length:5110 Length:5110 Length:5110
## 1st Qu.:0.00000 Class :character Class :character Class :character
## Median :0.00000 Mode :character Mode :character Mode :character
## Mean :0.05401
## 3rd Qu.:0.00000
## Max. :1.00000
## avg_glucose_level bmi smoking_status stroke
## Min. : 55.12 Length:5110 Length:5110 Min. :0.00000
## 1st Qu.: 77.25 Class :character Class :character 1st Qu.:0.00000
## Median : 91.89 Mode :character Mode :character Median :0.00000
## Mean :106.15 Mean :0.04873
## 3rd Qu.:114.09 3rd Qu.:0.00000
## Max. :271.74 Max. :1.00000
We can spot some “N/A”s that need handling
EDA
# Convert 'bmi' to numeric and handle "N/A" values
stroke$bmi <- as.numeric(ifelse(stroke$bmi == "N/A", NA, stroke$bmi))
# Impute missing BMI values with the mean
stroke$bmi[is.na(stroke$bmi)] <- mean(stroke$bmi, na.rm = TRUE)
# Convert categorical variables to factors
stroke$gender <- as.factor(stroke$gender)
stroke$ever_married <- as.factor(stroke$ever_married)
stroke$work_type <- as.factor(stroke$work_type)
stroke$Residence_type <- as.factor(stroke$Residence_type)
stroke$smoking_status <- as.factor(stroke$smoking_status)
# Drop the 'id' column as it is unnecessary for analysis
stroke <- stroke %>% dplyr::select(-id)
# Check the structure of the cleaned data
str(stroke)## tibble [5,110 × 11] (S3: tbl_df/tbl/data.frame)
## $ gender : Factor w/ 3 levels "Female","Male",..: 2 1 2 1 1 2 2 1 1 1 ...
## $ age : num [1:5110] 67 61 80 49 79 81 74 69 59 78 ...
## $ hypertension : num [1:5110] 0 0 0 0 1 0 1 0 0 0 ...
## $ heart_disease : num [1:5110] 1 0 1 0 0 0 1 0 0 0 ...
## $ ever_married : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 2 2 1 2 2 ...
## $ work_type : Factor w/ 5 levels "children","Govt_job",..: 4 5 4 4 5 4 4 4 4 4 ...
## $ Residence_type : Factor w/ 2 levels "Rural","Urban": 2 1 1 2 1 2 1 2 1 2 ...
## $ avg_glucose_level: num [1:5110] 229 202 106 171 174 ...
## $ bmi : num [1:5110] 36.6 28.9 32.5 34.4 24 ...
## $ smoking_status : Factor w/ 4 levels "formerly smoked",..: 1 2 2 3 2 1 2 2 4 4 ...
## $ stroke : num [1:5110] 1 1 1 1 1 1 1 1 1 1 ...
## [1] 5110 11
| Var1 | Freq |
|---|---|
| 0 | 4861 |
| 1 | 249 |
# Converting numeric binary variables (0 and 1) to "No" and "Yes" factors will make my dataset more readable and interpretable
binary_vars <- c("hypertension", "heart_disease", "stroke")
stroke[binary_vars] <- lapply(stroke[binary_vars], function(x) {
factor(ifelse(x == 1, "Yes", "No"), levels = c("No", "Yes"))
})
str(stroke)## tibble [5,110 × 11] (S3: tbl_df/tbl/data.frame)
## $ gender : Factor w/ 3 levels "Female","Male",..: 2 1 2 1 1 2 2 1 1 1 ...
## $ age : num [1:5110] 67 61 80 49 79 81 74 69 59 78 ...
## $ hypertension : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 1 2 1 1 1 ...
## $ heart_disease : Factor w/ 2 levels "No","Yes": 2 1 2 1 1 1 2 1 1 1 ...
## $ ever_married : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 2 2 1 2 2 ...
## $ work_type : Factor w/ 5 levels "children","Govt_job",..: 4 5 4 4 5 4 4 4 4 4 ...
## $ Residence_type : Factor w/ 2 levels "Rural","Urban": 2 1 1 2 1 2 1 2 1 2 ...
## $ avg_glucose_level: num [1:5110] 229 202 106 171 174 ...
## $ bmi : num [1:5110] 36.6 28.9 32.5 34.4 24 ...
## $ smoking_status : Factor w/ 4 levels "formerly smoked",..: 1 2 2 3 2 1 2 2 4 4 ...
## $ stroke : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 2 2 2 2 2 ...
# Set up a plotting area with multiple rows and columns
par(mfrow = c(2, 2))
# Plot histogram for age
hist(stroke$age, main = "Age Distribution", xlab = "Age", col = "blue", border = "black")
# Plot histogram for average glucose levels
hist(stroke$avg_glucose_level, main = "Average Glucose Level Distribution", xlab = "Avg Glucose Level", col = "green", border = "black")
# Plot histogram for BMI
hist(stroke$bmi, main = "BMI Distribution", xlab = "BMI", col = "orange", border = "black")
par(mfrow = c(1, 1))# Set up a new plotting area for boxplots
par(mfrow = c(2, 2)) # Set the plotting area to 2x2 grid for boxplots
# Boxplot for age (horizontal)
boxplot(stroke$age, main = "Age Boxplot", xlab = "Age", horizontal = TRUE, col = "blue")
# Boxplot for average glucose level (horizontal)
boxplot(stroke$avg_glucose_level, main = "Average Glucose Level Boxplot", xlab = "Avg Glucose Level", horizontal = TRUE, col = "green")
# Boxplot for BMI (horizontal)
boxplot(stroke$bmi, main = "BMI Boxplot", xlab = "BMI", horizontal = TRUE, col = "orange")
# Reset back to a single plotting area
par(mfrow = c(1, 1))A few comments about the plots above:
Age Distribution:The histogram shows that the ages are approximately uniformly distributed between 0 and 80 years, with a slight dip after 80 and the boxplot indicates a balanced spread of ages, with no significant outliers.
Average Glucose Level Distribution: The histogram reveals a right-skewed distribution, with most valuesconcentrated between 70 and 150 mg/dL and the boxplot shows several outliers in the higherrange, suggesting that extreme glucose levels may need special attention in further analysis.
BMI Distribution: The histogram depicts a positively skewed distribution, with most individuals having a BMI between 20 and 50 and the boxplot highlights the presence of multiple outliers.
theme_set(theme_minimal())
# Plot distribution for gender
ggplot(stroke, aes(x = gender)) +
geom_bar(fill = "skyblue") +
labs(title = "Distribution of Gender", x = "Gender", y = "Count")# Plot distribution for ever_married
ggplot(stroke, aes(x = ever_married)) +
geom_bar(fill = "orange") +
labs(title = "Distribution of Marital Status", x = "Ever Married", y = "Count")# Plot distribution for work_type
ggplot(stroke, aes(x = work_type)) +
geom_bar(fill = "green") +
labs(title = "Distribution of Work Type", x = "Work Type", y = "Count")# Plot distribution for Residence_type
ggplot(stroke, aes(x = Residence_type)) +
geom_bar(fill = "purple") +
labs(title = "Distribution of Residence Type", x = "Residence Type", y = "Count")# Plot distribution for smoking_status
ggplot(stroke, aes(x = smoking_status)) +
geom_bar(fill = "pink") +
labs(title = "Distribution of Smoking Status", x = "Smoking Status", y = "Count")# Set up the plot layout
par(mfrow = c(3, 2), # 3 rows, 2 columns
mar = c(5, 6, 4, 3), # Increased left margin
las = 1) # Rotate x-axis labels vertically
# Plot distribution for gender
plot(stroke$gender,
main = "Distribution of Gender",
xlab = "Gender",
ylab = "Count",
col = "skyblue")
# Plot distribution for ever_married
plot(stroke$ever_married,
main = "Distribution of Marital Status",
xlab = "Ever Married",
ylab = "Count",
col = "orange")
# Plot distribution for work_type
plot(stroke$work_type,
main = "Distribution of Work Type",
xlab = "Work Type",
ylab = "Count",
col = "green")
# Plot distribution for Residence_type
plot(stroke$Residence_type,
main = "Distribution of Residence Type",
xlab = "Residence Type",
ylab = "Count",
col = "purple")
# Plot distribution for smoking_status
plot(stroke$smoking_status,
main = "Distribution of Smoking Status",
xlab = "Smoking Status",
ylab = "Count",
col = "pink")
# Optional: Reset the plot layout after finishing
par(mfrow = c(1, 1))A few observations for the bar plots above:
Gender: The dataset has a higher proportion of females than males, with very few entries labeled as “Other.”
Marital Status:The majority of individuals are married, as evidenced by the larger bar for “Yes.”
Work Type:Most individuals are employed in the private sector. The “children” and “government job” categories represent a smaller proportion of the population.
Residence Type: The distribution between rural and urban residence types is nearly equal.
Smoking Status:The smoking habits are fairly distributed across the categories.
Now I will be examining the distribution of binary health conditions: hypertension and heart disease
par(mfrow = c(1, 2))
# Hypertension distribution
plot(stroke$hypertension,
main = "Distribution of Hypertension",
xlab = "Hypertension",
ylab = "Count",
col = "coral")
# Heart disease distribution
plot(stroke$heart_disease,
main = "Distribution of Heart Disease",
xlab = "Heart Disease",
ylab = "Count",
col = "darkred")Now we explore relationships between variables using a correlation plot. This helps identify multicollinearity or related patterns
numeric_data <- stroke %>%
dplyr::select_if(is.numeric)
corr_matrix <- cor(numeric_data)
# Plot the correlation matrix
corrplot(corr_matrix, method = "color", type = "upper",
tl.col = "black", tl.srt = 70,
addCoef.col = "black", number.cex = 0.7,
col = colorRampPalette(c("blue", "white", "red"))(200))The strongest relationship observed is between age and bmi (r = 0.33), which is moderate at best. Age and avg_glucose_level have a low positive correlation (r = 0.24).Bmi and avg_glucose_level are also weakly correlated (r = 0.17). All correlations are below 0.8, meaning we don’t need to worry about multicollinearity!
# In a stroke prediction context, having patients with an age below 1 year (newborns or infants) is extremely rare or even not applicable.
# Filter for age values less than 1
invalid_ages <- stroke %>% filter(age < 1)
print(invalid_ages)## # A tibble: 43 × 11
## gender age hypertension heart_disease ever_married work_type Residence_type
## <fct> <dbl> <fct> <fct> <fct> <fct> <fct>
## 1 Female 0.64 No No No children Urban
## 2 Female 0.88 No No No children Rural
## 3 Female 0.32 No No No children Rural
## 4 Male 0.88 No No No children Rural
## 5 Male 0.24 No No No children Rural
## 6 Female 0.32 No No No children Rural
## 7 Female 0.72 No No No children Urban
## 8 Male 0.8 No No No children Rural
## 9 Male 0.4 No No No children Urban
## 10 Female 0.08 No No No children Urban
## # ℹ 33 more rows
## # ℹ 4 more variables: avg_glucose_level <dbl>, bmi <dbl>, smoking_status <fct>,
## # stroke <fct>
## Handling Imbalanced data
set.seed(123)
balanced_data <- ROSE(stroke ~ ., data = stroke)$data
table(balanced_data$stroke)##
## No Yes
## 2571 2539
# Split data into train and test sets (70-30)
set.seed(123)
split <- createDataPartition(balanced_data$stroke, p = 0.7, list = FALSE)
train <- balanced_data[split, ]
test <- balanced_data[-split, ]# Feature Selection
# Build a preliminary random forest to get feature importance
set.seed(123)
rf_temp <- randomForest(as.factor(stroke) ~ ., data = train, importance = TRUE, ntree = 200)
importance_df <- importance(rf_temp)
importance_df <- data.frame(Feature = rownames(importance_df), importance_df)
importance_df <- importance_df[order(importance_df$MeanDecreaseGini, decreasing = TRUE), ]
ggplot(importance_df, aes(x = reorder(Feature, MeanDecreaseGini), y = MeanDecreaseGini)) +
geom_col(fill = "steelblue") +
coord_flip() +
labs(title = "Feature Importance Ranking (Random Forest)",
x = "Feature", y = "Mean Decrease in Gini") +
theme_minimal()As shown in the plot, age, average glucose level, and BMI emerged as the most influential predictors of stroke. These variables likely capture underlying health risks more effectively than others. In contrast, variables such as gender, residence type, and heart disease showed comparatively lower importance in this model.
Decision Trees
# ==== BASIC TREE ====
# Baseline tree using default settings
set.seed(123)
baseline_tree <- rpart(stroke ~ ., data = train, method = "class")
rpart.plot(baseline_tree, main = "Baseline Decision Tree")baseline_preds <- predict(baseline_tree, test, type = "class")
confusion_matrix_bt<-confusionMatrix(baseline_preds, as.factor(test$stroke))
print(confusion_matrix_bt)## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 568 134
## Yes 203 627
##
## Accuracy : 0.78
## 95% CI : (0.7584, 0.8005)
## No Information Rate : 0.5033
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5603
##
## Mcnemar's Test P-Value : 0.0002121
##
## Sensitivity : 0.7367
## Specificity : 0.8239
## Pos Pred Value : 0.8091
## Neg Pred Value : 0.7554
## Prevalence : 0.5033
## Detection Rate : 0.3708
## Detection Prevalence : 0.4582
## Balanced Accuracy : 0.7803
##
## 'Positive' Class : No
##
## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
## Baseline Decision Tree AUC: 0.8124507
# ==== RANDOM FOREST ====
# Using only top 3 features based on variable importance
top_features <- c("age", "avg_glucose_level", "bmi", "stroke")
train_top <- train[, top_features]
test_top <- test[, top_features]
set.seed(123)
rf_top <- randomForest(as.factor(stroke) ~ ., data = train_top, ntree = 500, importance = TRUE)
rf_preds <- predict(rf_top, test_top, type = "response")
confusion_matrix_rf <- confusionMatrix(rf_preds, as.factor(test_top$stroke))
print(confusion_matrix_rf)## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 543 153
## Yes 228 608
##
## Accuracy : 0.7513
## 95% CI : (0.7289, 0.7728)
## No Information Rate : 0.5033
## P-Value [Acc > NIR] : < 2e-16
##
## Kappa : 0.5029
##
## Mcnemar's Test P-Value : 0.00015
##
## Sensitivity : 0.7043
## Specificity : 0.7989
## Pos Pred Value : 0.7802
## Neg Pred Value : 0.7273
## Prevalence : 0.5033
## Detection Rate : 0.3544
## Detection Prevalence : 0.4543
## Balanced Accuracy : 0.7516
##
## 'Positive' Class : No
##
# AUC
rf_probs <- predict(rf_top, test_top, type = "prob")[, "Yes"]
roc_rf <- roc(test_top$stroke, rf_probs)## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
## Random Forest AUC (Top Features): 0.8311168
# ==== IMPROVED PRUNED TREE ====
ctrl <- trainControl(
method = "cv",
number = 10,
classProbs = TRUE,
summaryFunction = twoClassSummary,
verboseIter = TRUE
)
# defining grid for cp (complexity parameter)
grid <- expand.grid(cp = seq(0.001, 0.03, by = 0.001))
tree_model <- train(
stroke ~ .,
data = train,
method = "rpart",
trControl = ctrl,
tuneGrid = grid,
metric = "ROC"
)## + Fold01: cp=0.001
## - Fold01: cp=0.001
## + Fold02: cp=0.001
## - Fold02: cp=0.001
## + Fold03: cp=0.001
## - Fold03: cp=0.001
## + Fold04: cp=0.001
## - Fold04: cp=0.001
## + Fold05: cp=0.001
## - Fold05: cp=0.001
## + Fold06: cp=0.001
## - Fold06: cp=0.001
## + Fold07: cp=0.001
## - Fold07: cp=0.001
## + Fold08: cp=0.001
## - Fold08: cp=0.001
## + Fold09: cp=0.001
## - Fold09: cp=0.001
## + Fold10: cp=0.001
## - Fold10: cp=0.001
## Aggregating results
## Selecting tuning parameters
## Fitting cp = 0.001 on full training set
## Warning: labs do not fit even at cex 0.15, there may be some overplotting
tree_preds <- predict(tree_model, test)
tree_probs <- predict(tree_model, test, type = "prob")[, "Yes"]
conf_matrix <- confusionMatrix(tree_preds, test$stroke, positive = "Yes")
print(conf_matrix)## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 578 158
## Yes 193 603
##
## Accuracy : 0.7709
## 95% CI : (0.749, 0.7917)
## No Information Rate : 0.5033
## P-Value [Acc > NIR] : < 2e-16
##
## Kappa : 0.5419
##
## Mcnemar's Test P-Value : 0.06956
##
## Sensitivity : 0.7924
## Specificity : 0.7497
## Pos Pred Value : 0.7575
## Neg Pred Value : 0.7853
## Prevalence : 0.4967
## Detection Rate : 0.3936
## Detection Prevalence : 0.5196
## Balanced Accuracy : 0.7710
##
## 'Positive' Class : Yes
##
## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
## Improved Pruned Tree AUC: 0.828014
# Starting with empty results table
tree_results <- data.frame(
Model = character(),
Accuracy = numeric(),
Sensitivity = numeric(),
Specificity = numeric(),
AUC = numeric(),
stringsAsFactors = FALSE
)
# Clean helper function: returns a single-row data frame
extract_metrics_tree <- function(cm, auc_value) {
data.frame(
Accuracy = round(cm$overall["Accuracy"], 4),
Sensitivity = round(cm$byClass["Sensitivity"], 4),
Specificity = round(cm$byClass["Specificity"], 4),
AUC = round(auc_value, 4)
)
}
# Append Baseline Tree results
tree_results <- rbind(
tree_results,
cbind(Model = "Baseline Tree", extract_metrics_tree(confusion_matrix_bt, auc_baseline))
)
# Append Pruned Tree results
tree_results <- rbind(
tree_results,
cbind(Model = "Random Forest", extract_metrics_tree(confusion_matrix_rf , auc_rf))
)
# Append Improved Tree results
tree_results <- rbind(
tree_results,
cbind(Model = "Improved Tree", extract_metrics_tree(conf_matrix, auc_tree))
)
rownames(tree_results) <- NULL
print(tree_results)## Model Accuracy Sensitivity Specificity AUC
## 1 Baseline Tree 0.7800 0.7367 0.8239 0.8125
## 2 Random Forest 0.7513 0.7043 0.7989 0.8311
## 3 Improved Tree 0.7709 0.7924 0.7497 0.8280
SVM
## Scaling is critical because SVMs are sensitive to the scale of features
scale_features <- c("bmi", "avg_glucose_level")
# Scaling only numeric variables
preproc <- preProcess(train[, scale_features], method = c("center", "scale"))
train_scaled <- train
test_scaled <- test
train_scaled[, scale_features] <- predict(preproc, train[, scale_features])
test_scaled[, scale_features] <- predict(preproc, test[, scale_features])# ==== SVM RADIAL ====
# I will train a Support Vector Machine (SVM) using the radial basis function (RBF) kernel
set.seed(123)
svm_radial <- svm(stroke ~ ., data = train_scaled,
kernel = "radial", probability = TRUE)
pred_radial <- predict(svm_radial, test_scaled)
confusion_matrix <- confusionMatrix(pred_radial, test_scaled$stroke)
print(confusion_matrix)## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 557 110
## Yes 214 651
##
## Accuracy : 0.7885
## 95% CI : (0.7672, 0.8087)
## No Information Rate : 0.5033
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5774
##
## Mcnemar's Test P-Value : 1.051e-08
##
## Sensitivity : 0.7224
## Specificity : 0.8555
## Pos Pred Value : 0.8351
## Neg Pred Value : 0.7526
## Prevalence : 0.5033
## Detection Rate : 0.3636
## Detection Prevalence : 0.4354
## Balanced Accuracy : 0.7889
##
## 'Positive' Class : No
##
pred_with_probs <- predict(svm_radial, test_scaled, probability = TRUE)
prob_matrix <- attr(pred_with_probs, "probabilities")
# Checking what columns exist in the prob matrix
print(colnames(prob_matrix)) # Just to verify## [1] "No" "Yes"
if (!"Yes" %in% colnames(prob_matrix)) {
stop("Class 'Yes' not found in prediction output. Check if 'yes' is present in test data.")
}
prob_radial <- prob_matrix[, "Yes"]
# ROC & AUC
roc_radial <- roc(test_scaled$stroke, prob_radial)## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
## AUC - Radial Kernel: 0.8579144
plot(roc_radial, main = "ROC Curve - Radial Kernel", col = "blue", lwd = 2)
legend("bottomright", legend = paste("AUC =", round(auc_radial, 3)),
col = "blue", lwd = 2, bty = "n")# ==== SVM RADIAL TOP FEATURES ====
top_features <- c("age", "avg_glucose_level", "bmi", "stroke")
train_top <- train_scaled[, top_features]
test_top <- test_scaled[, top_features]
# Setting up the cross-validation control
ctrl <- trainControl(method = "cv", number = 3,
classProbs = TRUE,
summaryFunction = twoClassSummary,
savePredictions = "final")
# Create tuning grid for radial kernel
tune_grid <- expand.grid(
sigma = 2^seq(-15, -5, 2), # gamma-related
C = 2^seq(-1, 5, 2)
)
set.seed(123)
svm_top_tuned <- train(
stroke ~ .,
data = train_top,
method = "svmRadial",
metric = "ROC",
trControl = ctrl,
tuneLength = 5
)
print(svm_top_tuned)## Support Vector Machines with Radial Basis Function Kernel
##
## 3578 samples
## 3 predictor
## 2 classes: 'No', 'Yes'
##
## No pre-processing
## Resampling: Cross-Validated (3 fold)
## Summary of sample sizes: 2386, 2385, 2385
## Resampling results across tuning parameters:
##
## C ROC Sens Spec
## 0.25 0.8250242 0.7072222 0.8166741
## 0.50 0.8262765 0.7066667 0.8200496
## 1.00 0.8270642 0.7077778 0.8149877
## 2.00 0.8271718 0.7050000 0.8144266
## 4.00 0.8281791 0.7038889 0.8116160
##
## Tuning parameter 'sigma' was held constant at a value of 0.5875391
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were sigma = 0.5875391 and C = 4.
# Predict class
preds_top <- predict(svm_top_tuned, test_top)
confusion_matrix_top <- confusionMatrix(preds_top, test_top$stroke)
print(confusion_matrix_top)## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 559 142
## Yes 212 619
##
## Accuracy : 0.7689
## 95% CI : (0.747, 0.7898)
## No Information Rate : 0.5033
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5381
##
## Mcnemar's Test P-Value : 0.0002451
##
## Sensitivity : 0.7250
## Specificity : 0.8134
## Pos Pred Value : 0.7974
## Neg Pred Value : 0.7449
## Prevalence : 0.5033
## Detection Rate : 0.3649
## Detection Prevalence : 0.4576
## Balanced Accuracy : 0.7692
##
## 'Positive' Class : No
##
# Predict probabilities for AUC
probs_top <- predict(svm_top_tuned, test_top, type = "prob")[, "Yes"]
roc_top <- roc(test_top$stroke, probs_top)## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
## Tuned SVM with Selected Features AUC: 0.8417605
plot(roc_top,
main = "ROC Curve - Radial Kernel with Top Features",
col = "purple",
lwd = 2)
legend("bottomright", legend = paste("AUC =", round(auc_top, 3)),
col = "purple", lwd = 2, bty = "n")# ==== SVM LINEAR ====
# I need to define a smaller tuning grid for linear kernel
tune_grid_linear <- expand.grid(C = 2^seq(-3, 3, 1)) # Reasonable range for cost for linear kernel
# Train the SVM with Linear Kernel
set.seed(123)
svm_tuned_linear <- tune(
svm,
stroke ~ .,
data = train_scaled,
kernel = "linear",
ranges = tune_grid_linear,
probability = TRUE
)
print(svm_tuned_linear)##
## Parameter tuning of 'svm':
##
## - sampling method: 10-fold cross validation
##
## - best parameters:
## C
## 0.125
##
## - best performance: 0.2339413
# Getting the best model
svm_best_linear <- svm_tuned_linear$best.model
# predicting on the test set
pred_linear <- predict(svm_best_linear, test_scaled)
confusion_matrix_linear <- confusionMatrix(pred_linear, test_scaled$stroke)
print(confusion_matrix_linear)## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 560 133
## Yes 211 628
##
## Accuracy : 0.7755
## 95% CI : (0.7537, 0.7961)
## No Information Rate : 0.5033
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5512
##
## Mcnemar's Test P-Value : 3.302e-05
##
## Sensitivity : 0.7263
## Specificity : 0.8252
## Pos Pred Value : 0.8081
## Neg Pred Value : 0.7485
## Prevalence : 0.5033
## Detection Rate : 0.3655
## Detection Prevalence : 0.4523
## Balanced Accuracy : 0.7758
##
## 'Positive' Class : No
##
# Predicting the probabilities for ROC/AUC
pred_with_probs_linear <- predict(svm_best_linear, test_scaled, probability = TRUE)
prob_matrix_linear <- attr(pred_with_probs_linear, "probabilities")
# Extract probabilities for "Yes" class (stroke = Yes)
prob_linear <- prob_matrix_linear[, "Yes"]
# making sure no NA values
complete_cases_linear <- complete.cases(test_scaled$stroke, prob_linear)
# Compute ROC and AUC
roc_linear <- roc(test_scaled$stroke[complete_cases_linear], prob_linear[complete_cases_linear])## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
## AUC - Linear Kernel: 0.8448795
plot(roc_linear, main = "ROC Curve - Linear Kernel", col = "red", lwd = 2)
legend("bottomright", legend = paste("AUC =", round(auc_linear, 3)),
col = "red", lwd = 2, bty = "n")# Creating an empty data frame to store metrics
svm_results <- data.frame(
Kernel = character(),
Accuracy = numeric(),
Sensitivity = numeric(),
Specificity = numeric(),
AUC = numeric(),
stringsAsFactors = FALSE
)
# Function to extract metrics from confusion matrix
extract_metrics <- function(cm, auc_value) {
accuracy <- cm$overall["Accuracy"]
sensitivity <- cm$byClass["Sensitivity"]
specificity <- cm$byClass["Specificity"]
return(data.frame(
Accuracy = round(accuracy, 4),
Sensitivity = round(sensitivity, 4),
Specificity = round(specificity, 4),
AUC = round(auc_value, 4)
))
}
svm_results <- rbind(
svm_results,
cbind(Kernel = "SVM Radial", extract_metrics(confusion_matrix, auc_radial)),
cbind(Kernel = "SVM Radial Top Features", extract_metrics(confusion_matrix_top, auc_top)),
cbind(Kernel = "SVM Linear", extract_metrics(confusion_matrix_linear, auc_linear))
)
rownames(svm_results) <- NULL
# View the final comparison table
print(svm_results)## Kernel Accuracy Sensitivity Specificity AUC
## 1 SVM Radial 0.7885 0.7224 0.8555 0.8579
## 2 SVM Radial Top Features 0.7689 0.7250 0.8134 0.8418
## 3 SVM Linear 0.7755 0.7263 0.8252 0.8449
Conclusion
# Standardize column names before combining
svm_results$Model <- svm_results$Kernel
svm_results$Kernel <- NULL # Drop the original 'Kernel' column
# Reordering columns to match tree_results
svm_results <- svm_results[, c("Model", "Accuracy", "Sensitivity", "Specificity", "AUC")]
combined_results <- rbind(tree_results, svm_results)
rownames(combined_results) <- NULL
kable(combined_results, caption = "Model Comparison: Tree-based vs SVM Models", digits = 4) %>%
kable_styling(bootstrap_options = c("striped", "hover", "condensed"), full_width = F) %>%
column_spec(1, bold = TRUE) %>%
column_spec(2:5, color = "black", background = "white") %>%
add_header_above(c(" " = 1, "Model Metrics" = 4)) %>%
row_spec(0, bold = TRUE)| Model | Accuracy | Sensitivity | Specificity | AUC |
|---|---|---|---|---|
| Baseline Tree | 0.7800 | 0.7367 | 0.8239 | 0.8125 |
| Random Forest | 0.7513 | 0.7043 | 0.7989 | 0.8311 |
| Improved Tree | 0.7709 | 0.7924 | 0.7497 | 0.8280 |
| SVM Radial | 0.7885 | 0.7224 | 0.8555 | 0.8579 |
| SVM Radial Top Features | 0.7689 | 0.7250 | 0.8134 | 0.8418 |
| SVM Linear | 0.7755 | 0.7263 | 0.8252 | 0.8449 |