The data is taken from Kaggle and can be accessed here. The data contains medical records of 5,000 patients who had heart failure, collected during their follow-up period.
Each patient profile contains information for the following clinical features:
| Variable | Description |
|---|---|
age |
age of the patient (years) |
anaemia |
decrease of red blood cells or hemoglobin (boolean) |
creatinine_phosphokinase |
level of the CPK enzyme in the blood (mcg/L) |
diabetes |
if the patient has diabetes (boolean) |
ejection_fraction |
percentage of blood leaving the heart at each contraction (percentage) |
high_blood_pressure |
if the patient has hypertension (boolean) |
platelets |
platelets in the blood (kiloplatelets/mL) |
sex |
woman or man (binary) |
serum_creatinine |
level of serum creatinine in the blood (mg/dL) |
serum_sodium |
level of serum sodium in the blood (mEq/L) |
smoking |
if the patient smokes or not (boolean) |
time |
follow-up period (days) |
DEATH_EVENT |
if the patient died during the follow-up period (boolean) |
disease_data <- read.csv("https://raw.githubusercontent.com/ShanaFarber/cuny-sps/master/DATA_622/data/heart_failure_clinical_records.csv")
glimpse(disease_data)
## Rows: 5,000
## Columns: 13
## $ age <dbl> 55.000, 65.000, 45.000, 60.000, 95.000, 70.00…
## $ anaemia <int> 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, …
## $ creatinine_phosphokinase <int> 748, 56, 582, 754, 582, 232, 122, 171, 482, 4…
## $ diabetes <int> 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, …
## $ ejection_fraction <int> 45, 25, 38, 40, 30, 30, 60, 50, 30, 45, 40, 2…
## $ high_blood_pressure <int> 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, …
## $ platelets <dbl> 263358, 305000, 319000, 328000, 461000, 30200…
## $ serum_creatinine <dbl> 1.3, 5.0, 0.9, 1.2, 2.0, 1.2, 1.2, 0.9, 0.9, …
## $ serum_sodium <int> 137, 130, 140, 126, 132, 132, 145, 141, 132, …
## $ sex <int> 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, …
## $ smoking <int> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, …
## $ time <int> 88, 207, 244, 90, 50, 210, 147, 196, 109, 215…
## $ DEATH_EVENT <int> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
The data has 5,000 rows and 13 variables. All of the variables are numeric.
Some variables are continuous numeric variables while others are binary variables indicating a categorical feature.
The goal of this analysis is to predict whether the patient died at the follow-up period based on the recorded attributes of each patient. This analysis would be helpful for hospitals to be able to predict the likelihood of survival of heart failure patients. In doing this, hospitals would be able to identify patients who are in the most need of attention. Or, in the case where there may be multiple patients with heart failure and not enough resources to treat both, hospitals could be able to identify patients with the highest likelihood of survival to be able to treat them.
The analysis will focus on using Random Forest and Neural Network modeling to predict the instance of death based on the available recorded features for patients who had heart failure.
First, we will check if the data is missing any values.
plot_missing(disease_data)
None of the values are missing from the dataset.
Are there any near zero variance predictors?
nearZeroVar(disease_data |> select(age, creatinine_phosphokinase, ejection_fraction, platelets, serum_creatinine, serum_sodium, time), saveMetrics = TRUE) |>
knitr::kable()
| freqRatio | percentUnique | zeroVar | nzv | |
|---|---|---|---|---|
| age | 1.171371 | 0.96 | FALSE | FALSE |
| creatinine_phosphokinase | 10.534247 | 5.80 | FALSE | FALSE |
| ejection_fraction | 1.087447 | 0.34 | FALSE | FALSE |
| platelets | 4.817073 | 4.06 | FALSE | FALSE |
| serum_creatinine | 1.437607 | 0.86 | FALSE | FALSE |
| serum_sodium | 1.027244 | 0.54 | FALSE | FALSE |
| time | 1.122807 | 3.10 | FALSE | FALSE |
None of the predictors have zero or near zero variance.
Let’s visualize the distribution of the response variable.
# calculate counts and percents
counts <- disease_data |>
count(DEATH_EVENT) |>
mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
mutate(total = nrow(disease_data),
perc = round(n / total * 100, 0))
# plot
counts %>%
ggplot(aes(x = DEATH_EVENT, y = n)) +
geom_bar(aes(fill = DEATH_EVENT), stat="identity") +
geom_text(aes(label = paste0(perc, '%')), vjust = 2, color = "white", fontface = 'bold') +
theme(legend.position = "none") +
labs(title = "Distribution of Response Variable", x = "Death Event", y = "Count") +
scale_y_continuous(label=scales::comma)
The response variable is unbalanced. There are more than twice as many patients who did not experience a death event.
Let’s check the distributions of the predictor variables.
First, let’s print the summary statistics for each continuous numeric variable.
disease_data |>
select(age, creatinine_phosphokinase, ejection_fraction, platelets, serum_creatinine, serum_sodium, time) |>
summary()
## age creatinine_phosphokinase ejection_fraction platelets
## Min. :40.00 Min. : 23.0 Min. :14.00 Min. : 25100
## 1st Qu.:50.00 1st Qu.: 121.0 1st Qu.:30.00 1st Qu.:215000
## Median :60.00 Median : 248.0 Median :38.00 Median :263358
## Mean :60.29 Mean : 586.8 Mean :37.73 Mean :265075
## 3rd Qu.:68.00 3rd Qu.: 582.0 3rd Qu.:45.00 3rd Qu.:310000
## Max. :95.00 Max. :7861.0 Max. :80.00 Max. :850000
## serum_creatinine serum_sodium time
## Min. :0.500 Min. :113.0 Min. : 4.0
## 1st Qu.:0.900 1st Qu.:134.0 1st Qu.: 74.0
## Median :1.100 Median :137.0 Median :113.0
## Mean :1.369 Mean :136.8 Mean :130.7
## 3rd Qu.:1.400 3rd Qu.:140.0 3rd Qu.:201.0
## Max. :9.400 Max. :148.0 Max. :285.0
Now, let’c check the breakdown based on the response variable.
disease_data |>
mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
ggplot(aes(x = age, fill=DEATH_EVENT)) +
geom_boxplot() +
labs(title = "Distribution of Age by Response Variable", x = "Age")
Those who experienced a death event are generally older than those who did not.
disease_data |>
mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
ggplot(aes(x = creatinine_phosphokinase, fill=DEATH_EVENT)) +
geom_boxplot() +
labs(title = "Distribution of CPK by Response Variable", x = "CPK Level")
The variable for CPK levels has many outliers and is skewed. However, the distributions for those who experienced a death event and those who didn’t are relatively the same.
disease_data |>
mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
ggplot(aes(x = ejection_fraction, fill=DEATH_EVENT)) +
geom_boxplot() +
labs(title = "Distribution of Ejection Fraction by Response Variable", x = "Ejection Fraction")
There are a few outliers for both distributions. Those who experienced a death event had much lower ejection fractions than those who didn’t.
disease_data |>
mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
ggplot(aes(x = platelets, fill=DEATH_EVENT)) +
geom_boxplot() +
labs(title = "Distribution of Platelet Count by Response Variable", x = "Platelets")
There are a number of outliers for both distributions. The spread for those who experienced a death event is wider than the spread of those who didn’t but the median is the same.
disease_data |>
mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
ggplot(aes(x = serum_creatinine, fill=DEATH_EVENT)) +
geom_boxplot() +
labs(title = "Distribution of Serum Creatinine by Response Variable", x = "Serum Creatinine")
There are a number of outliers for both. There is a larger spread for those who experienced a death event. There is a higher level of serum creatinine for those who experienced a death event.
disease_data |>
mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
ggplot(aes(x = serum_sodium, fill=DEATH_EVENT)) +
geom_boxplot() +
labs(title = "Distribution of Serum Sodium by Response Variable", x = "Serum Sodium")
There are a few outliers. There is a lower level of serum sodium for those who did not experience a death event.
disease_data |>
mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
ggplot(aes(x = time, fill=DEATH_EVENT)) +
geom_boxplot() +
labs(title = "Distribution of Time by Response Variable", x = "Time")
There is a much shorter window from the first event to the follow up period for those who experienced a death event (i.e. shorter time window).
plot1 <- disease_data |>
mutate(anaemia = ifelse(anaemia == 1, "Yes", "No")) |>
ggplot(aes(x = anaemia)) +
geom_bar() +
theme(legend.position = "none") +
labs(y = "Count", x = "Anaemia Indicator")
plot2 <- disease_data |>
mutate(anaemia = ifelse(anaemia == 1, "Yes", "No"),
DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
ggplot(aes(x = anaemia)) +
geom_bar(aes(fill = DEATH_EVENT), position="dodge") +
theme(legend.position = "bottom") +
labs(y = "Count", x = "Anaemia Indicator")
title <- ggdraw() +
draw_label(
"Death Event by Anaemia",
fontface = 'bold',
x = 0,
hjust = -0.05
)
plot_grid(title, plot1, plot2, ncol=1, rel_heights = c(0.15, 1, 1))
plot1 <- disease_data |>
mutate(diabetes = ifelse(diabetes == 1, "Yes", "No")) |>
ggplot(aes(x = diabetes)) +
geom_bar() +
theme(legend.position = "none") +
labs(y = "Count", x = "Diabetes Indicator")
plot2 <- disease_data |>
mutate(diabetes = ifelse(diabetes == 1, "Yes", "No"),
DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
ggplot(aes(x = diabetes)) +
geom_bar(aes(fill = DEATH_EVENT), position="dodge") +
theme(legend.position = "bottom") +
labs(y = "Count", x = "Diabetes Indicator")
title <- ggdraw() +
draw_label(
"Death Event by Diabetes",
fontface = 'bold',
x = 0,
hjust = -0.05
)
plot_grid(title, plot1, plot2, ncol=1, rel_heights = c(0.15, 1, 1))
plot1 <- disease_data |>
mutate(high_blood_pressure = ifelse(high_blood_pressure == 1, "Yes", "No")) |>
ggplot(aes(x = high_blood_pressure)) +
geom_bar() +
theme(legend.position = "none") +
labs(y = "Count", x = "High Blood Pressure Indicator")
plot2 <- disease_data |>
mutate(high_blood_pressure = ifelse(high_blood_pressure == 1, "Yes", "No"),
DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
ggplot(aes(x = high_blood_pressure)) +
geom_bar(aes(fill = DEATH_EVENT), position="dodge") +
theme(legend.position = "bottom") +
labs(y = "Count", x = "High Blook Pressure Indicator")
title <- ggdraw() +
draw_label(
"Death Event by High Blood Pressure",
fontface = 'bold',
x = 0,
hjust = -0.05
)
plot_grid(title, plot1, plot2, ncol=1, rel_heights = c(0.15, 1, 1))
plot1 <- disease_data |>
mutate(smoking = ifelse(smoking == 1, "Yes", "No")) |>
ggplot(aes(x = smoking)) +
geom_bar() +
theme(legend.position = "none") +
labs(y = "Count", x = "Smoking Indicator")
plot2 <- disease_data |>
mutate(smoking = ifelse(smoking == 1, "Yes", "No"),
DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
ggplot(aes(x = smoking)) +
geom_bar(aes(fill = DEATH_EVENT), position="dodge") +
theme(legend.position = "bottom") +
labs(y = "Count", x = "Smoking Indicator")
title <- ggdraw() +
draw_label(
"Death Event by Smoking",
fontface = 'bold',
x = 0,
hjust = -0.05
)
plot_grid(title, plot1, plot2, ncol=1, rel_heights = c(0.15, 1, 1))
plot1 <- disease_data |>
mutate(sex = ifelse(sex == 1, "Male", "Female")) |>
ggplot(aes(x = sex)) +
geom_bar() +
theme(legend.position = "none") +
labs(y = "Count", x = "Sex Indicator")
plot2 <- disease_data |>
mutate(sex = ifelse(sex == 1, "Male", "Female"),
DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
ggplot(aes(x = sex)) +
geom_bar(aes(fill = DEATH_EVENT), position="dodge") +
theme(legend.position = "bottom") +
labs(y = "Count", x = "Sex Indicator")
title <- ggdraw() +
draw_label(
"Death Event by Sex",
fontface = 'bold',
x = 0,
hjust = -0.05
)
plot_grid(title, plot1, plot2, ncol=1, rel_heights = c(0.15, 1, 1))
None of the categorical variables seem to have much of an influence on whether or not the individual experienced a death event.
disease_data |>
cor() |>
corrplot(method="color",
diag=FALSE,
type="lower",
addCoef.col = "black",
number.cex=0.35,
tl.cex=0.5)
The variables that are most correlated with DEATH_EVENT
are age, ejection_fraction,
serum_creatinine, serum_sodium, and
time.
First, code the target variable as a factor:
disease_data$DEATH_EVENT <- as.factor(disease_data$DEATH_EVENT)
Next, we will split the data for training and testing:
set.seed(613)
# Sample row indices for the training set
train_index <- sample(nrow(disease_data), 0.8 * nrow(disease_data))
# Create training and testing sets
train <- disease_data[train_index, ]
test <- disease_data[-train_index, ]
We can now train our models.
We will start by training a Neural Network model with cross-fold validation of k=10.
set.seed(613)
nnet_model <- train(DEATH_EVENT ~ age + ejection_fraction + serum_creatinine + serum_sodium + time,
data = train,
method = "nnet",
trControl = trainControl(method = "cv", number = 10),
trace=0)
nnet_model
## Neural Network
##
## 4000 samples
## 5 predictor
## 2 classes: '0', '1'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 3599, 3599, 3601, 3601, 3600, 3601, ...
## Resampling results across tuning parameters:
##
## size decay Accuracy Kappa
## 1 0e+00 0.7210439 0.1312748
## 1 1e-04 0.7410872 0.2041227
## 1 1e-01 0.8457825 0.5948862
## 3 0e+00 0.7697829 0.3165491
## 3 1e-04 0.7750241 0.3326443
## 3 1e-01 0.8799914 0.7147496
## 5 0e+00 0.8487213 0.6049868
## 5 1e-04 0.8416725 0.5810650
## 5 1e-01 0.8767308 0.7153841
##
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were size = 3 and decay = 0.1.
At size = 3 and decay = 0.1, the accuracy of the model is about 88%.
nnet_pred <- predict(nnet_model, test)
# confusion matrix
nnet_confusion_matrix <- confusionMatrix(nnet_pred, test$DEATH_EVENT)
nnet_conf_matrix_df <- as.data.frame(nnet_confusion_matrix$table)
nnet_conf_matrix_df$Prediction <- factor(nnet_conf_matrix_df$Prediction, levels = c(1, 0))
nnet_conf_matrix_df$Reference <- factor(nnet_conf_matrix_df$Reference, levels = c(0, 1))
nnet_conf_matrix_df |>
ggplot(aes(x = Prediction, y = as.factor(Reference))) +
geom_tile(aes(fill = Freq), color = "white") +
scale_fill_gradient(low = "white", high = "palegreen3") +
labs(title = "NNET Model",
x = "Predicted",
y = "Actual") +
geom_text(aes(label = sprintf("%1.0f", Freq)), vjust = 1) +
theme_bw() +
theme(legend.position = "none")
When we apply the Neural Network model to the test data, there are 267 instances of true positives and 635 instances of true negatives. The model has a few more instances of false positives (53) than false negatives (45).
We will now build a Random Forest model, also using cross-fold validation.
set.seed(613)
rf_model <- train(DEATH_EVENT ~ age + ejection_fraction + serum_creatinine + serum_sodium + time,
data = train,
method = "rf",
trControl = trainControl(method = "cv", number = 10))
rf_model
## Random Forest
##
## 4000 samples
## 5 predictor
## 2 classes: '0', '1'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 3599, 3599, 3601, 3601, 3600, 3601, ...
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.9924968 0.9825220
## 3 0.9919968 0.9813709
## 5 0.9914975 0.9802279
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
The best Random Forest model has an accuracy of 99%.
rf_pred <- predict(rf_model, test)
# confusion matrix
rf_confusion_matrix <- confusionMatrix(rf_pred, test$DEATH_EVENT)
rf_conf_matrix_df <- as.data.frame(rf_confusion_matrix$table)
rf_conf_matrix_df$Prediction <- factor(rf_conf_matrix_df$Prediction, levels = c(1, 0))
rf_conf_matrix_df$Reference <- factor(rf_conf_matrix_df$Reference, levels = c(0, 1))
rf_conf_matrix_df |>
ggplot(aes(x = Prediction, y = as.factor(Reference))) +
geom_tile(aes(fill = Freq), color = "white") +
scale_fill_gradient(low = "white", high = "palegreen3") +
labs(title = "RF Model",
x = "Predicted",
y = "Actual") +
geom_text(aes(label = sprintf("%1.0f", Freq)), vjust = 1) +
theme_bw() +
theme(legend.position = "none")
When we apply the Random Forest model to the testing data, there are only 6 false negative and 1 false positive. The Random Forest model wholly outperforms the Neural Network model.
Let’s compare the accuracy metrics of these two models.
keep <- c("Balanced Accuracy", "F1", "Specificity", "Precision", "Recall")
# random forest model
rf_mod_metrics <- data.frame("RandomForest" = rf_confusion_matrix$byClass)
rf_mod_metrics$metric <- rownames(rf_mod_metrics)
rf_mod_metrics <- rf_mod_metrics |>
pivot_wider(names_from = metric,
values_from = c("RandomForest")) |>
dplyr::select(all_of(keep))
# nnet model
nnet_mod_metrics <- data.frame("NeuralNetwork" = nnet_confusion_matrix$byClass)
nnet_mod_metrics$metric <- rownames(nnet_mod_metrics)
nnet_mod_metrics <- nnet_mod_metrics |>
pivot_wider(names_from = metric,
values_from = c("NeuralNetwork")) |>
dplyr::select(all_of(keep))
# combine
metrics <- data.frame(rbind(rf_mod_metrics, nnet_mod_metrics))
rownames(metrics) <- c("RandomForest", "NeuralNetwork")
metrics |>
knitr::kable()
| Balanced.Accuracy | F1 | Specificity | Precision | Recall | |
|---|---|---|---|---|---|
| RandomForest | 0.9912604 | 0.9956522 | 0.9839744 | 0.9927746 | 0.9985465 |
| NeuralNetwork | 0.8893672 | 0.9283626 | 0.8557692 | 0.9338235 | 0.9229651 |
The Random Forest model has better accuracy, F1, specificity,
precision, and recall than the Neural Network model. This is the
appropriate model to use for prediction of DEATH_EVENT from
the available variables.
Let’s see which are the most important variables for this model.
plot(varImp(rf_model))
The most important variable in this model is time and
the least important variable is serum_sodium.
Recall that the purpose of this analysis is to predict the likelihood
of survival of heart failure patients at their initial admittance into
the hospital. According to the data description, time is
measured in days and it is the length of time from the initial encounter
until the follow-up period. While this is very important feature to the
model, it may not be as helpful for hospitals which are seeking to
identify at-risk patients at the time of the initial admittance into the
hospital, as it is not something that can be measured from the intial
encounter date.
Let’s see how the model performs when we remove the time
and serum_sodium variables.
set.seed(613)
rf_model <- train(DEATH_EVENT ~ age + ejection_fraction + serum_creatinine,
data = train,
method = "rf",
trControl = trainControl(method = "cv", number = 10))
## note: only 2 unique complexity parameters in default grid. Truncating the grid to 2 .
rf_model
## Random Forest
##
## 4000 samples
## 3 predictor
## 2 classes: '0', '1'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 3599, 3599, 3601, 3601, 3600, 3601, ...
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.9794899 0.9517692
## 3 0.9787405 0.9499829
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
The new Random Forest model has an accuracy of almost 98%.
rf_pred <- predict(rf_model, test)
# confusion matrix
rf_confusion_matrix <- confusionMatrix(rf_pred, test$DEATH_EVENT)
rf_conf_matrix_df <- as.data.frame(rf_confusion_matrix$table)
rf_conf_matrix_df$Prediction <- factor(rf_conf_matrix_df$Prediction, levels = c(1, 0))
rf_conf_matrix_df$Reference <- factor(rf_conf_matrix_df$Reference, levels = c(0, 1))
rf_conf_matrix_df |>
ggplot(aes(x = Prediction, y = as.factor(Reference))) +
geom_tile(aes(fill = Freq), color = "white") +
scale_fill_gradient(low = "white", high = "palegreen3") +
labs(title = "RF Model: Age, Ejection Fraction, Serum Creatinine",
x = "Predicted",
y = "Actual") +
geom_text(aes(label = sprintf("%1.0f", Freq)), vjust = 1) +
theme_bw() +
theme(legend.position = "none")
When applied to the testing data, this model has a few more false negatives and false positives than the previous Random Forest model. However, it still outperforms the Neural Network model.
# random forest model
rf_mod_metrics <- data.frame("RandomForest" = rf_confusion_matrix$byClass)
rf_mod_metrics$metric <- rownames(rf_mod_metrics)
rf_mod_metrics <- rf_mod_metrics |>
pivot_wider(names_from = metric,
values_from = c("RandomForest")) |>
dplyr::select(all_of(keep))
rf_mod_metrics |>
knitr::kable()
| Balanced Accuracy | F1 | Specificity | Precision | Recall |
|---|---|---|---|---|
| 0.9675201 | 0.9834413 | 0.9423077 | 0.9743224 | 0.9927326 |
The model has a 97% accuracy performance on the testing data.
The goal of this analysis was to predict the survival of heart failure patients based on the available recorded features.
The recorded clinical features included in the data include the age and sex of the patient, whether or not the patient is a smoker, diabetic, or anaemic, measured blood levels, and the length of time from the initial encounter until the follow-up period.
The response variable was DEATH_EVENT, which is an
indicator of whether or not the patient died at the follow-up event.
The predictor variables were analyzed to determine their relationships with the response variable.
The age of each patient was recorded as a continuous variable,
age. The youngest age of patients was 40 and the oldest was
95. The median age over the entire dataset is 40. Based on the
distribution of ages by the response variable, older patients tend to
have less of a chance of survival.
The gender of each patient was recorded as a categorical variable,
sex. There were twice as many males as females in the
dataset. However, the breakdown of those who survived and those who did
not was similar.
Other binary features included, anaemia,
diabetes, high_blood_pressure, and
smoking, indicating whether the patient had anaemia,
diabetes, hypertension, or smoked. Like with gender, there did not seem
to be such a difference between these groups.
Other than age, the continuous numeric features were
creatine_phosphokinase, ejection_fraction,
platelets, serum_creatinine,
serum_sodium, and time.
creatinine_phosphokinase is a measure of the CPK levels in
the blood. Based on the boxplots of this variable split by the response
variable, there did not seem to be a difference in the levels for
patients who died and survived. ejection_fraction is a
measure of the percentage of blood leaving the heart at each
contraction. This was much lower for patients who died.
platelets is a measure of the platelet levels in the blood.
This did not seem to differ between groups.
serum_creatinine is a measure of the creatinine levels in
the blood. This was elevated for patients who died.
serum_sodium is the measure of sodium levels in the blood.
This was decreased for those who died. time was a measure
of the days from the initial encounter to the follow-up period. This was
much lower for those who died.
The correlation plot confirmed the relationships between the response
variable and the predictors. age,
ejection_fraction, serum_creatinine,
serum_sodium, and time had the largest
correlations. age and serum_creatinine were
both positively correlated with the response variable.
ejection_fraction, serum_sodium, and
time were all negatively correlated with the response
variable. The remaining variables were barely correlated.
The analysis was conducted using a Neural Network and two Random Forest models. First, the data was split into a training and testing set (80:20 split). A Neural Network model was then trained using the previously mentioned prediction features and cross-fold validation (k=10). The model had an accuracy of 88% and was able to accurately predict 267 positive instances and 635 negative instances. The model inaccurately predicted 53 false positives and 45 false negatives. Next, a Random Forest model was trained using the same variables and cross-fold validation. This model significantly outperformed the Neural Network model, with an accuracy of 99%. The model had a prediction accuracy of 98% and was able to accurately predict 307 true positives and 687 true negatives for the testing set. The model only had 1 instance of false positive and 5 instances of false negative.
The most important features for the Random Forest model were
time, serum_creatinine,
ejection_fraction, and age.
serum_sodium was unimportant to the model.
Recalling that the goal of this analysis was to predict the
likelihood of survival of patients from the time of the initial
encounter, time is removed from the model, as it is not a
feature that can be measured at the time of admittance into the
hospital. serum_sodium was also removed as it was
unimportant to the model.
The resulting model has a 98% accuracy. The model also has about 97% accuracy on the testing set, with 294 accurately predicted true positives and 683 accurately predicted true negatives. The model has more than thrice the amount of false negatives (18) as it does false positives (5). However, the model is extremely accurate using only very few variables.
Based on the results of the analysis, hospitals could very accurately predict a patient’s likelihood of survival based on their age, creatinine levels, and ejection fraction. This could help hospitals to prioritize patients to treat. Based on likelihood of death, hospitals could either identify patients who are more likely of dying and therefore requiring more emergent attention, or hospitals could also determine those more likely to survive with treatment (i.e. prioritize a patient who is more likely to survive over one who is likely to die anyway).
Kaggle Heart Failure Dataset: https://www.kaggle.com/datasets/aadarshvelu/heart-failure-prediction-clinical-records