This study evaluates three different classification models for heart stroke prediction. The models are a Random Forest, a K-Nearest Neighbor and a Logistic Regression model.
The data used comes from a Stroke Prediction dataset available at Kaggle in the following link: https://www.kaggle.com/fedesoriano/stroke-prediction-dataset/metadata
This study and the models created are only for educational purposes.
library(tidyverse)
library(data.table)
library(caret)
library(pROC)
library(cvms)
library(imbalance)
data <- read.csv('./data/healthcare-dataset-stroke-data.csv', na.strings = c('N/A'))
data <- as.data.table(data)
summary(data)
## id gender age hypertension
## Min. : 67 Length:5110 Min. : 0.08 Min. :0.00000
## 1st Qu.:17741 Class :character 1st Qu.:25.00 1st Qu.:0.00000
## Median :36932 Mode :character Median :45.00 Median :0.00000
## Mean :36518 Mean :43.23 Mean :0.09746
## 3rd Qu.:54682 3rd Qu.:61.00 3rd Qu.:0.00000
## Max. :72940 Max. :82.00 Max. :1.00000
##
## heart_disease ever_married work_type Residence_type
## Min. :0.00000 Length:5110 Length:5110 Length:5110
## 1st Qu.:0.00000 Class :character Class :character Class :character
## Median :0.00000 Mode :character Mode :character Mode :character
## Mean :0.05401
## 3rd Qu.:0.00000
## Max. :1.00000
##
## avg_glucose_level bmi smoking_status stroke
## Min. : 55.12 Min. :10.30 Length:5110 Min. :0.00000
## 1st Qu.: 77.25 1st Qu.:23.50 Class :character 1st Qu.:0.00000
## Median : 91.89 Median :28.10 Mode :character Median :0.00000
## Mean :106.15 Mean :28.89 Mean :0.04873
## 3rd Qu.:114.09 3rd Qu.:33.10 3rd Qu.:0.00000
## Max. :271.74 Max. :97.60 Max. :1.00000
## NA's :201
str(data)
## Classes 'data.table' and 'data.frame': 5110 obs. of 12 variables:
## $ id : int 9046 51676 31112 60182 1665 56669 53882 10434 27419 60491 ...
## $ gender : chr "Male" "Female" "Male" "Female" ...
## $ age : num 67 61 80 49 79 81 74 69 59 78 ...
## $ hypertension : int 0 0 0 0 1 0 1 0 0 0 ...
## $ heart_disease : int 1 0 1 0 0 0 1 0 0 0 ...
## $ ever_married : chr "Yes" "Yes" "Yes" "Yes" ...
## $ work_type : chr "Private" "Self-employed" "Private" "Private" ...
## $ Residence_type : chr "Urban" "Rural" "Rural" "Urban" ...
## $ avg_glucose_level: num 229 202 106 171 174 ...
## $ bmi : num 36.6 NA 32.5 34.4 24 29 27.4 22.8 NA 24.2 ...
## $ smoking_status : chr "formerly smoked" "never smoked" "never smoked" "smokes" ...
## $ stroke : int 1 1 1 1 1 1 1 1 1 1 ...
## - attr(*, ".internal.selfref")=<externalptr>
We will drop the id variable and also, we will check for columns with missing values.
#Drop ID cols
data$id <- NULL
# Check cols with NA
colnames(data)[colSums(is.na(data)) > 0]
## [1] "bmi"
The bmi variable has missing values. We will replace the missing values with the mean bmi based on patient gender.
# Get BMI per gender
mean_bmi_per_gender <- data %>% group_by(gender) %>% summarise(bmi = mean(bmi, na.rm = TRUE))
# Replace NA in BMI with the mean for each gender
data[gender == 'Female' & is.na(data$bmi), bmi := mean_bmi_per_gender[1, 'bmi']]
data[gender == 'Male' & is.na(data$bmi), bmi := mean_bmi_per_gender[2, 'bmi']]
data[gender == 'Other' & is.na(data$bmi), bmi := mean_bmi_per_gender[3, 'bmi']]
First, we will compare the stroke events based on patient gender:
colors <- c("tomato", "royalblue", "olivedrab1")
tbl <- with(data, table(gender, stroke))
barplot(tbl, legend = TRUE, beside = TRUE, col = colors,
names.arg = c("No Stroke", "Stroke"), main = "Stroke events by gender")
barplot(tbl[, 2], legend = TRUE, col = colors, main = "Confirmed stroke by gender")
Based on the data set, there are more female patients with stroke than males.
Now, we will check the influence of work type in stroke events.
colors <- c("tomato", "royalblue", "olivedrab1", "mediumpurple", "turquoise")
tbl <- with(data, table(work_type, stroke))
barplot(tbl, legend = TRUE, beside = TRUE, col = colors,
names.arg = c("No Stroke", "Stroke"), main = "Stroke events by patient's work type")
barplot(tbl[, 2], col = colors, main = "Confirmed stroke events by patient's work type")
Most patients in the data set have jobs on the private sector. That is reflected in both charts where private sector is the highest value for both, no stroke and confirmed stroke.
The real influence of work type on stroke events need more investigation.
Next, we will compare the residence of the patients with the stroke events registered.
colors <- c("tomato", "royalblue")
tbl <- with(data, table(Residence_type, stroke))
barplot(tbl, legend = TRUE, beside = TRUE, col = colors,
names.arg = c("No Stroke", "Stroke"),
main = "Stroke events by patient's Residence type")
barplot(tbl[, 2], col = colors,
main = "Confirmed stroke events by patient's Residence type")
There are not too much difference between patients from rural or urban areas.
Now, we will check the relation of age and stroke events.
tbl <- with(data, table(age, stroke))
barplot(tbl[, 1], col = "royalblue", main = "Patients without stroke by age")
barplot(tbl[, 2], col = "tomato", main = "Patients with stroke events by age")
On the second charts we can see that older patients have an higher chance of stroke.
Next, the relation between smoking habits and stroke events.
colors <- c("tomato", "royalblue", "olivedrab1", "mediumpurple")
tbl <- with(data, table(smoking_status, stroke))
barplot(tbl, legend = TRUE, beside = TRUE, col = colors,
names.arg = c("No Stroke", "Stroke"), main = "Stroke events by smoking habits")
barplot(tbl[, 2], col = colors,
main = "Confirmed stroke events by smoking habits")
Surprisingly, patients that never smoked or that smoked in the past have more stroke events than active smokers. Although, we have to keep in mind that a notable portion of the data doesn’t have a clear register of the smoking habits of the patient, represented by the unknown category.
We will compare the hypertension factor with stroke events now.
colors <- c("royalblue", "tomato")
tbl <- with(data, table(hypertension, stroke))
barplot(tbl, legend = TRUE, legend.text = c("Hypertension", "No Hypertension"),
beside = TRUE, col = colors,
names.arg = c("No Stroke", "Stroke"),
main = "Stroke events by hypertension diagnosis")
barplot(tbl[, 2], col = colors,
main = "Confirmed stroke events by hypertension diagnosis",
names.arg = c("Without Hypertension", "With Hypertension"))
Again it’s surprising that most confirmed stroke events are from patients without an hypertension diagnosis.
The next analysis will be the comparations of stroke events and a heart disease background.
colors <- c("royalblue", "tomato")
tbl <- with(data, table(heart_disease, stroke))
barplot(tbl, legend = TRUE, legend.text = c("Without heart disease", "With heart disease"),
beside = TRUE, col = colors,
names.arg = c('No Stroke', 'Stroke'),
main = "Stroke events by heart disease background")
barplot(tbl[, 2], col = colors, main = "Confirmed stroke events by heart disease background",
names.arg = c("Without heart disease", "With heart disease"))
As shown in the second chart, most of patients with stroke don’t have heart diseases.
Finally, we will check the distribution of BMI and the average glucose level of patients.
hist(data$bmi, col = "royalblue", main = "BMI distribution", xlab = 'BMI')
hist(data$avg_glucose_level, col = "tomato", main = "Average glucose levels",
xlab = "Average glucose levels")
Before training models, the data must be prepared. We decided to use the one-hot encoding technique, converting the categorical variables into multiple ones, each one with a value of 0 or 1.
Also, age, average glucose level and bmi variables will be standarized.
data$age <- (data$age - mean(data$age)) / sd(data$age)
data$bmi <- (data$bmi - mean(data$bmi)) / sd(data$bmi)
data$avg_glucose_level <- (data$avg_glucose_level - mean(data$avg_glucose_level)) / sd(data$avg_glucose_level)
dummy <- dummyVars(" ~ . ", data = data)
data <- data.frame(predict(dummy, newdata = data))
Now, we will check for class imabalance.
table(data$stroke)
##
## 0 1
## 4861 249
We will use MWMOTE (Majority Weighted Minority Oversampling Technique) for oversampling the stroke class and reduce class imbalance.
oversampled <- mwmote(data, classAttr = "stroke", numInstances = 500)
oversampled <- round(oversampled)
First, we need to create a training and testing data set. We will use an 80:20 approach, 80% of the data to the training set and 20% for the final testing.
We also will use the K-folds cross validation method with K = 5 on the training set.
set.seed(1203)
fullData <- rbind(data, oversampled)
# Target class needs to be a factor
fullData$stroke <- factor(fullData$stroke)
sample <- createDataPartition(y = fullData$stroke, p = 0.8, list = FALSE)
train <- fullData[sample, ]
test <- fullData[-sample, ]
train_control <- trainControl(method = "cv", number = 5)
The models to evaluate are Random Forest, K-Nearest Neighbor and Logistic Regression
1. Random Forest
randomForest <- train(stroke ~ ., data = train, method = "rf", trControl = train_control)
randomForest
## Random Forest
##
## 4489 samples
## 21 predictor
## 2 classes: '0', '1'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 3591, 3591, 3591, 3592, 3591
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.8850524 0.2193004
## 11 0.9556694 0.7806806
## 21 0.9536647 0.7731322
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 11.
2. K-Nearest Neighbor
knn <- train(stroke~., data = train, method = "knn", trControl = train_control)
knn
## k-Nearest Neighbors
##
## 4489 samples
## 21 predictor
## 2 classes: '0', '1'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 3592, 3591, 3591, 3591, 3591
## Resampling results across tuning parameters:
##
## k Accuracy Kappa
## 5 0.8663384 0.3410047
## 7 0.8643320 0.2785663
## 9 0.8705720 0.2851689
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 9.
3. Logistic Regression
logisticRegression <- train(stroke~., data = train, method = "glm",
trControl = train_control,
family = "binomial")
logisticRegression
## Generalized Linear Model
##
## 4489 samples
## 21 predictor
## 2 classes: '0', '1'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 3591, 3592, 3591, 3591, 3591
## Resampling results:
##
## Accuracy Kappa
## 0.870796 0.2169308
The model with the best accuracy is Random Forest with an accuracy of 95.56%, so that is the model we will test in the following step.
We will do a test using the testing set created before and creating a confusion matrix to evaluate the results. The positive class is 1 which correspond to a stroke.
test$prediction <- predict(randomForest, newdata = test)
test$prediction <- as.character(test$prediction)
conf_matrix <- evaluate(test, target_col = "stroke", prediction_cols = "prediction",
type = "binomial", positive = "1")
plot_confusion_matrix(conf_matrix)
As seen in the confusion matrix, the model has a high Specifity (99.8%), correctly predicting patients without stroke, but has a Sensitivity (correctly predicting stroke) of 60.4%, which should be higher for real world use.
The original dataset has a very high class imbalance, so an oversampling technique was used. it would be interesting to try to create a new model using a balanced dataset.
This analysis is a good practical exercice and study material but the models created here are not useful for real world applications because they will need further validation and research.