Heart Failure Prediction

Daniel Plotkin

2022-11-10

Introduction

In this Project, I will be focusing on predicting a classification of heart failure mortality based on 12 different predictor variables.

Prerequisites

library(tidyverse)
library(here)
library(tidymodels)
library(vip)
library(ggplot2)
library(corrplot)
library(pdp)
library(DALEXtra)

Data Exploration

Below we will be Importing our data set, describe our data, and explore relationships/distributions.

Data Import

# Data Import
path <- here('data', 'heart_failure_clinical_records_dataset.csv')
df <- read_csv(path)

DT::datatable(head(df))
str(df)
## spec_tbl_df [299 × 13] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
##  $ age                     : num [1:299] 75 55 65 50 65 90 75 60 65 80 ...
##  $ anaemia                 : num [1:299] 0 0 0 1 1 1 1 1 0 1 ...
##  $ creatinine_phosphokinase: num [1:299] 582 7861 146 111 160 ...
##  $ diabetes                : num [1:299] 0 0 0 0 1 0 0 1 0 0 ...
##  $ ejection_fraction       : num [1:299] 20 38 20 20 20 40 15 60 65 35 ...
##  $ high_blood_pressure     : num [1:299] 1 0 0 0 0 1 0 0 0 1 ...
##  $ platelets               : num [1:299] 265000 263358 162000 210000 327000 ...
##  $ serum_creatinine        : num [1:299] 1.9 1.1 1.3 1.9 2.7 2.1 1.2 1.1 1.5 9.4 ...
##  $ serum_sodium            : num [1:299] 130 136 129 137 116 132 137 131 138 133 ...
##  $ sex                     : num [1:299] 1 1 1 1 0 1 1 1 0 1 ...
##  $ smoking                 : num [1:299] 0 0 1 0 0 1 0 1 0 1 ...
##  $ time                    : num [1:299] 4 6 7 7 8 8 10 10 10 10 ...
##  $ DEATH_EVENT             : num [1:299] 1 1 1 1 1 1 1 1 1 1 ...
##  - attr(*, "spec")=
##   .. cols(
##   ..   age = col_double(),
##   ..   anaemia = col_double(),
##   ..   creatinine_phosphokinase = col_double(),
##   ..   diabetes = col_double(),
##   ..   ejection_fraction = col_double(),
##   ..   high_blood_pressure = col_double(),
##   ..   platelets = col_double(),
##   ..   serum_creatinine = col_double(),
##   ..   serum_sodium = col_double(),
##   ..   sex = col_double(),
##   ..   smoking = col_double(),
##   ..   time = col_double(),
##   ..   DEATH_EVENT = col_double()
##   .. )
##  - attr(*, "problems")=<externalptr>

About the Data

Predictor Variables (\(X\))

  • age: How many years old is the subject.

  • anaemia: Whether or not the subject has anaemia (0 = yes, 1 = no).

  • creatinine_phosphokinase: Level of the CPK enzyme in the blood (mcg/L).

  • diabetes: If the patient has diabetes: (0 = no, 1 = yes).

  • ejection_fraction: Percentage of blood leaving the heart at each contraction (percentage).

  • high_blood_pressure: If the subject has high_blood_pressure (0 = no, 1 = yes).

  • palates: Platelets in the blood (kiloplatelets/mL).

  • serum_creatinine: Level of serum creatinine in the blood (mg/dL).

  • serum_sodium: Level of serum sodium in the blood (mEq/L).

  • sex: gender of the subject (0 = female, 1 = male).

  • smoking: if the subject smokes or not (0 = no, 1 = yes).

  • time: Follow-up period (days).

Response Variable (\(Y\))

  • DEATH_EVENT: If the patient deceased during the follow-up period.

Data Summary

summary(df)
##       age           anaemia       creatinine_phosphokinase    diabetes     
##  Min.   :40.00   Min.   :0.0000   Min.   :  23.0           Min.   :0.0000  
##  1st Qu.:51.00   1st Qu.:0.0000   1st Qu.: 116.5           1st Qu.:0.0000  
##  Median :60.00   Median :0.0000   Median : 250.0           Median :0.0000  
##  Mean   :60.83   Mean   :0.4314   Mean   : 581.8           Mean   :0.4181  
##  3rd Qu.:70.00   3rd Qu.:1.0000   3rd Qu.: 582.0           3rd Qu.:1.0000  
##  Max.   :95.00   Max.   :1.0000   Max.   :7861.0           Max.   :1.0000  
##  ejection_fraction high_blood_pressure   platelets      serum_creatinine
##  Min.   :14.00     Min.   :0.0000      Min.   : 25100   Min.   :0.500   
##  1st Qu.:30.00     1st Qu.:0.0000      1st Qu.:212500   1st Qu.:0.900   
##  Median :38.00     Median :0.0000      Median :262000   Median :1.100   
##  Mean   :38.08     Mean   :0.3512      Mean   :263358   Mean   :1.394   
##  3rd Qu.:45.00     3rd Qu.:1.0000      3rd Qu.:303500   3rd Qu.:1.400   
##  Max.   :80.00     Max.   :1.0000      Max.   :850000   Max.   :9.400   
##   serum_sodium        sex            smoking            time      
##  Min.   :113.0   Min.   :0.0000   Min.   :0.0000   Min.   :  4.0  
##  1st Qu.:134.0   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.: 73.0  
##  Median :137.0   Median :1.0000   Median :0.0000   Median :115.0  
##  Mean   :136.6   Mean   :0.6488   Mean   :0.3211   Mean   :130.3  
##  3rd Qu.:140.0   3rd Qu.:1.0000   3rd Qu.:1.0000   3rd Qu.:203.0  
##  Max.   :148.0   Max.   :1.0000   Max.   :1.0000   Max.   :285.0  
##   DEATH_EVENT    
##  Min.   :0.0000  
##  1st Qu.:0.0000  
##  Median :0.0000  
##  Mean   :0.3211  
##  3rd Qu.:1.0000  
##  Max.   :1.0000

Correlation Matrix

corrplot(cor(df), tl.cex = 1)

We can derive a few insights from this correlation matrix:

  • There is a strong negative correlation between DEATH_EVENT and the time in between appointments.

  • There is a correlation between whether or not they are males and if they smoke or not. More males smoke than females.

Data Exploration

# function to create distribution charts
create_binary_distribution <- function(var, title = none, x = none){
  ggplot(df, aes(x = as.factor(var))) +
  geom_bar(aes(fill = as.factor(DEATH_EVENT))) +
  labs(
    title = title,
    x = x,
    fill = 'DEATH_EVENT'
  ) +
  scale_fill_manual(values = c('darkgreen', 'brown'))
}
create_binary_distribution(
  var = df$DEATH_EVENT,
  title = 'Total Death Distribution',
  x = 'DEATH_EVENT'
)

There were more subjects that survived than died in our sample population.

create_binary_distribution(
  var = df$sex,
  x = 'Gender',
  title = 'Total Gender Distribution'
)

There were more males sampled than females, but both genders had similar death to survival ratios.

create_binary_distribution(
  var = df$anaemia,
  x = 'anaemia',
  title = 'Total Anaemia Distribution'
)

Anaemia did not play a big part in whether or not someone survived.

create_binary_distribution(
  var = df$diabetes,
  x = 'diabetes',
  title = 'Total Diabetes Distribution'
)

Diabetes did not play a large role in whether or not someone experienced heart failure.

create_binary_distribution(
  var = df$high_blood_pressure,
  x = 'high_blood_pressure',
  title = 'High Blood Pressure Distribution'
  )

High blood pressure did not play a large role in whether or not someone experienced heart failure.

create_binary_distribution(
  var = df$smoking,
  x = 'smoking',
  title = 'Total Smoking Distribution'
)

Smoking did not play a large role in whether or not someone experienced heart failure.

df %>%
  ggplot(aes(age)) +
  geom_histogram(aes(fill = as.factor(DEATH_EVENT))) +
  scale_x_continuous(breaks = seq(30, 100, 10)) +
  labs(
    title = 'Age Distribution',
    fill = 'DEATH_EVENT',
  ) +
  theme(
  plot.title = element_text(hjust = 0.5)
) +
  scale_fill_manual(values = c('darkgreen', 'brown'))

While the age of our subjects ranged from 40-95, the mean age sampled was about 61. The age group of 60 represents the age group where most deaths occurred (13).

Now we are going to look at our numeric features to understand their effect on the response variable a bit better.

lst = list(
  ejection_fraction = df$ejection_fraction,
  creatinine_phosphokinase = df$creatinine_phosphokinase,
  serum_creatinine = df$serum_creatinine,
  serum_sodium = df$serum_sodium,
  time = df$time
  )
color_list = list('darkgreen','blue','purple','lightblue', 'red')
n = 1
for (i in lst) {
  print(
    ggplot(df) + 
      geom_boxplot(aes(as.factor(DEATH_EVENT), i),fill = color_list[n],alpha = 0.5) +
      labs(
        x = 'DEATH_EVENT',
        y = names(lst[n])
      )
    )
      
    n = n + 1
}

Some insights we can make are:

  • Subjects who had a shorter returning appointment time were most likely to experience heart failure. ’

  • Subjects with a lower ejection fraction were most likely to experience heart failure.

  • Subjects with a higher serum creatinine level were more likely to experience heart failure.

  • Subjects with a lower serum sodium level were more likely to experience heart failure.

Data Processing

Below we are turning our categorical features/response variables into factors. After that, we begin our train/test split into a 75% train size.

cols = c(
  'anaemia', 'diabetes', 'high_blood_pressure',
  'sex', 'smoking', 'DEATH_EVENT'
  )

df <- df %>%
  mutate_at(cols, as.factor)

set.seed(123)
split <- initial_split(df, 0.75, strata = DEATH_EVENT)
train <- training(split)
test <- testing(split)

# kfold validation
set.seed(123)
kfold <- vfold_cv(train, v = 10, strata = DEATH_EVENT)

Feature Engineering

Now, we are creating a recipe to normalize our predictors.

# recipe
rcp <- recipe(DEATH_EVENT ~ ., data = train) %>%
  step_normalize(all_numeric_predictors()) %>%
  step_YeoJohnson(all_numeric_predictors()) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_other(all_nominal_predictors(), threshold = 0.05, other = 'other')

Machine Learning

Below are 2 different supervised classification algorithms to help predict heart failure mortality:

  • Logistic Regression

  • Random Forest

We will create these two models and see which of these two algorithms are more efficient.

Logistic Regression

Tuning Model

logit_model <- logistic_reg(penalty = tune(), mixture = tune()) %>%
  set_engine('glmnet')

logit_grid <- grid_regular(
  penalty(range = c(-3, -1)),
  mixture(),
  levels = 10
  )

tuning <- workflow() %>%
  add_recipe(rcp) %>%
  add_model(logit_model) %>%
  tune_grid(resamples = kfold, grid = logit_grid, control = control_resamples(save_pred = T)) 

show_best(tuning, metric = 'roc_auc')
## # A tibble: 5 × 8
##   penalty mixture .metric .estimator  mean     n std_err .config               
##     <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                 
## 1  0.1      0.333 roc_auc binary     0.901    10  0.0252 Preprocessor1_Model040
## 2  0.0599   0.556 roc_auc binary     0.900    10  0.0265 Preprocessor1_Model059
## 3  0.1      0.444 roc_auc binary     0.899    10  0.0275 Preprocessor1_Model050
## 4  0.0599   0.667 roc_auc binary     0.899    10  0.0276 Preprocessor1_Model069
## 5  0.0359   1     roc_auc binary     0.899    10  0.0276 Preprocessor1_Model098

ROC Curve

collect_predictions(tuning) %>%
  group_by(id) %>%
  roc_curve(DEATH_EVENT, .pred_0) %>%
  autoplot()

Finalizing Model

best_hyperparameters_lg <- select_best(tuning, metric = 'roc_auc')

final_lg_wf <- workflow() %>%
  add_model(logit_model) %>%
  add_recipe(rcp) %>%
  finalize_workflow(best_hyperparameters_lg)

set.seed(123)
final_fit_lg <- final_lg_wf %>%
  fit(data = train)
  
final_test_lg <- final_fit_lg %>%
  predict(test) %>%
  bind_cols(select(test, DEATH_EVENT)) 

# metrics
logit_acc <- final_test_lg %>%
  accuracy(DEATH_EVENT, .pred_class)

final_test_lg %>%
  conf_mat(DEATH_EVENT, .pred_class) 
##           Truth
## Prediction  0  1
##          0 48  8
##          1  3 16

Our Logistic Regression model had a 85.3% accuracy rate to new data (83.9% accuracy in our cross validation procedure). Now we will evaluate which features were the most important in predicting heart failure mortality classification in this model.

Feature Importance

final_fit_lg %>%
  extract_fit_parsnip() %>%
  vip() +
  ggtitle('Feature Importance') 

Time and Serum_Creatinine were the two most important features in our Logistic Regression model. Now we are going to look at the partial dependency graphs for these two variables.

First we have to create a function that creates these partial dependency plots.

predict_plot <- function(var, model) {
  explainer_lg <- explain_tidymodels(
    model,
    data = select(train, -DEATH_EVENT),
    y = as.integer(train$DEATH_EVENT)
  )
  
  pdp <- model_profile(
    explainer = explainer_lg,
    variables = var,
    N = NULL
  )
  
  pdp_df <- as_tibble(pdp$agr_profiles)
  print(
  ggplot(pdp_df, aes(x = `_x_`, y = `_yhat_`)) +
    geom_smooth(color = 'lightblue', se = F) +
    ylim(0, 1) +
    labs(
      title = paste(str_to_title(var), 'Partial Prediction'),
      x = var
    ) +
    theme_dark() +
    theme(plot.title = element_text(hjust = 0.5))
  )
}

Below is the partial dependency graph for the time feature:

predict_plot(var = 'time', model = final_fit_lg)
## Preparation of a new explainer is initiated
##   -> model label       :  workflow  (  default  )
##   -> data              :  224  rows  12  cols 
##   -> data              :  tibble converted into a data.frame 
##   -> target variable   :  224  values 
##   -> predict function  :  yhat.workflow  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package tidymodels , ver. 1.0.0 , task classification (  default  ) 
##   -> predicted values  :  numerical, min =  0.02985122 , mean =  0.3214246 , max =  0.85468  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  0.2799684 , mean =  1.000004 , max =  1.9325  
##   A new explainer has been created!

As shown, the longer the time until the next appointment is, the lower chance the patient had heart failure. This could be due to the patient having a lower severity than patients who need to come back in shorter intervals.

Below is the partial dependency plot for the feature serum_creatinine:

predict_plot(var = 'serum_creatinine', model = final_fit_lg)
## Preparation of a new explainer is initiated
##   -> model label       :  workflow  (  default  )
##   -> data              :  224  rows  12  cols 
##   -> data              :  tibble converted into a data.frame 
##   -> target variable   :  224  values 
##   -> predict function  :  yhat.workflow  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package tidymodels , ver. 1.0.0 , task classification (  default  ) 
##   -> predicted values  :  numerical, min =  0.02985122 , mean =  0.3214246 , max =  0.85468  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  0.2799684 , mean =  1.000004 , max =  1.9325  
##   A new explainer has been created!

The probability of heart failure increases as serum creatinine levels increase, but increase at a slower rate at every increase in serum creatinine level.

Random Forest

Tuning Model

rf_model <- rand_forest(
  mode = 'classification',
  trees = 500,
  mtry = tune(),
  min_n = tune()
  ) %>%
  set_engine("ranger", importance = "permutation")

hyper_grid <- grid_regular(
   mtry(range = c(2, 12)),
   min_n(range = c(1, 10)),        
   levels = 10
   )

# train our model across the hyper parameter grid
set.seed(123)
results <- tune_grid(
  rf_model, 
  rcp, 
  resamples = kfold, 
  grid = hyper_grid,
  control = control_resamples(save_pred = T) 
  )

# model results
show_best(results, metric = "roc_auc")
## # A tibble: 5 × 8
##    mtry min_n .metric .estimator  mean     n std_err .config               
##   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                 
## 1     2     2 roc_auc binary     0.922    10  0.0192 Preprocessor1_Model011
## 2     2     4 roc_auc binary     0.919    10  0.0194 Preprocessor1_Model031
## 3     2     8 roc_auc binary     0.918    10  0.0187 Preprocessor1_Model071
## 4     2     9 roc_auc binary     0.917    10  0.0193 Preprocessor1_Model081
## 5     3     9 roc_auc binary     0.916    10  0.0208 Preprocessor1_Model082

ROC Curve

collect_predictions(results) %>%
  group_by(id) %>%
  roc_curve(DEATH_EVENT, .pred_0) %>%
  autoplot()

Finalizing Model

best_hyperparameters <- select_best(results, metric = 'roc_auc')

final_rf_wf <- workflow() %>%
  add_model(rf_model) %>%
  add_recipe(rcp) %>%
  finalize_workflow(best_hyperparameters)

final_fit_rf <- final_rf_wf %>%
  fit(data = train)
  
set.seed(123)
final_test_rf <- final_fit_rf %>%
  predict(test) %>%
  bind_cols(select(test, DEATH_EVENT)) 

# metrics
rf_acc <- final_test_rf %>%
  accuracy(DEATH_EVENT, .pred_class)

final_test_rf %>%
  conf_mat(DEATH_EVENT, .pred_class) 
##           Truth
## Prediction  0  1
##          0 47  5
##          1  4 19

Our Random Forest model predicts an 88% accuracy to new data (84.8% accuracy in our cross validation set). Now lets look at which features were most important to determining mortality classification.

Feature Importance

final_fit_rf %>%
  extract_fit_parsnip() %>%
  vip() +
  ggtitle('Feature Importance')

We can see that time is by far the most important feature. Lets explore the partial dependency plot for this feature in our Random Forest Model.

predict_plot(var = 'time', model = final_fit_rf)
## Preparation of a new explainer is initiated
##   -> model label       :  workflow  (  default  )
##   -> data              :  224  rows  12  cols 
##   -> data              :  tibble converted into a data.frame 
##   -> target variable   :  224  values 
##   -> predict function  :  yhat.workflow  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package tidymodels , ver. 1.0.0 , task classification (  default  ) 
##   -> predicted values  :  numerical, min =  0.005565207 , mean =  0.3169936 , max =  0.973857  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  0.6271567 , mean =  1.004435 , max =  1.391136  
##   A new explainer has been created!

The plot shows the same overall relationship as our time predictor in our Logistic Regression model. However, there is a steeper slope down in probability but eventually gets less steep and levels out around 110 days.

Results

acc_df <- data.frame(
  algorithm = c('Logistic Regression', 'Random Forest'),
  accuracy = c(logit_acc$.estimate, rf_acc$.estimate)
) 

acc_percent <- paste(
  as.character(round(acc_df$accuracy, 3) * 100),
  '%'
)

acc_df %>%
  ggplot(aes(algorithm, accuracy)) +
    geom_segment(
      aes(
        xend = algorithm,
        x = algorithm,
        yend = accuracy,
        y = 0
        ),
      color = 'grey'
      ) +
    geom_point(size = 4, aes(color = algorithm)) +
  scale_color_manual(values = c("grey", "darkgreen")) +
  theme_classic() +
  coord_flip() +
  geom_text(label = acc_percent, nudge_y = 0.1) +
  labs(
    title = 'Model Accuracy'
  ) +
  theme(
    axis.text.x = element_blank(),
    axis.ticks.x = element_blank(),
    axis.title.x = element_blank(),
    axis.title.y = element_blank(),
    legend.title = element_blank(),
    legend.text = element_blank(),
    legend.position = 'none',
    plot.title = element_text(hjust = 0.5, face = 'bold')
    ) 

As shown, our Random Forest Model (88%) has a higher accuracy than our Logistic Regression model (85.3%). Time was the most important feature in determining mortality classification.