Heart disease is the #2 cause of death in Malaysia (DOSM, 2022), and 1 in 5 heart attack patients are younger than 40s. So, early detection of heart disease is very important to reduce the mortality rate.
This dataset describes the contents of the heart-disease diagnosis.
Check it out here for interactive notebook: https://colab.research.google.com/drive/1oTTgcuzR6jFtnBGBHRrAvNhJ_uNWOzP0?usp=sharing for this prediction model.
| Feature | Data Type | Description |
|---|---|---|
| age | Integer | age in years |
| sex | Categorical | sex (1 = male; 0 = female) |
| cp | Categorical | chest pain type - Value 1: typical angina - Value 2: atypical angina - Value 3: non-anginal pain - Value 4: asymptomatic |
| trestbps | Integer | resting blood pressure (in mm Hg on admission to the hospital) |
| chol | Integer | serum cholestoral in mg/dl |
| fbs | Categorical | (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false) |
| restecg | Categorical | resting electrocardiographic results - Value 0: normal - Value 1: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV) - Value 2: showing probable or definite left ventricular hypertrophy by Estes’ criteria |
| thalach | Integer | maximum heart rate achieved |
| exang | Categorical | exercise induced angina (1 = yes; 0 = no) |
| oldpeak | Integer | ST depression induced by exercise relative to rest |
| slope | Categorical | the slope of the peak exercise ST segment - Value 1: upsloping - Value 2: flat - Value 3: downsloping |
| ca | Integer | number of major vessels (0-3) colored by flourosopy |
| thal | Categorical | Thalassemia (3 = normal; 6 = fixed defect; 7 = reversable defect) |
| num | Categorical | the predicted attribute, which it represents the diagnosis of heart
disease (angiographic disease status) - Value 0: < 50% diameter narrowing - Value 1: > 50% diameter narrowing |
Context: The “goal” field refers to the presence of heart disease in the patient. It is integer valued from 0 (no presence) to 4. Experiments with the Cleveland database have concentrated on simply attempting to distinguish presence (values 1,2,3,4) from absence (value 0).
library('dplyr')
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library('ggplot2')
library('corrplot')
## corrplot 0.92 loaded
library('caret')
## Loading required package: lattice
# download dataset
# The original dataset is from UCI Machine Learning Repository: https://archive.ics.uci.edu/dataset/45/heart+disease
data_URL="https://drive.usercontent.google.com/download?id=1N-7oPru9BSyRHsq6VxMKGpsb9GFC75rV&export=download&authuser=0&confirm=t&uuid=99167f5a-72ec-473d-91d2-02cb2006fb5a&at=APZUnTVBQLvU28NrrHXxwlha56-A:1716345861329"
download.file(data_URL, destfile= "./cleveland.data")
df <- read.csv("cleveland.data", header=FALSE)
colnames(df) <- c("age", "sex", "cp", "trestbps", "chol", "fbs", "restecg",
"thalach", "exang", "oldpeak", "slope", "ca", "thal",
"num")
head(df, 5)
## age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal num
## 1 63 1 1 145 233 1 2 150 0 2.3 3 0.0 6.0 0
## 2 67 1 4 160 286 0 2 108 1 1.5 2 3.0 3.0 2
## 3 67 1 4 120 229 0 2 129 1 2.6 2 2.0 7.0 1
## 4 37 1 3 130 250 0 0 187 0 3.5 3 0.0 3.0 0
## 5 41 0 2 130 204 0 2 172 0 1.4 1 0.0 3.0 0
no_of_duplicated_rows <- sum(duplicated(df))
if (no_of_duplicated_rows > 0){
df <- df[!any(duplicated(df))]
print(paste0(no_of_duplicated_rows, " rows are removed!"))
} else {
print(paste0("There is no duplicated rows."))
}
## [1] "There is no duplicated rows."
no_of_missing_values = sum(is.na(df))
if (no_of_missing_values > 0){
print(paste0("There is ", no_of_missing_values, " missing values"))
} else {
print("There is no missing values")
}
## [1] "There is no missing values"
However, there is “?” symbol inside the dataset, which is actually the missing values!
rows_with_question_mark <- filter_all(df, any_vars(. == "?"))
rows_with_question_mark
## age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal num
## 1 53 0 3 128 216 0 2 115 0 0.0 1 0.0 ? 0
## 2 52 1 3 138 223 0 0 169 0 0.0 1 ? 3.0 0
## 3 43 1 4 132 247 1 2 143 1 0.1 2 ? 7.0 1
## 4 52 1 4 128 204 1 0 156 1 1.0 2 0.0 ? 2
## 5 58 1 2 125 220 0 0 144 0 0.4 2 ? 7.0 0
## 6 38 1 3 138 175 0 0 173 0 0.0 1 ? 3.0 0
print(paste0(nrow(rows_with_question_mark), " rows with missing values are removed!"))
## [1] "6 rows with missing values are removed!"
# remove rows with '?' symbols
df <- df %>% replace(df=="?", NA) %>% na.omit()
head(df, 5)
## age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal num
## 1 63 1 1 145 233 1 2 150 0 2.3 3 0.0 6.0 0
## 2 67 1 4 160 286 0 2 108 1 1.5 2 3.0 3.0 2
## 3 67 1 4 120 229 0 2 129 1 2.6 2 2.0 7.0 1
## 4 37 1 3 130 250 0 0 187 0 3.5 3 0.0 3.0 0
## 5 41 0 2 130 204 0 2 172 0 1.4 1 0.0 3.0 0
# assign correct data type to the dataframe variables
df <- df %>% mutate(
age=as.integer(age),
sex=as.factor(sex),
cp=as.factor(cp),
trestbps=as.integer(trestbps),
chol=as.integer(chol),
fbs=as.factor(fbs),
restecg=as.factor(restecg),
thalach=as.integer(thalach),
exang=as.factor(exang),
oldpeak=as.integer(oldpeak),
slope=as.factor(slope),
ca=as.factor(ca),
thal=as.factor(thal),
num=as.integer(num)
)
glimpse(df)
## Rows: 297
## Columns: 14
## $ age <int> 63, 67, 67, 37, 41, 56, 62, 57, 63, 53, 57, 56, 56, 44, 52, 5…
## $ sex <fct> 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1…
## $ cp <fct> 1, 4, 4, 3, 2, 2, 4, 4, 4, 4, 4, 2, 3, 2, 3, 3, 2, 4, 3, 2, 1…
## $ trestbps <int> 145, 160, 120, 130, 130, 120, 140, 120, 130, 140, 140, 140, 1…
## $ chol <int> 233, 286, 229, 250, 204, 236, 268, 354, 254, 203, 192, 294, 2…
## $ fbs <fct> 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0…
## $ restecg <fct> 2, 2, 2, 0, 2, 0, 2, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 2…
## $ thalach <int> 150, 108, 129, 187, 172, 178, 160, 163, 147, 155, 148, 153, 1…
## $ exang <fct> 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1…
## $ oldpeak <int> 2, 1, 2, 3, 1, 0, 3, 0, 1, 3, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1…
## $ slope <fct> 3, 2, 2, 3, 1, 1, 3, 1, 2, 3, 2, 2, 2, 1, 1, 1, 3, 1, 1, 1, 2…
## $ ca <fct> 0.0, 3.0, 2.0, 0.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1…
## $ thal <fct> 6.0, 3.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 6.0, 3.0, 6…
## $ num <int> 0, 2, 1, 0, 0, 0, 3, 0, 2, 1, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0…
summary(df)
## age sex cp trestbps chol fbs
## Min. :29.00 0: 96 1: 23 Min. : 94.0 Min. :126.0 0:254
## 1st Qu.:48.00 1:201 2: 49 1st Qu.:120.0 1st Qu.:211.0 1: 43
## Median :56.00 3: 83 Median :130.0 Median :243.0
## Mean :54.54 4:142 Mean :131.7 Mean :247.4
## 3rd Qu.:61.00 3rd Qu.:140.0 3rd Qu.:276.0
## Max. :77.00 Max. :200.0 Max. :564.0
## restecg thalach exang oldpeak slope ca thal
## 0:147 Min. : 71.0 0:200 Min. :0.0000 1:139 0.0:174 3.0:164
## 1: 4 1st Qu.:133.0 1: 97 1st Qu.:0.0000 2:137 1.0: 65 6.0: 18
## 2:146 Median :153.0 Median :0.0000 3: 21 2.0: 38 7.0:115
## Mean :149.6 Mean :0.7778 3.0: 20
## 3rd Qu.:166.0 3rd Qu.:1.0000
## Max. :202.0 Max. :6.0000
## num
## Min. :0.0000
## 1st Qu.:0.0000
## Median :0.0000
## Mean :0.9461
## 3rd Qu.:2.0000
## Max. :4.0000
# we want to do the binary classification for heart disease where 0 = no heart disease, 1 = heart disease
# thus we transform the 'num' target variable with values ranging 0 to 4 (which represents the severity of heart disease) to a binary target
df$output <- factor(ifelse(df$num > 0, 1, 0), levels=c(1, 0), labels=c("Heart disease","No disease"))
df <- subset(df, select=-num)
head(df, 5)
## age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal
## 1 63 1 1 145 233 1 2 150 0 2 3 0.0 6.0
## 2 67 1 4 160 286 0 2 108 1 1 2 3.0 3.0
## 3 67 1 4 120 229 0 2 129 1 2 2 2.0 7.0
## 4 37 1 3 130 250 0 0 187 0 3 3 0.0 3.0
## 5 41 0 2 130 204 0 2 172 0 1 1 0.0 3.0
## output
## 1 No disease
## 2 Heart disease
## 3 Heart disease
## 4 No disease
## 5 No disease
## define numerical columns , categorical columns
num_vars <- subset(df, select=c(age, trestbps, chol, thalach, oldpeak))
num_vars_with_target_variable <- subset(df, select=c(age, trestbps, chol, thalach, oldpeak, output))
cat_vars <- subset(df, select=c(sex, fbs, restecg, exang, slope, ca, thal))
cat_vars_with_target_variable <- subset(df, select=c(sex, fbs, restecg, exang, slope, ca, thal, output))
# check if there is any outliers for numerical variables
boxplot(num_vars, main="Boxplot for numerical variables")
# Looking into outliers of each numerical variables, by the `output` category
library(dplyr)
library(ggplot2)
library(gridExtra)
##
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
##
## combine
# Modify the dataframe
heart_disease_comparison_df <- df %>%
mutate(output=ifelse(output == 0, "No disease", "Heart disease"))
# Create boxplots using ggplot2
age_plot <- ggplot(heart_disease_comparison_df, aes(x = output, y = age)) +
geom_boxplot() +
ggtitle("Age distribution") +
xlab("Condition") +
ylab("Age")
trestbps_plot <- ggplot(heart_disease_comparison_df, aes(x = output, y=trestbps)) +
geom_boxplot() +
ggtitle("trestbps distribution") +
xlab("Condition") +
ylab("trestbps")
chol_plot <- ggplot(heart_disease_comparison_df, aes(x = output, y = chol)) +
geom_boxplot() +
ggtitle("Cholesterol distribution") +
xlab("Condition") +
ylab("chol")
thalach_plot <- ggplot(heart_disease_comparison_df, aes(x = output, y = thalach)) +
geom_boxplot() +
ggtitle("thalach distribution") +
xlab("Condition") +
ylab("thalach")
oldpeak_plot <- ggplot(heart_disease_comparison_df, aes(x = output, y = oldpeak)) +
geom_boxplot() +
ggtitle("oldpeak distribution") +
xlab("Condition") +
ylab("oldpeak")
# Arrange the plots side by side
grid.arrange(age_plot, trestbps_plot, chol_plot, thalach_plot, oldpeak_plot, nrow = 2, ncol = 3)
# check if there is imbalance in class for target variable
output_class_count <- df %>% group_by(output) %>% summarise(count=n())
labels <- as.vector(output_class_count$output)
values <- as.vector(output_class_count$count)
values_with_percent <- paste0(format(values, big.mark=",", scientific=FALSE), " people",
" (", round(values / sum(values) * 100, 2), "%", ")")
pie(values, values_with_percent, main="Heart disease distribution",
col = rainbow(length(values)))
legend("topright", c("No Disease","Heart Disease"), cex = 0.8,
fill = rainbow(length(values)))
# numeric var histogram
par(mfrow=c(2,3))
hist(df$age, freq=F, main = "age",
xlab = "age (target = 1)", ylab = "Density")
hist(df$trestbps, freq=F, main = "trestbps",
xlab = "trestbps (target = 1)", ylab = "Density")
hist(df$chol, freq=F, main = "chol",
xlab = "chol (target = 1)", ylab = "Density")
hist(df$thalach, freq=F, main = "thalach",
xlab = "thalach (target = 1)", ylab = "Density")
hist(df$oldpeak, freq=F, main = "oldpeak",
xlab = "oldpeak (target = 1)", ylab = "Density")
# create a matrix plot, break down with the binary target variable
library(GGally)
## Registered S3 method overwritten by 'GGally':
## method from
## +.gg ggplot2
ggpairs(num_vars)
# categorical var barplot
library(gridExtra)
# sex, cp, fbs, restecg, exang, slope, ca, thal, output
a = ggplot(df, aes(x=factor(sex), fill=output))+
geom_bar(stat="count", width=0.7, position=position_dodge()) +
labs(x="sex", y = "count") +
theme(legend.title = element_blank()) +
theme(legend.position = 'none')
b = ggplot(df, aes(x=factor(cp), fill=output))+
geom_bar(stat="count", width=0.7, position=position_dodge()) +
labs(x="cp", y = "count")+
theme(legend.title = element_blank()) +
theme(legend.position = 'none')
c = ggplot(df, aes(x=factor(fbs), fill=output))+
geom_bar(stat="count", width=0.7, position=position_dodge()) +
labs(x="fbs", y = "count")+
theme(legend.title = element_blank()) +
theme(legend.position = 'none')
d = ggplot(df, aes(x=factor(restecg), fill=output))+
geom_bar(stat="count", width=0.7, position=position_dodge()) +
labs(x="restecg", y = "count")+
theme(legend.title = element_blank()) +
theme(legend.position = 'none')
e = ggplot(df, aes(x=factor(exang), fill=output))+
geom_bar(stat="count", width=0.7, position=position_dodge()) +
labs(x="exang", y = "count")+
theme(legend.title = element_blank()) +
theme(legend.position = 'none')
f = ggplot(df, aes(x=factor(slope), fill=output))+
geom_bar(stat="count", width=0.7, position=position_dodge()) +
labs(x="slope", y = "count")+
theme(legend.title = element_blank()) +
theme(legend.position = 'none')
g = ggplot(df, aes(x=factor(ca), fill=output))+
geom_bar(stat="count", width=0.7, position=position_dodge()) +
labs(x="ca", y = "count")+
theme(legend.title = element_blank()) +
theme(legend.position = 'none')
h = ggplot(df, aes(x=factor(thal), fill=output))+
geom_bar(stat="count", width=0.7, position=position_dodge()) +
labs(x="thal", y = "count")+
theme(legend.title = element_blank()) +
theme(legend.position = 'none')
i = ggplot(df, aes(x=factor(output), fill=output))+
geom_bar(stat="count", width=0.7, position=position_dodge()) +
labs(x="output", y = "count")
grid.arrange(a,b,c,d,e,f,g,h,i, nrow=3, ncol=3)
ggpairs(num_vars_with_target_variable, ggplot2::aes(colour=output))
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
# check if there is any heavy correlating independent variables to avoid multi-collinearity issue
library(corrplot)
relation <- cor(num_vars, method='pearson')
corrplot(relation, method = "circle", type="upper")
model_df <- df
library(caret)
library(klaR)
## Loading required package: MASS
##
## Attaching package: 'MASS'
## The following object is masked from 'package:dplyr':
##
## select
library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:gridExtra':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
## The following object is masked from 'package:dplyr':
##
## combine
library(class)
library(mlbench)
GLM <-function(s, df, col) {
# logistics regression
set.seed(1)
trainIndex<-caret::createDataPartition(col, p=s, list=F)
data_train<-df[trainIndex,]
data_test<-df[-trainIndex,]
model <- train(as.factor(output)~., data=data_train, method="glm")
# estimate variable importance
importance <- varImp(model, scale=FALSE)
# make predictions
x_test <- data_test[,1:length(df)-1]
y_test <- data_test[,length(df)]
predictions <- stats::predict(model, x_test, type="raw")
cm<-caret::confusionMatrix(predictions, as.factor(y_test))
return(list(cm=cm,importance=importance))
}
NB <-function(s, df, col) {
# klaR NaiveBayes
set.seed(1)
trainIndex<-caret::createDataPartition(col, p=s, list=F)
data_train<-df[trainIndex,]
data_test<-df[-trainIndex,]
model <- train(as.factor(output)~., data=data_train, method="nb")
# estimate variable importance
importance <- varImp(model, scale=FALSE)
# make predictions
x_test <- data_test[,1:length(df)-1]
y_test <- data_test[,length(df)]
predictions <- stats::predict(model, x_test, type="raw")
cm<-caret::confusionMatrix(predictions, as.factor(y_test))
return(list(cm=cm,importance=importance))
}
RF <-function(s, df, col) {
set.seed(1)
trainIndex<-caret::createDataPartition(col, p=s, list=F)
data_train<-df[trainIndex,]
data_test<-df[-trainIndex,]
model <- caret::train(as.factor(output)~., data=data_train, method="rf")
# estimate variable importance
importance <- varImp(model, scale=FALSE)
# make predictions
x_test <- data_test[,1:length(df)-1]
y_test <- data_test[,length(df)]
predictions <- stats::predict(model, x_test, type="raw")
cm<-caret::confusionMatrix(predictions, as.factor(y_test))
return(list(cm=cm,importance=importance))
}
KNN <-function(s, df, col) {
set.seed(1)
trainIndex<-createDataPartition(col, p=s, list=F)
data_train<-df[trainIndex,]
data_test<-df[-trainIndex,]
model <- caret::train(as.factor(output)~., data=data_train, method = "knn")
# estimate variable importance
importance <- varImp(model, scale=FALSE)
# make predictions
x_test <- data_test[,1:length(df)-1]
y_test <- data_test[,length(df)]
predictions <- stats::predict(model, x_test)
cm<-caret::confusionMatrix(predictions, as.factor(y_test))
return(list(cm=cm,importance=importance))
}
split<-0.70 # 70%/30% train/test
glm_result<-GLM(split, model_df, model_df$output)
# confusion matrix of logistics regression
glm_result$cm
## Confusion Matrix and Statistics
##
## Reference
## Prediction Heart disease No disease
## Heart disease 35 7
## No disease 6 41
##
## Accuracy : 0.8539
## 95% CI : (0.7632, 0.9199)
## No Information Rate : 0.5393
## P-Value [Acc > NIR] : 3.064e-10
##
## Kappa : 0.7066
##
## Mcnemar's Test P-Value : 1
##
## Sensitivity : 0.8537
## Specificity : 0.8542
## Pos Pred Value : 0.8333
## Neg Pred Value : 0.8723
## Prevalence : 0.4607
## Detection Rate : 0.3933
## Detection Prevalence : 0.4719
## Balanced Accuracy : 0.8539
##
## 'Positive' Class : Heart disease
##
split<-0.70 # 70%/30% train/test
nb_result<-NB(split, model_df, model_df$output)
# confusion matrix of Naive Bayes
nb_result$cm
## Confusion Matrix and Statistics
##
## Reference
## Prediction Heart disease No disease
## Heart disease 31 4
## No disease 10 44
##
## Accuracy : 0.8427
## 95% CI : (0.7502, 0.9112)
## No Information Rate : 0.5393
## P-Value [Acc > NIR] : 1.452e-09
##
## Kappa : 0.68
##
## Mcnemar's Test P-Value : 0.1814
##
## Sensitivity : 0.7561
## Specificity : 0.9167
## Pos Pred Value : 0.8857
## Neg Pred Value : 0.8148
## Prevalence : 0.4607
## Detection Rate : 0.3483
## Detection Prevalence : 0.3933
## Balanced Accuracy : 0.8364
##
## 'Positive' Class : Heart disease
##
split<-0.70 # 70%/30% train/test
rf_result <-RF(split, model_df, model_df$output)
# confusion matrix of Random forest
rf_result$cm
## Confusion Matrix and Statistics
##
## Reference
## Prediction Heart disease No disease
## Heart disease 33 7
## No disease 8 41
##
## Accuracy : 0.8315
## 95% CI : (0.7373, 0.9025)
## No Information Rate : 0.5393
## P-Value [Acc > NIR] : 6.345e-09
##
## Kappa : 0.6602
##
## Mcnemar's Test P-Value : 1
##
## Sensitivity : 0.8049
## Specificity : 0.8542
## Pos Pred Value : 0.8250
## Neg Pred Value : 0.8367
## Prevalence : 0.4607
## Detection Rate : 0.3708
## Detection Prevalence : 0.4494
## Balanced Accuracy : 0.8295
##
## 'Positive' Class : Heart disease
##
split<-0.70 # 70%/30% train/test
knn_result <-KNN(split, model_df, model_df$output)
# confusion matrix of KNN
knn_result$cm
## Confusion Matrix and Statistics
##
## Reference
## Prediction Heart disease No disease
## Heart disease 23 11
## No disease 18 37
##
## Accuracy : 0.6742
## 95% CI : (0.5666, 0.7698)
## No Information Rate : 0.5393
## P-Value [Acc > NIR] : 0.006719
##
## Kappa : 0.336
##
## Mcnemar's Test P-Value : 0.265205
##
## Sensitivity : 0.5610
## Specificity : 0.7708
## Pos Pred Value : 0.6765
## Neg Pred Value : 0.6727
## Prevalence : 0.4607
## Detection Rate : 0.2584
## Detection Prevalence : 0.3820
## Balanced Accuracy : 0.6659
##
## 'Positive' Class : Heart disease
##
par(mfrow=c(2,2), oma=c(0, 0, 3, 0))
# Visualizing Confusion Matrix
fourfoldplot(as.table(nb_result$cm),color=c("yellow","pink"),main = "Naive Bayes classifier")
# Visualizing Confusion Matrix
fourfoldplot(as.table(glm_result$cm),color=c("yellow","pink"),main = "Logistics Regression")
# Visualizing Confusion Matrix
fourfoldplot(as.table(rf_result$cm),color=c("yellow","pink"),main = "Random Forest classifier")
# Visualizing Confusion Matrix
fourfoldplot(as.table(knn_result$cm),color=c("yellow","pink"),main = "KNN Classifier")
# Adding a title to the entire plot layout
mtext("Confusion Matrix Visualizations", side=3, outer=TRUE, line=1, cex=1.5, font=2)
# compare which which features are better for heart disease prediction?
library(gridExtra)
library(grid)
# Generate feature importance plots
fi_glm <- plot(glm_result$importance, main="Logistics Regression")
fi_nb <- plot(nb_result$importance, main="Naive Bayes")
fi_rf <- plot(rf_result$importance, main="Random Forest")
fi_knn <- plot(knn_result$importance, main="KNN")
grid.arrange(fi_glm, fi_nb, fi_rf, fi_knn, nrow = 2, ncol = 2,
top = textGrob("Feature Importance Plots",
gp = gpar(fontsize = 20, fontface = "bold")))
Cross-validation is a statistical method used to assess how well the results of a machine learning model generalize to an independent data set.
It estimates the accuracy of predictive models by partitioning the data into multiple subsets, training the model on some subsets, and testing it on the remaining subsets. This process is repeated multiple times to ensure a comprehensive evaluation.
In our case, cross-validation helps determine which model is the most accurate for forecasting by averaging the performance metrics across different iterations, thus providing a reliable estimate of the model’s generalizability.
# Cross validation with Logistics Regression
GLM_CV_accuracy <-function(s, df, col, number=10) {
set.seed(1)
train_control <- caret::trainControl(method='cv', number=number)
trainIndex<-caret::createDataPartition(col, p=s, list=F)
data_train<-df[trainIndex,]
cv_result <- train(as.factor(output) ~., data=data_train,
trControl=train_control, method='glm',
metric="accuracy")
return(cv_result)
}
# Cross validation with Naive Bayes
NB_CV_accuracy <-function(s, df, col, number=10) {
set.seed(1)
train_control <- caret::trainControl(method='cv', number=number)
trainIndex<-caret::createDataPartition(col, p=s, list=F)
data_train<-df[trainIndex,]
cv_result <- train(as.factor(output) ~., data=data_train,
trControl=train_control, method='nb',
metric="accuracy")
return(cv_result)
}
# Cross validation with Random forest
RF_CV_accuracy <-function(s, df, col, number=10) {
set.seed(1)
train_control <- caret::trainControl(method='cv', number=number)
trainIndex<-caret::createDataPartition(col, p=s, list=F)
data_train<-df[trainIndex,]
cv_result <- train(as.factor(output)~., data=data_train,
method="rf", # this will use the randomForest::randomForest function
metric="Accuracy",trControl=train_control)
return(cv_result)
}
# Cross validation with K-Nearest Neigbours (KNN)
KNN_CV_accuracy <-function(s, df, col, number=10) {
set.seed(1)
train_control <- caret::trainControl(method='cv', number=number)
trainIndex<-caret::createDataPartition(col, p=s, list=F)
data_train<-df[trainIndex,]
cv_result <- train(as.factor(output) ~., data=data_train,
trControl=train_control, method='knn',
metric="Accuracy")
return(cv_result)
}
glm_cv <- GLM_CV_accuracy(0.7, model_df, model_df$output, number=10)
glm_cv
## Generalized Linear Model
##
## 208 samples
## 13 predictor
## 2 classes: 'Heart disease', 'No disease'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 188, 186, 187, 187, 188, 187, ...
## Resampling results:
##
## Accuracy Kappa
## 0.7979221 0.5927979
nb_cv <- NB_CV_accuracy(0.7, model_df, model_df$output, number=10)
nb_cv
## Naive Bayes
##
## 208 samples
## 13 predictor
## 2 classes: 'Heart disease', 'No disease'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 188, 186, 187, 187, 188, 187, ...
## Resampling results across tuning parameters:
##
## usekernel Accuracy Kappa
## FALSE 0.8037518 0.6029040
## TRUE 0.7892857 0.5644714
##
## Tuning parameter 'fL' was held constant at a value of 0
## Tuning
## parameter 'adjust' was held constant at a value of 1
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were fL = 0, usekernel = FALSE and adjust
## = 1.
rf_cv <- RF_CV_accuracy(0.7, model_df, model_df$output, number=20)
rf_cv
## Random Forest
##
## 208 samples
## 13 predictor
## 2 classes: 'Heart disease', 'No disease'
##
## No pre-processing
## Resampling: Cross-Validated (20 fold)
## Summary of sample sizes: 197, 198, 197, 197, 198, 198, ...
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.8227273 0.6384762
## 11 0.7931818 0.5796070
## 20 0.7831818 0.5598600
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
knn_cv <- KNN_CV_accuracy(0.7, model_df, model_df$output, number=10)
knn_cv
## k-Nearest Neighbors
##
## 208 samples
## 13 predictor
## 2 classes: 'Heart disease', 'No disease'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 188, 186, 187, 187, 188, 187, ...
## Resampling results across tuning parameters:
##
## k Accuracy Kappa
## 5 0.6052814 0.2034778
## 7 0.6075758 0.2058725
## 9 0.6028139 0.2017203
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 7.
library(tidyr)
ml_accuracy <- data.frame(logistics_regression=glm_cv$resample$Accuracy, naive_bayes=nb_cv$resample$Accuracy,
KNN=knn_cv$resample$Accuracy, random_forest=rf_cv$resample$Accuracy)
ml_accuracy <- ml_accuracy %>% gather("groups", "accuracy")
ml_accuracy
## groups accuracy
## 1 logistics_regression 0.7500000
## 2 logistics_regression 0.8636364
## 3 logistics_regression 0.7619048
## 4 logistics_regression 0.7142857
## 5 logistics_regression 0.6500000
## 6 logistics_regression 0.8571429
## 7 logistics_regression 0.9000000
## 8 logistics_regression 0.7727273
## 9 logistics_regression 0.9000000
## 10 logistics_regression 0.8095238
## 11 logistics_regression 0.7500000
## 12 logistics_regression 0.8636364
## 13 logistics_regression 0.7619048
## 14 logistics_regression 0.7142857
## 15 logistics_regression 0.6500000
## 16 logistics_regression 0.8571429
## 17 logistics_regression 0.9000000
## 18 logistics_regression 0.7727273
## 19 logistics_regression 0.9000000
## 20 logistics_regression 0.8095238
## 21 naive_bayes NA
## 22 naive_bayes 0.8181818
## 23 naive_bayes 0.8571429
## 24 naive_bayes 0.6666667
## 25 naive_bayes 0.8000000
## 26 naive_bayes 0.8095238
## 27 naive_bayes 0.8500000
## 28 naive_bayes 0.7727273
## 29 naive_bayes 0.8500000
## 30 naive_bayes 0.8095238
## 31 naive_bayes NA
## 32 naive_bayes 0.8181818
## 33 naive_bayes 0.8571429
## 34 naive_bayes 0.6666667
## 35 naive_bayes 0.8000000
## 36 naive_bayes 0.8095238
## 37 naive_bayes 0.8500000
## 38 naive_bayes 0.7727273
## 39 naive_bayes 0.8500000
## 40 naive_bayes 0.8095238
## 41 KNN 0.7727273
## 42 KNN 0.4000000
## 43 KNN 0.7619048
## 44 KNN 0.5714286
## 45 KNN 0.6190476
## 46 KNN 0.4500000
## 47 KNN 0.6363636
## 48 KNN 0.6000000
## 49 KNN 0.7142857
## 50 KNN 0.5500000
## 51 KNN 0.7727273
## 52 KNN 0.4000000
## 53 KNN 0.7619048
## 54 KNN 0.5714286
## 55 KNN 0.6190476
## 56 KNN 0.4500000
## 57 KNN 0.6363636
## 58 KNN 0.6000000
## 59 KNN 0.7142857
## 60 KNN 0.5500000
## 61 random_forest 0.9090909
## 62 random_forest 0.9000000
## 63 random_forest 0.7000000
## 64 random_forest 0.7272727
## 65 random_forest 0.8181818
## 66 random_forest 0.6363636
## 67 random_forest 0.9000000
## 68 random_forest 0.9000000
## 69 random_forest 0.8181818
## 70 random_forest 0.8181818
## 71 random_forest 0.9090909
## 72 random_forest 0.8000000
## 73 random_forest 1.0000000
## 74 random_forest 0.6000000
## 75 random_forest 0.7000000
## 76 random_forest 1.0000000
## 77 random_forest 0.9000000
## 78 random_forest 0.8000000
## 79 random_forest 0.8181818
## 80 random_forest 0.8000000
# calculate average accuracy of the models
accuracy_summary <- ml_accuracy %>% group_by(groups) %>% summarise(avg_accuracy=mean(accuracy, na.rm=TRUE),
median_accuracy=median(accuracy, na.rm=TRUE))
accuracy_summary
## # A tibble: 4 × 3
## groups avg_accuracy median_accuracy
## <chr> <dbl> <dbl>
## 1 KNN 0.608 0.610
## 2 logistics_regression 0.798 0.791
## 3 naive_bayes 0.804 0.810
## 4 random_forest 0.823 0.818
# plotting boxplot for model accurracies
ggplot(ml_accuracy,aes(x=groups,y=accuracy)) +
geom_boxplot() + ggtitle("Boxplot - Model accuracies under 10-fold Cross Validation") +
xlab("ML models") + coord_flip()
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
GLM_CV_ROC <-function(s, df, col) {
set.seed(1)
train_control <- trainControl(method = "boot",
number = 10,
returnResamp = 'none',
summaryFunction = twoClassSummary,
classProbs = TRUE,
savePredictions = TRUE)
trainIndex<-caret::createDataPartition(col, p=s, list=F)
data_train<-df[trainIndex,]
data_test<-df[-trainIndex,]
x_test <- data_test[,1:length(df)-1]
y_test <- data_test[,length(df)]
model <- caret::train(make.names(output)~., data=data_train,
trControl=train_control,
method="glm", metric="ROC")
# make predictions
predictions <- stats::predict(model, x_test, type='prob')
roc_curve <- roc(y_test, predictions[,2])
return(roc_curve)
}
NB_CV_ROC <-function(s, df, col) {
set.seed(1)
train_control <- trainControl(method = "boot",
number = 10,
returnResamp = 'none',
summaryFunction = twoClassSummary,
classProbs = TRUE,
savePredictions = TRUE)
trainIndex<-caret::createDataPartition(col, p=s, list=F)
data_train<-df[trainIndex,]
data_test<-df[-trainIndex,]
x_test <- data_test[,1:length(df)-1]
y_test <- data_test[,length(df)]
model <- caret::train(make.names(output)~., data=data_train,
trControl=train_control,
method='nb', metric="ROC")
# make predictions
predictions <- stats::predict(model, x_test, type='prob')
roc_curve <- roc(y_test, predictions[,2])
return(roc_curve)
}
RF_CV_ROC <-function(s, df, col) {
set.seed(1)
# train test split
trainIndex<-caret::createDataPartition(col, p=s, list=F)
data_train<-df[trainIndex,]
data_test<-df[-trainIndex,]
x_test <- data_test[,1:length(df)-1]
y_test <- data_test[,length(df)]
# Random forest under k-folder cross validation with ROC as metric
train_control <- trainControl(method = "boot",
number = 10,
returnResamp = 'none',
summaryFunction = twoClassSummary,
classProbs = TRUE,
savePredictions = TRUE)
model <- caret::train(make.names(output)~., data=data_train,
trControl=train_control,
method="rf", metric="ROC")
# make predictions
predictions <- stats::predict(model, x_test, type='prob')
roc_curve <- roc(y_test, predictions[,2])
return(roc_curve)
}
KNN_CV_ROC <-function(s, df, col) {
set.seed(1)
# train test split
trainIndex<-caret::createDataPartition(col, p=s, list=F)
data_train<-df[trainIndex,]
data_test<-df[-trainIndex,]
x_test <- data_test[,1:length(df)-1]
y_test <- data_test[,length(df)]
# KNN under k-folder cross validation with ROC as metric
train_control <- trainControl(method = "boot",
number = 10,
returnResamp = 'none',
summaryFunction = twoClassSummary,
classProbs = TRUE,
savePredictions = TRUE)
model <- caret::train(make.names(output)~., data=data_train,
trControl=train_control,
method="knn", metric="ROC")
# make predictions
predictions <- stats::predict(model, x_test, type='prob')
roc_curve <- roc(y_test, predictions[,2])
return(roc_curve)
}
split <- 0.7
roc_glm <- GLM_CV_ROC(split, model_df, model_df$output)
## Setting levels: control = Heart disease, case = No disease
## Setting direction: controls < cases
roc_nb <- NB_CV_ROC(split, model_df, model_df$output)
## Setting levels: control = Heart disease, case = No disease
## Setting direction: controls < cases
roc_rf <- RF_CV_ROC(split, model_df, model_df$output)
## Setting levels: control = Heart disease, case = No disease
## Setting direction: controls < cases
roc_knn <- KNN_CV_ROC(split, model_df, model_df$output)
## Setting levels: control = Heart disease, case = No disease
## Setting direction: controls < cases
auc_labels <- c(paste0("GLM (AUC=", round(roc_glm$auc,4), ")"),
paste0("NB (AUC=", round(roc_nb$auc,4), ")"),
paste0("RF (AUC=", round(roc_rf$auc, 4), ")"),
paste0("KNN (AUC=", round(roc_knn$auc, 4), ")"))
# Plotting both ROC curves on the same plot
ggroc(list(GLM=roc_glm, NB = roc_nb, RF = roc_rf, KNN = roc_knn), legacy.axes=TRUE) +
ggtitle("Average ROC Curves after 10-fold cross validation") + scale_color_discrete(labels=auc_labels)
Among 4 models we tried to predict the heart disease, here are the results: