PREDICTIVE MODEL WITH METADATA, SEQUENCING DATA

galaxy_file <- "Galaxy_results.xlsx"

res_df   <- read_excel(galaxy_file, sheet = "ResFinder")
plasm_df <- read_excel(galaxy_file, sheet = "PlasmidFinder")
sum_df   <- read_excel(galaxy_file, sheet = "Summary")

###### ST extraction ######
st_df <- sum_df %>% select(`Isolate ID`, ST = `Sequence Type`) %>% filter(!is.na(ST) & ST != "")

# Get list of STs that occur frequently enough
keep_st <- st_df %>% count(ST, sort = TRUE) %>% filter(n >= 3) %>% pull(ST)

# encode STs
st_wide <- st_df %>% filter(ST %in% keep_st) %>% mutate(flag = 1) %>%
  pivot_wider(id_cols = `Isolate ID`, names_from = ST, values_from = flag,
              values_fill = 0, values_fn = max, names_prefix = "ST_")


##### AMR genes extraction #######
keep_genes <- res_df %>% count(Gene, sort = TRUE) %>% filter(n >= 5) %>% pull(Gene)

arg_wide <- res_df %>% filter(Gene %in% keep_genes) %>% mutate(flag = 1) %>%
  pivot_wider(id_cols  = `Isolate ID`, names_from = Gene, values_from = flag,
              values_fill = 0, values_fn = max, names_prefix = "AMR_")

arg_counts <- res_df %>% group_by(`Isolate ID`) %>% summarise(AMR_count = n_distinct(Gene), .groups = "drop")

###### Plasmid extraction ######
keep_rep <- plasm_df %>% count(Plasmid, sort = TRUE) %>% filter(n >= 3) %>% pull(Plasmid)

plasm_wide <- plasm_df %>% filter(Plasmid %in% keep_rep) %>% mutate(flag = 1) %>%
  pivot_wider(id_cols = `Isolate ID`, names_from = Plasmid, values_from = flag,
              values_fill = 0, values_fn = max, names_prefix = "Plasmid_")

plasm_counts <- plasm_df %>% group_by(`Isolate ID`) %>%
  summarise(Plasmid_count = n_distinct(Plasmid), .groups = "drop")

###### Virulence genes extraction ######
seq_feats <- sum_df %>% 
  select(`Isolate ID`)

if ("VirulenceFinder" %in% excel_sheets(galaxy_file)) {
  vir_df <- read_excel(galaxy_file, sheet = "VirulenceFinder")
  keep_vir <- vir_df %>% count(Gene, sort = TRUE) %>% filter(n >= 3) %>% pull(Gene)
  
  vir_wide <- vir_df %>% filter(Gene %in% keep_vir) %>% mutate(flag = 1) %>%
    pivot_wider(id_cols = `Isolate ID`, names_from = Gene, values_from = flag, values_fill = 0, values_fn = max, names_prefix = "VIR_")
  
  vir_counts <- vir_df %>% group_by(`Isolate ID`) %>% summarise(Virulence_count = n_distinct(Gene), .groups = "drop")
  
  seq_feats <- seq_feats %>% left_join(vir_wide,   by = "Isolate ID") %>%
    left_join(vir_counts, by = "Isolate ID")
}

###### Merge all sequencing feature blocks #######
seq_feats <- seq_feats %>%
  left_join(arg_wide, by = "Isolate ID") %>%
  left_join(arg_counts, by = "Isolate ID") %>%
  left_join(plasm_wide, by = "Isolate ID") %>%
  left_join(plasm_counts, by = "Isolate ID") %>%
  left_join(st_wide, by = "Isolate ID")

####### Patient metadata and antibiotics #########
meta <- read_excel("Merged_metadata.xlsx", sheet = "Entero_IMP")

meta <- meta %>%
  left_join(st_df, by = "Isolate ID")

abx_cols <- paste0("antibiotic_", 1:5)

abx_data <- meta %>%
  select(`Isolate ID`, all_of(abx_cols))

abx_long <- abx_data %>%
  pivot_longer(cols = -`Isolate ID`, names_to = "which_abx", values_to = "drug")

abx_clean <- abx_long %>%
  mutate(drug = str_trim(drug)) %>%
  filter(!is.na(drug) & drug != "") %>%
  mutate(flag = 1)

abx_unique <- abx_clean %>%
  distinct(`Isolate ID`, drug, .keep_all = TRUE)

abx_binary <- abx_unique %>%
  pivot_wider(id_cols = `Isolate ID`, names_from = drug, values_from = flag, values_fill = 0, names_prefix = "ABX_")

####### Combine metadata and Sequencing data #######

meta_core <- meta %>%
  select(`Isolate ID`, outcome, hospital_1m, icu_adm) %>% 
  mutate(outcome = factor(outcome, levels = c("Died in hospital", "Survived to discharge")))

model_df <- meta_core %>%
  left_join(abx_binary, by = "Isolate ID") %>%
  left_join(seq_feats,  by = "Isolate ID") %>%
  mutate(across(where(is.numeric), ~replace_na(., 0)))

model_df <- model_df %>%
  select(-`Isolate ID`)

colnames(model_df)
##  [1] "outcome"                            "hospital_1m"                       
##  [3] "icu_adm"                            "ABX_Ceftriaxone"                   
##  [5] "ABX_Piperacillin/tazobactam"        "ABX_Other"                         
##  [7] "ABX_Meropenem"                      "ABX_Benzylpenicillin"              
##  [9] "ABX_Ciprofloxacin"                  "ABX_Trimethoprim-sulphamethoxazole"
## [11] "ABX_Clarithromycin"                 "ABX_Ceftolozane/tazoabctam"        
## [13] "ABX_Caspofungin"                    "ABX_Amoxicillin"                   
## [15] "ABX_Vancomycin"                     "ABX_Cefepime"                      
## [17] "ABX_Cephazolin"                     "ABX_Tobramycin"                    
## [19] "ABX_Flucloxacillin"                 "ABX_Amoxicillin/clavulanate"       
## [21] "ABX_Trimethoprim"                   "ABX_Cephalexin"                    
## [23] "ABX_Clindamycin"                    "ABX_Fluconazole"                   
## [25] "AMR_ARR-3"                          "AMR_aac(3)-IId"                    
## [27] "AMR_aac(6')-Ib-cr"                  "AMR_aph(3'')-Ib"                   
## [29] "AMR_aph(6)-Id"                      "AMR_blaACT-15"                     
## [31] "AMR_blaIMP-4"                       "AMR_blaOXA-1"                      
## [33] "AMR_blaTEM-1B"                      "AMR_catA2"                         
## [35] "AMR_catB3"                          "AMR_dfrA19"                        
## [37] "AMR_mph(A)"                         "AMR_qnrB2"                         
## [39] "AMR_sul1"                           "AMR_tet(D)"                        
## [41] "AMR_aac(6')-Ib3"                    "AMR_blaACT-7"                      
## [43] "AMR_fosA"                           "AMR_blaSHV-12"                     
## [45] "AMR_qnrA1"                          "AMR_blaACT-16"                     
## [47] "AMR_aac(6')-Ib-Hangzhou"            "AMR_count"                         
## [49] "Plasmid_Col(pHAD28)"                "Plasmid_ColE10"                    
## [51] "Plasmid_IncFII(pECLA)"              "Plasmid_IncHI2"                    
## [53] "Plasmid_IncHI2A"                    "Plasmid_IncM2"                     
## [55] "Plasmid_IncFIB(pB171)"              "Plasmid_IncFIB(pECLA)"             
## [57] "Plasmid_IncR"                       "Plasmid_count"                     
## [59] "ST_90"                              "ST_415"                            
## [61] "ST_66"                              "ST_108"                            
## [63] "ST_133"
Modeling
set.seed(345)
splits <- initial_split(model_df, prop = 0.75, strata = outcome)
train_data <- training(splits)
test_data  <- testing(splits)

model_recipe <- recipe(outcome ~ ., data = train_data) %>%
  step_string2factor(all_nominal_predictors()) %>%
  step_unknown(all_nominal_predictors()) %>%
  step_zv(all_nominal_predictors()) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_zv(all_predictors()) %>%
  step_impute_median(all_numeric_predictors()) %>% # imputes numeric NA
  step_impute_mode(all_nominal_predictors()) %>% # imputes factor NA
  step_upsample(outcome)

log_reg <- logistic_reg(penalty = tune(), mixture = 1, mode = "classification") %>%
  set_engine("glmnet", weights = ifelse(train_data$outcome == "Survived to discharge", 1, 50/7))

wflow <- workflow() %>%
  add_recipe(model_recipe) %>%
  add_model(log_reg)

set.seed(678)
folds <- vfold_cv(train_data, v = 5, strata = outcome)
grid  <- tibble(penalty = 10^seq(-4, -1, length.out = 30))

tuned <- tune_grid(wflow, resamples = folds, grid = grid, metrics = metric_set(roc_auc, sens, spec), 
                   control = control_grid(save_pred = TRUE))
## → A | warning: The argument `weights` cannot be manually modified and was removed.
## There were issues with some computations   A: x1There were issues with some computations   A: x3There were issues with some computations   A: x5There were issues with some computations   A: x5
best <- select_best(tuned, metric = "roc_auc")

final_fit <- finalize_workflow(wflow, best) %>% fit(train_data)
## Warning: The argument `weights` cannot be manually modified and was removed.

Model results

test_pred <- bind_cols(test_data %>% select(outcome), predict(final_fit, test_data, 
                      type = "prob") %>% janitor::clean_names(), predict(final_fit, test_data))

metrics <- metric_set(roc_auc, accuracy, sens, spec)(test_pred, truth = outcome, 
                                                     estimate = .pred_class, pred_survived_to_discharge, event_level = "second")

confusion <- conf_mat(test_pred, truth = outcome, estimate = .pred_class)

print(metrics)
## # A tibble: 4 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.867
## 2 sens     binary         0.846
## 3 spec     binary         1    
## 4 roc_auc  binary         0.962
print(confusion)
##                        Truth
## Prediction              Died in hospital Survived to discharge
##   Died in hospital                     2                     2
##   Survived to discharge                0                    11
tidy(final_fit) %>%
  filter(estimate != 0) %>%
  arrange(desc(abs(estimate))) %>%
  select(term, estimate) %>%
  print(n = 30)
## # A tibble: 11 × 2
##    term                   estimate
##    <chr>                     <dbl>
##  1 AMR_blaACT-7           1.34e+ 0
##  2 Plasmid_IncFIB(pECLA)  7.36e- 1
##  3 ST_90                 -6.62e- 1
##  4 ABX_Ciprofloxacin     -6.29e- 1
##  5 AMR_ARR-3              5.79e- 1
##  6 (Intercept)           -3.35e- 1
##  7 ABX_Flucloxacillin     2.12e- 1
##  8 ABX_Vancomycin         1.61e- 1
##  9 AMR_aph(3'')-Ib       -1.40e- 1
## 10 AMR_blaSHV-12         -8.45e- 3
## 11 AMR_aph(6)-Id         -8.97e-17
autoplot(tuned) +
  ggtitle("Penalty vs. cross-validated AUC")

roc_curve(test_pred, outcome, pred_survived_to_discharge) %>% 
  ggplot(aes(1 - specificity, sensitivity)) +
  geom_path(linewidth = 1.2) +
  geom_abline(lty = 2, colour = "grey") +
  coord_equal() +
  labs(title = glue::glue("ROC curve (AUC = {round(metrics %>% filter(.metric == 'roc_auc') %>% pull(.estimate), 3)})"),
       x = "1 – Specificity (False-positive rate)", y = "Sensitivity (True-positive rate)")

conf_mat(test_pred, truth = outcome, estimate = .pred_class) %>% 
  autoplot(type = "heatmap") + scale_fill_gradient(low = "white", high = "hotpink") + theme_minimal()
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.

coef_df <- tidy(final_fit) %>% 
  filter(term != "(Intercept)", estimate != 0) %>% 
  mutate(direction = ifelse(estimate > 0, "Survival increase", "Survival decrease"))

ggplot(coef_df,  aes(x = reorder(term, abs(estimate)), y = estimate, fill = direction)) +
  geom_col(width = 0.75) + coord_flip() +
  scale_fill_manual(values = c("Survival increase" = "red", "Survival decrease" = "slategray")) +
  labs(x = NULL, y = "Log-odds coefficient\n(positive = increases survival)", title = "Features retained by LASSO--Metadata And Sequencing Data") +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    plot.title = element_text(face = "bold"),
    axis.title.y = element_text(face = "bold"),
    axis.title.x = element_text(face = "bold"),
    axis.text = element_text(face = "bold"),
    legend.text = element_text(face = "bold")
  )

Random Forest

library(ranger)
library(vip)  # for variable importance plot
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi
# Define random forest model
rf_model <- rand_forest(
  mtry = tune(), 
  trees = 1000, 
  min_n = tune()
) %>%
  set_engine("ranger", importance = "impurity") %>%
  set_mode("classification")

# Add to workflow
rf_wflow <- workflow() %>%
  add_recipe(model_recipe) %>%
  add_model(rf_model)

# Tune hyperparameters
rf_grid <- grid_regular(
  mtry(range = c(5, 30)), 
  min_n(range = c(2, 10)), 
  levels = 5)

set.seed(456)
rf_tuned <- tune_grid(
  rf_wflow, 
  resamples = folds, 
  grid = rf_grid,
  metrics = metric_set(roc_auc, accuracy, sens, spec),
  control = control_grid(save_pred = TRUE))

# Select best hyperparameters by AUC
rf_best <- select_best(rf_tuned, metric = "roc_auc")

# Finalize and fit
rf_final <- finalize_workflow(rf_wflow, rf_best) %>%
  fit(data = train_data)

# Predict on test set
rf_pred <- bind_cols(test_data %>% select(outcome),
  predict(rf_final, test_data, type = "prob") %>% janitor::clean_names(),
  predict(rf_final, test_data))

# Metrics
rf_metrics <- metric_set(roc_auc, accuracy, sens, spec)(rf_pred, truth = outcome, 
  estimate = .pred_class, pred_survived_to_discharge, event_level = "second")

rf_confusion <- conf_mat(rf_pred, truth = outcome, estimate = .pred_class)

# Print results
print(rf_metrics)
## # A tibble: 4 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.8  
## 2 sens     binary         0.923
## 3 spec     binary         0    
## 4 roc_auc  binary         0.846
print(rf_confusion)
##                        Truth
## Prediction              Died in hospital Survived to discharge
##   Died in hospital                     0                     1
##   Survived to discharge                2                    12
# ROC Curve
roc_curve(rf_pred, outcome, pred_survived_to_discharge) %>%
  ggplot(aes(1 - specificity, sensitivity)) +
  geom_path(linewidth = 1.2, color = "darkgreen") +
  geom_abline(lty = 2, color = "grey") +
  coord_equal() +
  labs(title = glue::glue("Random Forest ROC (AUC = {round(rf_metrics %>% filter(.metric == 'roc_auc') %>% pull(.estimate), 3)})"),
       x = "1 – Specificity", y = "Sensitivity")

# Variable Importance
rf_final %>%
  extract_fit_parsnip() %>%
  vip(num_features = 10, geom = "col", aesthetics = list(fill = "darkturquoise")) +
  labs(title = "Metadata and Genomic Sequencing Data Top 10 Important Features (Random Forest)") +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    plot.title = element_text(face = "bold", color = "black"),
    axis.title.y = element_text(face = "bold"),
    axis.title.x = element_text(face = "bold"),
    axis.text = element_text(face = "bold"),
    legend.text = element_text(face = "bold")
  )