WQD7004 Programming for Data Science

Lecturer: Assoc. Prof. Dr. Ang Tan Fong

Group Project - Group 3

GROUP MEMBERS (GROUP 3)

CHEW CHI YEW 17095779
RADZIAH ZAINUDDIN S2156690
FARAH JASMIEN BINTI ZUKARI 17172774
LEE YING QIU 17108552
TAN JIA YUE 17114152

Video: https://drive.google.com/file/d/11VXe5f2Qjx1xg3BoOfQqN-BR9ek4J3fr/view?usp=share_link

1 Introduction

  1. In this Group Assignment, we have adopted the Stroke Prediction Dataset from Kaggle, which study the likelihood of a patient to get stroke based on the input parameters like gender, age, various diseases, and smoking status.
  2. According to the World Health Organization (WHO) the top 10 causes of death accounted for 55% of the 55.4 million deaths worldwide.
  3. Since 2000, the largest increase in deaths has been for this disease, rising by more than 2 million to 8.9 million deaths in 2019.
  4. Stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths.
  5. Therefore, predictive model of likelihood of getting Stroke could help specific high-risk group of individuals/patients to be aware of their health condition and take precaution steps to reduce the risk of getting Stroke.

1.1 Research Questions

  1. Which factors are important in predicting stroke outcome of a patient, i.e., had stroke or not?
  2. How is the performance of machine learning classification and regression models in predicting stroke outcome?
  3. Which machine learning algorithms has the highest accuracy?
  4. With different evaluation metrics, how is the predictive model performing?

1.2 Research Objectives

  1. To build machine learning classification and regression models.
  2. To compare the performance metrics of different machine learning models.
  3. To identify the most important factors associated with stroke outcome.

2 Data Preprocessing

2.2 Load libraries

library(readr)
library(caret)
library(readxl)
library(Metrics)
library(dplyr)
library(ggplot2)
library(skimr)
library(tidyr)
library(reshape2)
library(ggpubr)
library(stringr)
library(e1071)
library(pROC)
library(magrittr)
library(RColorBrewer)
library(performanceEstimation)
library(superml)
library(randomForest)
library(graphics)

2.3 Load Data

In this stage, we will load the data and display the first five rows of the dataset.

#stroke <- read.csv("/cloud/project/Stroke Prediction.csv")
stroke <- read.csv("/cloud/project/Stroke Prediction.csv")
head(stroke)
##      id gender age hypertension heart_disease ever_married     work_type
## 1  9046   Male  67            0             1      Married       Private
## 2 51676 Female  61            0             0          Yes Self-employed
## 3 31112   Male  80            0             1          Yes       Private
## 4 60182 Female  49            0             0          Yes       Private
## 5  1665      F  79            1             0          Yes Self-employed
## 6 56669   Male  81            0             0          Yes       Private
##   Residence_type avg_glucose_level  bmi  smoking_status stroke
## 1          Urban            228.69 36.6 formerly smoked      1
## 2          Rural            202.21  N/A    never smoked      1
## 3          Rural            105.92 32.5    never smoked      1
## 4          Urban            171.23 34.4          smokes      1
## 5          Rural            174.12   24    never smoked      1
## 6          Urban            186.21   29 formerly smoked      1

2.4 Explore dataset

Now, let’s have a look on the internal structure of our data. We can see that there’s a total of 5110 observation with 12 variables in the stroke prediction dataset. From the output presented, we can observe that column type of bmi and stroke column are abnormal.

str(stroke)
## 'data.frame':    5110 obs. of  12 variables:
##  $ id               : int  9046 51676 31112 60182 1665 56669 53882 10434 27419 60491 ...
##  $ gender           : chr  "Male" "Female" "Male" "Female" ...
##  $ age              : num  67 61 80 49 79 81 74 69 59 78 ...
##  $ hypertension     : chr  "0" "0" "0" "0" ...
##  $ heart_disease    : chr  "1" "0" "1" "0" ...
##  $ ever_married     : chr  "Married" "Yes" "Yes" "Yes" ...
##  $ work_type        : chr  "Private" "Self-employed" "Private" "Private" ...
##  $ Residence_type   : chr  "Urban" "Rural" "Rural" "Urban" ...
##  $ avg_glucose_level: num  229 202 106 171 174 ...
##  $ bmi              : chr  "36.6" "N/A" "32.5" "34.4" ...
##  $ smoking_status   : chr  "formerly smoked" "never smoked" "never smoked" "smokes" ...
##  $ stroke           : int  1 1 1 1 1 1 1 1 1 1 ...

2.4.1 Change column data type

Thereafter, we will start the data preprocessing by changing the column data type of bmi to numeric data type and stroke column from integer to character data type.

stroke$bmi =as.numeric(stroke$bmi)
stroke$stroke =as.character(stroke$stroke)

str(stroke)
## 'data.frame':    5110 obs. of  12 variables:
##  $ id               : int  9046 51676 31112 60182 1665 56669 53882 10434 27419 60491 ...
##  $ gender           : chr  "Male" "Female" "Male" "Female" ...
##  $ age              : num  67 61 80 49 79 81 74 69 59 78 ...
##  $ hypertension     : chr  "0" "0" "0" "0" ...
##  $ heart_disease    : chr  "1" "0" "1" "0" ...
##  $ ever_married     : chr  "Married" "Yes" "Yes" "Yes" ...
##  $ work_type        : chr  "Private" "Self-employed" "Private" "Private" ...
##  $ Residence_type   : chr  "Urban" "Rural" "Rural" "Urban" ...
##  $ avg_glucose_level: num  229 202 106 171 174 ...
##  $ bmi              : num  36.6 NA 32.5 34.4 24 29 27.4 22.8 NA 24.2 ...
##  $ smoking_status   : chr  "formerly smoked" "never smoked" "never smoked" "smokes" ...
##  $ stroke           : chr  "1" "1" "1" "1" ...

2.4.2 Check on NA columns for all columns

Next, we will check whether there is any missing values for each column in the dataset.

colSums(is.na(stroke))
##                id            gender               age      hypertension 
##                 0                 0                 0                 0 
##     heart_disease      ever_married         work_type    Residence_type 
##                 0                 0                 0                 0 
## avg_glucose_level               bmi    smoking_status            stroke 
##                 0               188                 0                 0

2.5 Data Preprocessing for Categorical Variables

2.5.1 View categorical variable

Next, we will further look into the summary of every column by category in order for us to discover any abnormaly.Here, we use group by command for the inspection of every variable.

stroke %>%
  group_by(gender) %>%
  summarise(Count = n())
## # A tibble: 5 × 2
##   gender Count
##   <chr>  <int>
## 1 F         20
## 2 Female  2974
## 3 M          9
## 4 Male    2106
## 5 Other      1
stroke %>%
  group_by(hypertension) %>%
  summarise(Count = n())
## # A tibble: 4 × 2
##   hypertension Count
##   <chr>        <int>
## 1 0             4604
## 2 1              489
## 3 No               8
## 4 Yes              9
stroke %>%
  group_by(heart_disease) %>%
  summarise(Count = n())
## # A tibble: 4 × 2
##   heart_disease Count
##   <chr>         <int>
## 1 0              4821
## 2 1               270
## 3 No               13
## 4 Yes               6
stroke %>%
  group_by(ever_married) %>%
  summarise(Count = n())
## # A tibble: 4 × 2
##   ever_married Count
##   <chr>        <int>
## 1 Married         22
## 2 No            1738
## 3 Single          19
## 4 Yes           3331
stroke %>%
  group_by(work_type) %>%
  summarise(Count = n())
## # A tibble: 5 × 2
##   work_type     Count
##   <chr>         <int>
## 1 children        687
## 2 Govt_job        657
## 3 Never_worked     22
## 4 Private        2925
## 5 Self-employed   819
stroke %>%
  group_by(Residence_type) %>%
  summarise(Count = n())
## # A tibble: 2 × 2
##   Residence_type Count
##   <chr>          <int>
## 1 Rural           2514
## 2 Urban           2596
stroke %>%
  group_by(smoking_status) %>%
  summarise(Count = n())
## # A tibble: 4 × 2
##   smoking_status  Count
##   <chr>           <int>
## 1 formerly smoked   885
## 2 never smoked     1892
## 3 smokes            789
## 4 Unknown          1544
stroke %>%
  group_by(stroke) %>%
  summarise(Count = n())
## # A tibble: 2 × 2
##   stroke Count
##   <chr>  <int>
## 1 0       4861
## 2 1        249

2.5.2 Data cleaning for categorical columns

Based on the summary tables presented above, data issues are found in four categorical columns, specifically gender, hypertension, heart_disease and ever_married. First, we will replace some categories with correct strings. For instance, F is replaced with Female as they represent the same category. Then, we will drop the rows by filtering out with some unwanted category such as ‘Other’ presented in gender variable.

stroke1<- stroke %>% 
  mutate(gender = replace(gender, gender == 'F', 'Female')) %>% 
  mutate(gender = replace(gender, gender == 'M', 'Male'))%>% 
  mutate(hypertension = replace(hypertension, hypertension == 'Yes', '1')) %>% 
  mutate(hypertension = replace(hypertension, hypertension == 'No', '0')) %>% 
  mutate(heart_disease = replace(heart_disease, heart_disease == 'Yes', '1')) %>% 
  mutate(heart_disease = replace(heart_disease, heart_disease == 'No', '0')) %>%
  mutate(ever_married = replace(ever_married, ever_married == 'Single', 'No')) %>%
  mutate(ever_married = replace(ever_married, ever_married == 'Married', 'Yes')) %>%
  filter(!((gender=="Other" | smoking_status=="Unknown") ))

2.6 Data Preprocessing for Numerical Variables

2.6.1 View numerical variable

From the summary table, we can see that there are some issues on the numerical variables. For example, there are negative values in the average glucose level column and at the same time, it is impossible for bmi to have a zero value.

summary(stroke1[c("age","bmi","avg_glucose_level")])
##       age             bmi        avg_glucose_level
##  Min.   :10.00   Min.   : 0.00   Min.   :-129.54  
##  1st Qu.:34.00   1st Qu.:25.20   1st Qu.:  77.16  
##  Median :50.00   Median :29.10   Median :  92.37  
##  Mean   :48.86   Mean   :30.25   Mean   : 108.21  
##  3rd Qu.:63.00   3rd Qu.:34.10   3rd Qu.: 116.44  
##  Max.   :82.00   Max.   :92.00   Max.   : 271.74  
##                  NA's   :135

2.6.2 Impute with median values for numerical columns

As shown in the table above, we know that there are about 188 missing values in the bmi column. To deal with this problem, we will impute and replace them with median of the bmi variable. Subsequently, we can see that there is no N/A values once the code has been executed.

stroke1$bmi[is.na(stroke1$bmi)] <- median(stroke1$bmi, na.rm = T)
stroke1$avg_glucose_level[is.na(stroke1$avg_glucose_level)] <- median(stroke1$avg_glucose_level, na.rm = T)
colSums(is.na(stroke1))
##                id            gender               age      hypertension 
##                 0                 0                 0                 0 
##     heart_disease      ever_married         work_type    Residence_type 
##                 0                 0                 0                 0 
## avg_glucose_level               bmi    smoking_status            stroke 
##                 0                 0                 0                 0

2.6.3 Delete rows with negative values and exclude zero values in continuous columns

To solve the problems of negative values and zero values presented in the numerical columns, we will filter out for any of the numerical values which are less than 0.5.

stroke1 %>% filter_if(is.numeric, ~ .x >0.5)

2.6.4 Briefly view the statistics of the data

After all these pre-processing, we can have a glance on the cleaned dataset.

head(stroke1)
##      id gender age hypertension heart_disease ever_married     work_type
## 1  9046   Male  67            0             1          Yes       Private
## 2 51676 Female  61            0             0          Yes Self-employed
## 3 31112   Male  80            0             1          Yes       Private
## 4 60182 Female  49            0             0          Yes       Private
## 5  1665 Female  79            1             0          Yes Self-employed
## 6 56669   Male  81            0             0          Yes       Private
##   Residence_type avg_glucose_level  bmi  smoking_status stroke
## 1          Urban            228.69 36.6 formerly smoked      1
## 2          Rural            202.21 29.1    never smoked      1
## 3          Rural            105.92 32.5    never smoked      1
## 4          Urban            171.23 34.4          smokes      1
## 5          Rural            174.12 24.0    never smoked      1
## 6          Urban            186.21 29.0 formerly smoked      1
#skimr::skim_without_charts(stroke1)
skimr::skim(stroke1)
Data summary
Name stroke1
Number of rows 3565
Number of columns 12
_______________________
Column type frequency:
character 8
numeric 4
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
gender 0 1 4 6 0 2 0
hypertension 0 1 1 1 0 2 0
heart_disease 0 1 1 1 0 2 0
ever_married 0 1 2 3 0 2 0
work_type 0 1 7 13 0 5 0
Residence_type 0 1 5 5 0 2 0
smoking_status 0 1 6 15 0 3 0
stroke 0 1 1 1 0 2 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
id 0 1 36780.32 21240.50 67.00 18040.00 37446.00 54946.00 72915.00 ▇▇▇▇▇
age 0 1 48.86 18.87 10.00 34.00 50.00 63.00 82.00 ▃▆▇▇▆
avg_glucose_level 0 1 108.21 49.53 -129.54 77.16 92.37 116.44 271.74 ▁▁▇▂▂
bmi 0 1 30.20 7.24 0.00 25.40 29.10 33.80 92.00 ▁▇▂▁▁

2.6.5 Save cleaned dataset

Lastly, we will end this stage by saving the cleaned dataset into our local file.

write_csv(stroke1, path="Stroke Prediction_cleaned.csv")
## Warning: The `path` argument of `write_csv()` is deprecated as of readr 1.4.0.
## ℹ Please use the `file` argument instead.

3 Exploratory Data Analysis (EDA)

Exploratory Data Analysis (EDA) is a process where we run several explorations to get a glimpse of insights on the dataset used for better understanding.

3.1 Data Summarization

After the data pre-processing, let’s us analyse the latest variables and their data type before further exploration.

str(stroke1)
## 'data.frame':    3565 obs. of  12 variables:
##  $ id               : int  9046 51676 31112 60182 1665 56669 53882 10434 12109 12095 ...
##  $ gender           : chr  "Male" "Female" "Male" "Female" ...
##  $ age              : num  67 61 80 49 79 81 74 69 81 61 ...
##  $ hypertension     : chr  "0" "0" "0" "0" ...
##  $ heart_disease    : chr  "1" "0" "1" "0" ...
##  $ ever_married     : chr  "Yes" "Yes" "Yes" "Yes" ...
##  $ work_type        : chr  "Private" "Self-employed" "Private" "Private" ...
##  $ Residence_type   : chr  "Urban" "Rural" "Rural" "Urban" ...
##  $ avg_glucose_level: num  229 202 106 171 174 ...
##  $ bmi              : num  36.6 29.1 32.5 34.4 24 29 27.4 22.8 29.7 36.8 ...
##  $ smoking_status   : chr  "formerly smoked" "never smoked" "never smoked" "smokes" ...
##  $ stroke           : chr  "1" "1" "1" "1" ...

3.1.1 Central Tendency

We will explore the mean, median and mode in each numerical variable, but, since Mode function is not available by default in R, we will manually create one as shown below.

Mode = function(x){
  ta = table(x)
  tam = max(ta)
  if (all(ta == tam))
    mod = NA
  else
    if(is.numeric(x))
      mod = as.numeric(names(ta)[ta == tam])
  else
    mod = names(ta)[ta == tam]
  return(mod)
}
#age
age1 = stroke1$age
mean(age1)
## [1] 48.86031
median(age1)
## [1] 50
Mode(age1) #the Mode function is created separately in Mode Function.R
## [1] 54
#avg_glucose_level
agl1 = stroke1$avg_glucose_level
mean(agl1)
## [1] 108.2052
median(agl1)
## [1] 92.37
Mode(agl1)
## [1] 0
#bmi
bmi1 = stroke1$bmi
mean(bmi1)
## [1] 30.20471
median(bmi1)
## [1] 29.1
Mode(bmi1)
## [1] 29.1

3.1.2 Measure of Variability

Calculating standard deviation and variance for the numerical variables.

#age
sd(age1)
## [1] 18.87314
var(age1)
## [1] 356.1954
#avg_glucose_level
sd(agl1)
## [1] 49.53281
var(agl1)
## [1] 2453.5
#bmi
sd(bmi1)
## [1] 7.243686
var(bmi1)
## [1] 52.47099

3.2 Data Visualization

3.2.1 Smoothen Data

Smoothing data that have multiple range of values such as age, avg_glucose_level and bmi because when we tried visualize it for EDA, the distribution is hard to be read.

The level of average glucose is based on reference highlighted in the Center for Disease Control and Prevention, as well as myDr from Australia.

stroke2<-stroke1
#Smoothing age column
stroke2["age"] = cut(stroke2$age, c(0, 10, 20, 30, 40, 50, 60,  70, 80, Inf), c("0-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70-79",  ">80"), include.lowest=TRUE)

#Smoothing avg_glucose level column by converting into 3 levels "Normal", "Pre-diabetes" and "Diabetes" based on Glucose Tolerance Test indicator from https://www.cdc.gov/diabetes/basics/getting-tested.html#:~:text=A%20fasting%20blood%20sugar%20level,higher%20indicates%20you%20have%20diabetes.
#negative avg glucose level - https://www.mydr.com.au/tests-investigations/diabetes-and-urine-glucose-monitoring/

stroke2["avg_glucose_level"] = cut(stroke2$avg_glucose_level, c(-Inf, 0, 140, 200, Inf), c("Hypoglycaemia", "Normal", "Pre-diabetes",  "Diabetes"), include.lowest=TRUE)

#Smoothing bmi column based on https://www.cdc.gov/healthyweight/assessing/bmi/adult_bmi/index.html interpretation.
stroke2["bmi"] = cut(stroke2$bmi, c(0, 18.5, 25, 30, Inf), c("Underweight", "Healthy",  "Overweight", "Obesity"), include.lowest=TRUE)

3.2.2 Plotting Graph for Each Column

By plotting the graph for each column, we could see the distribution and the trend of the data in each column. We may also identify any extreme values in a particular column and ensure there’s no null values available after data pre-processing.

for (i in names(stroke2)){
  barplot(table(stroke2[i]), main = i, col = coul <- brewer.pal(5, "Set2") )
}

3.2.3 Identifying percentage distribution by age group

Here, we are going to identify the percentage of how many persons will get stroke based on their age group. We utilize the age group that has been smoothen based on the above task.

# Plotting percentage distribution of stroke by age group
temp <-subset.data.frame(stroke2, select = c(age, stroke)) %>% 
  group_by(age, stroke) %>% 
  summarize(n = n(), .groups = 'drop') %>% 
  group_by(age) %>% 
  summarize(percentage_stroke = round(sum(n[stroke==1]/sum(n)), digits=2), 
            percentage_not_stroke = 1-percentage_stroke, .groups = 'drop') %>% 
  arrange(desc(percentage_stroke)) %>% 
  slice(1:9)
melted_temp<-melt(temp, id = "age")

ggplot(melted_temp, aes(x = age, y = value, fill = variable)) + 
  geom_bar(position = "fill", 
           stat = "identity", 
           color = "black", 
           width = 1) + 
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.6)) + 
  scale_y_continuous(labels = scales::percent) + 
  geom_text(aes(label = paste0(value*100,"%")), 
            position = position_stack(vjust = 0.6), size = 2) + 
  ggtitle("Percentage Distribution of Stroke by Age Group") + 
  xlab("Age Group") + 
  ylab("Stroke (percentage)") 

#checking any missing values after EDA
colSums(is.na(stroke2))
##                id            gender               age      hypertension 
##                 0                 0                 0                 0 
##     heart_disease      ever_married         work_type    Residence_type 
##                 0                 0                 0                 0 
## avg_glucose_level               bmi    smoking_status            stroke 
##                 0                 0                 0                 0

3.2.4 Determining correlations between the features

To study the correlation between the features, we use heatmap to visualize the relation. We first transform the qualitative features into quantitative.

# transform qualitative features into quantitative features
stroke3<-stroke1
# gender
stroke3$gender[stroke3$gender == "Male"] <- 0
stroke3$gender[stroke3$gender == "Female"] <- 1
# ever_married
stroke3$ever_married[stroke3$ever_married == "Yes"] <- 1
stroke3$ever_married[stroke3$ever_married == "No"] <- 0
# work_type
stroke3$work_type[stroke3$work_type == "Private"] <- 0
stroke3$work_type[stroke3$work_type == "children"] <- 1
stroke3$work_type[stroke3$work_type == "Govt_job"] <- 2
stroke3$work_type[stroke3$work_type == "Never_worked"] <- 3
stroke3$work_type[stroke3$work_type == "Self-employed"] <- 4
# Residence_type
stroke3$Residence_type[stroke3$Residence_type == "Rural"] <- 0
stroke3$Residence_type[stroke3$Residence_type == "Urban"] <- 1
# smoking_status
stroke3$smoking_status[stroke3$smoking_status == "formerly smoked"] <- 0
stroke3$smoking_status[stroke3$smoking_status == "never smoked"] <- 1
stroke3$smoking_status[stroke3$smoking_status == "smokes"] <- 2

# convert characters into numbers
for (i in colnames(stroke3)){
    suppressWarnings(stroke3[[i]] <- as.numeric(as.character(stroke3[[i]])))
}
# remove column "id" to build correlation coefficient matrix
stroke_hm <- subset(stroke3, select= -id)
corrmatrix <- round(cor(stroke_hm),2)
melted_corrmat <- melt(corrmatrix)
head(melted_corrmat)
##            Var1   Var2 value
## 1        gender gender  1.00
## 2           age gender -0.04
## 3  hypertension gender -0.03
## 4 heart_disease gender -0.10
## 5  ever_married gender -0.02
## 6     work_type gender  0.01
# plot the correlation heatmap
ggplot(data = melted_corrmat, aes(x=Var1, y=Var2, fill=value)) + 
  geom_tile()+
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                       midpoint = 0, limit = c(-1,1), space = "Lab", name="Pearson\nCorrelation")+
  geom_text(aes(Var2, Var1, label = value), color = "black", size = 4)+
  theme(axis.text.x = element_text(angle = 45, vjust = 1, 
                                   size = 12, hjust = 1))

From the heatmap, the strongest positive correlation is between ever_married and age with coefficient of 0.52. Since the Pearson coefficient of these features does not exceed 0.7, we conclude that there is no presence of multi-collinearity in this dataset. Other notable correlation are between age with a few features which are hypertension, heart_disease, work_type, avg_glucose_level and stroke. We can conclude that age can be the main factor for stroke to happen.

4 Feature Scaling

There are 3 numerical variable (e.g., age, average glucose level, bmi) with different scale. Hence, to ensure a standardize scale among the 3 variables, min-max scaling is performed.

summary(stroke3)
##        id            gender            age         hypertension   
##  Min.   :   67   Min.   :0.0000   Min.   :10.00   Min.   :0.0000  
##  1st Qu.:18040   1st Qu.:0.0000   1st Qu.:34.00   1st Qu.:0.0000  
##  Median :37446   Median :1.0000   Median :50.00   Median :0.0000  
##  Mean   :36780   Mean   :0.6053   Mean   :48.86   Mean   :0.1251  
##  3rd Qu.:54946   3rd Qu.:1.0000   3rd Qu.:63.00   3rd Qu.:0.0000  
##  Max.   :72915   Max.   :1.0000   Max.   :82.00   Max.   :1.0000  
##  heart_disease      ever_married      work_type     Residence_type  
##  Min.   :0.00000   Min.   :0.0000   Min.   :0.000   Min.   :0.0000  
##  1st Qu.:0.00000   1st Qu.:1.0000   1st Qu.:0.000   1st Qu.:0.0000  
##  Median :0.00000   Median :1.0000   Median :0.000   Median :1.0000  
##  Mean   :0.06396   Mean   :0.7602   Mean   :1.075   Mean   :0.5088  
##  3rd Qu.:0.00000   3rd Qu.:1.0000   3rd Qu.:2.000   3rd Qu.:1.0000  
##  Max.   :1.00000   Max.   :1.0000   Max.   :4.000   Max.   :1.0000  
##  avg_glucose_level      bmi       smoking_status       stroke       
##  Min.   :-129.54   Min.   : 0.0   Min.   :0.0000   Min.   :0.00000  
##  1st Qu.:  77.16   1st Qu.:25.4   1st Qu.:1.0000   1st Qu.:0.00000  
##  Median :  92.37   Median :29.1   Median :1.0000   Median :0.00000  
##  Mean   : 108.21   Mean   :30.2   Mean   :0.9734   Mean   :0.05666  
##  3rd Qu.: 116.44   3rd Qu.:33.8   3rd Qu.:1.0000   3rd Qu.:0.00000  
##  Max.   : 271.74   Max.   :92.0   Max.   :2.0000   Max.   :1.00000
preproc <- preProcess(stroke3[,c(3,9,10)], method=c("range"))
norm_data <- predict(preproc, stroke3)
summary(norm_data)
##        id            gender            age          hypertension   
##  Min.   :   67   Min.   :0.0000   Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:18040   1st Qu.:0.0000   1st Qu.:0.3333   1st Qu.:0.0000  
##  Median :37446   Median :1.0000   Median :0.5556   Median :0.0000  
##  Mean   :36780   Mean   :0.6053   Mean   :0.5397   Mean   :0.1251  
##  3rd Qu.:54946   3rd Qu.:1.0000   3rd Qu.:0.7361   3rd Qu.:0.0000  
##  Max.   :72915   Max.   :1.0000   Max.   :1.0000   Max.   :1.0000  
##  heart_disease      ever_married      work_type     Residence_type  
##  Min.   :0.00000   Min.   :0.0000   Min.   :0.000   Min.   :0.0000  
##  1st Qu.:0.00000   1st Qu.:1.0000   1st Qu.:0.000   1st Qu.:0.0000  
##  Median :0.00000   Median :1.0000   Median :0.000   Median :1.0000  
##  Mean   :0.06396   Mean   :0.7602   Mean   :1.075   Mean   :0.5088  
##  3rd Qu.:0.00000   3rd Qu.:1.0000   3rd Qu.:2.000   3rd Qu.:1.0000  
##  Max.   :1.00000   Max.   :1.0000   Max.   :4.000   Max.   :1.0000  
##  avg_glucose_level      bmi         smoking_status       stroke       
##  Min.   :0.0000    Min.   :0.0000   Min.   :0.0000   Min.   :0.00000  
##  1st Qu.:0.5151    1st Qu.:0.2761   1st Qu.:1.0000   1st Qu.:0.00000  
##  Median :0.5530    Median :0.3163   Median :1.0000   Median :0.00000  
##  Mean   :0.5925    Mean   :0.3283   Mean   :0.9734   Mean   :0.05666  
##  3rd Qu.:0.6130    3rd Qu.:0.3674   3rd Qu.:1.0000   3rd Qu.:0.00000  
##  Max.   :1.0000    Max.   :1.0000   Max.   :2.0000   Max.   :1.00000

5 Data Partition

The dataset is split into 80% of training data and 20% of testing data. A total of 2852 rows has been partitioned into training data and 713 rows into testing data.

library(caret)
# createDataPartition function 
stroke4 <- norm_data
set.seed(123)
trainIndex <- createDataPartition(stroke4$stroke, p=0.8, list=FALSE) # stratified random split of 80:20

train_data <- stroke4[trainIndex,] 
test_data <- stroke4[-trainIndex,]

# check number of rows
nrow(train_data) 
## [1] 2852
nrow(train_data[train_data$stroke=="1",]) 
## [1] 163
nrow(train_data[train_data$stroke=="0",]) 
## [1] 2689
nrow(test_data) 
## [1] 713
nrow(test_data[test_data$stroke=="1",]) 
## [1] 39
nrow(test_data[test_data$stroke=="0",]) 
## [1] 674

6 SMOTE to Balance the Dataset

As the dataset has uneven proportion of cases for each class where 94% (2689 rows) are non-stroke cases and 6% (163 rows) are stroke cases. It is handled by SMOTE (Synthetic Minority Oversampling Technique) which is a popular method for dealing with imbalanced data sets, which often occur in classification problems where there are a disproportionate number of observations in one class compared to the other. SMOTE creates synthetic examples of the minority class in order to balance the class distribution. After applying the SMOTE,stroke cases increased to about 41% and non-stroke cases is about 59%.

train_data$stroke <- as.factor(train_data$stroke)
levels(train_data$stroke) <- c("Non_Stroke","Stroke")

# perc.over refers to the oversampling of minority cases and perc.under refers to the undersampling of majority cases
# set perc.over = 10 (10 times of over-sample in the minority cases)
# set perc.under = 1.6 (1.6 times of under-sample in majority cases)
set.seed(789)
smote_train_data = smote(train_data$stroke~., train_data, perc.over = 10, k = 5, perc.under = 1.6)

# Check number of rows
nrow(smote_train_data) 
## [1] 4401
nrow(smote_train_data[smote_train_data$stroke=="Stroke",]) 
## [1] 1793
nrow(smote_train_data[smote_train_data$stroke=="Non_Stroke",])
## [1] 2608

7 Data Modeling

7.1 Random Forest Classification

It is an ensemble method that creates multiple decision trees using random subsets of the input data and features, and then combines their predictions to make the final decision.

  1. Handling of non-linear relationships: Random Forest can handle non-linear relationships between variables, which is important for stroke prediction because the relationship between risk factors and stroke is often complex.

  2. Handling of categorical variables: Random Forest can handle categorical variables, which are commonly found in medical datasets.

  3. Feature importance: Random Forest can provide an estimate of feature importance, this can help in identifying the most important features in stroke prediction, which can help in understanding the underlying mechanisms that lead to stroke and inform preventative strategies.

table(smote_train_data$stroke)
## 
## Non_Stroke     Stroke 
##       2608       1793
smote_train_data01 <- smote_train_data[-1] # remove id column
floor(sqrt(ncol(smote_train_data01) - 1)) #3
## [1] 3
set.seed(1122)
mtry <- tuneRF(smote_train_data01[-length(smote_train_data01)],smote_train_data01$stroke, ntreeTry=500,stepFactor=1.5,improve=0.01, trace=TRUE, plot=TRUE)
## mtry = 3  OOB error = 4.54% 
## Searching left ...
## mtry = 2     OOB error = 6.23% 
## -0.37 0.01 
## Searching right ...
## mtry = 4     OOB error = 4.41% 
## 0.03 0.01 
## mtry = 6     OOB error = 4.61% 
## -0.04639175 0.01

best_m <- mtry[mtry[, 2] == min(mtry[, 2]), 1]
print(mtry)
##       mtry   OOBError
## 2.OOB    2 0.06225858
## 3.OOB    3 0.04544422
## 4.OOB    4 0.04408089
## 6.OOB    6 0.04612588
print(best_m) #4
## [1] 4
set.seed(1122)
rf <-randomForest(stroke ~ ., data=smote_train_data01, mtry=best_m, importance=TRUE,ntree=500)
print(rf)
## 
## Call:
##  randomForest(formula = stroke ~ ., data = smote_train_data01,      mtry = best_m, importance = TRUE, ntree = 500) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 4
## 
##         OOB estimate of  error rate: 4.84%
## Confusion matrix:
##            Non_Stroke Stroke class.error
## Non_Stroke       2509     99  0.03796012
## Stroke            114   1679  0.06358059
#Evaluate variable importance
importance(rf)
##                   Non_Stroke    Stroke MeanDecreaseAccuracy MeanDecreaseGini
## gender              41.73723  65.94787             70.04894         60.61531
## age                120.74353 151.19372            165.53890        806.47361
## hypertension        29.18763  37.96014             38.89608         48.20356
## heart_disease       29.98562  49.79057             48.02564         68.62065
## ever_married        34.64929  35.45674             40.94102         71.88809
## work_type           50.59063  63.05101             64.36997        140.27072
## Residence_type      42.68955  59.97943             63.61224         60.27460
## avg_glucose_level   77.27040 105.65118            115.44780        342.34315
## bmi                 69.23443  87.68249            101.86878        261.72633
## smoking_status      47.68864  84.66070             74.01676        244.67301
varImpPlot(rf)

Accuracy (Random Forest Classification):0.8315

rf_test_data <- test_data[-1]
rf_test_data$stroke <- as.factor(rf_test_data$stroke)
levels(rf_test_data$stroke) <- c("Non_Stroke","Stroke")

predicted <- predict(rf, rf_test_data)
conf_matrix <- confusionMatrix(as.factor(predicted), as.factor(rf_test_data$stroke), positive = "Stroke")
conf_matrix$table
##             Reference
## Prediction   Non_Stroke Stroke
##   Non_Stroke        622     31
##   Stroke             52      8
# Evaluate model performance
accuracy <- conf_matrix$overall[1]
precision <- conf_matrix$byClass[1]
recall <- conf_matrix$byClass[2]

# Print evaluation metrics
print(paste("Accuracy:", accuracy))
## [1] "Accuracy: 0.8835904628331"
print(paste("Precision:", precision))
## [1] "Precision: 0.205128205128205"
print(paste("Recall:", recall))
## [1] "Recall: 0.922848664688427"

7.2 Logistic Regression

Logistic Regression is a supervised machine learning algorithm that is used for classification tasks. It is a type of generalized linear model that is used to model the relationship between a binary outcome variable and one or more predictor variables. Logistic Regression uses a logistic function (also called the sigmoid function) to model the probability of the outcome variable being in one of the two classes (e.g. stroke or no stroke).

lr <- glm(stroke ~., data = smote_train_data01, family = "binomial")
summary(lr)
## 
## Call:
## glm(formula = stroke ~ ., family = "binomial", data = smote_train_data01)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.2900  -0.7551  -0.2527   0.8429   2.4642  
## 
## Coefficients:
##                   Estimate Std. Error z value Pr(>|z|)    
## (Intercept)       -5.83661    0.32190 -18.132  < 2e-16 ***
## gender             0.25342    0.07950   3.187 0.001435 ** 
## age                5.99596    0.23286  25.749  < 2e-16 ***
## hypertension       0.06562    0.09706   0.676 0.499010    
## heart_disease      0.40716    0.12715   3.202 0.001364 ** 
## ever_married       0.25925    0.13543   1.914 0.055591 .  
## work_type         -0.10906    0.02363  -4.616 3.91e-06 ***
## Residence_type     0.26097    0.07639   3.416 0.000635 ***
## avg_glucose_level  2.05565    0.29586   6.948 3.70e-12 ***
## bmi               -0.80739    0.60562  -1.333 0.182477    
## smoking_status    -0.14308    0.05564  -2.571 0.010130 *  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 5949.3  on 4400  degrees of freedom
## Residual deviance: 4269.6  on 4390  degrees of freedom
## AIC: 4291.6
## 
## Number of Fisher Scoring iterations: 5
lr_test_data <- test_data[-1] # remove the 'id' column
lr_test_data$stroke <- as.factor(lr_test_data$stroke)
levels(lr_test_data$stroke) <- c(0,1)

glm.probs <- predict(lr, newdata = lr_test_data, type = "response")
lr_test_data$lr <- ifelse(glm.probs > 0.5, 1, 0)
lr_test_data$lr <- as.factor(lr_test_data$lr)
confusion_matrix <- confusionMatrix(lr_test_data$stroke,lr_test_data$lr)

print(confusion_matrix)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 528 146
##          1  15  24
##                                           
##                Accuracy : 0.7742          
##                  95% CI : (0.7417, 0.8044)
##     No Information Rate : 0.7616          
##     P-Value [Acc > NIR] : 0.2286          
##                                           
##                   Kappa : 0.1544          
##                                           
##  Mcnemar's Test P-Value : <2e-16          
##                                           
##             Sensitivity : 0.9724          
##             Specificity : 0.1412          
##          Pos Pred Value : 0.7834          
##          Neg Pred Value : 0.6154          
##              Prevalence : 0.7616          
##          Detection Rate : 0.7405          
##    Detection Prevalence : 0.9453          
##       Balanced Accuracy : 0.5568          
##                                           
##        'Positive' Class : 0               
## 
# Evaluate model performance
accuracy <- confusion_matrix$overall[1]
precision <- confusion_matrix$byClass[1]
recall <- confusion_matrix$byClass[2]

# Print evaluation metrics
print(paste("Accuracy:", accuracy))
## [1] "Accuracy: 0.774193548387097"
print(paste("Precision:", precision))
## [1] "Precision: 0.972375690607735"
print(paste("Recall:", recall))
## [1] "Recall: 0.141176470588235"
## change to numeric for lr_test_data$stroke,lr_test_data$lr
lr_test_data$stroke =as.numeric(lr_test_data$stroke)
lr_test_data$lr =as.numeric(lr_test_data$lr)

Conclusion

Stroke is a dangerous disease and it did not only happen on old or middle-age population. Through the whole project, we found that there are several factors contribute into stroke except of age, which are avg_glucose_level and bmi. We used 2 algorithms which is random forest and logistic regression to predict stroke from a set of patients’ data, and random forest obtained the higher accuracy than logistic regression.

However, the models still need to be tuned to achieve higher accuracy and ROC.

We hope that everyone can stay and eat healthy, avoid smoking habit could be a good way to prevent from stroke.