Heart disease is the #2 cause of death in Malaysia (DOSM, 2022), and 1 in 5 heart attack patients are younger than 40s. So, early detection of heart disease is very important to reduce the mortality rate.

Dataset

This dataset describes the contents of the heart-disease diagnosis.

Check it out here for interactive notebook: https://colab.research.google.com/drive/1oTTgcuzR6jFtnBGBHRrAvNhJ_uNWOzP0?usp=sharing for this prediction model.

Variable Table

Feature Data Type Description
age Integer age in years
sex Categorical sex (1 = male; 0 = female)
cp Categorical chest pain type

- Value 1: typical angina
- Value 2: atypical angina
- Value 3: non-anginal pain
- Value 4: asymptomatic
trestbps Integer resting blood pressure (in mm Hg on admission to the hospital)
chol Integer serum cholestoral in mg/dl
fbs Categorical (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)
restecg Categorical resting electrocardiographic results
- Value 0: normal
- Value 1: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV)
- Value 2: showing probable or definite left ventricular hypertrophy by Estes’ criteria
thalach Integer maximum heart rate achieved
exang Categorical exercise induced angina (1 = yes; 0 = no)
oldpeak Integer ST depression induced by exercise relative to rest
slope Categorical the slope of the peak exercise ST segment
- Value 1: upsloping
- Value 2: flat
- Value 3: downsloping
ca Integer number of major vessels (0-3) colored by flourosopy
thal Categorical Thalassemia (3 = normal; 6 = fixed defect; 7 = reversable defect)
num Categorical the predicted attribute, which it represents the diagnosis of heart disease (angiographic disease status)
- Value 0: < 50% diameter narrowing
- Value 1: > 50% diameter narrowing


Context: The “goal” field refers to the presence of heart disease in the patient. It is integer valued from 0 (no presence) to 4. Experiments with the Cleveland database have concentrated on simply attempting to distinguish presence (values 1,2,3,4) from absence (value 0).

Part 0 - Load dataset

library('dplyr')
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library('ggplot2')
library('corrplot')
## corrplot 0.92 loaded
library('caret')
## Loading required package: lattice
# download dataset
# The original dataset is from UCI Machine Learning Repository: https://archive.ics.uci.edu/dataset/45/heart+disease
data_URL="https://drive.usercontent.google.com/download?id=1N-7oPru9BSyRHsq6VxMKGpsb9GFC75rV&export=download&authuser=0&confirm=t&uuid=99167f5a-72ec-473d-91d2-02cb2006fb5a&at=APZUnTVBQLvU28NrrHXxwlha56-A:1716345861329"
download.file(data_URL, destfile= "./cleveland.data")
df <- read.csv("cleveland.data", header=FALSE)
colnames(df) <- c("age", "sex", "cp", "trestbps", "chol", "fbs", "restecg",
                  "thalach", "exang", "oldpeak", "slope", "ca", "thal",
                  "num")
head(df, 5)
##   age sex cp trestbps chol fbs restecg thalach exang oldpeak slope  ca thal num
## 1  63   1  1      145  233   1       2     150     0     2.3     3 0.0  6.0   0
## 2  67   1  4      160  286   0       2     108     1     1.5     2 3.0  3.0   2
## 3  67   1  4      120  229   0       2     129     1     2.6     2 2.0  7.0   1
## 4  37   1  3      130  250   0       0     187     0     3.5     3 0.0  3.0   0
## 5  41   0  2      130  204   0       2     172     0     1.4     1 0.0  3.0   0

Part 1 - Data Processing

1.1 Data cleaning - Removing duplicated rows (if any)

no_of_duplicated_rows <- sum(duplicated(df))
if (no_of_duplicated_rows > 0){
  df <- df[!any(duplicated(df))]
  print(paste0(no_of_duplicated_rows, " rows are removed!"))
} else {
  print(paste0("There is no duplicated rows."))
}
## [1] "There is no duplicated rows."

1.2 Data cleaning - Checking & Handling missing values

no_of_missing_values = sum(is.na(df))
if (no_of_missing_values > 0){
  print(paste0("There is ", no_of_missing_values, " missing values"))
} else {
  print("There is no missing values")
}
## [1] "There is no missing values"

However, there is “?” symbol inside the dataset, which is actually the missing values!

rows_with_question_mark <- filter_all(df, any_vars(. == "?"))
rows_with_question_mark
##   age sex cp trestbps chol fbs restecg thalach exang oldpeak slope  ca thal num
## 1  53   0  3      128  216   0       2     115     0     0.0     1 0.0    ?   0
## 2  52   1  3      138  223   0       0     169     0     0.0     1   ?  3.0   0
## 3  43   1  4      132  247   1       2     143     1     0.1     2   ?  7.0   1
## 4  52   1  4      128  204   1       0     156     1     1.0     2 0.0    ?   2
## 5  58   1  2      125  220   0       0     144     0     0.4     2   ?  7.0   0
## 6  38   1  3      138  175   0       0     173     0     0.0     1   ?  3.0   0
print(paste0(nrow(rows_with_question_mark), " rows with missing values are removed!"))
## [1] "6 rows with missing values are removed!"
# remove rows with '?' symbols
df <- df %>% replace(df=="?", NA) %>% na.omit()

head(df, 5)
##   age sex cp trestbps chol fbs restecg thalach exang oldpeak slope  ca thal num
## 1  63   1  1      145  233   1       2     150     0     2.3     3 0.0  6.0   0
## 2  67   1  4      160  286   0       2     108     1     1.5     2 3.0  3.0   2
## 3  67   1  4      120  229   0       2     129     1     2.6     2 2.0  7.0   1
## 4  37   1  3      130  250   0       0     187     0     3.5     3 0.0  3.0   0
## 5  41   0  2      130  204   0       2     172     0     1.4     1 0.0  3.0   0

1.3 Change correct data types

# assign correct data type to the dataframe variables

df <- df %>% mutate(
  age=as.integer(age),
  sex=as.factor(sex),
  cp=as.factor(cp),
  trestbps=as.integer(trestbps),
  chol=as.integer(chol),
  fbs=as.factor(fbs),
  restecg=as.factor(restecg),
  thalach=as.integer(thalach),
  exang=as.factor(exang),
  oldpeak=as.integer(oldpeak),
  slope=as.factor(slope),
  ca=as.factor(ca),
  thal=as.factor(thal),
  num=as.integer(num)
)
glimpse(df)
## Rows: 297
## Columns: 14
## $ age      <int> 63, 67, 67, 37, 41, 56, 62, 57, 63, 53, 57, 56, 56, 44, 52, 5…
## $ sex      <fct> 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1…
## $ cp       <fct> 1, 4, 4, 3, 2, 2, 4, 4, 4, 4, 4, 2, 3, 2, 3, 3, 2, 4, 3, 2, 1…
## $ trestbps <int> 145, 160, 120, 130, 130, 120, 140, 120, 130, 140, 140, 140, 1…
## $ chol     <int> 233, 286, 229, 250, 204, 236, 268, 354, 254, 203, 192, 294, 2…
## $ fbs      <fct> 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0…
## $ restecg  <fct> 2, 2, 2, 0, 2, 0, 2, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 2…
## $ thalach  <int> 150, 108, 129, 187, 172, 178, 160, 163, 147, 155, 148, 153, 1…
## $ exang    <fct> 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1…
## $ oldpeak  <int> 2, 1, 2, 3, 1, 0, 3, 0, 1, 3, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1…
## $ slope    <fct> 3, 2, 2, 3, 1, 1, 3, 1, 2, 3, 2, 2, 2, 1, 1, 1, 3, 1, 1, 1, 2…
## $ ca       <fct> 0.0, 3.0, 2.0, 0.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1…
## $ thal     <fct> 6.0, 3.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 6.0, 3.0, 6…
## $ num      <int> 0, 2, 1, 0, 0, 0, 3, 0, 2, 1, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0…
summary(df)
##       age        sex     cp         trestbps          chol       fbs    
##  Min.   :29.00   0: 96   1: 23   Min.   : 94.0   Min.   :126.0   0:254  
##  1st Qu.:48.00   1:201   2: 49   1st Qu.:120.0   1st Qu.:211.0   1: 43  
##  Median :56.00           3: 83   Median :130.0   Median :243.0          
##  Mean   :54.54           4:142   Mean   :131.7   Mean   :247.4          
##  3rd Qu.:61.00                   3rd Qu.:140.0   3rd Qu.:276.0          
##  Max.   :77.00                   Max.   :200.0   Max.   :564.0          
##  restecg    thalach      exang      oldpeak       slope     ca       thal    
##  0:147   Min.   : 71.0   0:200   Min.   :0.0000   1:139   0.0:174   3.0:164  
##  1:  4   1st Qu.:133.0   1: 97   1st Qu.:0.0000   2:137   1.0: 65   6.0: 18  
##  2:146   Median :153.0           Median :0.0000   3: 21   2.0: 38   7.0:115  
##          Mean   :149.6           Mean   :0.7778           3.0: 20            
##          3rd Qu.:166.0           3rd Qu.:1.0000                              
##          Max.   :202.0           Max.   :6.0000                              
##       num        
##  Min.   :0.0000  
##  1st Qu.:0.0000  
##  Median :0.0000  
##  Mean   :0.9461  
##  3rd Qu.:2.0000  
##  Max.   :4.0000

1.4 Data transformation

# we want to do the binary classification for heart disease where 0 = no heart disease, 1 = heart disease
# thus we transform the 'num' target variable with values ranging 0 to 4 (which represents the severity of heart disease) to a binary target
df$output <- factor(ifelse(df$num > 0, 1, 0), levels=c(1, 0), labels=c("Heart disease","No disease"))
df <- subset(df, select=-num)
head(df, 5)
##   age sex cp trestbps chol fbs restecg thalach exang oldpeak slope  ca thal
## 1  63   1  1      145  233   1       2     150     0       2     3 0.0  6.0
## 2  67   1  4      160  286   0       2     108     1       1     2 3.0  3.0
## 3  67   1  4      120  229   0       2     129     1       2     2 2.0  7.0
## 4  37   1  3      130  250   0       0     187     0       3     3 0.0  3.0
## 5  41   0  2      130  204   0       2     172     0       1     1 0.0  3.0
##          output
## 1    No disease
## 2 Heart disease
## 3 Heart disease
## 4    No disease
## 5    No disease
## define numerical columns , categorical columns
num_vars <- subset(df, select=c(age, trestbps, chol, thalach, oldpeak))
num_vars_with_target_variable <- subset(df, select=c(age, trestbps, chol, thalach, oldpeak, output))
cat_vars <- subset(df, select=c(sex, fbs, restecg, exang, slope, ca, thal))
cat_vars_with_target_variable <- subset(df, select=c(sex, fbs, restecg, exang, slope, ca, thal, output))

1.5 Checking & Handling Outlier (if any)

# check if there is any outliers for numerical variables
boxplot(num_vars, main="Boxplot for numerical variables")

# Looking into outliers of each numerical variables, by the `output` category

library(dplyr)
library(ggplot2)
library(gridExtra)
## 
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
## 
##     combine
# Modify the dataframe
heart_disease_comparison_df <- df %>%
  mutate(output=ifelse(output == 0, "No disease", "Heart disease"))

# Create boxplots using ggplot2
age_plot <- ggplot(heart_disease_comparison_df, aes(x = output, y = age)) +
  geom_boxplot() +
  ggtitle("Age distribution") +
  xlab("Condition") +
  ylab("Age")

trestbps_plot <- ggplot(heart_disease_comparison_df, aes(x = output, y=trestbps)) +
  geom_boxplot() +
  ggtitle("trestbps distribution") +
  xlab("Condition") +
  ylab("trestbps")

chol_plot <- ggplot(heart_disease_comparison_df, aes(x = output, y = chol)) +
  geom_boxplot() +
  ggtitle("Cholesterol distribution") +
  xlab("Condition") +
  ylab("chol")

thalach_plot <- ggplot(heart_disease_comparison_df, aes(x = output, y = thalach)) +
  geom_boxplot() +
  ggtitle("thalach distribution") +
  xlab("Condition") +
  ylab("thalach")

oldpeak_plot <- ggplot(heart_disease_comparison_df, aes(x = output, y = oldpeak)) +
  geom_boxplot() +
  ggtitle("oldpeak distribution") +
  xlab("Condition") +
  ylab("oldpeak")

# Arrange the plots side by side
grid.arrange(age_plot, trestbps_plot, chol_plot, thalach_plot, oldpeak_plot, nrow = 2, ncol = 3)

Part 2 - Data visualization

2.1 Univariate analysis

# check if there is imbalance in class for target variable
output_class_count <- df %>% group_by(output) %>% summarise(count=n())
labels <- as.vector(output_class_count$output)
values <- as.vector(output_class_count$count)
values_with_percent <- paste0(format(values, big.mark=",", scientific=FALSE), " people",
                  " (", round(values / sum(values) * 100, 2), "%", ")")

pie(values, values_with_percent, main="Heart disease distribution",
    col = rainbow(length(values)))

legend("topright", c("No Disease","Heart Disease"), cex = 0.8,
   fill = rainbow(length(values)))

# numeric var histogram
par(mfrow=c(2,3))

hist(df$age, freq=F, main = "age",
     xlab = "age (target = 1)", ylab = "Density")
hist(df$trestbps, freq=F, main = "trestbps",
     xlab = "trestbps (target = 1)", ylab = "Density")
hist(df$chol, freq=F, main = "chol",
     xlab = "chol (target = 1)", ylab = "Density")
hist(df$thalach, freq=F, main = "thalach",
     xlab = "thalach (target = 1)", ylab = "Density")
hist(df$oldpeak, freq=F, main = "oldpeak",
     xlab = "oldpeak (target = 1)", ylab = "Density")

# create a matrix plot, break down with the binary target variable
library(GGally)
## Registered S3 method overwritten by 'GGally':
##   method from   
##   +.gg   ggplot2
ggpairs(num_vars)

# categorical var barplot
library(gridExtra)

# sex, cp, fbs, restecg, exang, slope, ca, thal, output
a = ggplot(df, aes(x=factor(sex), fill=output))+
  geom_bar(stat="count", width=0.7, position=position_dodge()) +
  labs(x="sex", y = "count") +
  theme(legend.title = element_blank()) +
  theme(legend.position = 'none')

b = ggplot(df, aes(x=factor(cp), fill=output))+
  geom_bar(stat="count", width=0.7, position=position_dodge()) +
  labs(x="cp", y = "count")+
  theme(legend.title = element_blank()) +
  theme(legend.position = 'none')

c = ggplot(df, aes(x=factor(fbs), fill=output))+
  geom_bar(stat="count", width=0.7, position=position_dodge()) +
  labs(x="fbs", y = "count")+
  theme(legend.title = element_blank()) +
  theme(legend.position = 'none')

d = ggplot(df, aes(x=factor(restecg), fill=output))+
  geom_bar(stat="count", width=0.7, position=position_dodge()) +
  labs(x="restecg", y = "count")+
  theme(legend.title = element_blank()) +
  theme(legend.position = 'none')

e = ggplot(df, aes(x=factor(exang), fill=output))+
  geom_bar(stat="count", width=0.7, position=position_dodge()) +
  labs(x="exang", y = "count")+
  theme(legend.title = element_blank()) +
  theme(legend.position = 'none')

f = ggplot(df, aes(x=factor(slope), fill=output))+
  geom_bar(stat="count", width=0.7, position=position_dodge()) +
  labs(x="slope", y = "count")+
  theme(legend.title = element_blank()) +
  theme(legend.position = 'none')

g = ggplot(df, aes(x=factor(ca), fill=output))+
  geom_bar(stat="count", width=0.7, position=position_dodge()) +
  labs(x="ca", y = "count")+
  theme(legend.title = element_blank()) +
  theme(legend.position = 'none')

h = ggplot(df, aes(x=factor(thal), fill=output))+
  geom_bar(stat="count", width=0.7, position=position_dodge()) +
  labs(x="thal", y = "count")+
  theme(legend.title = element_blank()) +
  theme(legend.position = 'none')

i = ggplot(df, aes(x=factor(output), fill=output))+
  geom_bar(stat="count", width=0.7, position=position_dodge()) +
  labs(x="output", y = "count")

grid.arrange(a,b,c,d,e,f,g,h,i, nrow=3, ncol=3)

2.2 Bivariate analysis

ggpairs(num_vars_with_target_variable, ggplot2::aes(colour=output))
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

2.3 Multivariate analysis

# check if there is any heavy correlating independent variables to avoid multi-collinearity issue
library(corrplot)

relation <- cor(num_vars, method='pearson')
corrplot(relation, method = "circle", type="upper")

Part 3 - Modelling

3.1 Build Model

model_df <- df
library(caret)
library(klaR)
## Loading required package: MASS
## 
## Attaching package: 'MASS'
## The following object is masked from 'package:dplyr':
## 
##     select
library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:gridExtra':
## 
##     combine
## The following object is masked from 'package:ggplot2':
## 
##     margin
## The following object is masked from 'package:dplyr':
## 
##     combine
library(class)
library(mlbench)
GLM <-function(s, df, col) {
  # logistics regression
   set.seed(1)
   trainIndex<-caret::createDataPartition(col, p=s, list=F)
   data_train<-df[trainIndex,]
   data_test<-df[-trainIndex,]
   model <- train(as.factor(output)~., data=data_train, method="glm")
  # estimate variable importance
   importance <- varImp(model, scale=FALSE)
   # make predictions
   x_test <- data_test[,1:length(df)-1]
   y_test <- data_test[,length(df)]
   predictions <- stats::predict(model, x_test, type="raw")
   cm<-caret::confusionMatrix(predictions, as.factor(y_test))
   return(list(cm=cm,importance=importance))
}

NB <-function(s, df, col) {
   # klaR NaiveBayes
   set.seed(1)
   trainIndex<-caret::createDataPartition(col, p=s, list=F)
   data_train<-df[trainIndex,]
   data_test<-df[-trainIndex,]
   model <- train(as.factor(output)~., data=data_train, method="nb")
   # estimate variable importance
   importance <- varImp(model, scale=FALSE)
   # make predictions
   x_test <- data_test[,1:length(df)-1]
   y_test <- data_test[,length(df)]
   predictions <- stats::predict(model, x_test, type="raw")
   cm<-caret::confusionMatrix(predictions, as.factor(y_test))
   return(list(cm=cm,importance=importance))
}

RF <-function(s, df, col) {
   set.seed(1)
   trainIndex<-caret::createDataPartition(col, p=s, list=F)
   data_train<-df[trainIndex,]
   data_test<-df[-trainIndex,]
   model <- caret::train(as.factor(output)~., data=data_train, method="rf")
  # estimate variable importance
   importance <- varImp(model, scale=FALSE)
   # make predictions
   x_test <- data_test[,1:length(df)-1]
   y_test <- data_test[,length(df)]
   predictions <- stats::predict(model, x_test, type="raw")
   cm<-caret::confusionMatrix(predictions, as.factor(y_test))
   return(list(cm=cm,importance=importance))
}

KNN <-function(s, df, col) {
   set.seed(1)
   trainIndex<-createDataPartition(col, p=s, list=F)
   data_train<-df[trainIndex,]
   data_test<-df[-trainIndex,]
   model <- caret::train(as.factor(output)~., data=data_train, method = "knn")
   # estimate variable importance
   importance <- varImp(model, scale=FALSE)
   # make predictions
   x_test <- data_test[,1:length(df)-1]
   y_test <- data_test[,length(df)]
   
   predictions <- stats::predict(model, x_test)
   cm<-caret::confusionMatrix(predictions, as.factor(y_test))
   return(list(cm=cm,importance=importance))
}

(a) Logistics regression with 70%/30% train/test

split<-0.70  # 70%/30% train/test
glm_result<-GLM(split, model_df, model_df$output)

# confusion matrix of logistics regression
glm_result$cm
## Confusion Matrix and Statistics
## 
##                Reference
## Prediction      Heart disease No disease
##   Heart disease            35          7
##   No disease                6         41
##                                           
##                Accuracy : 0.8539          
##                  95% CI : (0.7632, 0.9199)
##     No Information Rate : 0.5393          
##     P-Value [Acc > NIR] : 3.064e-10       
##                                           
##                   Kappa : 0.7066          
##                                           
##  Mcnemar's Test P-Value : 1               
##                                           
##             Sensitivity : 0.8537          
##             Specificity : 0.8542          
##          Pos Pred Value : 0.8333          
##          Neg Pred Value : 0.8723          
##              Prevalence : 0.4607          
##          Detection Rate : 0.3933          
##    Detection Prevalence : 0.4719          
##       Balanced Accuracy : 0.8539          
##                                           
##        'Positive' Class : Heart disease   
## 

(b) Naive Bayes with 70%/30% train/test

split<-0.70  # 70%/30% train/test
nb_result<-NB(split, model_df, model_df$output)

# confusion matrix of Naive Bayes
nb_result$cm
## Confusion Matrix and Statistics
## 
##                Reference
## Prediction      Heart disease No disease
##   Heart disease            31          4
##   No disease               10         44
##                                           
##                Accuracy : 0.8427          
##                  95% CI : (0.7502, 0.9112)
##     No Information Rate : 0.5393          
##     P-Value [Acc > NIR] : 1.452e-09       
##                                           
##                   Kappa : 0.68            
##                                           
##  Mcnemar's Test P-Value : 0.1814          
##                                           
##             Sensitivity : 0.7561          
##             Specificity : 0.9167          
##          Pos Pred Value : 0.8857          
##          Neg Pred Value : 0.8148          
##              Prevalence : 0.4607          
##          Detection Rate : 0.3483          
##    Detection Prevalence : 0.3933          
##       Balanced Accuracy : 0.8364          
##                                           
##        'Positive' Class : Heart disease   
## 

(c) Random Forest with 70%/30% train/test

split<-0.70  # 70%/30% train/test
rf_result <-RF(split, model_df, model_df$output)

# confusion matrix of Random forest
rf_result$cm
## Confusion Matrix and Statistics
## 
##                Reference
## Prediction      Heart disease No disease
##   Heart disease            33          7
##   No disease                8         41
##                                           
##                Accuracy : 0.8315          
##                  95% CI : (0.7373, 0.9025)
##     No Information Rate : 0.5393          
##     P-Value [Acc > NIR] : 6.345e-09       
##                                           
##                   Kappa : 0.6602          
##                                           
##  Mcnemar's Test P-Value : 1               
##                                           
##             Sensitivity : 0.8049          
##             Specificity : 0.8542          
##          Pos Pred Value : 0.8250          
##          Neg Pred Value : 0.8367          
##              Prevalence : 0.4607          
##          Detection Rate : 0.3708          
##    Detection Prevalence : 0.4494          
##       Balanced Accuracy : 0.8295          
##                                           
##        'Positive' Class : Heart disease   
## 

(d) K-Nearest Neighbour (KNN) with 70%/30% train/test

split<-0.70  # 70%/30% train/test
knn_result <-KNN(split, model_df, model_df$output)

# confusion matrix of KNN
knn_result$cm
## Confusion Matrix and Statistics
## 
##                Reference
## Prediction      Heart disease No disease
##   Heart disease            23         11
##   No disease               18         37
##                                           
##                Accuracy : 0.6742          
##                  95% CI : (0.5666, 0.7698)
##     No Information Rate : 0.5393          
##     P-Value [Acc > NIR] : 0.006719        
##                                           
##                   Kappa : 0.336           
##                                           
##  Mcnemar's Test P-Value : 0.265205        
##                                           
##             Sensitivity : 0.5610          
##             Specificity : 0.7708          
##          Pos Pred Value : 0.6765          
##          Neg Pred Value : 0.6727          
##              Prevalence : 0.4607          
##          Detection Rate : 0.2584          
##    Detection Prevalence : 0.3820          
##       Balanced Accuracy : 0.6659          
##                                           
##        'Positive' Class : Heart disease   
## 
par(mfrow=c(2,2), oma=c(0, 0, 3, 0))

# Visualizing Confusion Matrix
fourfoldplot(as.table(nb_result$cm),color=c("yellow","pink"),main = "Naive Bayes classifier")

# Visualizing Confusion Matrix
fourfoldplot(as.table(glm_result$cm),color=c("yellow","pink"),main = "Logistics Regression")

# Visualizing Confusion Matrix
fourfoldplot(as.table(rf_result$cm),color=c("yellow","pink"),main = "Random Forest classifier")

# Visualizing Confusion Matrix
fourfoldplot(as.table(knn_result$cm),color=c("yellow","pink"),main = "KNN Classifier")

# Adding a title to the entire plot layout
mtext("Confusion Matrix Visualizations", side=3, outer=TRUE, line=1, cex=1.5, font=2)

3.2 Comparing Feature importance

# compare which which features are better for heart disease prediction?
library(gridExtra)
library(grid)

# Generate feature importance plots
fi_glm <- plot(glm_result$importance, main="Logistics Regression")
fi_nb <- plot(nb_result$importance, main="Naive Bayes")
fi_rf <- plot(rf_result$importance, main="Random Forest")
fi_knn <- plot(knn_result$importance, main="KNN")


grid.arrange(fi_glm, fi_nb, fi_rf, fi_knn, nrow = 2, ncol = 2,
   top = textGrob("Feature Importance Plots",
   gp = gpar(fontsize = 20, fontface = "bold")))

3.3 Cross-validation & Performance Evaluation

Cross-validation is a statistical method used to assess how well the results of a machine learning model generalize to an independent data set.

It estimates the accuracy of predictive models by partitioning the data into multiple subsets, training the model on some subsets, and testing it on the remaining subsets. This process is repeated multiple times to ensure a comprehensive evaluation.

In our case, cross-validation helps determine which model is the most accurate for forecasting by averaging the performance metrics across different iterations, thus providing a reliable estimate of the model’s generalizability.

3.3.1 Performance Evaluation with ‘Accuracy’ Metric under Cross Validation

# Cross validation with Logistics Regression
GLM_CV_accuracy <-function(s, df, col, number=10) {
   set.seed(1)
   train_control <- caret::trainControl(method='cv', number=number)

   trainIndex<-caret::createDataPartition(col, p=s, list=F)
   data_train<-df[trainIndex,]

   cv_result <- train(as.factor(output) ~., data=data_train,
          trControl=train_control, method='glm',
          metric="accuracy")
  return(cv_result)
}

# Cross validation with Naive Bayes
NB_CV_accuracy <-function(s, df, col, number=10) {
   set.seed(1)
   train_control <- caret::trainControl(method='cv', number=number)

   trainIndex<-caret::createDataPartition(col, p=s, list=F)
   data_train<-df[trainIndex,]

   cv_result <- train(as.factor(output) ~., data=data_train,
          trControl=train_control, method='nb',
          metric="accuracy")
  return(cv_result)
}

# Cross validation with Random forest
RF_CV_accuracy <-function(s, df, col, number=10) {
   set.seed(1)
   train_control <- caret::trainControl(method='cv', number=number)

   trainIndex<-caret::createDataPartition(col, p=s, list=F)
   data_train<-df[trainIndex,]

   cv_result <- train(as.factor(output)~., data=data_train,
      method="rf", # this will use the randomForest::randomForest function
      metric="Accuracy",trControl=train_control)
   return(cv_result)
}

# Cross validation with K-Nearest Neigbours (KNN)
KNN_CV_accuracy <-function(s, df, col, number=10) {
  set.seed(1)
   train_control <- caret::trainControl(method='cv', number=number)

  trainIndex<-caret::createDataPartition(col, p=s, list=F)
  data_train<-df[trainIndex,]

  cv_result <- train(as.factor(output) ~., data=data_train,
          trControl=train_control, method='knn',
          metric="Accuracy")
  return(cv_result)
}
(a) Logistics Regression
glm_cv <- GLM_CV_accuracy(0.7, model_df, model_df$output, number=10)
glm_cv
## Generalized Linear Model 
## 
## 208 samples
##  13 predictor
##   2 classes: 'Heart disease', 'No disease' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 188, 186, 187, 187, 188, 187, ... 
## Resampling results:
## 
##   Accuracy   Kappa    
##   0.7979221  0.5927979
(b) Naive Bayes
nb_cv <- NB_CV_accuracy(0.7, model_df, model_df$output, number=10)
nb_cv
## Naive Bayes 
## 
## 208 samples
##  13 predictor
##   2 classes: 'Heart disease', 'No disease' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 188, 186, 187, 187, 188, 187, ... 
## Resampling results across tuning parameters:
## 
##   usekernel  Accuracy   Kappa    
##   FALSE      0.8037518  0.6029040
##    TRUE      0.7892857  0.5644714
## 
## Tuning parameter 'fL' was held constant at a value of 0
## Tuning
##  parameter 'adjust' was held constant at a value of 1
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were fL = 0, usekernel = FALSE and adjust
##  = 1.
(c) Random Forest
rf_cv <- RF_CV_accuracy(0.7, model_df, model_df$output, number=20)
rf_cv
## Random Forest 
## 
## 208 samples
##  13 predictor
##   2 classes: 'Heart disease', 'No disease' 
## 
## No pre-processing
## Resampling: Cross-Validated (20 fold) 
## Summary of sample sizes: 197, 198, 197, 197, 198, 198, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.8227273  0.6384762
##   11    0.7931818  0.5796070
##   20    0.7831818  0.5598600
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
(d) KNN
knn_cv <- KNN_CV_accuracy(0.7, model_df, model_df$output, number=10)
knn_cv
## k-Nearest Neighbors 
## 
## 208 samples
##  13 predictor
##   2 classes: 'Heart disease', 'No disease' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 188, 186, 187, 187, 188, 187, ... 
## Resampling results across tuning parameters:
## 
##   k  Accuracy   Kappa    
##   5  0.6052814  0.2034778
##   7  0.6075758  0.2058725
##   9  0.6028139  0.2017203
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 7.
Comparing accuracy among different models under cross-validation
library(tidyr)

ml_accuracy <- data.frame(logistics_regression=glm_cv$resample$Accuracy, naive_bayes=nb_cv$resample$Accuracy,
                KNN=knn_cv$resample$Accuracy, random_forest=rf_cv$resample$Accuracy)
ml_accuracy <- ml_accuracy %>% gather("groups", "accuracy")
ml_accuracy
##                  groups  accuracy
## 1  logistics_regression 0.7500000
## 2  logistics_regression 0.8636364
## 3  logistics_regression 0.7619048
## 4  logistics_regression 0.7142857
## 5  logistics_regression 0.6500000
## 6  logistics_regression 0.8571429
## 7  logistics_regression 0.9000000
## 8  logistics_regression 0.7727273
## 9  logistics_regression 0.9000000
## 10 logistics_regression 0.8095238
## 11 logistics_regression 0.7500000
## 12 logistics_regression 0.8636364
## 13 logistics_regression 0.7619048
## 14 logistics_regression 0.7142857
## 15 logistics_regression 0.6500000
## 16 logistics_regression 0.8571429
## 17 logistics_regression 0.9000000
## 18 logistics_regression 0.7727273
## 19 logistics_regression 0.9000000
## 20 logistics_regression 0.8095238
## 21          naive_bayes        NA
## 22          naive_bayes 0.8181818
## 23          naive_bayes 0.8571429
## 24          naive_bayes 0.6666667
## 25          naive_bayes 0.8000000
## 26          naive_bayes 0.8095238
## 27          naive_bayes 0.8500000
## 28          naive_bayes 0.7727273
## 29          naive_bayes 0.8500000
## 30          naive_bayes 0.8095238
## 31          naive_bayes        NA
## 32          naive_bayes 0.8181818
## 33          naive_bayes 0.8571429
## 34          naive_bayes 0.6666667
## 35          naive_bayes 0.8000000
## 36          naive_bayes 0.8095238
## 37          naive_bayes 0.8500000
## 38          naive_bayes 0.7727273
## 39          naive_bayes 0.8500000
## 40          naive_bayes 0.8095238
## 41                  KNN 0.7727273
## 42                  KNN 0.4000000
## 43                  KNN 0.7619048
## 44                  KNN 0.5714286
## 45                  KNN 0.6190476
## 46                  KNN 0.4500000
## 47                  KNN 0.6363636
## 48                  KNN 0.6000000
## 49                  KNN 0.7142857
## 50                  KNN 0.5500000
## 51                  KNN 0.7727273
## 52                  KNN 0.4000000
## 53                  KNN 0.7619048
## 54                  KNN 0.5714286
## 55                  KNN 0.6190476
## 56                  KNN 0.4500000
## 57                  KNN 0.6363636
## 58                  KNN 0.6000000
## 59                  KNN 0.7142857
## 60                  KNN 0.5500000
## 61        random_forest 0.9090909
## 62        random_forest 0.9000000
## 63        random_forest 0.7000000
## 64        random_forest 0.7272727
## 65        random_forest 0.8181818
## 66        random_forest 0.6363636
## 67        random_forest 0.9000000
## 68        random_forest 0.9000000
## 69        random_forest 0.8181818
## 70        random_forest 0.8181818
## 71        random_forest 0.9090909
## 72        random_forest 0.8000000
## 73        random_forest 1.0000000
## 74        random_forest 0.6000000
## 75        random_forest 0.7000000
## 76        random_forest 1.0000000
## 77        random_forest 0.9000000
## 78        random_forest 0.8000000
## 79        random_forest 0.8181818
## 80        random_forest 0.8000000
# calculate average accuracy of the models
accuracy_summary <- ml_accuracy %>% group_by(groups) %>% summarise(avg_accuracy=mean(accuracy, na.rm=TRUE),
                                      median_accuracy=median(accuracy, na.rm=TRUE))
accuracy_summary
## # A tibble: 4 × 3
##   groups               avg_accuracy median_accuracy
##   <chr>                       <dbl>           <dbl>
## 1 KNN                         0.608           0.610
## 2 logistics_regression        0.798           0.791
## 3 naive_bayes                 0.804           0.810
## 4 random_forest               0.823           0.818
# plotting boxplot for model accurracies 
ggplot(ml_accuracy,aes(x=groups,y=accuracy)) +
  geom_boxplot() + ggtitle("Boxplot - Model accuracies under 10-fold Cross Validation") +
   xlab("ML models") + coord_flip()

3.3.2 Performance evaluation with ROC Curve under Cross Validation

library(pROC)
## Type 'citation("pROC")' for a citation.
## 
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
GLM_CV_ROC <-function(s, df, col) {
   set.seed(1)
   train_control <- trainControl(method = "boot",
                           number = 10,
                           returnResamp = 'none',
                           summaryFunction = twoClassSummary,
                           classProbs = TRUE,
                           savePredictions = TRUE)

   trainIndex<-caret::createDataPartition(col, p=s, list=F)
   data_train<-df[trainIndex,]
   data_test<-df[-trainIndex,]
   x_test <- data_test[,1:length(df)-1]
   y_test <- data_test[,length(df)]

   model <- caret::train(make.names(output)~., data=data_train,
            trControl=train_control,
            method="glm", metric="ROC")
   # make predictions
   predictions <- stats::predict(model, x_test, type='prob')
   roc_curve <- roc(y_test, predictions[,2])
   return(roc_curve)
}

NB_CV_ROC <-function(s, df, col) {
   set.seed(1)
   train_control <- trainControl(method = "boot",
                           number = 10,
                           returnResamp = 'none',
                           summaryFunction = twoClassSummary,
                           classProbs = TRUE,
                           savePredictions = TRUE)

   trainIndex<-caret::createDataPartition(col, p=s, list=F)
   data_train<-df[trainIndex,]
   data_test<-df[-trainIndex,]
   x_test <- data_test[,1:length(df)-1]
   y_test <- data_test[,length(df)]

   model <- caret::train(make.names(output)~., data=data_train,
            trControl=train_control,
            method='nb', metric="ROC")
   # make predictions
   predictions <- stats::predict(model, x_test, type='prob')
   roc_curve <- roc(y_test, predictions[,2])
   return(roc_curve)
}


RF_CV_ROC <-function(s, df, col) {
   set.seed(1)
   # train test split
   trainIndex<-caret::createDataPartition(col, p=s, list=F)
   data_train<-df[trainIndex,]
   data_test<-df[-trainIndex,]
   x_test <- data_test[,1:length(df)-1]
   y_test <- data_test[,length(df)]
  # Random forest under k-folder cross validation with ROC as metric
  train_control <- trainControl(method = "boot",
                           number = 10,
                           returnResamp = 'none',
                           summaryFunction = twoClassSummary,
                           classProbs = TRUE,
                           savePredictions = TRUE)
   model <- caret::train(make.names(output)~., data=data_train,
            trControl=train_control,
            method="rf", metric="ROC")
   # make predictions
   predictions <- stats::predict(model, x_test, type='prob')
   roc_curve <- roc(y_test, predictions[,2])
   return(roc_curve)
}


KNN_CV_ROC <-function(s, df, col) {
   set.seed(1)
   # train test split
   trainIndex<-caret::createDataPartition(col, p=s, list=F)
   data_train<-df[trainIndex,]
   data_test<-df[-trainIndex,]
   x_test <- data_test[,1:length(df)-1]
   y_test <- data_test[,length(df)]
   # KNN under k-folder cross validation with ROC as metric
   train_control <- trainControl(method = "boot",
                           number = 10,
                           returnResamp = 'none',
                           summaryFunction = twoClassSummary,
                           classProbs = TRUE,
                           savePredictions = TRUE)
   model <- caret::train(make.names(output)~., data=data_train,
            trControl=train_control,
            method="knn", metric="ROC")
   # make predictions
   predictions <- stats::predict(model, x_test, type='prob')
   roc_curve <- roc(y_test, predictions[,2])
   return(roc_curve)
}
split <- 0.7

roc_glm <- GLM_CV_ROC(split, model_df, model_df$output)
## Setting levels: control = Heart disease, case = No disease
## Setting direction: controls < cases
roc_nb <- NB_CV_ROC(split, model_df, model_df$output)
## Setting levels: control = Heart disease, case = No disease
## Setting direction: controls < cases
roc_rf <- RF_CV_ROC(split, model_df, model_df$output)
## Setting levels: control = Heart disease, case = No disease
## Setting direction: controls < cases
roc_knn <- KNN_CV_ROC(split, model_df, model_df$output)
## Setting levels: control = Heart disease, case = No disease
## Setting direction: controls < cases
auc_labels <- c(paste0("GLM (AUC=", round(roc_glm$auc,4), ")"),
      paste0("NB (AUC=", round(roc_nb$auc,4), ")"),
      paste0("RF (AUC=", round(roc_rf$auc, 4), ")"),
      paste0("KNN (AUC=", round(roc_knn$auc, 4), ")"))

# Plotting both ROC curves on the same plot
ggroc(list(GLM=roc_glm, NB = roc_nb, RF = roc_rf, KNN = roc_knn), legacy.axes=TRUE) +
  ggtitle("Average ROC Curves after 10-fold cross validation") + scale_color_discrete(labels=auc_labels)

Part 4 - Conclusion

Among 4 models we tried to predict the heart disease, here are the results:

  • Naive Bayes is the best model. Under 10-fold cross validation, it has 93.19% ROC-AUC and 80.95% median accuracy.
  • Random Forest is the second best, with 90.9% ROC-AUC and 83.3% median accuracy.