library(tidyverse)
library(here)
library(tidymodels)
library(vip)
library(DT)
library(ggplot2)
# Data Import
path <- here('data', 'heart_failure_clinical_records_dataset.csv')
df <- read_csv(path)
str(df)
## spec_tbl_df [299 × 13] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
## $ age : num [1:299] 75 55 65 50 65 90 75 60 65 80 ...
## $ anaemia : num [1:299] 0 0 0 1 1 1 1 1 0 1 ...
## $ creatinine_phosphokinase: num [1:299] 582 7861 146 111 160 ...
## $ diabetes : num [1:299] 0 0 0 0 1 0 0 1 0 0 ...
## $ ejection_fraction : num [1:299] 20 38 20 20 20 40 15 60 65 35 ...
## $ high_blood_pressure : num [1:299] 1 0 0 0 0 1 0 0 0 1 ...
## $ platelets : num [1:299] 265000 263358 162000 210000 327000 ...
## $ serum_creatinine : num [1:299] 1.9 1.1 1.3 1.9 2.7 2.1 1.2 1.1 1.5 9.4 ...
## $ serum_sodium : num [1:299] 130 136 129 137 116 132 137 131 138 133 ...
## $ sex : num [1:299] 1 1 1 1 0 1 1 1 0 1 ...
## $ smoking : num [1:299] 0 0 1 0 0 1 0 1 0 1 ...
## $ time : num [1:299] 4 6 7 7 8 8 10 10 10 10 ...
## $ DEATH_EVENT : num [1:299] 1 1 1 1 1 1 1 1 1 1 ...
## - attr(*, "spec")=
## .. cols(
## .. age = col_double(),
## .. anaemia = col_double(),
## .. creatinine_phosphokinase = col_double(),
## .. diabetes = col_double(),
## .. ejection_fraction = col_double(),
## .. high_blood_pressure = col_double(),
## .. platelets = col_double(),
## .. serum_creatinine = col_double(),
## .. serum_sodium = col_double(),
## .. sex = col_double(),
## .. smoking = col_double(),
## .. time = col_double(),
## .. DEATH_EVENT = col_double()
## .. )
## - attr(*, "problems")=<externalptr>
Predictor Variables (\(X\))
age: How many years old is the subject.
anaemia: Whether or not the subject has anaemia (0 = yes, 1 = no).
creatinine_phosphokinase: Level of the CPK enzyme in the blood (mcg/L).
diabetes: If the patient has diabetes: (0 = no, 1 = yes).
ejection_fraction: Percentage of blood leaving the heart at each contraction (percentage).
high_blood_pressure: If the subject has high_blood_pressure (0 = no, 1 = yes).
palates: Platelets in the blood (kiloplatelets/mL).
serum_creatinine: Level of serum creatinine in the blood (mg/dL).
serum_sodium: Level of serum sodium in the blood (mEq/L).
sex: gender of the subject (0 = female, 1 = male).
smoking: if the subject smokes or not (0 = no, 1 = yes).
time: Follow-up period (days).
Response Variable (\(Y\))
cols = c(
'anaemia', 'diabetes', 'high_blood_pressure',
'sex', 'smoking', 'DEATH_EVENT'
)
df <- df %>%
mutate_at(cols, as.factor)
Total Deaths from Heart failure
table(df$DEATH_EVENT)
##
## 0 1
## 203 96
Total Males and Females
table(df$sex)
##
## 0 1
## 105 194
Deaths by Sex
xtabs(~ DEATH_EVENT + sex, df)
## sex
## DEATH_EVENT 0 1
## 0 71 132
## 1 34 62
Deaths and Anaemia by sex
xtabs(~ DEATH_EVENT + anaemia + sex, df)
## , , sex = 0
##
## anaemia
## DEATH_EVENT 0 1
## 0 39 32
## 1 14 20
##
## , , sex = 1
##
## anaemia
## DEATH_EVENT 0 1
## 0 81 51
## 1 36 26
Deaths and Diabetes by Sex
xtabs(~ DEATH_EVENT + diabetes + sex, df)
## , , sex = 0
##
## diabetes
## DEATH_EVENT 0 1
## 0 36 35
## 1 14 20
##
## , , sex = 1
##
## diabetes
## DEATH_EVENT 0 1
## 0 82 50
## 1 42 20
Deaths and High Blood Pressure by Sex
xtabs(~ DEATH_EVENT + high_blood_pressure + sex, df)
## , , sex = 0
##
## high_blood_pressure
## DEATH_EVENT 0 1
## 0 44 27
## 1 17 17
##
## , , sex = 1
##
## high_blood_pressure
## DEATH_EVENT 0 1
## 0 93 39
## 1 40 22
Deaths and Smoking by Sex
xtabs(~ DEATH_EVENT + smoking + sex, df)
## , , sex = 0
##
## smoking
## DEATH_EVENT 0 1
## 0 70 1
## 1 31 3
##
## , , sex = 1
##
## smoking
## DEATH_EVENT 0 1
## 0 67 65
## 1 35 27
Age Distribution
df %>%
ggplot(aes(age)) +
geom_histogram(aes(fill = DEATH_EVENT)) +
scale_x_continuous(breaks = seq(30, 100, 10)) +
labs(
title = 'Age Distribution'
) +
theme(
plot.title = element_text(hjust = 0.5)
)
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
set.seed(123)
split <- initial_split(df, 0.75, strata = DEATH_EVENT)
train <- training(split)
test <- testing(split)
# kfold validation
set.seed(123)
kfold <- vfold_cv(train, v = 5, strata = DEATH_EVENT)
table(train$DEATH_EVENT)
##
## 0 1
## 152 72
table(test$DEATH_EVENT)
##
## 0 1
## 51 24
# recipe
rcp <- recipe(DEATH_EVENT ~ ., data = train) %>%
step_normalize(all_numeric_predictors()) %>%
step_YeoJohnson(all_numeric_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_other(all_nominal_predictors(), threshold = 0.05, other = 'other')
# rf Models
rf_model <- rand_forest(
mode = 'classification',
trees = 300,
mtry = tune(),
min_n = tune()
) %>%
set_engine("ranger", importance = "permutation")
hyper_grid <- grid_regular(
mtry(range = c(2, 5)),
min_n(range = c(5, 10)),
levels = 10
)
# train our model across the hyper parameter grid
set.seed(123)
results <- tune_grid(
rf_model,
rcp,
resamples = kfold,
grid = hyper_grid)
## Warning: package 'ranger' was built under R version 4.2.2
# model results
show_best(results, metric = "roc_auc")
## # A tibble: 5 × 8
## mtry min_n .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 3 10 roc_auc binary 0.909 5 0.0206 Preprocessor1_Model22
## 2 4 10 roc_auc binary 0.907 5 0.0208 Preprocessor1_Model23
## 3 4 7 roc_auc binary 0.906 5 0.0230 Preprocessor1_Model11
## 4 2 10 roc_auc binary 0.906 5 0.0233 Preprocessor1_Model21
## 5 3 8 roc_auc binary 0.905 5 0.0200 Preprocessor1_Model14
autoplot(results)
best_hyperparameters <- select_best(results, metric = 'roc_auc')
final_rf_wf <- workflow() %>%
add_model(rf_model) %>%
add_recipe(rcp) %>%
finalize_workflow(best_hyperparameters)
final_fit <- final_rf_wf %>%
fit(data = train)
set.seed(123)
final_test <- final_fit %>%
predict(test) %>%
bind_cols(select(test, DEATH_EVENT))
# metrics
final_test %>%
accuracy(DEATH_EVENT, .pred_class)
## # A tibble: 1 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.867
final_test %>%
conf_mat(DEATH_EVENT, .pred_class)
## Truth
## Prediction 0 1
## 0 47 6
## 1 4 18