Load Data

#Load Data
df <- read.csv("heart_failure_clinical_records_dataset.csv")
head(df)
##   age anaemia creatinine_phosphokinase diabetes ejection_fraction
## 1  75       0                      582        0                20
## 2  55       0                     7861        0                38
## 3  65       0                      146        0                20
## 4  50       1                      111        0                20
## 5  65       1                      160        1                20
## 6  90       1                       47        0                40
##   high_blood_pressure platelets serum_creatinine serum_sodium sex smoking time
## 1                   1    265000              1.9          130   1       0    4
## 2                   0    263358              1.1          136   1       0    6
## 3                   0    162000              1.3          129   1       1    7
## 4                   0    210000              1.9          137   1       0    7
## 5                   0    327000              2.7          116   0       0    8
## 6                   1    204000              2.1          132   1       1    8
##   DEATH_EVENT
## 1           1
## 2           1
## 3           1
## 4           1
## 5           1
## 6           1
str(df)
## 'data.frame':    299 obs. of  13 variables:
##  $ age                     : num  75 55 65 50 65 90 75 60 65 80 ...
##  $ anaemia                 : int  0 0 0 1 1 1 1 1 0 1 ...
##  $ creatinine_phosphokinase: int  582 7861 146 111 160 47 246 315 157 123 ...
##  $ diabetes                : int  0 0 0 0 1 0 0 1 0 0 ...
##  $ ejection_fraction       : int  20 38 20 20 20 40 15 60 65 35 ...
##  $ high_blood_pressure     : int  1 0 0 0 0 1 0 0 0 1 ...
##  $ platelets               : num  265000 263358 162000 210000 327000 ...
##  $ serum_creatinine        : num  1.9 1.1 1.3 1.9 2.7 2.1 1.2 1.1 1.5 9.4 ...
##  $ serum_sodium            : int  130 136 129 137 116 132 137 131 138 133 ...
##  $ sex                     : int  1 1 1 1 0 1 1 1 0 1 ...
##  $ smoking                 : int  0 0 1 0 0 1 0 1 0 1 ...
##  $ time                    : int  4 6 7 7 8 8 10 10 10 10 ...
##  $ DEATH_EVENT             : int  1 1 1 1 1 1 1 1 1 1 ...
summary(df)
##       age           anaemia       creatinine_phosphokinase    diabetes     
##  Min.   :40.00   Min.   :0.0000   Min.   :  23.0           Min.   :0.0000  
##  1st Qu.:51.00   1st Qu.:0.0000   1st Qu.: 116.5           1st Qu.:0.0000  
##  Median :60.00   Median :0.0000   Median : 250.0           Median :0.0000  
##  Mean   :60.83   Mean   :0.4314   Mean   : 581.8           Mean   :0.4181  
##  3rd Qu.:70.00   3rd Qu.:1.0000   3rd Qu.: 582.0           3rd Qu.:1.0000  
##  Max.   :95.00   Max.   :1.0000   Max.   :7861.0           Max.   :1.0000  
##  ejection_fraction high_blood_pressure   platelets      serum_creatinine
##  Min.   :14.00     Min.   :0.0000      Min.   : 25100   Min.   :0.500   
##  1st Qu.:30.00     1st Qu.:0.0000      1st Qu.:212500   1st Qu.:0.900   
##  Median :38.00     Median :0.0000      Median :262000   Median :1.100   
##  Mean   :38.08     Mean   :0.3512      Mean   :263358   Mean   :1.394   
##  3rd Qu.:45.00     3rd Qu.:1.0000      3rd Qu.:303500   3rd Qu.:1.400   
##  Max.   :80.00     Max.   :1.0000      Max.   :850000   Max.   :9.400   
##   serum_sodium        sex            smoking            time      
##  Min.   :113.0   Min.   :0.0000   Min.   :0.0000   Min.   :  4.0  
##  1st Qu.:134.0   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.: 73.0  
##  Median :137.0   Median :1.0000   Median :0.0000   Median :115.0  
##  Mean   :136.6   Mean   :0.6488   Mean   :0.3211   Mean   :130.3  
##  3rd Qu.:140.0   3rd Qu.:1.0000   3rd Qu.:1.0000   3rd Qu.:203.0  
##  Max.   :148.0   Max.   :1.0000   Max.   :1.0000   Max.   :285.0  
##   DEATH_EVENT    
##  Min.   :0.0000  
##  1st Qu.:0.0000  
##  Median :0.0000  
##  Mean   :0.3211  
##  3rd Qu.:1.0000  
##  Max.   :1.0000

Preprocessing

Convert Binary to Factor

#convert binary variables to factor
cat_vars <- c("anaemia", 
              "diabetes", 
              "high_blood_pressure", 
              "sex", 
              "smoking", 
              "DEATH_EVENT")
df[cat_vars] <- lapply(df[cat_vars], factor)

#change colname of DEATH_EVENT
names(df)[13] <- "death"
lapply(df, class)
## $age
## [1] "numeric"
## 
## $anaemia
## [1] "factor"
## 
## $creatinine_phosphokinase
## [1] "integer"
## 
## $diabetes
## [1] "factor"
## 
## $ejection_fraction
## [1] "integer"
## 
## $high_blood_pressure
## [1] "factor"
## 
## $platelets
## [1] "numeric"
## 
## $serum_creatinine
## [1] "numeric"
## 
## $serum_sodium
## [1] "integer"
## 
## $sex
## [1] "factor"
## 
## $smoking
## [1] "factor"
## 
## $time
## [1] "integer"
## 
## $death
## [1] "factor"

Check for Missing Values

#check for missing values
colSums(is.na(df))
##                      age                  anaemia creatinine_phosphokinase 
##                        0                        0                        0 
##                 diabetes        ejection_fraction      high_blood_pressure 
##                        0                        0                        0 
##                platelets         serum_creatinine             serum_sodium 
##                        0                        0                        0 
##                      sex                  smoking                     time 
##                        0                        0                        0 
##                    death 
##                        0

Binning the Age Variable for EDA

From the summary we can see that minimum age is 40 and maximum is 95. We can divide this into two groups.

  • 40<=age<65 - Middle Age
  • age >= 65 - OLd
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
df <- df %>% mutate(agegroup = case_when(age >= 40  & age <= 64 ~ 'Middle Aged',
                                             age >= 65  ~ 'old'))

EDA

library(ggplot2)
library(gridExtra)
## 
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
## 
##     combine
g1 <- ggplot(df, aes(agegroup)) + geom_bar(aes(fill = death))
g2 <- ggplot(df, aes(anaemia)) + geom_bar(aes(fill = death))
g3 <- ggplot(df, aes(diabetes)) + geom_bar(aes(fill = death))
g4 <- ggplot(df, aes(high_blood_pressure)) + geom_bar(aes(fill = death))
g5 <- ggplot(df, aes(sex)) + geom_bar(aes(fill = death))
g6 <- ggplot(df, aes(smoking)) + geom_bar(aes(fill = death))
grid.arrange(g1, g2,g3, g4, g5, g6, nrow = 3)

s1 <- ggplot(df, aes(x = age, y = creatinine_phosphokinase, color = death)) +
  geom_point()

s2 <- ggplot(df, aes(x = age, y = ejection_fraction, color = death)) +
  geom_point()

s3 <- ggplot(df, aes(x = age, y = platelets, color = death)) +
  geom_point()

s4 <- ggplot(df, aes(x = age, y = serum_sodium, color = death)) +
  geom_point()

s5 <- ggplot(df, aes(x = age, y = serum_creatinine, color = death)) +
  geom_point()

s6 <- ggplot(df, aes(x = age, y = time, color = death)) +
  geom_point()

grid.arrange(s1, s2, s3, s4, s5, s6, nrow = 3)

table(df$death)
## 
##   0   1 
## 203  96
#bar plot of death variable
ggplot(df, aes(death)) + geom_bar()

Above graph shows that there is class imbalance in data.

Train Test Split

df <- df[, -14]

## 80% of the sample size
smp_size <- floor(0.8 * nrow(df))

set.seed(123)
train_ind <- sample(seq_len(nrow(df)), size = smp_size)

train <- df[train_ind, ]
test  <- df[-train_ind, ]

#Logistic Regression

# Logistics Regression
lr_fit <- glm(death ~ ., data = train, family = binomial)
summary(lr_fit)
## 
## Call:
## glm(formula = death ~ ., family = binomial, data = train)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.4646  -0.4676  -0.1527   0.3352   2.2099  
## 
## Coefficients:
##                            Estimate Std. Error z value Pr(>|z|)    
## (Intercept)               1.169e+01  7.018e+00   1.666  0.09566 .  
## age                       5.643e-02  1.978e-02   2.852  0.00434 ** 
## anaemia1                  6.993e-02  4.300e-01   0.163  0.87080    
## creatinine_phosphokinase  2.160e-04  2.115e-04   1.021  0.30702    
## diabetes1                 3.590e-01  4.274e-01   0.840  0.40100    
## ejection_fraction        -8.354e-02  1.999e-02  -4.180 2.91e-05 ***
## high_blood_pressure1      6.338e-02  4.239e-01   0.150  0.88116    
## platelets                -1.638e-06  2.230e-06  -0.734  0.46273    
## serum_creatinine          8.573e-01  2.154e-01   3.981 6.87e-05 ***
## serum_sodium             -7.818e-02  4.851e-02  -1.612  0.10706    
## sex1                     -4.638e-01  4.953e-01  -0.936  0.34908    
## smoking1                 -1.971e-01  4.933e-01  -0.400  0.68944    
## time                     -2.723e-02  4.231e-03  -6.437 1.22e-10 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 301.89  on 238  degrees of freedom
## Residual deviance: 152.08  on 226  degrees of freedom
## AIC: 178.08
## 
## Number of Fisher Scoring iterations: 6

Performance of Logistic Regression

#on train data
train$predictions <- predict(lr_fit, newdata = train[,-13], type = "response")
#train$predictions <- as.factor(ifelse(train$predictions > 0.5, 1, 0))
preds_act_train <- table(train$predictions, train$death)
library(caret)
## Loading required package: lattice
library(e1071)
#lr_conf_mat_train <- confusionMatrix(preds_act_train)
#lr_conf_mat_train

#on test data
test$predictions <- predict(lr_fit, newdata = test[,-13], type = "response")
#test$predictions <- as.factor(ifelse(test$predictions > 0.5, 1, 0))

preds_act_test <- table(test$predictions, test$death)
#lr_conf_mat_test <- confusionMatrix(preds_act_test)
#lr_conf_mat_test

As we can see that train accuracy is at 85% and test accuracy is at 76%. Also this is a imbalance classification problem hence we need to tune the threshold of the model. By default logistic regression uses 0.5 as threshold. We will tune this parameter in following section.

Tuning Logistic Regression Model

# Useful functions when working with logistic regression
library(ROCR)
library(grid)
library(caret)
library(dplyr)
library(scales)
library(ggplot2)
library(gridExtra)
library(data.table)
## 
## Attaching package: 'data.table'
## The following objects are masked from 'package:dplyr':
## 
##     between, first, last
library(tidyr)

# ------------------------------------------------------------------------------------------
# [AccuracyCutoffInfo] : 
# Obtain the accuracy on the trainining and testing dataset.
# for cutoff value ranging from .4 to .8 ( with a .05 increase )
# @train   : your data.table or data.frame type training data ( assumes you have the predicted score in it ).
# @test    : your data.table or data.frame type testing data
# @predict : prediction's column name (assumes the same for training and testing set)
# @actual  : actual results' column name
# returns  : 1. data : a data.table with three columns.
#                      each row indicates the cutoff value and the accuracy for the 
#                      train and test set respectively.
#            2. plot : plot that visualizes the data.table

AccuracyCutoffInfo <- function( train, test, predict, actual )
{
    # change the cutoff value's range as you please 
    cutoff <- seq( .4, .9, by = .05 )

    accuracy <- lapply( cutoff, function(c)
    {
        # use the confusionMatrix from the caret package
        cm_train <- confusionMatrix( as.factor(as.numeric( train[[predict]] > c )), train[[actual]] )
        cm_test  <- confusionMatrix( as.factor(as.numeric( test[[predict]]  > c )), test[[actual]]  )
            
        dt <- data.table( cutoff = c,
                          train  = cm_train$overall[["Accuracy"]],
                          test   = cm_test$overall[["Accuracy"]] )
        return(dt)
    }) %>% rbindlist()

    # visualize the accuracy of the train and test set for different cutoff value 
    # accuracy in percentage.
    accuracy_long <- gather( accuracy, "data", "accuracy", -1 )
    
    plot <- ggplot( accuracy_long, aes( cutoff, accuracy, group = data, color = data ) ) + 
            geom_line( size = 1 ) + geom_point( size = 3 ) +
            scale_y_continuous( label = percent ) +
            ggtitle( "Train/Test Accuracy for Different Cutoff" )

    return( list( data = accuracy, plot = plot ) )
}


# ------------------------------------------------------------------------------------------
# [ConfusionMatrixInfo] : 
# Obtain the confusion matrix plot and data.table for a given
# dataset that already consists the predicted score and actual outcome.
# @data    : your data.table or data.frame type data that consists the column
#            of the predicted score and actual outcome 
# @predict : predicted score's column name
# @actual  : actual results' column name
# @cutoff  : cutoff value for the prediction score 
# return   : 1. data : a data.table consisting of three column
#                      the first two stores the original value of the prediction and actual outcome from
#                      the passed in data frame, the third indicates the type, which is after choosing the 
#                      cutoff value, will this row be a true/false positive/ negative 
#            2. plot : plot that visualizes the data.table 

ConfusionMatrixInfo <- function( data, predict, actual, cutoff )
{   
    # extract the column ;
    # relevel making 1 appears on the more commonly seen position in 
    # a two by two confusion matrix 
    predict <- data[[predict]]
    actual  <- relevel( as.factor( data[[actual]] ), "1" )
    
    result <- data.table( actual = actual, predict = predict )

    # caculating each pred falls into which category for the confusion matrix
    result[ , type := ifelse( predict >= cutoff & actual == 1, "TP",
                      ifelse( predict >= cutoff & actual == 0, "FP", 
                      ifelse( predict <  cutoff & actual == 1, "FN", "TN" ) ) ) %>% as.factor() ]

    # jittering : can spread the points along the x axis 
    plot <- ggplot( result, aes( actual, predict, color = type ) ) + 
            geom_violin( fill = "white", color = NA ) +
            geom_jitter( shape = 1 ) + 
            geom_hline( yintercept = cutoff, color = "blue", alpha = 0.6 ) + 
            scale_y_continuous( limits = c( 0, 1 ) ) + 
            scale_color_discrete( breaks = c( "TP", "FN", "FP", "TN" ) ) + # ordering of the legend 
            guides( col = guide_legend( nrow = 2 ) ) + # adjust the legend to have two rows  
            ggtitle( sprintf( "Confusion Matrix with Cutoff at %.2f", cutoff ) )

    return( list( data = result, plot = plot ) )
}


# ------------------------------------------------------------------------------------------
# [ROCInfo] : 
# Pass in the data that already consists the predicted score and actual outcome.
# to obtain the ROC curve 
# @data    : your data.table or data.frame type data that consists the column
#            of the predicted score and actual outcome
# @predict : predicted score's column name
# @actual  : actual results' column name
# @cost.fp : associated cost for a false positive 
# @cost.fn : associated cost for a false negative 
# return   : a list containing  
#            1. plot        : a side by side roc and cost plot, title showing optimal cutoff value
#                             title showing optimal cutoff, total cost, and area under the curve (auc)
#            2. cutoff      : optimal cutoff value according to the specified fp/fn cost 
#            3. totalcost   : total cost according to the specified fp/fn cost
#            4. auc         : area under the curve
#            5. sensitivity : TP / (TP + FN)
#            6. specificity : TN / (FP + TN)

ROCInfo <- function( data, predict, actual, cost.fp, cost.fn )
{
    # calculate the values using the ROCR library
    # true positive, false postive 
    pred <- prediction( data[[predict]], data[[actual]] )
    perf <- performance( pred, "tpr", "fpr" )
    roc_dt <- data.frame( fpr = perf@x.values[[1]], tpr = perf@y.values[[1]] )

    # cost with the specified false positive and false negative cost 
    # false postive rate * number of negative instances * false positive cost + 
    # false negative rate * number of positive instances * false negative cost
    cost <- perf@x.values[[1]] * cost.fp * sum( data[[actual]] == 0 ) + 
            ( 1 - perf@y.values[[1]] ) * cost.fn * sum( data[[actual]] == 1 )

    cost_dt <- data.frame( cutoff = pred@cutoffs[[1]], cost = cost )

    # optimal cutoff value, and the corresponding true positive and false positive rate
    best_index  <- which.min(cost)
    best_cost   <- cost_dt[ best_index, "cost" ]
    best_tpr    <- roc_dt[ best_index, "tpr" ]
    best_fpr    <- roc_dt[ best_index, "fpr" ]
    best_cutoff <- pred@cutoffs[[1]][ best_index ]
    
    # area under the curve
    auc <- performance( pred, "auc" )@y.values[[1]]

    # normalize the cost to assign colors to 1
    normalize <- function(v) ( v - min(v) ) / diff( range(v) )
    
    # create color from a palette to assign to the 100 generated threshold between 0 ~ 1
    # then normalize each cost and assign colors to it, the higher the blacker
    # don't times it by 100, there will be 0 in the vector
    col_ramp <- colorRampPalette( c( "green", "orange", "red", "black" ) )(100)   
    col_by_cost <- col_ramp[ ceiling( normalize(cost) * 99 ) + 1 ]

    roc_plot <- ggplot( roc_dt, aes( fpr, tpr ) ) + 
                geom_line( color = rgb( 0, 0, 1, alpha = 0.3 ) ) +
                geom_point( color = col_by_cost, size = 4, alpha = 0.2 ) + 
                geom_segment( aes( x = 0, y = 0, xend = 1, yend = 1 ), alpha = 0.8, color = "royalblue" ) + 
                labs( title = "ROC", x = "False Postive Rate", y = "True Positive Rate" ) +
                geom_hline( yintercept = best_tpr, alpha = 0.8, linetype = "dashed", color = "steelblue4" ) +
                geom_vline( xintercept = best_fpr, alpha = 0.8, linetype = "dashed", color = "steelblue4" )             

    cost_plot <- ggplot( cost_dt, aes( cutoff, cost ) ) +
                 geom_line( color = "blue", alpha = 0.5 ) +
                 geom_point( color = col_by_cost, size = 4, alpha = 0.5 ) +
                 ggtitle( "Cost" ) +
                 scale_y_continuous( labels = comma ) +
                 geom_vline( xintercept = best_cutoff, alpha = 0.8, linetype = "dashed", color = "steelblue4" ) 

    # the main title for the two arranged plot
    sub_title <- sprintf( "Cutoff at %.2f - Total Cost = %d, AUC = %.3f", 
                          best_cutoff, best_cost, auc )
    
    # arranged into a side by side plot
    plot <- arrangeGrob( roc_plot, cost_plot, ncol = 2, 
                         top = textGrob( sub_title, gp = gpar( fontsize = 16, fontface = "bold" ) ) )
    
    return( list( plot        = plot, 
                  cutoff      = best_cutoff, 
                  totalcost   = best_cost, 
                  auc         = auc,
                  sensitivity = best_tpr, 
                  specificity = 1 - best_fpr ) )
}
accuracy_info <- AccuracyCutoffInfo( train = train, test = test, 
                                     predict = "predictions", actual = "death" )

accuracy_info$plot

From above plot we can see that test accuracy is highest at cutoff of 0.8

cm_info <- ConfusionMatrixInfo( data = test, predict = "predictions", 
                                actual = "death", cutoff = .8 )
cm_info$plot