Introduction

In this Model, I will be focusing on predicting a classification of heart failure mortality based on 12 different predictor variables.

Prerequisites

library(tidyverse)
library(here)
library(tidymodels)
library(vip)
library(ggplot2)
library(corrplot)

Data Exploration

Data Import

# Data Import
path <- here('data', 'heart_failure_clinical_records_dataset.csv')
df <- read_csv(path)

DT::datatable(head(df))
str(df)
## spec_tbl_df [299 × 13] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
##  $ age                     : num [1:299] 75 55 65 50 65 90 75 60 65 80 ...
##  $ anaemia                 : num [1:299] 0 0 0 1 1 1 1 1 0 1 ...
##  $ creatinine_phosphokinase: num [1:299] 582 7861 146 111 160 ...
##  $ diabetes                : num [1:299] 0 0 0 0 1 0 0 1 0 0 ...
##  $ ejection_fraction       : num [1:299] 20 38 20 20 20 40 15 60 65 35 ...
##  $ high_blood_pressure     : num [1:299] 1 0 0 0 0 1 0 0 0 1 ...
##  $ platelets               : num [1:299] 265000 263358 162000 210000 327000 ...
##  $ serum_creatinine        : num [1:299] 1.9 1.1 1.3 1.9 2.7 2.1 1.2 1.1 1.5 9.4 ...
##  $ serum_sodium            : num [1:299] 130 136 129 137 116 132 137 131 138 133 ...
##  $ sex                     : num [1:299] 1 1 1 1 0 1 1 1 0 1 ...
##  $ smoking                 : num [1:299] 0 0 1 0 0 1 0 1 0 1 ...
##  $ time                    : num [1:299] 4 6 7 7 8 8 10 10 10 10 ...
##  $ DEATH_EVENT             : num [1:299] 1 1 1 1 1 1 1 1 1 1 ...
##  - attr(*, "spec")=
##   .. cols(
##   ..   age = col_double(),
##   ..   anaemia = col_double(),
##   ..   creatinine_phosphokinase = col_double(),
##   ..   diabetes = col_double(),
##   ..   ejection_fraction = col_double(),
##   ..   high_blood_pressure = col_double(),
##   ..   platelets = col_double(),
##   ..   serum_creatinine = col_double(),
##   ..   serum_sodium = col_double(),
##   ..   sex = col_double(),
##   ..   smoking = col_double(),
##   ..   time = col_double(),
##   ..   DEATH_EVENT = col_double()
##   .. )
##  - attr(*, "problems")=<externalptr>

About the Data

Predictor Variables (\(X\))

  • age: How many years old is the subject.

  • anaemia: Whether or not the subject has anaemia (0 = yes, 1 = no).

  • creatinine_phosphokinase: Level of the CPK enzyme in the blood (mcg/L).

  • diabetes: If the patient has diabetes: (0 = no, 1 = yes).

  • ejection_fraction: Percentage of blood leaving the heart at each contraction (percentage).

  • high_blood_pressure: If the subject has high_blood_pressure (0 = no, 1 = yes).

  • palates: Platelets in the blood (kiloplatelets/mL).

  • serum_creatinine: Level of serum creatinine in the blood (mg/dL).

  • serum_sodium: Level of serum sodium in the blood (mEq/L).

  • sex: gender of the subject (0 = female, 1 = male).

  • smoking: if the subject smokes or not (0 = no, 1 = yes).

  • time: Follow-up period (days).

Response Variable (\(Y\))

  • DEATH_EVENT: If the patient deceased during the follow-up period.

Data Summary

summary(df)
##       age           anaemia       creatinine_phosphokinase    diabetes     
##  Min.   :40.00   Min.   :0.0000   Min.   :  23.0           Min.   :0.0000  
##  1st Qu.:51.00   1st Qu.:0.0000   1st Qu.: 116.5           1st Qu.:0.0000  
##  Median :60.00   Median :0.0000   Median : 250.0           Median :0.0000  
##  Mean   :60.83   Mean   :0.4314   Mean   : 581.8           Mean   :0.4181  
##  3rd Qu.:70.00   3rd Qu.:1.0000   3rd Qu.: 582.0           3rd Qu.:1.0000  
##  Max.   :95.00   Max.   :1.0000   Max.   :7861.0           Max.   :1.0000  
##  ejection_fraction high_blood_pressure   platelets      serum_creatinine
##  Min.   :14.00     Min.   :0.0000      Min.   : 25100   Min.   :0.500   
##  1st Qu.:30.00     1st Qu.:0.0000      1st Qu.:212500   1st Qu.:0.900   
##  Median :38.00     Median :0.0000      Median :262000   Median :1.100   
##  Mean   :38.08     Mean   :0.3512      Mean   :263358   Mean   :1.394   
##  3rd Qu.:45.00     3rd Qu.:1.0000      3rd Qu.:303500   3rd Qu.:1.400   
##  Max.   :80.00     Max.   :1.0000      Max.   :850000   Max.   :9.400   
##   serum_sodium        sex            smoking            time      
##  Min.   :113.0   Min.   :0.0000   Min.   :0.0000   Min.   :  4.0  
##  1st Qu.:134.0   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.: 73.0  
##  Median :137.0   Median :1.0000   Median :0.0000   Median :115.0  
##  Mean   :136.6   Mean   :0.6488   Mean   :0.3211   Mean   :130.3  
##  3rd Qu.:140.0   3rd Qu.:1.0000   3rd Qu.:1.0000   3rd Qu.:203.0  
##  Max.   :148.0   Max.   :1.0000   Max.   :1.0000   Max.   :285.0  
##   DEATH_EVENT    
##  Min.   :0.0000  
##  1st Qu.:0.0000  
##  Median :0.0000  
##  Mean   :0.3211  
##  3rd Qu.:1.0000  
##  Max.   :1.0000

Correlation Matrix

corrplot(cor(df), tl.cex = 1)

Data Exploration

create_binary_distribution <- function(var, title = none, x = none, col1, col2){
  ggplot(df, aes(x = as.factor(var))) +
  geom_bar(aes(fill = as.factor(var))) +
  labs(
    title = title,
    x = x,
    fill = x
  ) +
  scale_fill_manual(values = c(col1, col2))
}
create_binary_distribution(
  var = df$DEATH_EVENT,
  title = 'Total Death Distribution',
  x = 'DEATH_EVENT',
  col1 = 'darkgreen',
  col2 = 'darkred'
)

create_binary_distribution(
  var = df$sex,
  x = 'Gender',
  title = 'Total Gender Distribution',
  col1 = 'pink',
  col2 = 'lightblue'
)

create_binary_distribution(
  var = df$anaemia,
  x = 'anaemia',
  title = 'Total Anaemia Distribution',
  col1 = 'navy',
  col2 = 'orange'
)

create_binary_distribution(
  var = df$diabetes,
  x = 'diabetes',
  title = 'Total Diabetes Distribution',
  col1 = 'lightgreen',
  col2 = 'maroon'
)

create_binary_distribution(
  var = df$high_blood_pressure,
  x = 'high_blood_pressure',
  title = 'High Blood Pressure Distribution',
  col1 = 'orange',
  col2 = 'violet'
  )

create_binary_distribution(
  var = df$smoking,
  x = 'smoking',
  title = 'Total Smoking Distribution',
  col1 = 'darkgreen',
  col2 = 'purple'
)

df %>%
  ggplot(aes(age)) +
  geom_histogram(aes(fill = as.factor(DEATH_EVENT))) +
  scale_x_continuous(breaks = seq(30, 100, 10)) +
  labs(
    title = 'Age Distribution',
    fill = 'DEATH_EVENT',
  ) +
  theme(
  plot.title = element_text(hjust = 0.5)
) 
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

lst = list(
  ejection_fraction = df$ejection_fraction,
  creatinine_phosphokinase = df$creatinine_phosphokinase,
  serum_creatinine = df$serum_creatinine,
  serum_sodium = df$serum_sodium,
  time = df$time
  )
color_list = list('darkgreen','blue','purple','lightblue', 'red')
n = 1
for (i in lst) {
  print(
    ggplot(df) + 
      geom_boxplot(aes(as.factor(DEATH_EVENT), i),fill = color_list[n],alpha = 0.5) +
      labs(
        x = 'DEATH_EVENT',
        y = names(lst[n])
      )
    )
      
    n = n + 1
}

Machine Learning

Data Processing

cols = c(
  'anaemia', 'diabetes', 'high_blood_pressure',
  'sex', 'smoking', 'DEATH_EVENT'
  )

df <- df %>%
  mutate_at(cols, as.factor)

set.seed(123)
split <- initial_split(df, 0.75, strata = DEATH_EVENT)
train <- training(split)
test <- testing(split)

# kfold validation
set.seed(123)
kfold <- vfold_cv(train, v = 10, strata = DEATH_EVENT)

Feature Engineering

# recipe
rcp <- recipe(DEATH_EVENT ~ ., data = train) %>%
  step_normalize(all_numeric_predictors()) %>%
  step_YeoJohnson(all_numeric_predictors()) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_other(all_nominal_predictors(), threshold = 0.05, other = 'other')

Modeling

Logistic Regression

Tuning Model

logit_model <- logistic_reg(penalty = tune(), mixture = tune()) %>%
  set_engine('glmnet')

logit_grid <- grid_regular(
  penalty(range = c(-3, -1)),
  mixture(),
  levels = 10
  )

tuning <- workflow() %>%
  add_recipe(rcp) %>%
  add_model(logit_model) %>%
  tune_grid(resamples = kfold, grid = logit_grid) 
## Warning: package 'glmnet' was built under R version 4.2.2
show_best(tuning, metric = 'roc_auc')
## # A tibble: 5 × 8
##   penalty mixture .metric .estimator  mean     n std_err .config               
##     <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                 
## 1  0.1      0.333 roc_auc binary     0.901    10  0.0252 Preprocessor1_Model040
## 2  0.0599   0.556 roc_auc binary     0.900    10  0.0265 Preprocessor1_Model059
## 3  0.1      0.444 roc_auc binary     0.899    10  0.0275 Preprocessor1_Model050
## 4  0.0599   0.667 roc_auc binary     0.899    10  0.0276 Preprocessor1_Model069
## 5  0.0359   1     roc_auc binary     0.899    10  0.0276 Preprocessor1_Model098

Finalizing Model

best_hyperparameters <- select_best(tuning, metric = 'roc_auc')

final_lg_wf <- workflow() %>%
  add_model(logit_model) %>%
  add_recipe(rcp) %>%
  finalize_workflow(best_hyperparameters)

set.seed(123)
final_fit <- final_lg_wf %>%
  fit(data = train)
  
final_test <- final_fit %>%
  predict(test) %>%
  bind_cols(select(test, DEATH_EVENT)) 

# metrics
final_test %>%
  accuracy(DEATH_EVENT, .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.853
final_test %>%
  conf_mat(DEATH_EVENT, .pred_class) 
##           Truth
## Prediction  0  1
##          0 48  8
##          1  3 16

Our Logistic Regression had a 85.3% accuracy rate.

Feature Importance

final_fit %>%
  extract_fit_parsnip() %>%
  vip() +
  ggtitle('Feature Importance') 

Random Forest

Tuning Model

rf_model <- rand_forest(
  mode = 'classification',
  trees = 500,
  mtry = tune(),
  min_n = tune()
  ) %>%
  set_engine("ranger", importance = "permutation")

hyper_grid <- grid_regular(
   mtry(range = c(2, 12)),
   min_n(range = c(1, 10)),        
   levels = 10
   )

# train our model across the hyper parameter grid
set.seed(123)
results <- tune_grid(
  rf_model, 
  rcp, 
  resamples = kfold, 
  grid = hyper_grid)
## Warning: package 'ranger' was built under R version 4.2.2
# model results
show_best(results, metric = "roc_auc")
## # A tibble: 5 × 8
##    mtry min_n .metric .estimator  mean     n std_err .config               
##   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                 
## 1     2     2 roc_auc binary     0.922    10  0.0192 Preprocessor1_Model011
## 2     2     4 roc_auc binary     0.919    10  0.0194 Preprocessor1_Model031
## 3     2     8 roc_auc binary     0.918    10  0.0187 Preprocessor1_Model071
## 4     2     9 roc_auc binary     0.917    10  0.0193 Preprocessor1_Model081
## 5     3     9 roc_auc binary     0.916    10  0.0208 Preprocessor1_Model082

Finalizing Model

best_hyperparameters <- select_best(results, metric = 'roc_auc')

final_rf_wf <- workflow() %>%
  add_model(rf_model) %>%
  add_recipe(rcp) %>%
  finalize_workflow(best_hyperparameters)

final_fit <- final_rf_wf %>%
  fit(data = train)
  
set.seed(123)
final_test <- final_fit %>%
  predict(test) %>%
  bind_cols(select(test, DEATH_EVENT)) 

# metrics
final_test %>%
  accuracy(DEATH_EVENT, .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary          0.88
final_test %>%
  conf_mat(DEATH_EVENT, .pred_class) 
##           Truth
## Prediction  0  1
##          0 47  5
##          1  4 19

Our Random Forest Model predicts an 88% accuracy to new data.

Feature Interpretation

final_fit %>%
  extract_fit_parsnip() %>%
  vip() +
  ggtitle('Feature Importance')