Introduction
In this Project, I will be focusing on predicting a classification of heart failure mortality based on 12 different predictor variables.
Prerequisites
library(tidyverse)
library(here)
library(tidymodels)
library(vip)
library(ggplot2)
library(corrplot)
library(pdp)
library(DALEXtra)
Data Exploration
Below we will be Importing our data set, describe our data, and explore relationships/distributions.
Data Import
# Data Import
path <- here('data', 'heart_failure_clinical_records_dataset.csv')
df <- read_csv(path)
DT::datatable(head(df))
str(df)
## spec_tbl_df [299 × 13] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
## $ age : num [1:299] 75 55 65 50 65 90 75 60 65 80 ...
## $ anaemia : num [1:299] 0 0 0 1 1 1 1 1 0 1 ...
## $ creatinine_phosphokinase: num [1:299] 582 7861 146 111 160 ...
## $ diabetes : num [1:299] 0 0 0 0 1 0 0 1 0 0 ...
## $ ejection_fraction : num [1:299] 20 38 20 20 20 40 15 60 65 35 ...
## $ high_blood_pressure : num [1:299] 1 0 0 0 0 1 0 0 0 1 ...
## $ platelets : num [1:299] 265000 263358 162000 210000 327000 ...
## $ serum_creatinine : num [1:299] 1.9 1.1 1.3 1.9 2.7 2.1 1.2 1.1 1.5 9.4 ...
## $ serum_sodium : num [1:299] 130 136 129 137 116 132 137 131 138 133 ...
## $ sex : num [1:299] 1 1 1 1 0 1 1 1 0 1 ...
## $ smoking : num [1:299] 0 0 1 0 0 1 0 1 0 1 ...
## $ time : num [1:299] 4 6 7 7 8 8 10 10 10 10 ...
## $ DEATH_EVENT : num [1:299] 1 1 1 1 1 1 1 1 1 1 ...
## - attr(*, "spec")=
## .. cols(
## .. age = col_double(),
## .. anaemia = col_double(),
## .. creatinine_phosphokinase = col_double(),
## .. diabetes = col_double(),
## .. ejection_fraction = col_double(),
## .. high_blood_pressure = col_double(),
## .. platelets = col_double(),
## .. serum_creatinine = col_double(),
## .. serum_sodium = col_double(),
## .. sex = col_double(),
## .. smoking = col_double(),
## .. time = col_double(),
## .. DEATH_EVENT = col_double()
## .. )
## - attr(*, "problems")=<externalptr>
About the Data
Predictor Variables (\(X\))
age: How many years old is the subject.
anaemia: Whether or not the subject has anaemia (0 = yes, 1 = no).
creatinine_phosphokinase: Level of the CPK enzyme in the blood (mcg/L).
diabetes: If the patient has diabetes: (0 = no, 1 = yes).
ejection_fraction: Percentage of blood leaving the heart at each contraction (percentage).
high_blood_pressure: If the subject has high_blood_pressure (0 = no, 1 = yes).
palates: Platelets in the blood (kiloplatelets/mL).
serum_creatinine: Level of serum creatinine in the blood (mg/dL).
serum_sodium: Level of serum sodium in the blood (mEq/L).
sex: gender of the subject (0 = female, 1 = male).
smoking: if the subject smokes or not (0 = no, 1 = yes).
time: Follow-up period (days).
Response Variable (\(Y\))
- DEATH_EVENT: If the patient deceased during the follow-up period.
Data Summary
summary(df)
## age anaemia creatinine_phosphokinase diabetes
## Min. :40.00 Min. :0.0000 Min. : 23.0 Min. :0.0000
## 1st Qu.:51.00 1st Qu.:0.0000 1st Qu.: 116.5 1st Qu.:0.0000
## Median :60.00 Median :0.0000 Median : 250.0 Median :0.0000
## Mean :60.83 Mean :0.4314 Mean : 581.8 Mean :0.4181
## 3rd Qu.:70.00 3rd Qu.:1.0000 3rd Qu.: 582.0 3rd Qu.:1.0000
## Max. :95.00 Max. :1.0000 Max. :7861.0 Max. :1.0000
## ejection_fraction high_blood_pressure platelets serum_creatinine
## Min. :14.00 Min. :0.0000 Min. : 25100 Min. :0.500
## 1st Qu.:30.00 1st Qu.:0.0000 1st Qu.:212500 1st Qu.:0.900
## Median :38.00 Median :0.0000 Median :262000 Median :1.100
## Mean :38.08 Mean :0.3512 Mean :263358 Mean :1.394
## 3rd Qu.:45.00 3rd Qu.:1.0000 3rd Qu.:303500 3rd Qu.:1.400
## Max. :80.00 Max. :1.0000 Max. :850000 Max. :9.400
## serum_sodium sex smoking time
## Min. :113.0 Min. :0.0000 Min. :0.0000 Min. : 4.0
## 1st Qu.:134.0 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.: 73.0
## Median :137.0 Median :1.0000 Median :0.0000 Median :115.0
## Mean :136.6 Mean :0.6488 Mean :0.3211 Mean :130.3
## 3rd Qu.:140.0 3rd Qu.:1.0000 3rd Qu.:1.0000 3rd Qu.:203.0
## Max. :148.0 Max. :1.0000 Max. :1.0000 Max. :285.0
## DEATH_EVENT
## Min. :0.0000
## 1st Qu.:0.0000
## Median :0.0000
## Mean :0.3211
## 3rd Qu.:1.0000
## Max. :1.0000
Correlation Matrix
corrplot(cor(df), tl.cex = 1)
We can derive a few insights from this correlation matrix:
There is a strong negative correlation between DEATH_EVENT and the time in between appointments.
There is a correlation between whether or not they are males and if they smoke or not. More males smoke than females.
Data Exploration
# function to create distribution charts
create_binary_distribution <- function(var, title = none, x = none){
ggplot(df, aes(x = as.factor(var))) +
geom_bar(aes(fill = as.factor(DEATH_EVENT))) +
labs(
title = title,
x = x,
fill = 'DEATH_EVENT'
) +
scale_fill_manual(values = c('darkgreen', 'brown'))
}
create_binary_distribution(
var = df$DEATH_EVENT,
title = 'Total Death Distribution',
x = 'DEATH_EVENT'
)
There were more subjects that survived than died in our sample population.
create_binary_distribution(
var = df$sex,
x = 'Gender',
title = 'Total Gender Distribution'
)
There were more males sampled than females, but both genders had similar death to survival ratios.
create_binary_distribution(
var = df$anaemia,
x = 'anaemia',
title = 'Total Anaemia Distribution'
)
Anaemia did not play a big part in whether or not someone survived.
create_binary_distribution(
var = df$diabetes,
x = 'diabetes',
title = 'Total Diabetes Distribution'
)
Diabetes did not play a large role in whether or not someone experienced heart failure.
create_binary_distribution(
var = df$high_blood_pressure,
x = 'high_blood_pressure',
title = 'High Blood Pressure Distribution'
)
High blood pressure did not play a large role in whether or not someone experienced heart failure.
create_binary_distribution(
var = df$smoking,
x = 'smoking',
title = 'Total Smoking Distribution'
)
Smoking did not play a large role in whether or not someone experienced heart failure.
df %>%
ggplot(aes(age)) +
geom_histogram(aes(fill = as.factor(DEATH_EVENT))) +
scale_x_continuous(breaks = seq(30, 100, 10)) +
labs(
title = 'Age Distribution',
fill = 'DEATH_EVENT',
) +
theme(
plot.title = element_text(hjust = 0.5)
) +
scale_fill_manual(values = c('darkgreen', 'brown'))
While the age of our subjects ranged from 40-95, the mean age sampled was about 61. The age group of 60 represents the age group where most deaths occurred (13).
Now we are going to look at our numeric features to understand their effect on the response variable a bit better.
lst = list(
ejection_fraction = df$ejection_fraction,
creatinine_phosphokinase = df$creatinine_phosphokinase,
serum_creatinine = df$serum_creatinine,
serum_sodium = df$serum_sodium,
time = df$time
)
color_list = list('darkgreen','blue','purple','lightblue', 'red')
n = 1
for (i in lst) {
print(
ggplot(df) +
geom_boxplot(aes(as.factor(DEATH_EVENT), i),fill = color_list[n],alpha = 0.5) +
labs(
x = 'DEATH_EVENT',
y = names(lst[n])
)
)
n = n + 1
}
Some insights we can make are:
Subjects who had a shorter returning appointment time were most likely to experience heart failure. ’
Subjects with a lower ejection fraction were most likely to experience heart failure.
Subjects with a higher serum creatinine level were more likely to experience heart failure.
Subjects with a lower serum sodium level were more likely to experience heart failure.
Data Processing
Below we are turning our categorical features/response variables into factors. After that, we begin our train/test split into a 75% train size.
cols = c(
'anaemia', 'diabetes', 'high_blood_pressure',
'sex', 'smoking', 'DEATH_EVENT'
)
df <- df %>%
mutate_at(cols, as.factor)
set.seed(123)
split <- initial_split(df, 0.75, strata = DEATH_EVENT)
train <- training(split)
test <- testing(split)
# kfold validation
set.seed(123)
kfold <- vfold_cv(train, v = 10, strata = DEATH_EVENT)
Feature Engineering
Now, we are creating a recipe to normalize our predictors.
# recipe
rcp <- recipe(DEATH_EVENT ~ ., data = train) %>%
step_normalize(all_numeric_predictors()) %>%
step_YeoJohnson(all_numeric_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_other(all_nominal_predictors(), threshold = 0.05, other = 'other')
Machine Learning
Below are 2 different supervised classification algorithms to help predict heart failure mortality:
Logistic Regression
Random Forest
We will create these two models and see which of these two algorithms are more efficient.
Logistic Regression
Tuning Model
logit_model <- logistic_reg(penalty = tune(), mixture = tune()) %>%
set_engine('glmnet')
logit_grid <- grid_regular(
penalty(range = c(-3, -1)),
mixture(),
levels = 10
)
tuning <- workflow() %>%
add_recipe(rcp) %>%
add_model(logit_model) %>%
tune_grid(resamples = kfold, grid = logit_grid, control = control_resamples(save_pred = T))
show_best(tuning, metric = 'roc_auc')
## # A tibble: 5 × 8
## penalty mixture .metric .estimator mean n std_err .config
## <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.1 0.333 roc_auc binary 0.901 10 0.0252 Preprocessor1_Model040
## 2 0.0599 0.556 roc_auc binary 0.900 10 0.0265 Preprocessor1_Model059
## 3 0.1 0.444 roc_auc binary 0.899 10 0.0275 Preprocessor1_Model050
## 4 0.0599 0.667 roc_auc binary 0.899 10 0.0276 Preprocessor1_Model069
## 5 0.0359 1 roc_auc binary 0.899 10 0.0276 Preprocessor1_Model098
ROC Curve
collect_predictions(tuning) %>%
group_by(id) %>%
roc_curve(DEATH_EVENT, .pred_0) %>%
autoplot()
Finalizing Model
best_hyperparameters_lg <- select_best(tuning, metric = 'roc_auc')
final_lg_wf <- workflow() %>%
add_model(logit_model) %>%
add_recipe(rcp) %>%
finalize_workflow(best_hyperparameters_lg)
set.seed(123)
final_fit_lg <- final_lg_wf %>%
fit(data = train)
final_test_lg <- final_fit_lg %>%
predict(test) %>%
bind_cols(select(test, DEATH_EVENT))
# metrics
logit_acc <- final_test_lg %>%
accuracy(DEATH_EVENT, .pred_class)
final_test_lg %>%
conf_mat(DEATH_EVENT, .pred_class)
## Truth
## Prediction 0 1
## 0 48 8
## 1 3 16
Our Logistic Regression model had a 85.3% accuracy rate to new data (83.9% accuracy in our cross validation procedure). Now we will evaluate which features were the most important in predicting heart failure mortality classification in this model.
Feature Importance
final_fit_lg %>%
extract_fit_parsnip() %>%
vip() +
ggtitle('Feature Importance')
Time and Serum_Creatinine were the two most important features in our Logistic Regression model. Now we are going to look at the partial dependency graphs for these two variables.
First we have to create a function that creates these partial dependency plots.
predict_plot <- function(var, model) {
explainer_lg <- explain_tidymodels(
model,
data = select(train, -DEATH_EVENT),
y = as.integer(train$DEATH_EVENT)
)
pdp <- model_profile(
explainer = explainer_lg,
variables = var,
N = NULL
)
pdp_df <- as_tibble(pdp$agr_profiles)
print(
ggplot(pdp_df, aes(x = `_x_`, y = `_yhat_`)) +
geom_smooth(color = 'lightblue', se = F) +
ylim(0, 1) +
labs(
title = paste(str_to_title(var), 'Partial Prediction'),
x = var
) +
theme_dark() +
theme(plot.title = element_text(hjust = 0.5))
)
}
Below is the partial dependency graph for the time feature:
predict_plot(var = 'time', model = final_fit_lg)
## Preparation of a new explainer is initiated
## -> model label : workflow ( default )
## -> data : 224 rows 12 cols
## -> data : tibble converted into a data.frame
## -> target variable : 224 values
## -> predict function : yhat.workflow will be used ( default )
## -> predicted values : No value for predict function target column. ( default )
## -> model_info : package tidymodels , ver. 1.0.0 , task classification ( default )
## -> predicted values : numerical, min = 0.02985122 , mean = 0.3214246 , max = 0.85468
## -> residual function : difference between y and yhat ( default )
## -> residuals : numerical, min = 0.2799684 , mean = 1.000004 , max = 1.9325
## A new explainer has been created!
As shown, the longer the time until the next appointment is, the lower chance the patient had heart failure. This could be due to the patient having a lower severity than patients who need to come back in shorter intervals.
Below is the partial dependency plot for the feature serum_creatinine:
predict_plot(var = 'serum_creatinine', model = final_fit_lg)
## Preparation of a new explainer is initiated
## -> model label : workflow ( default )
## -> data : 224 rows 12 cols
## -> data : tibble converted into a data.frame
## -> target variable : 224 values
## -> predict function : yhat.workflow will be used ( default )
## -> predicted values : No value for predict function target column. ( default )
## -> model_info : package tidymodels , ver. 1.0.0 , task classification ( default )
## -> predicted values : numerical, min = 0.02985122 , mean = 0.3214246 , max = 0.85468
## -> residual function : difference between y and yhat ( default )
## -> residuals : numerical, min = 0.2799684 , mean = 1.000004 , max = 1.9325
## A new explainer has been created!
The probability of heart failure increases as serum creatinine levels increase, but increase at a slower rate at every increase in serum creatinine level.
Random Forest
Tuning Model
rf_model <- rand_forest(
mode = 'classification',
trees = 500,
mtry = tune(),
min_n = tune()
) %>%
set_engine("ranger", importance = "permutation")
hyper_grid <- grid_regular(
mtry(range = c(2, 12)),
min_n(range = c(1, 10)),
levels = 10
)
# train our model across the hyper parameter grid
set.seed(123)
results <- tune_grid(
rf_model,
rcp,
resamples = kfold,
grid = hyper_grid,
control = control_resamples(save_pred = T)
)
# model results
show_best(results, metric = "roc_auc")
## # A tibble: 5 × 8
## mtry min_n .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 2 2 roc_auc binary 0.922 10 0.0192 Preprocessor1_Model011
## 2 2 4 roc_auc binary 0.919 10 0.0194 Preprocessor1_Model031
## 3 2 8 roc_auc binary 0.918 10 0.0187 Preprocessor1_Model071
## 4 2 9 roc_auc binary 0.917 10 0.0193 Preprocessor1_Model081
## 5 3 9 roc_auc binary 0.916 10 0.0208 Preprocessor1_Model082
ROC Curve
collect_predictions(results) %>%
group_by(id) %>%
roc_curve(DEATH_EVENT, .pred_0) %>%
autoplot()
Finalizing Model
best_hyperparameters <- select_best(results, metric = 'roc_auc')
final_rf_wf <- workflow() %>%
add_model(rf_model) %>%
add_recipe(rcp) %>%
finalize_workflow(best_hyperparameters)
final_fit_rf <- final_rf_wf %>%
fit(data = train)
set.seed(123)
final_test_rf <- final_fit_rf %>%
predict(test) %>%
bind_cols(select(test, DEATH_EVENT))
# metrics
rf_acc <- final_test_rf %>%
accuracy(DEATH_EVENT, .pred_class)
final_test_rf %>%
conf_mat(DEATH_EVENT, .pred_class)
## Truth
## Prediction 0 1
## 0 47 5
## 1 4 19
Our Random Forest model predicts an 88% accuracy to new data (84.8% accuracy in our cross validation set). Now lets look at which features were most important to determining mortality classification.
Feature Importance
final_fit_rf %>%
extract_fit_parsnip() %>%
vip() +
ggtitle('Feature Importance')
We can see that time is by far the most important feature. Lets explore the partial dependency plot for this feature in our Random Forest Model.
predict_plot(var = 'time', model = final_fit_rf)
## Preparation of a new explainer is initiated
## -> model label : workflow ( default )
## -> data : 224 rows 12 cols
## -> data : tibble converted into a data.frame
## -> target variable : 224 values
## -> predict function : yhat.workflow will be used ( default )
## -> predicted values : No value for predict function target column. ( default )
## -> model_info : package tidymodels , ver. 1.0.0 , task classification ( default )
## -> predicted values : numerical, min = 0.005565207 , mean = 0.3169936 , max = 0.973857
## -> residual function : difference between y and yhat ( default )
## -> residuals : numerical, min = 0.6271567 , mean = 1.004435 , max = 1.391136
## A new explainer has been created!
The plot shows the same overall relationship as our time predictor in our Logistic Regression model. However, there is a steeper slope down in probability but eventually gets less steep and levels out around 110 days.
Results
acc_df <- data.frame(
algorithm = c('Logistic Regression', 'Random Forest'),
accuracy = c(logit_acc$.estimate, rf_acc$.estimate)
)
acc_percent <- paste(
as.character(round(acc_df$accuracy, 3) * 100),
'%'
)
acc_df %>%
ggplot(aes(algorithm, accuracy)) +
geom_segment(
aes(
xend = algorithm,
x = algorithm,
yend = accuracy,
y = 0
),
color = 'grey'
) +
geom_point(size = 4, aes(color = algorithm)) +
scale_color_manual(values = c("grey", "darkgreen")) +
theme_classic() +
coord_flip() +
geom_text(label = acc_percent, nudge_y = 0.1) +
labs(
title = 'Model Accuracy'
) +
theme(
axis.text.x = element_blank(),
axis.ticks.x = element_blank(),
axis.title.x = element_blank(),
axis.title.y = element_blank(),
legend.title = element_blank(),
legend.text = element_blank(),
legend.position = 'none',
plot.title = element_text(hjust = 0.5, face = 'bold')
)
As shown, our Random Forest Model (88%) has a higher accuracy than our Logistic Regression model (85.3%). Time was the most important feature in determining mortality classification.