Prerequisites

library(tidyverse)
library(here)
library(tidymodels)
library(vip)
library(DT)
library(ggplot2)

Data Wrangling

Data Import

# Data Import
path <- here('data', 'heart_failure_clinical_records_dataset.csv')
df <- read_csv(path)

str(df)
## spec_tbl_df [299 × 13] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
##  $ age                     : num [1:299] 75 55 65 50 65 90 75 60 65 80 ...
##  $ anaemia                 : num [1:299] 0 0 0 1 1 1 1 1 0 1 ...
##  $ creatinine_phosphokinase: num [1:299] 582 7861 146 111 160 ...
##  $ diabetes                : num [1:299] 0 0 0 0 1 0 0 1 0 0 ...
##  $ ejection_fraction       : num [1:299] 20 38 20 20 20 40 15 60 65 35 ...
##  $ high_blood_pressure     : num [1:299] 1 0 0 0 0 1 0 0 0 1 ...
##  $ platelets               : num [1:299] 265000 263358 162000 210000 327000 ...
##  $ serum_creatinine        : num [1:299] 1.9 1.1 1.3 1.9 2.7 2.1 1.2 1.1 1.5 9.4 ...
##  $ serum_sodium            : num [1:299] 130 136 129 137 116 132 137 131 138 133 ...
##  $ sex                     : num [1:299] 1 1 1 1 0 1 1 1 0 1 ...
##  $ smoking                 : num [1:299] 0 0 1 0 0 1 0 1 0 1 ...
##  $ time                    : num [1:299] 4 6 7 7 8 8 10 10 10 10 ...
##  $ DEATH_EVENT             : num [1:299] 1 1 1 1 1 1 1 1 1 1 ...
##  - attr(*, "spec")=
##   .. cols(
##   ..   age = col_double(),
##   ..   anaemia = col_double(),
##   ..   creatinine_phosphokinase = col_double(),
##   ..   diabetes = col_double(),
##   ..   ejection_fraction = col_double(),
##   ..   high_blood_pressure = col_double(),
##   ..   platelets = col_double(),
##   ..   serum_creatinine = col_double(),
##   ..   serum_sodium = col_double(),
##   ..   sex = col_double(),
##   ..   smoking = col_double(),
##   ..   time = col_double(),
##   ..   DEATH_EVENT = col_double()
##   .. )
##  - attr(*, "problems")=<externalptr>

Predictor Variables (\(X\))

  • age: How many years old is the subject.

  • anaemia: Whether or not the subject has anaemia (0 = yes, 1 = no).

  • creatinine_phosphokinase: Level of the CPK enzyme in the blood (mcg/L).

  • diabetes: If the patient has diabetes: (0 = no, 1 = yes).

  • ejection_fraction: Percentage of blood leaving the heart at each contraction (percentage).

  • high_blood_pressure: If the subject has high_blood_pressure (0 = no, 1 = yes).

  • palates: Platelets in the blood (kiloplatelets/mL).

  • serum_creatinine: Level of serum creatinine in the blood (mg/dL).

  • serum_sodium: Level of serum sodium in the blood (mEq/L).

  • sex: gender of the subject (0 = female, 1 = male).

  • smoking: if the subject smokes or not (0 = no, 1 = yes).

  • time: Follow-up period (days).

Response Variable (\(Y\))

  • DEATH_EVENT: If the patient deceased during the follow-up period.

Creating Factors

cols = c(
  'anaemia', 'diabetes', 'high_blood_pressure',
  'sex', 'smoking', 'DEATH_EVENT'
  )

df <- df %>%
  mutate_at(cols, as.factor)

Data Exploration

Total Deaths from Heart failure

table(df$DEATH_EVENT)
## 
##   0   1 
## 203  96

Total Males and Females

table(df$sex)
## 
##   0   1 
## 105 194

Deaths by Sex

xtabs(~ DEATH_EVENT + sex, df)
##            sex
## DEATH_EVENT   0   1
##           0  71 132
##           1  34  62

Deaths and Anaemia by sex

xtabs(~ DEATH_EVENT + anaemia + sex, df)
## , , sex = 0
## 
##            anaemia
## DEATH_EVENT  0  1
##           0 39 32
##           1 14 20
## 
## , , sex = 1
## 
##            anaemia
## DEATH_EVENT  0  1
##           0 81 51
##           1 36 26

Deaths and Diabetes by Sex

xtabs(~ DEATH_EVENT + diabetes + sex, df)
## , , sex = 0
## 
##            diabetes
## DEATH_EVENT  0  1
##           0 36 35
##           1 14 20
## 
## , , sex = 1
## 
##            diabetes
## DEATH_EVENT  0  1
##           0 82 50
##           1 42 20

Deaths and High Blood Pressure by Sex

xtabs(~ DEATH_EVENT + high_blood_pressure + sex, df)
## , , sex = 0
## 
##            high_blood_pressure
## DEATH_EVENT  0  1
##           0 44 27
##           1 17 17
## 
## , , sex = 1
## 
##            high_blood_pressure
## DEATH_EVENT  0  1
##           0 93 39
##           1 40 22

Deaths and Smoking by Sex

xtabs(~ DEATH_EVENT + smoking + sex, df)
## , , sex = 0
## 
##            smoking
## DEATH_EVENT  0  1
##           0 70  1
##           1 31  3
## 
## , , sex = 1
## 
##            smoking
## DEATH_EVENT  0  1
##           0 67 65
##           1 35 27

Age Distribution

df %>%
  ggplot(aes(age)) +
  geom_histogram(aes(fill = DEATH_EVENT)) +
  scale_x_continuous(breaks = seq(30, 100, 10)) +
  labs(
    title = 'Age Distribution'
  ) +
  theme(
  plot.title = element_text(hjust = 0.5)
) 
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

Machine Learning

Random Forest Model

set.seed(123)
split <- initial_split(df, 0.75, strata = DEATH_EVENT)
train <- training(split)
test <- testing(split)

# kfold validation
set.seed(123)
kfold <- vfold_cv(train, v = 5, strata = DEATH_EVENT)

table(train$DEATH_EVENT)
## 
##   0   1 
## 152  72
table(test$DEATH_EVENT)
## 
##  0  1 
## 51 24
# recipe
rcp <- recipe(DEATH_EVENT ~ ., data = train) %>%
  step_normalize(all_numeric_predictors()) %>%
  step_YeoJohnson(all_numeric_predictors()) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_other(all_nominal_predictors(), threshold = 0.05, other = 'other')

# rf Models 
rf_model <- rand_forest(
  mode = 'classification',
  trees = 300,
  mtry = tune(),
  min_n = tune()
  ) %>%
  set_engine("ranger", importance = "permutation")

hyper_grid <- grid_regular(
   mtry(range = c(2, 5)),
   min_n(range = c(5, 10)),        
   levels = 10
   )

# train our model across the hyper parameter grid
set.seed(123)
results <- tune_grid(
  rf_model, 
  rcp, 
  resamples = kfold, 
  grid = hyper_grid)
## Warning: package 'ranger' was built under R version 4.2.2
# model results
show_best(results, metric = "roc_auc")
## # A tibble: 5 × 8
##    mtry min_n .metric .estimator  mean     n std_err .config              
##   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
## 1     3    10 roc_auc binary     0.909     5  0.0206 Preprocessor1_Model22
## 2     4    10 roc_auc binary     0.907     5  0.0208 Preprocessor1_Model23
## 3     4     7 roc_auc binary     0.906     5  0.0230 Preprocessor1_Model11
## 4     2    10 roc_auc binary     0.906     5  0.0233 Preprocessor1_Model21
## 5     3     8 roc_auc binary     0.905     5  0.0200 Preprocessor1_Model14
autoplot(results)

best_hyperparameters <- select_best(results, metric = 'roc_auc')

final_rf_wf <- workflow() %>%
  add_model(rf_model) %>%
  add_recipe(rcp) %>%
  finalize_workflow(best_hyperparameters)

final_fit <- final_rf_wf %>%
  fit(data = train)
  
set.seed(123)
final_test <- final_fit %>%
  predict(test) %>%
  bind_cols(select(test, DEATH_EVENT)) 

# metrics
final_test %>%
  accuracy(DEATH_EVENT, .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.867
final_test %>%
  conf_mat(DEATH_EVENT, .pred_class) 
##           Truth
## Prediction  0  1
##          0 47  6
##          1  4 18