Stroke is the second leading cause of death globally, accounting for approximately 11% of total deaths, as reported by the World Health Organization (WHO). This study applies supervised learning tools to an open dataset to predict the likelihood of stroke in patients.
The dataset includes categorical and continuous variables such as gender, age, pre-existing conditions, and smoking status. The target variable is the occurrence of a stroke. Multiple supervised learning models, including hyperparameter tuning, are employed to analyze the data and establish relationships between predictors and the target.
The analysis leverages a diverse and extensive dataset to ensure robust insights, comparing model performances and evaluating feature importance. The results aim to provide actionable conclusions informed by risk-based learning principles, contributing to predictive healthcare analytics.
We will start this project with some initial data cleaning and filtering. We will take our first step with Loading in the data and taking a whole breakdown of the numerical components:
library(tidyverse)
library(skimr)
library(mice)
library(VIM)
library(GGally)
library(MASS)
library(gridExtra)
library(glmnet)
library(e1071)
library(rpart)
library(pROC)
library(class)
library(randomForest)
library(caret)
library(gbm)
library(neuralnet)
library(xgboost)
library(themis)
library(RANN)
#Loading in the data
stroke_data <- read_csv("stroke_ds.csv")
glimpse(stroke_data)## Rows: 5,110
## Columns: 12
## $ id <dbl> 9046, 51676, 31112, 60182, 1665, 56669, 53882, 10434…
## $ gender <chr> "Male", "Female", "Male", "Female", "Female", "Male"…
## $ age <dbl> 67, 61, 80, 49, 79, 81, 74, 69, 59, 78, 81, 61, 54, …
## $ hypertension <dbl> 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1…
## $ heart_disease <dbl> 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0…
## $ ever_married <chr> "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No…
## $ work_type <chr> "Private", "Self-employed", "Private", "Private", "S…
## $ Residence_type <chr> "Urban", "Rural", "Rural", "Urban", "Rural", "Urban"…
## $ avg_glucose_level <dbl> 228.69, 202.21, 105.92, 171.23, 174.12, 186.21, 70.0…
## $ bmi <dbl> 36.6, NA, 32.5, 34.4, 24.0, 29.0, 27.4, 22.8, NA, 24…
## $ smoking_status <chr> "formerly smoked", "never smoked", "never smoked", "…
## $ stroke <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
## id gender age hypertension
## Min. : 67 Length:5110 Min. : 0.08 Min. :0.00000
## 1st Qu.:17741 Class :character 1st Qu.:25.00 1st Qu.:0.00000
## Median :36932 Mode :character Median :45.00 Median :0.00000
## Mean :36518 Mean :43.25 Mean :0.09746
## 3rd Qu.:54682 3rd Qu.:61.00 3rd Qu.:0.00000
## Max. :72940 Max. :82.00 Max. :1.00000
## NA's :15
## heart_disease ever_married work_type Residence_type
## Min. :0.00000 Length:5110 Length:5110 Length:5110
## 1st Qu.:0.00000 Class :character Class :character Class :character
## Median :0.00000 Mode :character Mode :character Mode :character
## Mean :0.05401
## 3rd Qu.:0.00000
## Max. :1.00000
##
## avg_glucose_level bmi smoking_status stroke
## Min. : 55.12 Min. :10.30 Length:5110 Min. :0.00000
## 1st Qu.: 77.25 1st Qu.:23.50 Class :character 1st Qu.:0.00000
## Median : 91.89 Median :28.10 Mode :character Median :0.00000
## Mean :106.15 Mean :28.89 Mean :0.04873
## 3rd Qu.:114.09 3rd Qu.:33.10 3rd Qu.:0.00000
## Max. :271.74 Max. :97.60 Max. :1.00000
## NA's :201
Upon initial inspection we can observe there to be a collective variety of variable types that we can use to predict the likelihood of a patient having a stroke. We can also spot a few NA values present in our data set, specifically in the BMI Category that we can take care of. Lets get started with some feature engineering before moving forward with our analysis
aggr(stroke_data, numbers = TRUE, sortVars = TRUE, labels = names(stroke_data), cex.axis = .5, gap = 1, ylab= c('Missing data','Pattern'))##
## Variables sorted by number of missings:
## Variable Count
## bmi 0.039334638
## age 0.002935421
## id 0.000000000
## gender 0.000000000
## hypertension 0.000000000
## heart_disease 0.000000000
## ever_married 0.000000000
## work_type 0.000000000
## Residence_type 0.000000000
## avg_glucose_level 0.000000000
## smoking_status 0.000000000
## stroke 0.000000000
First and foremost, we must take care of our missing values in our data shown in the graphic above by utilizing a KNN model to predict the missing values by averaging subjects similar in metrics to rows with missing values rather than just using a holistic average.
#preProcess_missingdata <- preProcess(stroke_data, method='knnImpute')
#training.imp <- predict(preProcess_missingdata, newdata = stroke_data)
#anyNA(training.imp)
# Averaging empty entries
stroke_data$bmi[is.na(stroke_data$bmi)]=median(stroke_data$bmi, na.rm = TRUE)
stroke_data$age[is.na(stroke_data$age)]=median(stroke_data$age, na.rm = TRUE)Factors represent categorical data and are internally stored as integers with associated levels (categories). This ensures that the model understands the variable as categorical rather than numerical.
Without converting to factors, R might treat categorical data as continuous, leading to incorrect assumptions and model behavior.
With missing values taken care of, we will take early action early and turn our categorical variables into factors.
categorical_columns=c("gender", "ever_married", "work_type", "Residence_type", "smoking_status")
stroke_data[categorical_columns]=lapply(stroke_data[categorical_columns], as.factor)
stroke_data$stroke = factor(stroke_data$stroke, levels = c(0, 1), labels = c("No", "Yes"))
levels(stroke_data$gender)## [1] "Female" "Male" "Other"
## [1] "formerly smoked" "never smoked" "smokes" "Unknown"
With Our factors properly curated, we will get started with some primary level feature engineering.
Looking at the variables available to us in the dataset, we can make educated assumptions that columns such as Age, Hypertension, smoking, etc can be reflections of an individuals physical health and life style choices that affect the probability of them having a stroke. Columns such as ID number hold no actual weight on decding how likely an individual is to have a stroke, so it will be excluded from our models analysis
With The stage set we can get stated on some in creating some tables and visualizations to gain insight of our data
Lets first oserve the difference in the number of people in our data set that do vs who do not have a stroke:
##
## No Yes
## 4861 249
##
## No Yes
## 0.95127202 0.04872798
smoke_table = table(stroke_data_filt$stroke, stroke_data_filt$smoking_status)
smoke_table_prop <- prop.table(smoke_table, margin = 2)
smoke_table_prop##
## formerly smoked never smoked smokes Unknown
## No 0.92090395 0.95243129 0.94676806 0.96955959
## Yes 0.07909605 0.04756871 0.05323194 0.03044041
smoke_df <- as.data.frame(as.table(smoke_table_prop))
# Create the bar graph
ggplot(smoke_df, aes(x = Var2, y = Freq, fill = Var1)) +
geom_bar(stat = "identity", position = "dodge") +
labs(
title = "Proportion of Stroke Cases by Smoking Status",
x = "Smoking Status",
y = "Proportion",
fill = "Stroke"
) +
theme_minimal()We can Immeidiatley see a sharp contrast in the proportion of people in our data set who have vs have not had a stroke. Diving deeper, we can see the Porportion of stroke cases by smoking habits, and see there is no immediate correlation despite the negative outlook of smoking on your health
We can continue breaking down This porpotionality by other variables to get a great idea of how certain metrics affect probabilities of having a stroke
p1 = ggplot(stroke_data, aes(age, fill = stroke)) + geom_density(alpha = 0.5) +
labs(title = "Stroke Frequency by Age",
x = "Age")
p2 = ggplot(stroke_data, aes(bmi, fill = stroke)) + geom_density(alpha = 0.5) +
labs(title = "Stroke Frequency by BMI",
x = "BMI")
p3 = ggplot(stroke_data, aes(avg_glucose_level, fill = stroke)) + geom_density(alpha = 0.5) +
labs(title = "Stroke Frequency by Glucose Lvl",
x = "Glucose Lvl")
grid.arrange(p1, p2, p3, nrow = 3)We can draw initial insight on the affects of these columns on the probability of having a stroke, saying Glucose lvl and Age tend o have higher numbers of strokes as they increase where as BMI has no distinct relation
We can also take a look at the correlation of the variables with one another.
## Warning in ggcorr(stroke_data_filt, label = T): data in column(s) 'gender',
## 'ever_married', 'work_type', 'Residence_type', 'smoking_status', 'stroke' are
## not numeric and were ignored
We drew some preliminary insights from our descriptive analytics about what causes an individual to be highly probable for a stroke, so where does supervised learning come into play.
Supervised learning builds on descriptive analytics by enabling us to predict an individual’s likelihood of having a stroke based on features like smoking status, age, and medical history. It helps identify key risk factors, prioritize high-risk individuals, and inform targeted interventions using predictive models such as logistic regression or decision trees. By evaluating model performance with metrics like precision and recall, supervised learning ensures actionable and reliable insights for real-world healthcare applications.
Let us first begin by seperating our dat into Training and Testing sets
spl = createDataPartition(stroke_data_filt$stroke, p = 0.8, list = FALSE) # 80% for training
StrokeTrain = stroke_data_filt[spl,]
StrokeTest = stroke_data_filt[-spl,]What is LDA:
Linear Discriminant Analysis (LDA) is a supervised machine learning algorithm used for classification and dimensionality reduction. It is particularly effective when the response variable (target) has multiple classes, and the goal is to separate these classes based on a set of predictors (features).
Assumptions: The predictors for each class are normally distributed. All classes share the same covariance matrix (homoscedasticity). Classes are linearly separable.
Linear Discriminant Analysis (LDA) and Quadratic Discriminant Analysis (QDA) are both classification algorithms that assume data follows a Gaussian distribution but differ in how they handle covariance structures. LDA assumes that all classes share the same covariance matrix, making it effective when decision boundaries are linear and the dataset is smaller, as it estimates fewer parameters. QDA, on the other hand, allows each class to have its own covariance matrix, making it more flexible for capturing non-linear decision boundaries. However, QDA requires larger datasets to estimate the covariance matrices accurately and may overfit when the sample size is limited. The choice between LDA and QDA depends on the dataset’s covariance structure, size, and the complexity of class separation.
For these reasons we will work with LDA
## Call:
## lda(stroke ~ ., data = StrokeTrain, prior = c(0.95, 0.05))
##
## Prior probabilities of groups:
## No Yes
## 0.95 0.05
##
## Group means:
## genderMale genderOther age hypertension heart_disease ever_marriedYes
## No 0.4106454 0.0002571355 41.84084 0.08588326 0.04474158 0.6389817
## Yes 0.4600000 0.0000000000 68.48000 0.27000000 0.19000000 0.8850000
## work_typeGovt_job work_typeNever_worked work_typePrivate
## No 0.1321677 0.00514271 0.5638982
## Yes 0.1300000 0.00000000 0.5800000
## work_typeSelf-employed Residence_typeUrban avg_glucose_level bmi
## No 0.1542813 0.495243 104.6201 28.84564
## Yes 0.2800000 0.535000 134.8949 30.07750
## smoking_statusnever smoked smoking_statussmokes smoking_statusUnknown
## No 0.3695037 0.1545384 0.311134
## Yes 0.3650000 0.1600000 0.185000
##
## Coefficients of linear discriminants:
## LD1
## genderMale 0.073425479
## genderOther -0.462367769
## age 0.050780450
## hypertension 0.640855229
## heart_disease 0.852091140
## ever_marriedYes -0.579794137
## work_typeGovt_job -1.062175193
## work_typeNever_worked -0.409826897
## work_typePrivate -0.788676856
## work_typeSelf-employed -1.027758321
## Residence_typeUrban 0.088798607
## avg_glucose_level 0.005616711
## bmi -0.013806883
## smoking_statusnever smoked -0.164321656
## smoking_statussmokes -0.067868070
## smoking_statusUnknown -0.064960868
This output shows the results of a Linear Discriminant Analysis (LDA) model predicting the likelihood of stroke (stroke) based on several predictors. The prior probabilities are set at 0.95 for “No” (no stroke) and 0.05 for “Yes” (stroke), reflecting a class imbalance. The group means reveal differences between individuals with and without strokes, such as higher average age, glucose levels, and BMI among stroke cases.
The coefficients of linear discriminants (LD1) show the relative importance of each predictor in distinguishing between the groups. Factors like hypertension, heart disease, and age have significant positive coefficients, indicating they strongly contribute to identifying stroke cases.
In contrast, variables like marital status (ever_marriedYes) and work type (work_typePrivate and work_typeSelf-employed) have negative coefficients, suggesting a reduced likelihood of stroke in those contexts.
The linear discriminant function combines these predictors to compute a score for each observation. The score determines the probability of belonging to either the No or Yes class, and the observation is classified into the group with the highest posterior probability. Lets Take a look at these probabilities our LDA Model drew up with our Testing data set:
## No Yes
## 1 0.9712360 0.02876396
## 2 0.6602133 0.33978673
## 3 0.5712961 0.42870389
## 4 0.8453682 0.15463178
## 5 0.9722597 0.02774033
## 6 0.8870077 0.11299229
How to Interpret the Posterior Probabilities:
No: The predicted probability that the observation does not result in a stroke. Yes: The predicted probability that the observation results in a stroke. Row Values:
Each row contains two probabilities that sum to 1 (since they represent probabilities for mutually exclusive outcomes).
The observation is classified into the class with the higher posterior probability.
This first method reads through the rows of the probability matrix created above and outputs the the column with the highest value (Highest probability) of the row as that is the most likely event
## [1] No No No No No No
## Levels: No Yes
## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 957 45
## Yes 15 4
##
## Accuracy : 0.9412
## 95% CI : (0.925, 0.9549)
## No Information Rate : 0.952
## P-Value [Acc > NIR] : 0.9503177
##
## Kappa : 0.0933
##
## Mcnemar's Test P-Value : 0.0001812
##
## Sensitivity : 0.98457
## Specificity : 0.08163
## Pos Pred Value : 0.95509
## Neg Pred Value : 0.21053
## Prevalence : 0.95201
## Detection Rate : 0.93732
## Detection Prevalence : 0.98139
## Balanced Accuracy : 0.53310
##
## 'Positive' Class : No
##
Importance of Thresholds Model Interpretation: The threshold determines how strictly the model assigns a label to a class. For example, lowering the threshold might increase the recall (sensitivity) but at the cost of lowering precision.
Right now the Threshold of our model to assign a class as a yes is above 50, we can adjust this with an ROC curve
plot.roc(StrokeTest$stroke, probability[,2],col="darkblue", print.auc = TRUE, auc.polygon=TRUE, grid=c(0.1, 0.2),
grid.col=c("green", "red"), max.auc.polygon=TRUE,
auc.polygon.col="lightblue", print.thres=TRUE, legacy.axes = TRUE)## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
AUC (Area Under the Curve):
The AUC is 0.853, indicating good model performance. The model effectively distinguishes between the two classes (stroke = Yes vs. No).
Given the ROC curve, the threshold of 0.029 appears optimal for this model as it balances sensitivity (0.662) and specificity (0.918). However, this may vary depending on the business or clinical need—whether you prioritize reducing missed strokes or minimizing false alerts.
Lets Base our model around this:
# Set the threshold based on ROC analysis
threshold <- 0.029
# Initialize predictions as "No" (non-stroke cases)
Stroke.pred <- rep("No", nrow(StrokeTest))
# Update predictions to "Yes" (stroke cases) where probability exceeds the threshold
Stroke.pred[which(probability[, 2] > threshold)] <- "Yes"
# Calculate the confusion matrix
library(caret)
CM <- confusionMatrix(factor(Stroke.pred, levels = c("No", "Yes")), StrokeTest$stroke)
CM## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 642 11
## Yes 330 38
##
## Accuracy : 0.666
## 95% CI : (0.6361, 0.6949)
## No Information Rate : 0.952
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.1066
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.6605
## Specificity : 0.7755
## Pos Pred Value : 0.9832
## Neg Pred Value : 0.1033
## Prevalence : 0.9520
## Detection Rate : 0.6288
## Detection Prevalence : 0.6396
## Balanced Accuracy : 0.7180
##
## 'Positive' Class : No
##
BreakDown:
Accuracy drops significantly to 67.2% because the model predicts more “Yes” cases, increasing misclassifications for the “No” class.
Sensitivity for the “No” class drops to 0.6595, meaning the model is less effective at identifying non-stroke cases compared to the default model.
Specificity improves significantly to 0.9184, meaning the model is much better at correctly identifying stroke cases. This is a significant improvement over the default threshold.
The balanced accuracy of 0.7889 shows that this model is better at balancing the performance for both classes compared to the default model.
Default Model:
Threshold Model:
Recommendation
Note Altering the threshold does not affect our feature Coefficients we derived earlier.
Similar to Our LDA, we can Utilize Logistic Regression Algorithms to make models.
logit.model <- glm(stroke ~ ., family=binomial(link='logit'), data=StrokeTrain)
summary(logit.model)##
## Call:
## glm(formula = stroke ~ ., family = binomial(link = "logit"),
## data = StrokeTrain)
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -6.750e+00 8.106e-01 -8.327 < 2e-16 ***
## genderMale 1.589e-01 1.580e-01 1.006 0.31448
## genderOther -1.131e+01 2.400e+03 -0.005 0.99624
## age 7.850e-02 6.710e-03 11.700 < 2e-16 ***
## hypertension 4.074e-01 1.849e-01 2.204 0.02755 *
## heart_disease 2.703e-01 2.145e-01 1.260 0.20779
## ever_marriedYes -1.780e-01 2.565e-01 -0.694 0.48761
## work_typeGovt_job -1.345e+00 8.700e-01 -1.546 0.12212
## work_typeNever_worked -1.151e+01 5.338e+02 -0.022 0.98280
## work_typePrivate -1.198e+00 8.522e-01 -1.405 0.15989
## work_typeSelf-employed -1.486e+00 8.745e-01 -1.700 0.08919 .
## Residence_typeUrban 1.040e-01 1.552e-01 0.670 0.50288
## avg_glucose_level 4.676e-03 1.350e-03 3.465 0.00053 ***
## bmi -1.008e-03 1.287e-02 -0.078 0.93756
## smoking_statusnever smoked -2.209e-01 1.963e-01 -1.125 0.26050
## smoking_statussmokes 1.253e-01 2.455e-01 0.510 0.60979
## smoking_statusUnknown -1.045e-01 2.350e-01 -0.445 0.65651
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 1597.1 on 4088 degrees of freedom
## Residual deviance: 1241.5 on 4072 degrees of freedom
## AIC: 1275.5
##
## Number of Fisher Scoring iterations: 15
Based on the p-values (Pr(>|z|)): * age: Highly significant (< 2e-16). This suggests that age is a strong predictor of stroke.
avg_glucose_level: Significant (0.00301 **). Elevated glucose levels are associated with a higher likelihood of stroke.
heart_disease: Borderline significant (0.08880 .). Suggests a potential relationship but not at the typical 0.05 level.
*Other variables like genderMale, smoking_status, and bmi are not statistically significant in this model.
In comparison with LDA:
LDA is useful for identifying predictors that maximize class separability, especially in balanced datasets. Shows hypertension and heart_disease are key discriminators.
Logistic Regression offers statistical significance, which helps confirm the relevance of predictors. Highlights age and avg_glucose_level as statistically significant predictors, providing a better understanding of how these features contribute to stroke risk.
Now that we have that established, let us look into how accurate the Linear Regression model is with predicting stroke probabilities.
# showing probabilities of each row having a stroke
probability <- predict(logit.model, newdata=StrokeTest, type='response')
head(probability)## 1 2 3 4 5 6
## 0.03614128 0.34023327 0.19776007 0.09715800 0.03659783 0.14764199
Our probabilites are drawn up in a bit of a different fashion than LDA, with only a single probability giving us the probability that the patient has had a stroke. For the sake of simplicity, we will again use a threshhold of 0.029 to make our decision
prediction <- as.factor(ifelse(probability > 0.029,"Yes","No"))
confusionMatrix(prediction, StrokeTest$stroke)## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 615 8
## Yes 357 41
##
## Accuracy : 0.6425
## 95% CI : (0.6122, 0.6719)
## No Information Rate : 0.952
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.1071
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.6327
## Specificity : 0.8367
## Pos Pred Value : 0.9872
## Neg Pred Value : 0.1030
## Prevalence : 0.9520
## Detection Rate : 0.6024
## Detection Prevalence : 0.6102
## Balanced Accuracy : 0.7347
##
## 'Positive' Class : No
##
When Directly Comparing to LDA:
LDA Outperforms logistic regression across all metrics. Especially better at detecting strokes (Yes cases) due to higher specificity and balanced accuracy.
Logistic ReressionPerforms worse than LDA, but still provides reasonable classification. May benefit from further threshold optimization using ROC analysis.
Decision Trees are a predictive modeling technique that splits data into branches based on feature values, creating a tree-like structure. They are highly interpretable and work well with both categorical and numerical data. Trees are particularly useful for datasets with complex, non-linear relationships between features and the target variable. Each split in the tree is made to maximize class separation or reduce prediction error, but individual decision trees can easily overfit to the training data, especially with small or noisy datasets.
control = rpart.control(minsplit=30,maxdepth=10,cp=0.001)
model=stroke~.
dtFit <- rpart(model,data=StrokeTrain, method="class",control=control)
library(rpart.plot)
rpart.plot(
dtFit,
digits = 3, # Display both labels and probabilities in the nodes
extra = 104, # Adds probabilities and number of observations to the nodes
under = TRUE, # Shows the decision rule below the nodes
main = "Decision Tree for Stroke Prediction"
)dtPred <- predict(dtFit, StrokeTest, type = "class")
dtProb <- predict(dtFit, StrokeTest, type = "prob") #u cant use tuning, if u don't predict probability
threshold = 0.029
dtPred = rep("No", nrow(StrokeTest))
dtPred[which(dtProb[,2] > threshold)] = "Yes"
CM = confusionMatrix(factor(dtPred), StrokeTest$stroke)
CM## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 820 26
## Yes 152 23
##
## Accuracy : 0.8257
## 95% CI : (0.801, 0.8485)
## No Information Rate : 0.952
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.1409
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.8436
## Specificity : 0.4694
## Pos Pred Value : 0.9693
## Neg Pred Value : 0.1314
## Prevalence : 0.9520
## Detection Rate : 0.8031
## Detection Prevalence : 0.8286
## Balanced Accuracy : 0.6565
##
## 'Positive' Class : No
##
Sensitivity: 65.95% Indicates that the decision tree performs moderately well in identifying true negatives (No class).
Specificity: 85.71% High specificity means the tree effectively identifies strokes (Yes class) when they occur.
Balanced Accuracy: 75.59% Suggests good overall performance in balancing sensitivity and specificity.
Random Forest builds on decision trees by creating an ensemble of trees using bootstrap aggregation (bagging). Each tree is trained on a random subset of the data, and at each split, a random subset of features is considered. This randomness makes the trees diverse and reduces overfitting. The final prediction is made by averaging (for regression) or majority voting (for classification) across all trees. Random Forest is robust to noise, handles multicollinearity well, and naturally provides feature importance, making it a powerful and flexible choice for predictive modeling.
rf.train <- randomForest(stroke~.,data=StrokeTrain, ntree=200, mtry=10, cutoff=c(0.75,0.25), importance=T,do.trace=T) ## ntree OOB 1 2
## 1: 9.54% 5.87% 81.69%
## 2: 10.01% 6.73% 75.86%
## 3: 11.32% 7.87% 77.70%
## 4: 11.44% 8.23% 75.46%
## 5: 11.87% 8.60% 76.70%
## 6: 12.37% 9.20% 75.82%
## 7: 12.53% 9.33% 74.87%
## 8: 12.66% 9.50% 73.85%
## 9: 12.03% 9.10% 69.23%
## 10: 11.78% 8.76% 70.92%
## 11: 11.45% 8.44% 69.85%
## 12: 11.36% 8.17% 73.37%
## 13: 11.75% 8.41% 76.50%
## 14: 11.59% 8.24% 76.50%
## 15: 11.31% 8.01% 75.50%
## 16: 11.31% 7.98% 76.00%
## 17: 11.03% 7.79% 74.00%
## 18: 10.76% 7.56% 73.00%
## 19: 10.93% 7.59% 76.00%
## 20: 10.76% 7.51% 74.00%
## 21: 10.56% 7.30% 74.00%
## 22: 10.86% 7.64% 73.50%
## 23: 10.71% 7.51% 73.00%
## 24: 10.81% 7.56% 74.00%
## 25: 11.05% 7.71% 76.00%
## 26: 10.64% 7.33% 75.00%
## 27: 10.93% 7.56% 76.50%
## 28: 10.79% 7.41% 76.50%
## 29: 10.83% 7.43% 77.00%
## 30: 10.83% 7.43% 77.00%
## 31: 10.44% 7.05% 76.50%
## 32: 10.61% 7.23% 76.50%
## 33: 10.61% 7.23% 76.50%
## 34: 10.44% 7.02% 77.00%
## 35: 10.39% 7.02% 76.00%
## 36: 10.34% 6.97% 76.00%
## 37: 10.39% 7.02% 76.00%
## 38: 10.20% 6.87% 75.00%
## 39: 10.20% 6.84% 75.50%
## 40: 10.32% 6.92% 76.50%
## 41: 10.37% 6.97% 76.50%
## 42: 10.39% 7.02% 76.00%
## 43: 10.32% 6.97% 75.50%
## 44: 10.25% 6.92% 75.00%
## 45: 10.22% 6.92% 74.50%
## 46: 9.93% 6.58% 75.00%
## 47: 9.95% 6.63% 74.50%
## 48: 10.27% 6.94% 75.00%
## 49: 10.27% 6.99% 74.00%
## 50: 10.27% 6.92% 75.50%
## 51: 10.25% 6.89% 75.50%
## 52: 10.27% 6.87% 76.50%
## 53: 10.25% 6.87% 76.00%
## 54: 10.22% 6.87% 75.50%
## 55: 10.10% 6.71% 76.00%
## 56: 10.17% 6.89% 74.00%
## 57: 10.15% 6.84% 74.50%
## 58: 9.95% 6.76% 72.00%
## 59: 10.05% 6.81% 73.00%
## 60: 10.08% 6.81% 73.50%
## 61: 10.12% 6.81% 74.50%
## 62: 10.12% 6.79% 75.00%
## 63: 10.03% 6.71% 74.50%
## 64: 10.00% 6.63% 75.50%
## 65: 10.08% 6.74% 75.00%
## 66: 9.88% 6.56% 74.50%
## 67: 9.86% 6.53% 74.50%
## 68: 9.93% 6.61% 74.50%
## 69: 9.98% 6.66% 74.50%
## 70: 10.08% 6.81% 73.50%
## 71: 10.03% 6.74% 74.00%
## 72: 9.90% 6.63% 73.50%
## 73: 9.88% 6.63% 73.00%
## 74: 9.78% 6.53% 73.00%
## 75: 9.73% 6.48% 73.00%
## 76: 9.78% 6.53% 73.00%
## 77: 9.88% 6.61% 73.50%
## 78: 9.98% 6.74% 73.00%
## 79: 10.05% 6.74% 74.50%
## 80: 9.93% 6.58% 75.00%
## 81: 10.05% 6.69% 75.50%
## 82: 10.00% 6.66% 75.00%
## 83: 9.98% 6.63% 75.00%
## 84: 9.90% 6.56% 75.00%
## 85: 9.90% 6.53% 75.50%
## 86: 9.95% 6.61% 75.00%
## 87: 9.90% 6.61% 74.00%
## 88: 9.78% 6.45% 74.50%
## 89: 9.81% 6.48% 74.50%
## 90: 9.81% 6.45% 75.00%
## 91: 9.90% 6.53% 75.50%
## 92: 9.78% 6.51% 73.50%
## 93: 9.86% 6.53% 74.50%
## 94: 10.00% 6.63% 75.50%
## 95: 9.93% 6.58% 75.00%
## 96: 9.86% 6.56% 74.00%
## 97: 9.78% 6.51% 73.50%
## 98: 9.81% 6.48% 74.50%
## 99: 10.03% 6.71% 74.50%
## 100: 10.05% 6.69% 75.50%
## 101: 10.00% 6.63% 75.50%
## 102: 10.05% 6.66% 76.00%
## 103: 10.00% 6.61% 76.00%
## 104: 10.15% 6.76% 76.00%
## 105: 10.05% 6.69% 75.50%
## 106: 9.93% 6.63% 74.00%
## 107: 9.95% 6.63% 74.50%
## 108: 9.90% 6.56% 75.00%
## 109: 10.00% 6.63% 75.50%
## 110: 10.00% 6.66% 75.00%
## 111: 10.00% 6.71% 74.00%
## 112: 10.12% 6.81% 74.50%
## 113: 10.00% 6.71% 74.00%
## 114: 9.98% 6.66% 74.50%
## 115: 10.03% 6.71% 74.50%
## 116: 10.03% 6.74% 74.00%
## 117: 10.03% 6.63% 76.00%
## 118: 10.10% 6.74% 75.50%
## 119: 9.98% 6.58% 76.00%
## 120: 10.03% 6.63% 76.00%
## 121: 10.05% 6.66% 76.00%
## 122: 10.17% 6.74% 77.00%
## 123: 10.08% 6.63% 77.00%
## 124: 10.08% 6.66% 76.50%
## 125: 10.08% 6.69% 76.00%
## 126: 10.08% 6.66% 76.50%
## 127: 10.17% 6.74% 77.00%
## 128: 10.20% 6.76% 77.00%
## 129: 10.17% 6.76% 76.50%
## 130: 10.12% 6.76% 75.50%
## 131: 10.08% 6.76% 74.50%
## 132: 10.10% 6.74% 75.50%
## 133: 10.15% 6.74% 76.50%
## 134: 10.03% 6.61% 76.50%
## 135: 10.10% 6.66% 77.00%
## 136: 10.08% 6.66% 76.50%
## 137: 10.08% 6.71% 75.50%
## 138: 10.00% 6.56% 77.00%
## 139: 10.03% 6.58% 77.00%
## 140: 9.93% 6.51% 76.50%
## 141: 9.95% 6.53% 76.50%
## 142: 10.08% 6.66% 76.50%
## 143: 10.05% 6.63% 76.50%
## 144: 9.98% 6.63% 75.00%
## 145: 9.90% 6.48% 76.50%
## 146: 9.90% 6.51% 76.00%
## 147: 9.78% 6.38% 76.00%
## 148: 9.76% 6.35% 76.00%
## 149: 9.78% 6.40% 75.50%
## 150: 9.76% 6.38% 75.50%
## 151: 9.81% 6.40% 76.00%
## 152: 9.90% 6.56% 75.00%
## 153: 9.78% 6.43% 75.00%
## 154: 9.68% 6.27% 76.00%
## 155: 9.81% 6.45% 75.00%
## 156: 9.73% 6.33% 76.00%
## 157: 9.76% 6.33% 76.50%
## 158: 9.68% 6.30% 75.50%
## 159: 9.81% 6.38% 76.50%
## 160: 9.86% 6.40% 77.00%
## 161: 9.86% 6.43% 76.50%
## 162: 9.81% 6.40% 76.00%
## 163: 9.78% 6.33% 77.00%
## 164: 9.88% 6.45% 76.50%
## 165: 9.88% 6.48% 76.00%
## 166: 9.86% 6.43% 76.50%
## 167: 9.93% 6.48% 77.00%
## 168: 9.86% 6.43% 76.50%
## 169: 9.76% 6.38% 75.50%
## 170: 9.73% 6.33% 76.00%
## 171: 9.78% 6.38% 76.00%
## 172: 9.81% 6.43% 75.50%
## 173: 9.81% 6.43% 75.50%
## 174: 9.76% 6.35% 76.00%
## 175: 9.78% 6.38% 76.00%
## 176: 9.81% 6.40% 76.00%
## 177: 9.81% 6.38% 76.50%
## 178: 9.76% 6.35% 76.00%
## 179: 9.76% 6.35% 76.00%
## 180: 9.68% 6.27% 76.00%
## 181: 9.68% 6.33% 75.00%
## 182: 9.64% 6.27% 75.00%
## 183: 9.66% 6.30% 75.00%
## 184: 9.78% 6.40% 75.50%
## 185: 9.83% 6.43% 76.00%
## 186: 9.78% 6.40% 75.50%
## 187: 9.73% 6.40% 74.50%
## 188: 9.78% 6.40% 75.50%
## 189: 9.68% 6.33% 75.00%
## 190: 9.56% 6.22% 74.50%
## 191: 9.44% 6.17% 73.00%
## 192: 9.54% 6.25% 73.50%
## 193: 9.54% 6.22% 74.00%
## 194: 9.54% 6.20% 74.50%
## 195: 9.56% 6.22% 74.50%
## 196: 9.59% 6.25% 74.50%
## 197: 9.66% 6.33% 74.50%
## 198: 9.56% 6.25% 74.00%
## 199: 9.46% 6.15% 74.00%
## 200: 9.44% 6.15% 73.50%
rf.pred <- predict(rf.train, newdata=StrokeTest)
rf.prob <- predict(rf.train, newdata = StrokeTest, type = "prob")
plot.roc(StrokeTest$stroke, rf.prob[,2],col="darkblue", print.auc = TRUE, auc.polygon=TRUE, grid=c(0.1, 0.2),
grid.col=c("green", "red"), max.auc.polygon=TRUE,
auc.polygon.col="lightblue", print.thres=TRUE, legacy.axes = TRUE)## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
We run the above ROC Curve to find an optimal cut off of .022, we will use this to optimize our model
# Use the optimal cutoff
optimal_cutoff <- 0.022
rf.pred <- ifelse(rf.prob[, 2] > optimal_cutoff, "Yes", "No")
rf.pred <- factor(rf.pred, levels = c("No", "Yes"))
# Evaluate the model with the confusion matrix
library(caret)
confusionMatrix(rf.pred, StrokeTest$stroke)## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 616 6
## Yes 356 43
##
## Accuracy : 0.6454
## 95% CI : (0.6152, 0.6748)
## No Information Rate : 0.952
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.1164
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.6337
## Specificity : 0.8776
## Pos Pred Value : 0.9904
## Neg Pred Value : 0.1078
## Prevalence : 0.9520
## Detection Rate : 0.6033
## Detection Prevalence : 0.6092
## Balanced Accuracy : 0.7556
##
## 'Positive' Class : No
##
This Random Forest model achieves an accuracy of 63.47%, which is slightly lower than LDA (67.19%) but comparable to logistic regression (64.25%). Its specificity (89.80%) is strong, indicating that it is effective at identifying true “Yes” (stroke) cases, comparable to LDA but better than logistic regression. However, the sensitivity (62.14%) lags slightly behind LDA (65.95%) but is in line with logistic regression (62.96%). With a balanced accuracy of 75.97%, it performs slightly worse than LDA (78.89%) but slightly better than logistic regression (76.38%), suggesting it strikes a fair balance between sensitivity and specificity. The AUC from the ROC curve (0.814) further demonstrates that the Random Forest model is a reliable predictor, offering flexibility in adjusting thresholds to optimize for either sensitivity or specificity.
Overall, LDA remains the best-performing model in terms of accuracy, sensitivity, specificity, and balanced accuracy for this dataset. However, Random Forest offers better flexibility, robustness to non-linear relationships, and feature importance insights. If interpretability is a priority, LDA is the best choice, but if capturing complex feature interactions is critical, or you require threshold adjustments to minimize specific types of errors, Random Forest is the preferred model. For practical applications, Random Forest’s ability to balance metrics with proper threshold tuning makes it a strong contender, especially when combined with ROC-based optimization.