Summary

This study evaluates three different classification models for heart stroke prediction. The models are a Random Forest, a K-Nearest Neighbor and a Logistic Regression model.

The data used comes from a Stroke Prediction dataset available at Kaggle in the following link: https://www.kaggle.com/fedesoriano/stroke-prediction-dataset/metadata

This study and the models created are only for educational purposes.

Load data and libraries

library(tidyverse)
library(data.table)
library(caret)
library(pROC)
library(cvms)
library(imbalance)

data <- read.csv('./data/healthcare-dataset-stroke-data.csv', na.strings = c('N/A'))
data <- as.data.table(data)
summary(data)
##        id           gender               age         hypertension    
##  Min.   :   67   Length:5110        Min.   : 0.08   Min.   :0.00000  
##  1st Qu.:17741   Class :character   1st Qu.:25.00   1st Qu.:0.00000  
##  Median :36932   Mode  :character   Median :45.00   Median :0.00000  
##  Mean   :36518                      Mean   :43.23   Mean   :0.09746  
##  3rd Qu.:54682                      3rd Qu.:61.00   3rd Qu.:0.00000  
##  Max.   :72940                      Max.   :82.00   Max.   :1.00000  
##                                                                      
##  heart_disease     ever_married        work_type         Residence_type    
##  Min.   :0.00000   Length:5110        Length:5110        Length:5110       
##  1st Qu.:0.00000   Class :character   Class :character   Class :character  
##  Median :0.00000   Mode  :character   Mode  :character   Mode  :character  
##  Mean   :0.05401                                                           
##  3rd Qu.:0.00000                                                           
##  Max.   :1.00000                                                           
##                                                                            
##  avg_glucose_level      bmi        smoking_status         stroke       
##  Min.   : 55.12    Min.   :10.30   Length:5110        Min.   :0.00000  
##  1st Qu.: 77.25    1st Qu.:23.50   Class :character   1st Qu.:0.00000  
##  Median : 91.89    Median :28.10   Mode  :character   Median :0.00000  
##  Mean   :106.15    Mean   :28.89                      Mean   :0.04873  
##  3rd Qu.:114.09    3rd Qu.:33.10                      3rd Qu.:0.00000  
##  Max.   :271.74    Max.   :97.60                      Max.   :1.00000  
##                    NA's   :201
str(data)
## Classes 'data.table' and 'data.frame':   5110 obs. of  12 variables:
##  $ id               : int  9046 51676 31112 60182 1665 56669 53882 10434 27419 60491 ...
##  $ gender           : chr  "Male" "Female" "Male" "Female" ...
##  $ age              : num  67 61 80 49 79 81 74 69 59 78 ...
##  $ hypertension     : int  0 0 0 0 1 0 1 0 0 0 ...
##  $ heart_disease    : int  1 0 1 0 0 0 1 0 0 0 ...
##  $ ever_married     : chr  "Yes" "Yes" "Yes" "Yes" ...
##  $ work_type        : chr  "Private" "Self-employed" "Private" "Private" ...
##  $ Residence_type   : chr  "Urban" "Rural" "Rural" "Urban" ...
##  $ avg_glucose_level: num  229 202 106 171 174 ...
##  $ bmi              : num  36.6 NA 32.5 34.4 24 29 27.4 22.8 NA 24.2 ...
##  $ smoking_status   : chr  "formerly smoked" "never smoked" "never smoked" "smokes" ...
##  $ stroke           : int  1 1 1 1 1 1 1 1 1 1 ...
##  - attr(*, ".internal.selfref")=<externalptr>

We will drop the id variable and also, we will check for columns with missing values.

#Drop ID cols
data$id <- NULL

# Check cols with NA
colnames(data)[colSums(is.na(data)) > 0]
## [1] "bmi"

The bmi variable has missing values. We will replace the missing values with the mean bmi based on patient gender.

# Get BMI per gender
mean_bmi_per_gender <- data %>% group_by(gender) %>% summarise(bmi = mean(bmi, na.rm = TRUE))


# Replace NA in BMI with the mean for each gender
data[gender == 'Female' & is.na(data$bmi), bmi := mean_bmi_per_gender[1, 'bmi']]
data[gender == 'Male'   & is.na(data$bmi), bmi := mean_bmi_per_gender[2, 'bmi']]
data[gender == 'Other'  & is.na(data$bmi), bmi := mean_bmi_per_gender[3, 'bmi']]

Exploratory Data Analysis

First, we will compare the stroke events based on patient gender:

colors <- c("tomato", "royalblue", "olivedrab1")


tbl <- with(data, table(gender, stroke))

barplot(tbl, legend = TRUE, beside = TRUE, col = colors,
        names.arg = c("No Stroke", "Stroke"), main = "Stroke events by gender")

barplot(tbl[, 2], legend = TRUE, col = colors, main = "Confirmed stroke by gender")

Based on the data set, there are more female patients with stroke than males.

Now, we will check the influence of work type in stroke events.

colors <- c("tomato", "royalblue", "olivedrab1", "mediumpurple", "turquoise")
  
tbl <- with(data, table(work_type, stroke))

barplot(tbl, legend = TRUE, beside = TRUE, col = colors,
        names.arg = c("No Stroke", "Stroke"), main = "Stroke events by patient's work type")

barplot(tbl[, 2], col = colors, main = "Confirmed stroke events by patient's work type")

Most patients in the data set have jobs on the private sector. That is reflected in both charts where private sector is the highest value for both, no stroke and confirmed stroke.

The real influence of work type on stroke events need more investigation.

Next, we will compare the residence of the patients with the stroke events registered.

colors <- c("tomato", "royalblue")

tbl <- with(data, table(Residence_type, stroke))

barplot(tbl, legend = TRUE, beside = TRUE, col = colors, 
        names.arg = c("No Stroke", "Stroke"),
        main = "Stroke events by patient's Residence type")

barplot(tbl[, 2], col = colors,
        main = "Confirmed stroke events by patient's Residence type")

There are not too much difference between patients from rural or urban areas.

Now, we will check the relation of age and stroke events.

tbl <- with(data, table(age, stroke))

barplot(tbl[, 1], col = "royalblue", main = "Patients without stroke by age")

barplot(tbl[, 2], col = "tomato", main = "Patients with stroke events by age")

On the second charts we can see that older patients have an higher chance of stroke.

Next, the relation between smoking habits and stroke events.

colors <- c("tomato", "royalblue", "olivedrab1", "mediumpurple")

tbl <- with(data, table(smoking_status, stroke))

barplot(tbl, legend = TRUE, beside = TRUE, col = colors,
        names.arg = c("No Stroke", "Stroke"), main = "Stroke events by smoking habits")

barplot(tbl[, 2], col = colors, 
        main = "Confirmed stroke events by smoking habits")

Surprisingly, patients that never smoked or that smoked in the past have more stroke events than active smokers. Although, we have to keep in mind that a notable portion of the data doesn’t have a clear register of the smoking habits of the patient, represented by the unknown category.

We will compare the hypertension factor with stroke events now.

colors <- c("royalblue", "tomato")

tbl <- with(data, table(hypertension, stroke))

barplot(tbl, legend = TRUE, legend.text = c("Hypertension", "No Hypertension"), 
        beside = TRUE, col = colors,
        names.arg = c("No Stroke", "Stroke"), 
        main = "Stroke events by hypertension diagnosis")

barplot(tbl[, 2], col = colors,
        main = "Confirmed stroke events by hypertension diagnosis",
        names.arg = c("Without Hypertension", "With Hypertension"))

Again it’s surprising that most confirmed stroke events are from patients without an hypertension diagnosis.

The next analysis will be the comparations of stroke events and a heart disease background.

colors <- c("royalblue", "tomato")

tbl <- with(data, table(heart_disease, stroke))

barplot(tbl, legend = TRUE, legend.text = c("Without heart disease", "With heart disease"),
        beside = TRUE, col = colors,
        names.arg = c('No Stroke', 'Stroke'), 
        main = "Stroke events by heart disease background")

barplot(tbl[, 2], col = colors, main = "Confirmed stroke events by heart disease background",
        names.arg = c("Without heart disease", "With heart disease"))

As shown in the second chart, most of patients with stroke don’t have heart diseases.

Finally, we will check the distribution of BMI and the average glucose level of patients.

hist(data$bmi, col = "royalblue", main = "BMI distribution", xlab = 'BMI')

hist(data$avg_glucose_level, col = "tomato", main = "Average glucose levels",
     xlab = "Average glucose levels")

Data transformation

Before training models, the data must be prepared. We decided to use the one-hot encoding technique, converting the categorical variables into multiple ones, each one with a value of 0 or 1.

Also, age, average glucose level and bmi variables will be standarized.

data$age <- (data$age - mean(data$age)) / sd(data$age)
data$bmi <- (data$bmi - mean(data$bmi)) / sd(data$bmi)
data$avg_glucose_level <- (data$avg_glucose_level - mean(data$avg_glucose_level)) / sd(data$avg_glucose_level)
dummy <- dummyVars(" ~ . ", data = data)
data <- data.frame(predict(dummy, newdata = data))

Now, we will check for class imabalance.

table(data$stroke)
## 
##    0    1 
## 4861  249

We will use MWMOTE (Majority Weighted Minority Oversampling Technique) for oversampling the stroke class and reduce class imbalance.

oversampled <- mwmote(data, classAttr = "stroke", numInstances = 500)
oversampled <- round(oversampled)

Model training

First, we need to create a training and testing data set. We will use an 80:20 approach, 80% of the data to the training set and 20% for the final testing.

We also will use the K-folds cross validation method with K = 5 on the training set.

set.seed(1203)

fullData <- rbind(data, oversampled)

# Target class needs to be a factor
fullData$stroke <- factor(fullData$stroke)

sample <- createDataPartition(y = fullData$stroke, p = 0.8, list = FALSE)
train <- fullData[sample, ]
test <- fullData[-sample, ]

train_control <- trainControl(method = "cv", number = 5)

The models to evaluate are Random Forest, K-Nearest Neighbor and Logistic Regression

1. Random Forest

randomForest <- train(stroke ~ ., data = train, method = "rf", trControl = train_control)
randomForest
## Random Forest 
## 
## 4489 samples
##   21 predictor
##    2 classes: '0', '1' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 3591, 3591, 3591, 3592, 3591 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.8850524  0.2193004
##   11    0.9556694  0.7806806
##   21    0.9536647  0.7731322
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 11.

2. K-Nearest Neighbor

knn <- train(stroke~., data = train, method = "knn", trControl = train_control)
knn
## k-Nearest Neighbors 
## 
## 4489 samples
##   21 predictor
##    2 classes: '0', '1' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 3592, 3591, 3591, 3591, 3591 
## Resampling results across tuning parameters:
## 
##   k  Accuracy   Kappa    
##   5  0.8663384  0.3410047
##   7  0.8643320  0.2785663
##   9  0.8705720  0.2851689
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 9.

3. Logistic Regression

logisticRegression <- train(stroke~., data = train, method = "glm", 
                            trControl = train_control,
                            family = "binomial")
logisticRegression
## Generalized Linear Model 
## 
## 4489 samples
##   21 predictor
##    2 classes: '0', '1' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 3591, 3592, 3591, 3591, 3591 
## Resampling results:
## 
##   Accuracy  Kappa    
##   0.870796  0.2169308

The model with the best accuracy is Random Forest with an accuracy of 95.56%, so that is the model we will test in the following step.

Model testing

We will do a test using the testing set created before and creating a confusion matrix to evaluate the results. The positive class is 1 which correspond to a stroke.

test$prediction <- predict(randomForest, newdata = test)

test$prediction <- as.character(test$prediction)
conf_matrix <- evaluate(test, target_col = "stroke", prediction_cols = "prediction", 
                        type = "binomial", positive = "1")

plot_confusion_matrix(conf_matrix)

Conclusions

As seen in the confusion matrix, the model has a high Specifity (99.8%), correctly predicting patients without stroke, but has a Sensitivity (correctly predicting stroke) of 60.4%, which should be higher for real world use.

The original dataset has a very high class imbalance, so an oversampling technique was used. it would be interesting to try to create a new model using a balanced dataset.

This analysis is a good practical exercice and study material but the models created here are not useful for real world applications because they will need further validation and research.