setwd("/Users/cjavitia/Desktop/Internships/CFDE/AustraliaData")

while (!is.null(dev.list())) dev.off()
rm(list = ls())

library(tidymodels)
library(readxl)
library(stringr)
library(themis)
library(glmnet)
library(synthpop)
library(janitor)
library(ggplot2)

Patient metadata and antibiotics

meta <- read_excel("Merged_metadata.xlsx", sheet = "Entero_IMP")

abx_cols <- paste0("antibiotic_", 1:5)

abx_data <- meta %>%
  select(`Isolate ID`, all_of(abx_cols))

abx_long <- abx_data %>%
  pivot_longer(cols = -`Isolate ID`, names_to = "which_abx", values_to = "drug")

abx_clean <- abx_long %>%
  mutate(drug = str_trim(drug)) %>%
  filter(!is.na(drug) & drug != "") %>%
  mutate(flag = 1)

abx_unique <- abx_clean %>%
  distinct(`Isolate ID`, drug, .keep_all = TRUE)

abx_binary <- abx_unique %>%
  pivot_wider(id_cols = `Isolate ID`, names_from = drug, values_from = flag, values_fill = 0, names_prefix = "ABX_")

Combine metadata and Sequencing data

meta_core <- meta %>%
  select(`Isolate ID`, outcome, hospital_1m, icu_adm) %>% 
  mutate(outcome = factor(outcome, levels = c("Died in hospital", "Survived to discharge")))

model_df <- meta_core %>%
  left_join(abx_binary, by = "Isolate ID") %>%
  mutate(across(where(is.numeric), ~replace_na(., 0)))

model_df <- model_df %>%
  select(-`Isolate ID`)

colnames(model_df)
##  [1] "outcome"                            "hospital_1m"                       
##  [3] "icu_adm"                            "ABX_Ceftriaxone"                   
##  [5] "ABX_Piperacillin/tazobactam"        "ABX_Other"                         
##  [7] "ABX_Meropenem"                      "ABX_Benzylpenicillin"              
##  [9] "ABX_Ciprofloxacin"                  "ABX_Trimethoprim-sulphamethoxazole"
## [11] "ABX_Clarithromycin"                 "ABX_Ceftolozane/tazoabctam"        
## [13] "ABX_Caspofungin"                    "ABX_Amoxicillin"                   
## [15] "ABX_Vancomycin"                     "ABX_Cefepime"                      
## [17] "ABX_Cephazolin"                     "ABX_Tobramycin"                    
## [19] "ABX_Flucloxacillin"                 "ABX_Amoxicillin/clavulanate"       
## [21] "ABX_Trimethoprim"                   "ABX_Cephalexin"                    
## [23] "ABX_Clindamycin"                    "ABX_Fluconazole"
Modeling
set.seed(345)
splits <- initial_split(model_df, prop = 0.75, strata = outcome)
train_data <- training(splits)
test_data  <- testing(splits)

model_recipe <- recipe(outcome ~ ., data = train_data) %>%
  step_string2factor(all_nominal_predictors()) %>%
  step_unknown(all_nominal_predictors()) %>%
  step_zv(all_nominal_predictors()) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_zv(all_predictors()) %>%
  step_impute_median(all_numeric_predictors()) %>% # imputes numeric NA
  step_impute_mode(all_nominal_predictors()) %>% # imputes factor NA
  step_upsample(outcome)

log_reg <- logistic_reg(penalty = tune(), mixture = 1, mode = "classification") %>%
  set_engine("glmnet", weights = ifelse(train_data$outcome == "Survived to discharge", 1, 50/7))

wflow <- workflow() %>%
  add_recipe(model_recipe) %>%
  add_model(log_reg)
set.seed(678)
folds <- vfold_cv(train_data, v = 5, strata = outcome)
grid  <- tibble(penalty = 10^seq(-4, -1, length.out = 30))

tuned <- tune_grid(wflow, resamples = folds, grid = grid, metrics = metric_set(roc_auc, sens, spec), 
                   control = control_grid(save_pred = TRUE))
## → A | warning: The argument `weights` cannot be manually modified and was removed.
## There were issues with some computations   A: x1There were issues with some computations   A: x4There were issues with some computations   A: x5
best <- select_best(tuned, metric = "roc_auc")

final_fit <- finalize_workflow(wflow, best) %>% fit(train_data)
## Warning: The argument `weights` cannot be manually modified and was removed.

Model results

test_pred <- bind_cols(test_data %>% select(outcome), predict(final_fit, test_data, 
                                                              type = "prob") %>% janitor::clean_names(), predict(final_fit, test_data))

metrics <- metric_set(roc_auc, accuracy, sens, spec)(test_pred, truth = outcome, 
                                                     estimate = .pred_class, pred_survived_to_discharge, event_level = "second")

confusion <- conf_mat(test_pred, truth = outcome, estimate = .pred_class)

print(metrics)
## # A tibble: 4 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.733
## 2 sens     binary         0.692
## 3 spec     binary         1    
## 4 roc_auc  binary         0.808
print(confusion)
##                        Truth
## Prediction              Died in hospital Survived to discharge
##   Died in hospital                     2                     4
##   Survived to discharge                0                     9
tidy(final_fit) %>%
  filter(estimate != 0) %>%
  arrange(desc(abs(estimate))) %>%
  select(term, estimate) %>%
  print(n = 30)
## # A tibble: 17 × 2
##    term                                estimate
##    <chr>                                  <dbl>
##  1 ABX_Vancomycin                      1.27e+ 1
##  2 ABX_Cephalexin                      5.12e+ 0
##  3 ABX_Clindamycin                     4.83e+ 0
##  4 ABX_Meropenem                       4.69e+ 0
##  5 ABX_Ciprofloxacin                  -4.36e+ 0
##  6 ABX_Ceftriaxone                    -3.33e+ 0
##  7 ABX_Flucloxacillin                  3.03e+ 0
##  8 ABX_Amoxicillin/clavulanate         2.76e+ 0
##  9 ABX_Other                          -2.51e+ 0
## 10 ABX_Amoxicillin                     2.02e+ 0
## 11 ABX_Ceftolozane/tazoabctam          1.98e+ 0
## 12 hospital_1m_Yes                     1.41e+ 0
## 13 ABX_Fluconazole                     9.72e- 1
## 14 ABX_Clarithromycin                  9.19e- 1
## 15 ABX_Trimethoprim-sulphamethoxazole  9.03e- 1
## 16 (Intercept)                        -1.77e- 1
## 17 ABX_Caspofungin                     8.96e-15
autoplot(tuned) +
  ggtitle("Penalty vs. cross-validated AUC")

roc_curve(test_pred, outcome, pred_survived_to_discharge) %>% 
  ggplot(aes(1 - specificity, sensitivity)) +
  geom_path(linewidth = 1.2) +
  geom_abline(lty = 2, colour = "grey50") +
  coord_equal() +
  labs(title = glue::glue("ROC curve (AUC = {round(metrics %>% filter(.metric == 'roc_auc') %>% pull(.estimate), 3)})"),
       x = "1 – Specificity (False-positive rate)", y = "Sensitivity (True-positive rate)")

conf_mat(test_pred, truth = outcome, estimate = .pred_class) %>% 
  autoplot(type = "heatmap") + scale_fill_gradient(low = "white", high = "hotpink") + theme_minimal()
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.

coef_df <- tidy(final_fit) %>% 
  filter(term != "(Intercept)", estimate != 0) %>% 
  mutate(direction = ifelse(estimate > 0, "Survival increase", "Survival decrease"))

ggplot(coef_df,  aes(x = reorder(term, abs(estimate)), y = estimate, fill = direction)) +
  geom_col(width = 0.75) + coord_flip() +
  scale_fill_manual(values = c("Survival increase" = "red", "Survival decrease" = "slategray")) +
  labs(x = NULL, y = "Log-odds coefficient\n(positive = increases survival)", title = "Features retained by LASSO Metadata Only") +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    plot.title = element_text(face = "bold"),
    axis.title.y = element_text(face = "bold"),
    axis.title.x = element_text(face = "bold"),
    axis.text = element_text(face = "bold"),
    legend.text = element_text(face = "bold")
  )

Random Forest

library(ranger)
library(vip)  # for variable importance plot
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi
# Define random forest model
rf_model <- rand_forest(
  mtry = tune(), 
  trees = 1000, 
  min_n = tune()
) %>%
  set_engine("ranger", importance = "impurity") %>%
  set_mode("classification")

# Add to workflow
rf_wflow <- workflow() %>%
  add_recipe(model_recipe) %>%
  add_model(rf_model)

# Tune hyperparameters
rf_grid <- grid_regular(
  mtry(range = c(5, 30)), 
  min_n(range = c(2, 10)), 
  levels = 5)

set.seed(456)
rf_tuned <- tune_grid(
  rf_wflow, 
  resamples = folds, 
  grid = rf_grid,
  metrics = metric_set(roc_auc, accuracy, sens, spec),
  control = control_grid(save_pred = TRUE))
## → A | warning: ! 23 columns were requested but there were 18 predictors in the data.
##                ℹ 18 predictors will be used.
## There were issues with some computations   A: x1                                                 → B | warning: ! 30 columns were requested but there were 18 predictors in the data.
##                ℹ 18 predictors will be used.
## There were issues with some computations   A: x1There were issues with some computations   A: x2   B: x1There were issues with some computations   A: x3   B: x2There were issues with some computations   A: x4   B: x3There were issues with some computations   A: x5   B: x4                                                         → C | warning: ! 23 columns were requested but there were 21 predictors in the data.
##                ℹ 21 predictors will be used.
## There were issues with some computations   A: x5   B: x4There were issues with some computations   A: x5   B: x5   C: x1                                                                 → D | warning: ! 30 columns were requested but there were 21 predictors in the data.
##                ℹ 21 predictors will be used.
## There were issues with some computations   A: x5   B: x5   C: x1There were issues with some computations   A: x5   B: x5   C: x2   D: x1There were issues with some computations   A: x5   B: x5   C: x2   D: x2There were issues with some computations   A: x5   B: x5   C: x3   D: x3There were issues with some computations   A: x5   B: x5   C: x4   D: x3There were issues with some computations   A: x5   B: x5   C: x5   D: x4There were issues with some computations   A: x5   B: x5   C: x6   D: x5There were issues with some computations   A: x5   B: x5   C: x7   D: x6There were issues with some computations   A: x5   B: x5   C: x8   D: x7There were issues with some computations   A: x5   B: x5   C: x9   D: x8There were issues with some computations   A: x5   B: x5   C: x10   D: x9                                                                          → E | warning: ! 23 columns were requested but there were 19 predictors in the data.
##                ℹ 19 predictors will be used.
## There were issues with some computations   A: x5   B: x5   C: x10   D: x9There were issues with some computations   A: x5   B: x5   C: x10   D: x10   E:…                                                                                 → F | warning: ! 30 columns were requested but there were 19 predictors in the data.
##                ℹ 19 predictors will be used.
## There were issues with some computations   A: x5   B: x5   C: x10   D: x10   E:…There were issues with some computations   A: x5   B: x5   C: x10   D: x10   E:…There were issues with some computations   A: x5   B: x5   C: x10   D: x10   E:…There were issues with some computations   A: x5   B: x5   C: x10   D: x10   E:…There were issues with some computations   A: x5   B: x5   C: x10   D: x10   E:…There were issues with some computations   A: x5   B: x5   C: x10   D: x10   E:…There were issues with some computations   A: x5   B: x5   C: x11   D: x10   E:…There were issues with some computations   A: x5   B: x5   C: x12   D: x11   E:…There were issues with some computations   A: x5   B: x5   C: x13   D: x12   E:…There were issues with some computations   A: x5   B: x5   C: x14   D: x13   E:…There were issues with some computations   A: x5   B: x5   C: x15   D: x14   E:…There were issues with some computations   A: x5   B: x5   C: x15   D: x15   E:…
# Select best hyperparameters by AUC
rf_best <- select_best(rf_tuned, metric = "roc_auc")

# Finalize and fit
rf_final <- finalize_workflow(rf_wflow, rf_best) %>%
  fit(data = train_data)
## Warning: ! 23 columns were requested but there were 21 predictors in the data.
## ℹ 21 predictors will be used.
# Predict on test set
rf_pred <- bind_cols(test_data %>% select(outcome),
                     predict(rf_final, test_data, type = "prob") %>% janitor::clean_names(),
                     predict(rf_final, test_data))

# Metrics
rf_metrics <- metric_set(roc_auc, accuracy, sens, spec)(rf_pred, truth = outcome, 
                                                        estimate = .pred_class, pred_survived_to_discharge, event_level = "second")

rf_confusion <- conf_mat(rf_pred, truth = outcome, estimate = .pred_class)

# Print results
print(rf_metrics)
## # A tibble: 4 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.8  
## 2 sens     binary         0.769
## 3 spec     binary         1    
## 4 roc_auc  binary         0.846
print(rf_confusion)
##                        Truth
## Prediction              Died in hospital Survived to discharge
##   Died in hospital                     2                     3
##   Survived to discharge                0                    10
# ROC Curve
roc_curve(rf_pred, outcome, pred_survived_to_discharge) %>%
  ggplot(aes(1 - specificity, sensitivity)) +
  geom_path(linewidth = 1.2, color = "darkgreen") +
  geom_abline(lty = 2, color = "darkgrey") +
  coord_equal() +
  labs(title = glue::glue("Random Forest ROC (AUC = {round(rf_metrics %>% filter(.metric == 'roc_auc') %>% pull(.estimate), 3)})"),
       x = "1 – Specificity", y = "Sensitivity")

# Variable Importance
rf_final %>%
  extract_fit_parsnip() %>%
  vip(num_features = 10, geom = "col", aesthetics = list(fill = "darkturquoise")) +
  labs(title = "Metadata Only Top 10 Important Features (Random Forest)") +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    plot.title = element_text(face = "bold", color = "black"),
    axis.title.y = element_text(face = "bold"),
    axis.title.x = element_text(face = "bold"),
    axis.text = element_text(face = "bold"),
    legend.text = element_text(face = "bold")
  )