GROUP MEMBERS (GROUP 3)
CHEW CHI YEW 17095779
RADZIAH ZAINUDDIN S2156690
FARAH JASMIEN BINTI ZUKARI 17172774
LEE YING QIU 17108552
TAN JIA YUE 17114152
Video: https://drive.google.com/file/d/11VXe5f2Qjx1xg3BoOfQqN-BR9ek4J3fr/view?usp=share_link
Stroke Prediction
Source: https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-dataset
library(readr)
library(caret)
library(readxl)
library(Metrics)
library(dplyr)
library(ggplot2)
library(skimr)
library(tidyr)
library(reshape2)
library(ggpubr)
library(stringr)
library(e1071)
library(pROC)
library(magrittr)
library(RColorBrewer)
library(performanceEstimation)
library(superml)
library(randomForest)
library(graphics)
In this stage, we will load the data and display the first five rows of the dataset.
#stroke <- read.csv("/cloud/project/Stroke Prediction.csv")
stroke <- read.csv("/cloud/project/Stroke Prediction.csv")
head(stroke)
## id gender age hypertension heart_disease ever_married work_type
## 1 9046 Male 67 0 1 Married Private
## 2 51676 Female 61 0 0 Yes Self-employed
## 3 31112 Male 80 0 1 Yes Private
## 4 60182 Female 49 0 0 Yes Private
## 5 1665 F 79 1 0 Yes Self-employed
## 6 56669 Male 81 0 0 Yes Private
## Residence_type avg_glucose_level bmi smoking_status stroke
## 1 Urban 228.69 36.6 formerly smoked 1
## 2 Rural 202.21 N/A never smoked 1
## 3 Rural 105.92 32.5 never smoked 1
## 4 Urban 171.23 34.4 smokes 1
## 5 Rural 174.12 24 never smoked 1
## 6 Urban 186.21 29 formerly smoked 1
Now, let’s have a look on the internal structure of our data. We can see that there’s a total of 5110 observation with 12 variables in the stroke prediction dataset. From the output presented, we can observe that column type of bmi and stroke column are abnormal.
str(stroke)
## 'data.frame': 5110 obs. of 12 variables:
## $ id : int 9046 51676 31112 60182 1665 56669 53882 10434 27419 60491 ...
## $ gender : chr "Male" "Female" "Male" "Female" ...
## $ age : num 67 61 80 49 79 81 74 69 59 78 ...
## $ hypertension : chr "0" "0" "0" "0" ...
## $ heart_disease : chr "1" "0" "1" "0" ...
## $ ever_married : chr "Married" "Yes" "Yes" "Yes" ...
## $ work_type : chr "Private" "Self-employed" "Private" "Private" ...
## $ Residence_type : chr "Urban" "Rural" "Rural" "Urban" ...
## $ avg_glucose_level: num 229 202 106 171 174 ...
## $ bmi : chr "36.6" "N/A" "32.5" "34.4" ...
## $ smoking_status : chr "formerly smoked" "never smoked" "never smoked" "smokes" ...
## $ stroke : int 1 1 1 1 1 1 1 1 1 1 ...
Thereafter, we will start the data preprocessing by changing the column data type of bmi to numeric data type and stroke column from integer to character data type.
stroke$bmi =as.numeric(stroke$bmi)
stroke$stroke =as.character(stroke$stroke)
str(stroke)
## 'data.frame': 5110 obs. of 12 variables:
## $ id : int 9046 51676 31112 60182 1665 56669 53882 10434 27419 60491 ...
## $ gender : chr "Male" "Female" "Male" "Female" ...
## $ age : num 67 61 80 49 79 81 74 69 59 78 ...
## $ hypertension : chr "0" "0" "0" "0" ...
## $ heart_disease : chr "1" "0" "1" "0" ...
## $ ever_married : chr "Married" "Yes" "Yes" "Yes" ...
## $ work_type : chr "Private" "Self-employed" "Private" "Private" ...
## $ Residence_type : chr "Urban" "Rural" "Rural" "Urban" ...
## $ avg_glucose_level: num 229 202 106 171 174 ...
## $ bmi : num 36.6 NA 32.5 34.4 24 29 27.4 22.8 NA 24.2 ...
## $ smoking_status : chr "formerly smoked" "never smoked" "never smoked" "smokes" ...
## $ stroke : chr "1" "1" "1" "1" ...
Next, we will check whether there is any missing values for each column in the dataset.
colSums(is.na(stroke))
## id gender age hypertension
## 0 0 0 0
## heart_disease ever_married work_type Residence_type
## 0 0 0 0
## avg_glucose_level bmi smoking_status stroke
## 0 188 0 0
Next, we will further look into the summary of every column by category in order for us to discover any abnormaly.Here, we use group by command for the inspection of every variable.
stroke %>%
group_by(gender) %>%
summarise(Count = n())
## # A tibble: 5 × 2
## gender Count
## <chr> <int>
## 1 F 20
## 2 Female 2974
## 3 M 9
## 4 Male 2106
## 5 Other 1
stroke %>%
group_by(hypertension) %>%
summarise(Count = n())
## # A tibble: 4 × 2
## hypertension Count
## <chr> <int>
## 1 0 4604
## 2 1 489
## 3 No 8
## 4 Yes 9
stroke %>%
group_by(heart_disease) %>%
summarise(Count = n())
## # A tibble: 4 × 2
## heart_disease Count
## <chr> <int>
## 1 0 4821
## 2 1 270
## 3 No 13
## 4 Yes 6
stroke %>%
group_by(ever_married) %>%
summarise(Count = n())
## # A tibble: 4 × 2
## ever_married Count
## <chr> <int>
## 1 Married 22
## 2 No 1738
## 3 Single 19
## 4 Yes 3331
stroke %>%
group_by(work_type) %>%
summarise(Count = n())
## # A tibble: 5 × 2
## work_type Count
## <chr> <int>
## 1 children 687
## 2 Govt_job 657
## 3 Never_worked 22
## 4 Private 2925
## 5 Self-employed 819
stroke %>%
group_by(Residence_type) %>%
summarise(Count = n())
## # A tibble: 2 × 2
## Residence_type Count
## <chr> <int>
## 1 Rural 2514
## 2 Urban 2596
stroke %>%
group_by(smoking_status) %>%
summarise(Count = n())
## # A tibble: 4 × 2
## smoking_status Count
## <chr> <int>
## 1 formerly smoked 885
## 2 never smoked 1892
## 3 smokes 789
## 4 Unknown 1544
stroke %>%
group_by(stroke) %>%
summarise(Count = n())
## # A tibble: 2 × 2
## stroke Count
## <chr> <int>
## 1 0 4861
## 2 1 249
Based on the summary tables presented above, data issues are found in four categorical columns, specifically gender, hypertension, heart_disease and ever_married. First, we will replace some categories with correct strings. For instance, F is replaced with Female as they represent the same category. Then, we will drop the rows by filtering out with some unwanted category such as ‘Other’ presented in gender variable.
stroke1<- stroke %>%
mutate(gender = replace(gender, gender == 'F', 'Female')) %>%
mutate(gender = replace(gender, gender == 'M', 'Male'))%>%
mutate(hypertension = replace(hypertension, hypertension == 'Yes', '1')) %>%
mutate(hypertension = replace(hypertension, hypertension == 'No', '0')) %>%
mutate(heart_disease = replace(heart_disease, heart_disease == 'Yes', '1')) %>%
mutate(heart_disease = replace(heart_disease, heart_disease == 'No', '0')) %>%
mutate(ever_married = replace(ever_married, ever_married == 'Single', 'No')) %>%
mutate(ever_married = replace(ever_married, ever_married == 'Married', 'Yes')) %>%
filter(!((gender=="Other" | smoking_status=="Unknown") ))
From the summary table, we can see that there are some issues on the numerical variables. For example, there are negative values in the average glucose level column and at the same time, it is impossible for bmi to have a zero value.
summary(stroke1[c("age","bmi","avg_glucose_level")])
## age bmi avg_glucose_level
## Min. :10.00 Min. : 0.00 Min. :-129.54
## 1st Qu.:34.00 1st Qu.:25.20 1st Qu.: 77.16
## Median :50.00 Median :29.10 Median : 92.37
## Mean :48.86 Mean :30.25 Mean : 108.21
## 3rd Qu.:63.00 3rd Qu.:34.10 3rd Qu.: 116.44
## Max. :82.00 Max. :92.00 Max. : 271.74
## NA's :135
As shown in the table above, we know that there are about 188 missing values in the bmi column. To deal with this problem, we will impute and replace them with median of the bmi variable. Subsequently, we can see that there is no N/A values once the code has been executed.
stroke1$bmi[is.na(stroke1$bmi)] <- median(stroke1$bmi, na.rm = T)
stroke1$avg_glucose_level[is.na(stroke1$avg_glucose_level)] <- median(stroke1$avg_glucose_level, na.rm = T)
colSums(is.na(stroke1))
## id gender age hypertension
## 0 0 0 0
## heart_disease ever_married work_type Residence_type
## 0 0 0 0
## avg_glucose_level bmi smoking_status stroke
## 0 0 0 0
To solve the problems of negative values and zero values presented in the numerical columns, we will filter out for any of the numerical values which are less than 0.5.
stroke1 %>% filter_if(is.numeric, ~ .x >0.5)
After all these pre-processing, we can have a glance on the cleaned dataset.
head(stroke1)
## id gender age hypertension heart_disease ever_married work_type
## 1 9046 Male 67 0 1 Yes Private
## 2 51676 Female 61 0 0 Yes Self-employed
## 3 31112 Male 80 0 1 Yes Private
## 4 60182 Female 49 0 0 Yes Private
## 5 1665 Female 79 1 0 Yes Self-employed
## 6 56669 Male 81 0 0 Yes Private
## Residence_type avg_glucose_level bmi smoking_status stroke
## 1 Urban 228.69 36.6 formerly smoked 1
## 2 Rural 202.21 29.1 never smoked 1
## 3 Rural 105.92 32.5 never smoked 1
## 4 Urban 171.23 34.4 smokes 1
## 5 Rural 174.12 24.0 never smoked 1
## 6 Urban 186.21 29.0 formerly smoked 1
#skimr::skim_without_charts(stroke1)
skimr::skim(stroke1)
| Name | stroke1 |
| Number of rows | 3565 |
| Number of columns | 12 |
| _______________________ | |
| Column type frequency: | |
| character | 8 |
| numeric | 4 |
| ________________________ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| gender | 0 | 1 | 4 | 6 | 0 | 2 | 0 |
| hypertension | 0 | 1 | 1 | 1 | 0 | 2 | 0 |
| heart_disease | 0 | 1 | 1 | 1 | 0 | 2 | 0 |
| ever_married | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
| work_type | 0 | 1 | 7 | 13 | 0 | 5 | 0 |
| Residence_type | 0 | 1 | 5 | 5 | 0 | 2 | 0 |
| smoking_status | 0 | 1 | 6 | 15 | 0 | 3 | 0 |
| stroke | 0 | 1 | 1 | 1 | 0 | 2 | 0 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| id | 0 | 1 | 36780.32 | 21240.50 | 67.00 | 18040.00 | 37446.00 | 54946.00 | 72915.00 | ▇▇▇▇▇ |
| age | 0 | 1 | 48.86 | 18.87 | 10.00 | 34.00 | 50.00 | 63.00 | 82.00 | ▃▆▇▇▆ |
| avg_glucose_level | 0 | 1 | 108.21 | 49.53 | -129.54 | 77.16 | 92.37 | 116.44 | 271.74 | ▁▁▇▂▂ |
| bmi | 0 | 1 | 30.20 | 7.24 | 0.00 | 25.40 | 29.10 | 33.80 | 92.00 | ▁▇▂▁▁ |
Lastly, we will end this stage by saving the cleaned dataset into our local file.
write_csv(stroke1, path="Stroke Prediction_cleaned.csv")
## Warning: The `path` argument of `write_csv()` is deprecated as of readr 1.4.0.
## ℹ Please use the `file` argument instead.
Exploratory Data Analysis (EDA) is a process where we run several explorations to get a glimpse of insights on the dataset used for better understanding.
After the data pre-processing, let’s us analyse the latest variables and their data type before further exploration.
str(stroke1)
## 'data.frame': 3565 obs. of 12 variables:
## $ id : int 9046 51676 31112 60182 1665 56669 53882 10434 12109 12095 ...
## $ gender : chr "Male" "Female" "Male" "Female" ...
## $ age : num 67 61 80 49 79 81 74 69 81 61 ...
## $ hypertension : chr "0" "0" "0" "0" ...
## $ heart_disease : chr "1" "0" "1" "0" ...
## $ ever_married : chr "Yes" "Yes" "Yes" "Yes" ...
## $ work_type : chr "Private" "Self-employed" "Private" "Private" ...
## $ Residence_type : chr "Urban" "Rural" "Rural" "Urban" ...
## $ avg_glucose_level: num 229 202 106 171 174 ...
## $ bmi : num 36.6 29.1 32.5 34.4 24 29 27.4 22.8 29.7 36.8 ...
## $ smoking_status : chr "formerly smoked" "never smoked" "never smoked" "smokes" ...
## $ stroke : chr "1" "1" "1" "1" ...
We will explore the mean, median and mode in each numerical variable, but, since Mode function is not available by default in R, we will manually create one as shown below.
Mode = function(x){
ta = table(x)
tam = max(ta)
if (all(ta == tam))
mod = NA
else
if(is.numeric(x))
mod = as.numeric(names(ta)[ta == tam])
else
mod = names(ta)[ta == tam]
return(mod)
}
#age
age1 = stroke1$age
mean(age1)
## [1] 48.86031
median(age1)
## [1] 50
Mode(age1) #the Mode function is created separately in Mode Function.R
## [1] 54
#avg_glucose_level
agl1 = stroke1$avg_glucose_level
mean(agl1)
## [1] 108.2052
median(agl1)
## [1] 92.37
Mode(agl1)
## [1] 0
#bmi
bmi1 = stroke1$bmi
mean(bmi1)
## [1] 30.20471
median(bmi1)
## [1] 29.1
Mode(bmi1)
## [1] 29.1
Calculating standard deviation and variance for the numerical variables.
#age
sd(age1)
## [1] 18.87314
var(age1)
## [1] 356.1954
#avg_glucose_level
sd(agl1)
## [1] 49.53281
var(agl1)
## [1] 2453.5
#bmi
sd(bmi1)
## [1] 7.243686
var(bmi1)
## [1] 52.47099
Smoothing data that have multiple range of values such as age, avg_glucose_level and bmi because when we tried visualize it for EDA, the distribution is hard to be read.
The level of average glucose is based on reference highlighted in the Center for Disease Control and Prevention, as well as myDr from Australia.
stroke2<-stroke1
#Smoothing age column
stroke2["age"] = cut(stroke2$age, c(0, 10, 20, 30, 40, 50, 60, 70, 80, Inf), c("0-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70-79", ">80"), include.lowest=TRUE)
#Smoothing avg_glucose level column by converting into 3 levels "Normal", "Pre-diabetes" and "Diabetes" based on Glucose Tolerance Test indicator from https://www.cdc.gov/diabetes/basics/getting-tested.html#:~:text=A%20fasting%20blood%20sugar%20level,higher%20indicates%20you%20have%20diabetes.
#negative avg glucose level - https://www.mydr.com.au/tests-investigations/diabetes-and-urine-glucose-monitoring/
stroke2["avg_glucose_level"] = cut(stroke2$avg_glucose_level, c(-Inf, 0, 140, 200, Inf), c("Hypoglycaemia", "Normal", "Pre-diabetes", "Diabetes"), include.lowest=TRUE)
#Smoothing bmi column based on https://www.cdc.gov/healthyweight/assessing/bmi/adult_bmi/index.html interpretation.
stroke2["bmi"] = cut(stroke2$bmi, c(0, 18.5, 25, 30, Inf), c("Underweight", "Healthy", "Overweight", "Obesity"), include.lowest=TRUE)
By plotting the graph for each column, we could see the distribution and the trend of the data in each column. We may also identify any extreme values in a particular column and ensure there’s no null values available after data pre-processing.
for (i in names(stroke2)){
barplot(table(stroke2[i]), main = i, col = coul <- brewer.pal(5, "Set2") )
}
Here, we are going to identify the percentage of how many persons will get stroke based on their age group. We utilize the age group that has been smoothen based on the above task.
# Plotting percentage distribution of stroke by age group
temp <-subset.data.frame(stroke2, select = c(age, stroke)) %>%
group_by(age, stroke) %>%
summarize(n = n(), .groups = 'drop') %>%
group_by(age) %>%
summarize(percentage_stroke = round(sum(n[stroke==1]/sum(n)), digits=2),
percentage_not_stroke = 1-percentage_stroke, .groups = 'drop') %>%
arrange(desc(percentage_stroke)) %>%
slice(1:9)
melted_temp<-melt(temp, id = "age")
ggplot(melted_temp, aes(x = age, y = value, fill = variable)) +
geom_bar(position = "fill",
stat = "identity",
color = "black",
width = 1) +
theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.6)) +
scale_y_continuous(labels = scales::percent) +
geom_text(aes(label = paste0(value*100,"%")),
position = position_stack(vjust = 0.6), size = 2) +
ggtitle("Percentage Distribution of Stroke by Age Group") +
xlab("Age Group") +
ylab("Stroke (percentage)")
#checking any missing values after EDA
colSums(is.na(stroke2))
## id gender age hypertension
## 0 0 0 0
## heart_disease ever_married work_type Residence_type
## 0 0 0 0
## avg_glucose_level bmi smoking_status stroke
## 0 0 0 0
To study the correlation between the features, we use heatmap to visualize the relation. We first transform the qualitative features into quantitative.
# transform qualitative features into quantitative features
stroke3<-stroke1
# gender
stroke3$gender[stroke3$gender == "Male"] <- 0
stroke3$gender[stroke3$gender == "Female"] <- 1
# ever_married
stroke3$ever_married[stroke3$ever_married == "Yes"] <- 1
stroke3$ever_married[stroke3$ever_married == "No"] <- 0
# work_type
stroke3$work_type[stroke3$work_type == "Private"] <- 0
stroke3$work_type[stroke3$work_type == "children"] <- 1
stroke3$work_type[stroke3$work_type == "Govt_job"] <- 2
stroke3$work_type[stroke3$work_type == "Never_worked"] <- 3
stroke3$work_type[stroke3$work_type == "Self-employed"] <- 4
# Residence_type
stroke3$Residence_type[stroke3$Residence_type == "Rural"] <- 0
stroke3$Residence_type[stroke3$Residence_type == "Urban"] <- 1
# smoking_status
stroke3$smoking_status[stroke3$smoking_status == "formerly smoked"] <- 0
stroke3$smoking_status[stroke3$smoking_status == "never smoked"] <- 1
stroke3$smoking_status[stroke3$smoking_status == "smokes"] <- 2
# convert characters into numbers
for (i in colnames(stroke3)){
suppressWarnings(stroke3[[i]] <- as.numeric(as.character(stroke3[[i]])))
}
# remove column "id" to build correlation coefficient matrix
stroke_hm <- subset(stroke3, select= -id)
corrmatrix <- round(cor(stroke_hm),2)
melted_corrmat <- melt(corrmatrix)
head(melted_corrmat)
## Var1 Var2 value
## 1 gender gender 1.00
## 2 age gender -0.04
## 3 hypertension gender -0.03
## 4 heart_disease gender -0.10
## 5 ever_married gender -0.02
## 6 work_type gender 0.01
# plot the correlation heatmap
ggplot(data = melted_corrmat, aes(x=Var1, y=Var2, fill=value)) +
geom_tile()+
scale_fill_gradient2(low = "blue", high = "red", mid = "white",
midpoint = 0, limit = c(-1,1), space = "Lab", name="Pearson\nCorrelation")+
geom_text(aes(Var2, Var1, label = value), color = "black", size = 4)+
theme(axis.text.x = element_text(angle = 45, vjust = 1,
size = 12, hjust = 1))
From the heatmap, the strongest positive correlation is between ever_married and age with coefficient of 0.52. Since the Pearson coefficient of these features does not exceed 0.7, we conclude that there is no presence of multi-collinearity in this dataset. Other notable correlation are between age with a few features which are hypertension, heart_disease, work_type, avg_glucose_level and stroke. We can conclude that age can be the main factor for stroke to happen.
There are 3 numerical variable (e.g., age, average glucose level, bmi) with different scale. Hence, to ensure a standardize scale among the 3 variables, min-max scaling is performed.
summary(stroke3)
## id gender age hypertension
## Min. : 67 Min. :0.0000 Min. :10.00 Min. :0.0000
## 1st Qu.:18040 1st Qu.:0.0000 1st Qu.:34.00 1st Qu.:0.0000
## Median :37446 Median :1.0000 Median :50.00 Median :0.0000
## Mean :36780 Mean :0.6053 Mean :48.86 Mean :0.1251
## 3rd Qu.:54946 3rd Qu.:1.0000 3rd Qu.:63.00 3rd Qu.:0.0000
## Max. :72915 Max. :1.0000 Max. :82.00 Max. :1.0000
## heart_disease ever_married work_type Residence_type
## Min. :0.00000 Min. :0.0000 Min. :0.000 Min. :0.0000
## 1st Qu.:0.00000 1st Qu.:1.0000 1st Qu.:0.000 1st Qu.:0.0000
## Median :0.00000 Median :1.0000 Median :0.000 Median :1.0000
## Mean :0.06396 Mean :0.7602 Mean :1.075 Mean :0.5088
## 3rd Qu.:0.00000 3rd Qu.:1.0000 3rd Qu.:2.000 3rd Qu.:1.0000
## Max. :1.00000 Max. :1.0000 Max. :4.000 Max. :1.0000
## avg_glucose_level bmi smoking_status stroke
## Min. :-129.54 Min. : 0.0 Min. :0.0000 Min. :0.00000
## 1st Qu.: 77.16 1st Qu.:25.4 1st Qu.:1.0000 1st Qu.:0.00000
## Median : 92.37 Median :29.1 Median :1.0000 Median :0.00000
## Mean : 108.21 Mean :30.2 Mean :0.9734 Mean :0.05666
## 3rd Qu.: 116.44 3rd Qu.:33.8 3rd Qu.:1.0000 3rd Qu.:0.00000
## Max. : 271.74 Max. :92.0 Max. :2.0000 Max. :1.00000
preproc <- preProcess(stroke3[,c(3,9,10)], method=c("range"))
norm_data <- predict(preproc, stroke3)
summary(norm_data)
## id gender age hypertension
## Min. : 67 Min. :0.0000 Min. :0.0000 Min. :0.0000
## 1st Qu.:18040 1st Qu.:0.0000 1st Qu.:0.3333 1st Qu.:0.0000
## Median :37446 Median :1.0000 Median :0.5556 Median :0.0000
## Mean :36780 Mean :0.6053 Mean :0.5397 Mean :0.1251
## 3rd Qu.:54946 3rd Qu.:1.0000 3rd Qu.:0.7361 3rd Qu.:0.0000
## Max. :72915 Max. :1.0000 Max. :1.0000 Max. :1.0000
## heart_disease ever_married work_type Residence_type
## Min. :0.00000 Min. :0.0000 Min. :0.000 Min. :0.0000
## 1st Qu.:0.00000 1st Qu.:1.0000 1st Qu.:0.000 1st Qu.:0.0000
## Median :0.00000 Median :1.0000 Median :0.000 Median :1.0000
## Mean :0.06396 Mean :0.7602 Mean :1.075 Mean :0.5088
## 3rd Qu.:0.00000 3rd Qu.:1.0000 3rd Qu.:2.000 3rd Qu.:1.0000
## Max. :1.00000 Max. :1.0000 Max. :4.000 Max. :1.0000
## avg_glucose_level bmi smoking_status stroke
## Min. :0.0000 Min. :0.0000 Min. :0.0000 Min. :0.00000
## 1st Qu.:0.5151 1st Qu.:0.2761 1st Qu.:1.0000 1st Qu.:0.00000
## Median :0.5530 Median :0.3163 Median :1.0000 Median :0.00000
## Mean :0.5925 Mean :0.3283 Mean :0.9734 Mean :0.05666
## 3rd Qu.:0.6130 3rd Qu.:0.3674 3rd Qu.:1.0000 3rd Qu.:0.00000
## Max. :1.0000 Max. :1.0000 Max. :2.0000 Max. :1.00000
The dataset is split into 80% of training data and 20% of testing data. A total of 2852 rows has been partitioned into training data and 713 rows into testing data.
library(caret)
# createDataPartition function
stroke4 <- norm_data
set.seed(123)
trainIndex <- createDataPartition(stroke4$stroke, p=0.8, list=FALSE) # stratified random split of 80:20
train_data <- stroke4[trainIndex,]
test_data <- stroke4[-trainIndex,]
# check number of rows
nrow(train_data)
## [1] 2852
nrow(train_data[train_data$stroke=="1",])
## [1] 163
nrow(train_data[train_data$stroke=="0",])
## [1] 2689
nrow(test_data)
## [1] 713
nrow(test_data[test_data$stroke=="1",])
## [1] 39
nrow(test_data[test_data$stroke=="0",])
## [1] 674
As the dataset has uneven proportion of cases for each class where 94% (2689 rows) are non-stroke cases and 6% (163 rows) are stroke cases. It is handled by SMOTE (Synthetic Minority Oversampling Technique) which is a popular method for dealing with imbalanced data sets, which often occur in classification problems where there are a disproportionate number of observations in one class compared to the other. SMOTE creates synthetic examples of the minority class in order to balance the class distribution. After applying the SMOTE,stroke cases increased to about 41% and non-stroke cases is about 59%.
train_data$stroke <- as.factor(train_data$stroke)
levels(train_data$stroke) <- c("Non_Stroke","Stroke")
# perc.over refers to the oversampling of minority cases and perc.under refers to the undersampling of majority cases
# set perc.over = 10 (10 times of over-sample in the minority cases)
# set perc.under = 1.6 (1.6 times of under-sample in majority cases)
set.seed(789)
smote_train_data = smote(train_data$stroke~., train_data, perc.over = 10, k = 5, perc.under = 1.6)
# Check number of rows
nrow(smote_train_data)
## [1] 4401
nrow(smote_train_data[smote_train_data$stroke=="Stroke",])
## [1] 1793
nrow(smote_train_data[smote_train_data$stroke=="Non_Stroke",])
## [1] 2608
It is an ensemble method that creates multiple decision trees using random subsets of the input data and features, and then combines their predictions to make the final decision.
Handling of non-linear relationships: Random Forest can handle non-linear relationships between variables, which is important for stroke prediction because the relationship between risk factors and stroke is often complex.
Handling of categorical variables: Random Forest can handle categorical variables, which are commonly found in medical datasets.
Feature importance: Random Forest can provide an estimate of feature importance, this can help in identifying the most important features in stroke prediction, which can help in understanding the underlying mechanisms that lead to stroke and inform preventative strategies.
table(smote_train_data$stroke)
##
## Non_Stroke Stroke
## 2608 1793
smote_train_data01 <- smote_train_data[-1] # remove id column
floor(sqrt(ncol(smote_train_data01) - 1)) #3
## [1] 3
set.seed(1122)
mtry <- tuneRF(smote_train_data01[-length(smote_train_data01)],smote_train_data01$stroke, ntreeTry=500,stepFactor=1.5,improve=0.01, trace=TRUE, plot=TRUE)
## mtry = 3 OOB error = 4.54%
## Searching left ...
## mtry = 2 OOB error = 6.23%
## -0.37 0.01
## Searching right ...
## mtry = 4 OOB error = 4.41%
## 0.03 0.01
## mtry = 6 OOB error = 4.61%
## -0.04639175 0.01
best_m <- mtry[mtry[, 2] == min(mtry[, 2]), 1]
print(mtry)
## mtry OOBError
## 2.OOB 2 0.06225858
## 3.OOB 3 0.04544422
## 4.OOB 4 0.04408089
## 6.OOB 6 0.04612588
print(best_m) #4
## [1] 4
set.seed(1122)
rf <-randomForest(stroke ~ ., data=smote_train_data01, mtry=best_m, importance=TRUE,ntree=500)
print(rf)
##
## Call:
## randomForest(formula = stroke ~ ., data = smote_train_data01, mtry = best_m, importance = TRUE, ntree = 500)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 4
##
## OOB estimate of error rate: 4.84%
## Confusion matrix:
## Non_Stroke Stroke class.error
## Non_Stroke 2509 99 0.03796012
## Stroke 114 1679 0.06358059
#Evaluate variable importance
importance(rf)
## Non_Stroke Stroke MeanDecreaseAccuracy MeanDecreaseGini
## gender 41.73723 65.94787 70.04894 60.61531
## age 120.74353 151.19372 165.53890 806.47361
## hypertension 29.18763 37.96014 38.89608 48.20356
## heart_disease 29.98562 49.79057 48.02564 68.62065
## ever_married 34.64929 35.45674 40.94102 71.88809
## work_type 50.59063 63.05101 64.36997 140.27072
## Residence_type 42.68955 59.97943 63.61224 60.27460
## avg_glucose_level 77.27040 105.65118 115.44780 342.34315
## bmi 69.23443 87.68249 101.86878 261.72633
## smoking_status 47.68864 84.66070 74.01676 244.67301
varImpPlot(rf)
Accuracy (Random Forest Classification):0.8315
rf_test_data <- test_data[-1]
rf_test_data$stroke <- as.factor(rf_test_data$stroke)
levels(rf_test_data$stroke) <- c("Non_Stroke","Stroke")
predicted <- predict(rf, rf_test_data)
conf_matrix <- confusionMatrix(as.factor(predicted), as.factor(rf_test_data$stroke), positive = "Stroke")
conf_matrix$table
## Reference
## Prediction Non_Stroke Stroke
## Non_Stroke 622 31
## Stroke 52 8
# Evaluate model performance
accuracy <- conf_matrix$overall[1]
precision <- conf_matrix$byClass[1]
recall <- conf_matrix$byClass[2]
# Print evaluation metrics
print(paste("Accuracy:", accuracy))
## [1] "Accuracy: 0.8835904628331"
print(paste("Precision:", precision))
## [1] "Precision: 0.205128205128205"
print(paste("Recall:", recall))
## [1] "Recall: 0.922848664688427"
Logistic Regression is a supervised machine learning algorithm that is used for classification tasks. It is a type of generalized linear model that is used to model the relationship between a binary outcome variable and one or more predictor variables. Logistic Regression uses a logistic function (also called the sigmoid function) to model the probability of the outcome variable being in one of the two classes (e.g. stroke or no stroke).
lr <- glm(stroke ~., data = smote_train_data01, family = "binomial")
summary(lr)
##
## Call:
## glm(formula = stroke ~ ., family = "binomial", data = smote_train_data01)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.2900 -0.7551 -0.2527 0.8429 2.4642
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -5.83661 0.32190 -18.132 < 2e-16 ***
## gender 0.25342 0.07950 3.187 0.001435 **
## age 5.99596 0.23286 25.749 < 2e-16 ***
## hypertension 0.06562 0.09706 0.676 0.499010
## heart_disease 0.40716 0.12715 3.202 0.001364 **
## ever_married 0.25925 0.13543 1.914 0.055591 .
## work_type -0.10906 0.02363 -4.616 3.91e-06 ***
## Residence_type 0.26097 0.07639 3.416 0.000635 ***
## avg_glucose_level 2.05565 0.29586 6.948 3.70e-12 ***
## bmi -0.80739 0.60562 -1.333 0.182477
## smoking_status -0.14308 0.05564 -2.571 0.010130 *
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 5949.3 on 4400 degrees of freedom
## Residual deviance: 4269.6 on 4390 degrees of freedom
## AIC: 4291.6
##
## Number of Fisher Scoring iterations: 5
lr_test_data <- test_data[-1] # remove the 'id' column
lr_test_data$stroke <- as.factor(lr_test_data$stroke)
levels(lr_test_data$stroke) <- c(0,1)
glm.probs <- predict(lr, newdata = lr_test_data, type = "response")
lr_test_data$lr <- ifelse(glm.probs > 0.5, 1, 0)
lr_test_data$lr <- as.factor(lr_test_data$lr)
confusion_matrix <- confusionMatrix(lr_test_data$stroke,lr_test_data$lr)
print(confusion_matrix)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 528 146
## 1 15 24
##
## Accuracy : 0.7742
## 95% CI : (0.7417, 0.8044)
## No Information Rate : 0.7616
## P-Value [Acc > NIR] : 0.2286
##
## Kappa : 0.1544
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.9724
## Specificity : 0.1412
## Pos Pred Value : 0.7834
## Neg Pred Value : 0.6154
## Prevalence : 0.7616
## Detection Rate : 0.7405
## Detection Prevalence : 0.9453
## Balanced Accuracy : 0.5568
##
## 'Positive' Class : 0
##
# Evaluate model performance
accuracy <- confusion_matrix$overall[1]
precision <- confusion_matrix$byClass[1]
recall <- confusion_matrix$byClass[2]
# Print evaluation metrics
print(paste("Accuracy:", accuracy))
## [1] "Accuracy: 0.774193548387097"
print(paste("Precision:", precision))
## [1] "Precision: 0.972375690607735"
print(paste("Recall:", recall))
## [1] "Recall: 0.141176470588235"
## change to numeric for lr_test_data$stroke,lr_test_data$lr
lr_test_data$stroke =as.numeric(lr_test_data$stroke)
lr_test_data$lr =as.numeric(lr_test_data$lr)
Stroke is a dangerous disease and it did not only happen on old or
middle-age population. Through the whole project, we found that there
are several factors contribute into stroke except of age,
which are avg_glucose_level and bmi. We used 2
algorithms which is random forest and logistic regression to predict
stroke from a set of patients’ data, and random forest obtained the
higher accuracy than logistic regression.
However, the models still need to be tuned to achieve higher accuracy and ROC.
We hope that everyone can stay and eat healthy, avoid smoking habit could be a good way to prevent from stroke.