According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths. This dataset is used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relavant information about the patient.
Data: Stroke Data (310KB)
library(tidyverse)
## Warning: package 'tidyverse' was built under R version 4.0.5
## -- Attaching packages --------------------------------------- tidyverse 1.3.1 --
## v ggplot2 3.3.4 v purrr 0.3.4
## v tibble 3.1.2 v dplyr 1.0.7
## v tidyr 1.1.3 v stringr 1.4.0
## v readr 1.4.0 v forcats 0.5.1
## Warning: package 'ggplot2' was built under R version 4.0.5
## Warning: package 'tibble' was built under R version 4.0.5
## Warning: package 'tidyr' was built under R version 4.0.5
## Warning: package 'readr' was built under R version 4.0.5
## Warning: package 'dplyr' was built under R version 4.0.5
## Warning: package 'forcats' was built under R version 4.0.5
## -- Conflicts ------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(ggplot2)
base_stroke<-read_csv("healthcare-dataset-stroke-data.csv")
##
## -- Column specification --------------------------------------------------------
## cols(
## id = col_double(),
## gender = col_character(),
## age = col_double(),
## hypertension = col_double(),
## heart_disease = col_double(),
## ever_married = col_character(),
## work_type = col_character(),
## Residence_type = col_character(),
## avg_glucose_level = col_double(),
## bmi = col_character(),
## smoking_status = col_character(),
## stroke = col_double()
## )
head(base_stroke)
## # A tibble: 6 x 12
## id gender age hypertension heart_disease ever_married work_type
## <dbl> <chr> <dbl> <dbl> <dbl> <chr> <chr>
## 1 9046 Male 67 0 1 Yes Private
## 2 51676 Female 61 0 0 Yes Self-employed
## 3 31112 Male 80 0 1 Yes Private
## 4 60182 Female 49 0 0 Yes Private
## 5 1665 Female 79 1 0 Yes Self-employed
## 6 56669 Male 81 0 0 Yes Private
## # ... with 5 more variables: Residence_type <chr>, avg_glucose_level <dbl>,
## # bmi <chr>, smoking_status <chr>, stroke <dbl>
base_stroke$Indice<-c(1:nrow(base_stroke))
base_stroke<-base_stroke[,c(13,1:12)]
base_stroke$gender<-as_factor(base_stroke$gender)
base_stroke$ever_married<-as_factor(base_stroke$ever_married)
base_stroke$work_type<-as_factor(base_stroke$work_type)
base_stroke$Residence_type<-as_factor(base_stroke$Residence_type)
base_stroke$smoking_status<-as_factor(base_stroke$smoking_status)
base_stroke$stroke<-as_factor(base_stroke$stroke)
base_stroke$bmi<-as.numeric(base_stroke$bmi)
Detecting NAs in the data
qtd_NAs<-c()
for(i in 1 : length(base_stroke)){
qtd_NAs[i]<-sum(is.na(base_stroke[,i]))
}
qtd_NAs
## [1] 0 0 0 0 0 0 0 0 0 0 201 0 0
media_bmi<-base_stroke %>% group_by(gender) %>% summarize(media_Bmi=mean(bmi,na.rm = TRUE),n=n())
for(i in 1: nrow(base_stroke)){
if(is.na(base_stroke$bmi[i])){
if(base_stroke$gender[i] == "Male"){
base_stroke$bmi[i]<-media_bmi$media_Bmi[1]
}
if(base_stroke$gender[i] == "Female"){
base_stroke$bmi[i]<-media_bmi$media_Bmi[2]
}
if(base_stroke$gender[i] == "Other"){
base_stroke$bmi[i]<-media_bmi$media_Bmi[3]
}
}
}
base_stroke %>% filter(stroke == 1) %>%
ggplot(mapping = aes(x = gender, fill= gender ))+
geom_bar(color="#AEC0D2")+
ylim(0,300)+
theme_linedraw()+
ggtitle(label = "People who had a stroke: male and female." )
base_stroke %>% filter(stroke == 1) %>%
ggplot(mapping = aes(x = age))+
geom_histogram(bins = 10, fill="blue", color="white")+
theme_linedraw()+
ggtitle(label = "Age distribution of people who had a stroke." )
base_stroke$hypertension<-as_factor(base_stroke$hypertension)
levels(base_stroke$hypertension)<-c("Without hypertension","With hypertension")
levels(base_stroke$stroke)<-c("No_had_stroke","Had_stroke")
base_stroke %>% group_by(hypertension) %>% count(stroke) %>%
ggplot()+ geom_col(mapping = aes(x = hypertension,y = n, fill =stroke ),position = "dodge",color="black")+
ggtitle(label = "People who have not suffered a stroke./ With hypertension and without hypertension.")+
theme_light()
base_stroke %>% count(stroke) %>%
ggplot() + geom_col(mapping = aes(x = stroke, y = n, fill = stroke), color="#363636")+
coord_polar()+
ggtitle(label = "Number of people who suffered a stroke.")+
theme_minimal()
media_bmi<-base_stroke %>% filter(stroke == "No_had_stroke") %>% group_by(age,stroke) %>% summarize(average_body_mass_index=mean(bmi,na.rm = TRUE),N=n()) %>% ungroup() %>%
ggplot()+ geom_point(mapping = aes(x = age,y = average_body_mass_index),color = "#00BFFF" )+
ggtitle(label = "Average body mass index of people without stroke.")+
ylab(label = "Body mass index.")+
xlab(label = "Age")+
ylim(0,100)+
theme_linedraw()
## `summarise()` has grouped output by 'age'. You can override using the `.groups` argument.
bmi<-base_stroke %>% filter(stroke == "No_had_stroke") %>%
ggplot()+ geom_point(mapping = aes(x = age,y = bmi),color="#009ACD")+
ggtitle(label = "Body mass index of people without stroke.")+
ylab(label = "Body mass index.")+
xlab(label = "Age")+
theme_linedraw()
library(gridExtra)
##
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
##
## combine
grid.arrange(bmi,media_bmi)
media_bmi<-base_stroke %>% filter(stroke == "Had_stroke") %>% group_by(age,stroke) %>% summarize(average_body_mass_index=mean(bmi,na.rm = TRUE),N=n()) %>% ungroup() %>%
ggplot()+ geom_point(mapping = aes(x = age,y = average_body_mass_index),color = "#EE5C42" )+
ggtitle(label = "Average body mass index of people who had a stroke.")+
ylab(label = "Body mass index.")+
xlab(label = "Age")+
ylim(10,60)+
theme_linedraw()
## `summarise()` has grouped output by 'age'. You can override using the `.groups` argument.
#########
bmi<-base_stroke %>% filter(stroke == "Had_stroke") %>%
ggplot()+ geom_point(mapping = aes(x = age,y = bmi),color="#CD3700")+
ggtitle(label = "Body mass index of people with stroke.")+
ylab(label = "Body mass index.")+
xlab(label = "Age")+
ylim(10,60)+
theme_linedraw()
grid.arrange(bmi,media_bmi)
pie<-base_stroke %>% filter(stroke == "No_had_stroke") %>% count(work_type)
pie$fraction <- pie$n / sum(pie$n)
pie$ymax <- cumsum(pie$fraction )
pie$ymin <- c(0, head(pie$ymax, n=-1))
piegraf1<-ggplot(data = pie, mapping = aes(ymax=ymax, ymin=ymin,xmax =4,xmin= 3,fill =work_type))+
geom_rect(color="#0000EE")+
ggtitle(label = "People who have not suffered a stroke./Types of work..")+
coord_polar(theta = "y")+
xlim(c(1.5, 4))+
theme_light()
pie2<-base_stroke %>% filter(stroke == "Had_stroke") %>% count(work_type)
pie2$fraction <- pie2$n / sum(pie2$n)
pie2$ymax <- cumsum(pie2$fraction )
pie2$ymin <- c(0, head(pie2$ymax, n=-1))
piegraf2<-ggplot(data = pie2, mapping = aes(ymax=ymax, ymin=ymin,xmax =4,xmin= 3,fill =work_type))+
geom_rect(color="black")+
ggtitle(label = "People who suffered a stroke./Types of work.")+
coord_polar(theta = "y")+
xlim(c(1.5, 4))+
theme_light()
library(gridExtra)
grid.arrange(piegraf1,piegraf2)
library(caTools)
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:gridExtra':
##
## combine
## The following object is masked from 'package:dplyr':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
base_stroke$heart_disease<-as_factor(base_stroke$heart_disease)
levels(base_stroke$hypertension )<-c(0,1)
levels(base_stroke$stroke)<-c(1,0)
base_stroke<-base_stroke[,-c(1:2)]
set.seed(10)
divide<-sample.split(Y = base_stroke$stroke,SplitRatio = 7/10)
base_train<-subset(base_stroke,subset = divide == TRUE)
base_test<-subset(base_stroke,subset = divide == FALSE)
set.seed(1)
model_RF<-randomForest(formula = stroke ~.,data = base_train,ntree =20)
model_RF
##
## Call:
## randomForest(formula = stroke ~ ., data = base_train, ntree = 20)
## Type of random forest: classification
## Number of trees: 20
## No. of variables tried at each split: 3
##
## OOB estimate of error rate: 5.62%
## Confusion matrix:
## 1 0 class.error
## 1 3371 31 0.009112287
## 0 170 4 0.977011494
prevision<-predict(model_RF,newdata = base_test[,-11])
confusionMatrix<-table(base_test$stroke,prevision)
confusionMatrix
## prevision
## 1 0
## 1 1456 2
## 0 73 2
accuracy<-(confusionMatrix[1] + confusionMatrix[4])/ sum(confusionMatrix)
accuracy
## [1] 0.9510763
compare<-data.frame(Real=base_test$stroke,Prevision=prevision)
head(compare)
## Real Prevision
## 1 0 1
## 2 0 1
## 3 0 1
## 4 0 1
## 5 0 1
## 6 0 1
real_valus<-compare %>% count(Real) %>%
ggplot() + geom_col(mapping = aes(x = Real,y = n),fill="#87CEFA", color="#4682B4")+
geom_text(aes(x = Real,y = n,label=n),vjust=-1,size=5)+
ggtitle(label = "Actual database values.")+
xlab(label = "Result.")+
ylab(label = "Quantity")+
ylim(0,1700)+
theme_light()
prevision<-compare %>% count(Prevision) %>%
ggplot() + geom_col(mapping = aes(x = Prevision,y = n),fill="#9ACD32", color="#698B22")+
geom_text(aes(x = Prevision,y = n,label=n),vjust=-0.5,size=5)+
ggtitle(label = "Values obtained with the Random Forest forecast model.")+
xlab(label = "Result.")+
ylab(label = "Quantity")+
ylim(0,1700)+
theme_light()
grid.arrange(real_valus,prevision)