Load Data

The data is taken from Kaggle and can be accessed here. The data contains medical records of 5,000 patients who had heart failure, collected during their follow-up period.

Each patient profile contains information for the following clinical features:

Variable Description
age age of the patient (years)
anaemia decrease of red blood cells or hemoglobin (boolean)
creatinine_phosphokinase level of the CPK enzyme in the blood (mcg/L)
diabetes if the patient has diabetes (boolean)
ejection_fraction percentage of blood leaving the heart at each contraction (percentage)
high_blood_pressure if the patient has hypertension (boolean)
platelets platelets in the blood (kiloplatelets/mL)
sex woman or man (binary)
serum_creatinine level of serum creatinine in the blood (mg/dL)
serum_sodium level of serum sodium in the blood (mEq/L)
smoking if the patient smokes or not (boolean)
time follow-up period (days)
DEATH_EVENT if the patient died during the follow-up period (boolean)
disease_data <- read.csv("https://raw.githubusercontent.com/ShanaFarber/cuny-sps/master/DATA_622/data/heart_failure_clinical_records.csv")

glimpse(disease_data)
## Rows: 5,000
## Columns: 13
## $ age                      <dbl> 55.000, 65.000, 45.000, 60.000, 95.000, 70.00…
## $ anaemia                  <int> 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, …
## $ creatinine_phosphokinase <int> 748, 56, 582, 754, 582, 232, 122, 171, 482, 4…
## $ diabetes                 <int> 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, …
## $ ejection_fraction        <int> 45, 25, 38, 40, 30, 30, 60, 50, 30, 45, 40, 2…
## $ high_blood_pressure      <int> 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, …
## $ platelets                <dbl> 263358, 305000, 319000, 328000, 461000, 30200…
## $ serum_creatinine         <dbl> 1.3, 5.0, 0.9, 1.2, 2.0, 1.2, 1.2, 0.9, 0.9, …
## $ serum_sodium             <int> 137, 130, 140, 126, 132, 132, 145, 141, 132, …
## $ sex                      <int> 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, …
## $ smoking                  <int> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, …
## $ time                     <int> 88, 207, 244, 90, 50, 210, 147, 196, 109, 215…
## $ DEATH_EVENT              <int> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …

The data has 5,000 rows and 13 variables. All of the variables are numeric.

Some variables are continuous numeric variables while others are binary variables indicating a categorical feature.

Goal of Analysis

The goal of this analysis is to predict whether the patient died at the follow-up period based on the recorded attributes of each patient. This analysis would be helpful for hospitals to be able to predict the likelihood of survival of heart failure patients. In doing this, hospitals would be able to identify patients who are in the most need of attention. Or, in the case where there may be multiple patients with heart failure and not enough resources to treat both, hospitals could be able to identify patients with the highest likelihood of survival to be able to treat them.

The analysis will focus on using Random Forest and Neural Network modeling to predict the instance of death based on the available recorded features for patients who had heart failure.

Exploratory Analysis

First, we will check if the data is missing any values.

plot_missing(disease_data)

None of the values are missing from the dataset.

Are there any near zero variance predictors?

nearZeroVar(disease_data |> select(age, creatinine_phosphokinase, ejection_fraction, platelets, serum_creatinine, serum_sodium, time), saveMetrics = TRUE) |>
  knitr::kable()
freqRatio percentUnique zeroVar nzv
age 1.171371 0.96 FALSE FALSE
creatinine_phosphokinase 10.534247 5.80 FALSE FALSE
ejection_fraction 1.087447 0.34 FALSE FALSE
platelets 4.817073 4.06 FALSE FALSE
serum_creatinine 1.437607 0.86 FALSE FALSE
serum_sodium 1.027244 0.54 FALSE FALSE
time 1.122807 3.10 FALSE FALSE

None of the predictors have zero or near zero variance.

Let’s visualize the distribution of the response variable.

# calculate counts and percents
counts <- disease_data |>
  count(DEATH_EVENT) |>
  mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  mutate(total = nrow(disease_data),
         perc = round(n / total * 100, 0))

# plot
counts %>%
  ggplot(aes(x = DEATH_EVENT, y = n)) +
  geom_bar(aes(fill = DEATH_EVENT), stat="identity") +
  geom_text(aes(label = paste0(perc, '%')), vjust = 2, color = "white", fontface = 'bold') +
  theme(legend.position = "none") +
  labs(title = "Distribution of Response Variable", x = "Death Event", y = "Count") +
  scale_y_continuous(label=scales::comma) 

The response variable is unbalanced. There are more than twice as many patients who did not experience a death event.

Let’s check the distributions of the predictor variables.

Continuous Variables

First, let’s print the summary statistics for each continuous numeric variable.

disease_data |>
  select(age, creatinine_phosphokinase, ejection_fraction, platelets, serum_creatinine, serum_sodium, time) |>
  summary()
##       age        creatinine_phosphokinase ejection_fraction   platelets     
##  Min.   :40.00   Min.   :  23.0           Min.   :14.00     Min.   : 25100  
##  1st Qu.:50.00   1st Qu.: 121.0           1st Qu.:30.00     1st Qu.:215000  
##  Median :60.00   Median : 248.0           Median :38.00     Median :263358  
##  Mean   :60.29   Mean   : 586.8           Mean   :37.73     Mean   :265075  
##  3rd Qu.:68.00   3rd Qu.: 582.0           3rd Qu.:45.00     3rd Qu.:310000  
##  Max.   :95.00   Max.   :7861.0           Max.   :80.00     Max.   :850000  
##  serum_creatinine  serum_sodium        time      
##  Min.   :0.500    Min.   :113.0   Min.   :  4.0  
##  1st Qu.:0.900    1st Qu.:134.0   1st Qu.: 74.0  
##  Median :1.100    Median :137.0   Median :113.0  
##  Mean   :1.369    Mean   :136.8   Mean   :130.7  
##  3rd Qu.:1.400    3rd Qu.:140.0   3rd Qu.:201.0  
##  Max.   :9.400    Max.   :148.0   Max.   :285.0

Now, let’c check the breakdown based on the response variable.

Age
disease_data |>
  mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  ggplot(aes(x = age, fill=DEATH_EVENT)) +
  geom_boxplot() +
  labs(title = "Distribution of Age by Response Variable", x = "Age")

Those who experienced a death event are generally older than those who did not.

CPK
disease_data |>
  mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  ggplot(aes(x = creatinine_phosphokinase, fill=DEATH_EVENT)) +
  geom_boxplot() +
  labs(title = "Distribution of CPK by Response Variable", x = "CPK Level")

The variable for CPK levels has many outliers and is skewed. However, the distributions for those who experienced a death event and those who didn’t are relatively the same.

Ejection Fraction
disease_data |>
  mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  ggplot(aes(x = ejection_fraction, fill=DEATH_EVENT)) +
  geom_boxplot() +
  labs(title = "Distribution of Ejection Fraction by Response Variable", x = "Ejection Fraction")

There are a few outliers for both distributions. Those who experienced a death event had much lower ejection fractions than those who didn’t.

Platelets
disease_data |>
  mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  ggplot(aes(x = platelets, fill=DEATH_EVENT)) +
  geom_boxplot() +
  labs(title = "Distribution of Platelet Count by Response Variable", x = "Platelets")

There are a number of outliers for both distributions. The spread for those who experienced a death event is wider than the spread of those who didn’t but the median is the same.

Serum Creatinine
disease_data |>
  mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  ggplot(aes(x = serum_creatinine, fill=DEATH_EVENT)) +
  geom_boxplot() +
  labs(title = "Distribution of Serum Creatinine by Response Variable", x = "Serum Creatinine")

There are a number of outliers for both. There is a larger spread for those who experienced a death event. There is a higher level of serum creatinine for those who experienced a death event.

Serum Sodium
disease_data |>
  mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  ggplot(aes(x = serum_sodium, fill=DEATH_EVENT)) +
  geom_boxplot() +
  labs(title = "Distribution of Serum Sodium by Response Variable", x = "Serum Sodium")

There are a few outliers. There is a lower level of serum sodium for those who did not experience a death event.

Time
disease_data |>
  mutate(DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  ggplot(aes(x = time, fill=DEATH_EVENT)) +
  geom_boxplot() +
  labs(title = "Distribution of Time by Response Variable", x = "Time")

There is a much shorter window from the first event to the follow up period for those who experienced a death event (i.e. shorter time window).

Categorical Variables

Anaemia
plot1 <- disease_data |>
  mutate(anaemia = ifelse(anaemia == 1, "Yes", "No")) |>
  ggplot(aes(x = anaemia)) +
  geom_bar() +
  theme(legend.position = "none") +
  labs(y = "Count", x = "Anaemia Indicator")

plot2 <- disease_data |>
  mutate(anaemia = ifelse(anaemia == 1, "Yes", "No"),
         DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  ggplot(aes(x = anaemia)) +
  geom_bar(aes(fill = DEATH_EVENT), position="dodge") +
  theme(legend.position = "bottom") +
  labs(y = "Count", x = "Anaemia Indicator")

title <- ggdraw() + 
  draw_label(
    "Death Event by Anaemia",
    fontface = 'bold',
    x = 0,
    hjust = -0.05
  )

plot_grid(title, plot1, plot2, ncol=1, rel_heights = c(0.15, 1, 1))

Diabetes
plot1 <- disease_data |>
  mutate(diabetes = ifelse(diabetes == 1, "Yes", "No")) |>
  ggplot(aes(x = diabetes)) +
  geom_bar() +
  theme(legend.position = "none") +
  labs(y = "Count", x = "Diabetes Indicator")

plot2 <- disease_data |>
  mutate(diabetes = ifelse(diabetes == 1, "Yes", "No"),
         DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  ggplot(aes(x = diabetes)) +
  geom_bar(aes(fill = DEATH_EVENT), position="dodge") +
  theme(legend.position = "bottom") +
  labs(y = "Count", x = "Diabetes Indicator")

title <- ggdraw() + 
  draw_label(
    "Death Event by Diabetes",
    fontface = 'bold',
    x = 0,
    hjust = -0.05
  )

plot_grid(title, plot1, plot2, ncol=1, rel_heights = c(0.15, 1, 1))

Smoking
plot1 <- disease_data |>
  mutate(high_blood_pressure = ifelse(high_blood_pressure == 1, "Yes", "No")) |>
  ggplot(aes(x = high_blood_pressure)) +
  geom_bar() +
  theme(legend.position = "none") +
  labs(y = "Count", x = "High Blood Pressure Indicator")

plot2 <- disease_data |>
  mutate(high_blood_pressure = ifelse(high_blood_pressure == 1, "Yes", "No"),
         DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  ggplot(aes(x = high_blood_pressure)) +
  geom_bar(aes(fill = DEATH_EVENT), position="dodge") +
  theme(legend.position = "bottom") +
  labs(y = "Count", x = "High Blook Pressure Indicator")

title <- ggdraw() + 
  draw_label(
    "Death Event by High Blood Pressure",
    fontface = 'bold',
    x = 0,
    hjust = -0.05
  )

plot_grid(title, plot1, plot2, ncol=1, rel_heights = c(0.15, 1, 1))

Smoking
plot1 <- disease_data |>
  mutate(smoking = ifelse(smoking == 1, "Yes", "No")) |>
  ggplot(aes(x = smoking)) +
  geom_bar() +
  theme(legend.position = "none") +
  labs(y = "Count", x = "Smoking Indicator")

plot2 <- disease_data |>
  mutate(smoking = ifelse(smoking == 1, "Yes", "No"),
         DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  ggplot(aes(x = smoking)) +
  geom_bar(aes(fill = DEATH_EVENT), position="dodge") +
  theme(legend.position = "bottom") +
  labs(y = "Count", x = "Smoking Indicator")

title <- ggdraw() + 
  draw_label(
    "Death Event by Smoking",
    fontface = 'bold',
    x = 0,
    hjust = -0.05
  )

plot_grid(title, plot1, plot2, ncol=1, rel_heights = c(0.15, 1, 1))

Sex
plot1 <- disease_data |>
  mutate(sex = ifelse(sex == 1, "Male", "Female")) |>
  ggplot(aes(x = sex)) +
  geom_bar() +
  theme(legend.position = "none") +
  labs(y = "Count", x = "Sex Indicator")

plot2 <- disease_data |>
  mutate(sex = ifelse(sex == 1, "Male", "Female"),
         DEATH_EVENT = ifelse(DEATH_EVENT == 1, "Yes", "No")) |>
  ggplot(aes(x = sex)) +
  geom_bar(aes(fill = DEATH_EVENT), position="dodge") +
  theme(legend.position = "bottom") +
  labs(y = "Count", x = "Sex Indicator")

title <- ggdraw() + 
  draw_label(
    "Death Event by Sex",
    fontface = 'bold',
    x = 0,
    hjust = -0.05
  )

plot_grid(title, plot1, plot2, ncol=1, rel_heights = c(0.15, 1, 1))

None of the categorical variables seem to have much of an influence on whether or not the individual experienced a death event.

Correlation

disease_data |>
  cor() |>
  corrplot(method="color", 
           diag=FALSE,
           type="lower",
           addCoef.col = "black",
           number.cex=0.35,
           tl.cex=0.5)

The variables that are most correlated with DEATH_EVENT are age, ejection_fraction, serum_creatinine, serum_sodium, and time.

Data Preparation

First, code the target variable as a factor:

disease_data$DEATH_EVENT <- as.factor(disease_data$DEATH_EVENT)

Next, we will split the data for training and testing:

set.seed(613)

# Sample row indices for the training set
train_index <- sample(nrow(disease_data), 0.8 * nrow(disease_data))

# Create training and testing sets
train <- disease_data[train_index, ]
test <- disease_data[-train_index, ]

We can now train our models.

Neural Network Model

We will start by training a Neural Network model with cross-fold validation of k=10.

set.seed(613)
nnet_model <- train(DEATH_EVENT ~ age + ejection_fraction + serum_creatinine + serum_sodium + time, 
                    data = train, 
                    method = "nnet", 
                    trControl = trainControl(method = "cv", number = 10),
                    trace=0)

nnet_model
## Neural Network 
## 
## 4000 samples
##    5 predictor
##    2 classes: '0', '1' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 3599, 3599, 3601, 3601, 3600, 3601, ... 
## Resampling results across tuning parameters:
## 
##   size  decay  Accuracy   Kappa    
##   1     0e+00  0.7210439  0.1312748
##   1     1e-04  0.7410872  0.2041227
##   1     1e-01  0.8457825  0.5948862
##   3     0e+00  0.7697829  0.3165491
##   3     1e-04  0.7750241  0.3326443
##   3     1e-01  0.8799914  0.7147496
##   5     0e+00  0.8487213  0.6049868
##   5     1e-04  0.8416725  0.5810650
##   5     1e-01  0.8767308  0.7153841
## 
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were size = 3 and decay = 0.1.

At size = 3 and decay = 0.1, the accuracy of the model is about 88%.

nnet_pred <- predict(nnet_model, test)

# confusion matrix
nnet_confusion_matrix <- confusionMatrix(nnet_pred, test$DEATH_EVENT)

nnet_conf_matrix_df <- as.data.frame(nnet_confusion_matrix$table) 
nnet_conf_matrix_df$Prediction <- factor(nnet_conf_matrix_df$Prediction, levels = c(1, 0))
nnet_conf_matrix_df$Reference <- factor(nnet_conf_matrix_df$Reference, levels = c(0, 1))
  
nnet_conf_matrix_df |>
  ggplot(aes(x = Prediction, y = as.factor(Reference))) +
  geom_tile(aes(fill = Freq), color = "white") +
  scale_fill_gradient(low = "white", high = "palegreen3") +
  labs(title = "NNET Model",
       x = "Predicted",
       y = "Actual") +
  geom_text(aes(label = sprintf("%1.0f", Freq)), vjust = 1) +
  theme_bw() +
  theme(legend.position = "none")

When we apply the Neural Network model to the test data, there are 267 instances of true positives and 635 instances of true negatives. The model has a few more instances of false positives (53) than false negatives (45).

Random Forest Model

We will now build a Random Forest model, also using cross-fold validation.

set.seed(613)
rf_model <- train(DEATH_EVENT ~ age + ejection_fraction + serum_creatinine + serum_sodium + time, 
                  data = train, 
                  method = "rf", 
                  trControl = trainControl(method = "cv", number = 10))

rf_model
## Random Forest 
## 
## 4000 samples
##    5 predictor
##    2 classes: '0', '1' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 3599, 3599, 3601, 3601, 3600, 3601, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##   2     0.9924968  0.9825220
##   3     0.9919968  0.9813709
##   5     0.9914975  0.9802279
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.

The best Random Forest model has an accuracy of 99%.

rf_pred <- predict(rf_model, test)

# confusion matrix
rf_confusion_matrix <- confusionMatrix(rf_pred, test$DEATH_EVENT)

rf_conf_matrix_df <- as.data.frame(rf_confusion_matrix$table) 
rf_conf_matrix_df$Prediction <- factor(rf_conf_matrix_df$Prediction, levels = c(1, 0))
rf_conf_matrix_df$Reference <- factor(rf_conf_matrix_df$Reference, levels = c(0, 1))
  
rf_conf_matrix_df |>
  ggplot(aes(x = Prediction, y = as.factor(Reference))) +
  geom_tile(aes(fill = Freq), color = "white") +
  scale_fill_gradient(low = "white", high = "palegreen3") +
  labs(title = "RF Model",
       x = "Predicted",
       y = "Actual") +
  geom_text(aes(label = sprintf("%1.0f", Freq)), vjust = 1) +
  theme_bw() +
  theme(legend.position = "none")

When we apply the Random Forest model to the testing data, there are only 6 false negative and 1 false positive. The Random Forest model wholly outperforms the Neural Network model.

Let’s compare the accuracy metrics of these two models.

keep <- c("Balanced Accuracy", "F1", "Specificity", "Precision", "Recall")

# random forest model
rf_mod_metrics <- data.frame("RandomForest" = rf_confusion_matrix$byClass)

rf_mod_metrics$metric <- rownames(rf_mod_metrics)

rf_mod_metrics <- rf_mod_metrics |>
  pivot_wider(names_from = metric,
              values_from = c("RandomForest")) |>
  dplyr::select(all_of(keep))

# nnet model
nnet_mod_metrics <- data.frame("NeuralNetwork" = nnet_confusion_matrix$byClass)

nnet_mod_metrics$metric <- rownames(nnet_mod_metrics)

nnet_mod_metrics <- nnet_mod_metrics |>
  pivot_wider(names_from = metric,
              values_from = c("NeuralNetwork")) |>
  dplyr::select(all_of(keep))

# combine
metrics <- data.frame(rbind(rf_mod_metrics, nnet_mod_metrics))
rownames(metrics) <- c("RandomForest", "NeuralNetwork")

metrics |>
  knitr::kable()
Balanced.Accuracy F1 Specificity Precision Recall
RandomForest 0.9912604 0.9956522 0.9839744 0.9927746 0.9985465
NeuralNetwork 0.8893672 0.9283626 0.8557692 0.9338235 0.9229651

The Random Forest model has better accuracy, F1, specificity, precision, and recall than the Neural Network model. This is the appropriate model to use for prediction of DEATH_EVENT from the available variables.

Let’s see which are the most important variables for this model.

plot(varImp(rf_model))

The most important variable in this model is time and the least important variable is serum_sodium.

Recall that the purpose of this analysis is to predict the likelihood of survival of heart failure patients at their initial admittance into the hospital. According to the data description, time is measured in days and it is the length of time from the initial encounter until the follow-up period. While this is very important feature to the model, it may not be as helpful for hospitals which are seeking to identify at-risk patients at the time of the initial admittance into the hospital, as it is not something that can be measured from the intial encounter date.

Let’s see how the model performs when we remove the time and serum_sodium variables.

set.seed(613)
rf_model <- train(DEATH_EVENT ~ age + ejection_fraction + serum_creatinine, 
                  data = train, 
                  method = "rf", 
                  trControl = trainControl(method = "cv", number = 10))
## note: only 2 unique complexity parameters in default grid. Truncating the grid to 2 .
rf_model
## Random Forest 
## 
## 4000 samples
##    3 predictor
##    2 classes: '0', '1' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 3599, 3599, 3601, 3601, 3600, 3601, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##   2     0.9794899  0.9517692
##   3     0.9787405  0.9499829
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.

The new Random Forest model has an accuracy of almost 98%.

rf_pred <- predict(rf_model, test)

# confusion matrix
rf_confusion_matrix <- confusionMatrix(rf_pred, test$DEATH_EVENT)

rf_conf_matrix_df <- as.data.frame(rf_confusion_matrix$table) 
rf_conf_matrix_df$Prediction <- factor(rf_conf_matrix_df$Prediction, levels = c(1, 0))
rf_conf_matrix_df$Reference <- factor(rf_conf_matrix_df$Reference, levels = c(0, 1))
  
rf_conf_matrix_df |>
  ggplot(aes(x = Prediction, y = as.factor(Reference))) +
  geom_tile(aes(fill = Freq), color = "white") +
  scale_fill_gradient(low = "white", high = "palegreen3") +
  labs(title = "RF Model: Age, Ejection Fraction, Serum Creatinine",
       x = "Predicted",
       y = "Actual") +
  geom_text(aes(label = sprintf("%1.0f", Freq)), vjust = 1) +
  theme_bw() +
  theme(legend.position = "none")

When applied to the testing data, this model has a few more false negatives and false positives than the previous Random Forest model. However, it still outperforms the Neural Network model.

# random forest model
rf_mod_metrics <- data.frame("RandomForest" = rf_confusion_matrix$byClass)

rf_mod_metrics$metric <- rownames(rf_mod_metrics)

rf_mod_metrics <- rf_mod_metrics |>
  pivot_wider(names_from = metric,
              values_from = c("RandomForest")) |>
  dplyr::select(all_of(keep))

rf_mod_metrics |>
  knitr::kable()
Balanced Accuracy F1 Specificity Precision Recall
0.9675201 0.9834413 0.9423077 0.9743224 0.9927326

The model has a 97% accuracy performance on the testing data.

Conclusions

The goal of this analysis was to predict the survival of heart failure patients based on the available recorded features.

The recorded clinical features included in the data include the age and sex of the patient, whether or not the patient is a smoker, diabetic, or anaemic, measured blood levels, and the length of time from the initial encounter until the follow-up period.

The response variable was DEATH_EVENT, which is an indicator of whether or not the patient died at the follow-up event.

The predictor variables were analyzed to determine their relationships with the response variable.

The age of each patient was recorded as a continuous variable, age. The youngest age of patients was 40 and the oldest was 95. The median age over the entire dataset is 40. Based on the distribution of ages by the response variable, older patients tend to have less of a chance of survival.

The gender of each patient was recorded as a categorical variable, sex. There were twice as many males as females in the dataset. However, the breakdown of those who survived and those who did not was similar.

Other binary features included, anaemia, diabetes, high_blood_pressure, and smoking, indicating whether the patient had anaemia, diabetes, hypertension, or smoked. Like with gender, there did not seem to be such a difference between these groups.

Other than age, the continuous numeric features were creatine_phosphokinase, ejection_fraction, platelets, serum_creatinine, serum_sodium, and time. creatinine_phosphokinase is a measure of the CPK levels in the blood. Based on the boxplots of this variable split by the response variable, there did not seem to be a difference in the levels for patients who died and survived. ejection_fraction is a measure of the percentage of blood leaving the heart at each contraction. This was much lower for patients who died. platelets is a measure of the platelet levels in the blood. This did not seem to differ between groups. serum_creatinine is a measure of the creatinine levels in the blood. This was elevated for patients who died. serum_sodium is the measure of sodium levels in the blood. This was decreased for those who died. time was a measure of the days from the initial encounter to the follow-up period. This was much lower for those who died.

The correlation plot confirmed the relationships between the response variable and the predictors. age, ejection_fraction, serum_creatinine, serum_sodium, and time had the largest correlations. age and serum_creatinine were both positively correlated with the response variable. ejection_fraction, serum_sodium, and time were all negatively correlated with the response variable. The remaining variables were barely correlated.

The analysis was conducted using a Neural Network and two Random Forest models. First, the data was split into a training and testing set (80:20 split). A Neural Network model was then trained using the previously mentioned prediction features and cross-fold validation (k=10). The model had an accuracy of 88% and was able to accurately predict 267 positive instances and 635 negative instances. The model inaccurately predicted 53 false positives and 45 false negatives. Next, a Random Forest model was trained using the same variables and cross-fold validation. This model significantly outperformed the Neural Network model, with an accuracy of 99%. The model had a prediction accuracy of 98% and was able to accurately predict 307 true positives and 687 true negatives for the testing set. The model only had 1 instance of false positive and 5 instances of false negative.

The most important features for the Random Forest model were time, serum_creatinine, ejection_fraction, and age. serum_sodium was unimportant to the model.

Recalling that the goal of this analysis was to predict the likelihood of survival of patients from the time of the initial encounter, time is removed from the model, as it is not a feature that can be measured at the time of admittance into the hospital. serum_sodium was also removed as it was unimportant to the model.

The resulting model has a 98% accuracy. The model also has about 97% accuracy on the testing set, with 294 accurately predicted true positives and 683 accurately predicted true negatives. The model has more than thrice the amount of false negatives (18) as it does false positives (5). However, the model is extremely accurate using only very few variables.

Based on the results of the analysis, hospitals could very accurately predict a patient’s likelihood of survival based on their age, creatinine levels, and ejection fraction. This could help hospitals to prioritize patients to treat. Based on likelihood of death, hospitals could either identify patients who are more likely of dying and therefore requiring more emergent attention, or hospitals could also determine those more likely to survive with treatment (i.e. prioritize a patient who is more likely to survive over one who is likely to die anyway).

Sources

Kaggle Heart Failure Dataset: https://www.kaggle.com/datasets/aadarshvelu/heart-failure-prediction-clinical-records