galaxy_file <- "Galaxy_results.xlsx"
res_df <- read_excel(galaxy_file, sheet = "ResFinder")
plasm_df <- read_excel(galaxy_file, sheet = "PlasmidFinder")
sum_df <- read_excel(galaxy_file, sheet = "Summary")
###### ST extraction ######
st_df <- sum_df %>% select(`Isolate ID`, ST = `Sequence Type`) %>% filter(!is.na(ST) & ST != "")
# Get list of STs that occur frequently enough
keep_st <- st_df %>% count(ST, sort = TRUE) %>% filter(n >= 3) %>% pull(ST)
# encode STs
st_wide <- st_df %>% filter(ST %in% keep_st) %>% mutate(flag = 1) %>%
pivot_wider(id_cols = `Isolate ID`, names_from = ST, values_from = flag,
values_fill = 0, values_fn = max, names_prefix = "ST_")
##### AMR genes extraction #######
keep_genes <- res_df %>% count(Gene, sort = TRUE) %>% filter(n >= 5) %>% pull(Gene)
arg_wide <- res_df %>% filter(Gene %in% keep_genes) %>% mutate(flag = 1) %>%
pivot_wider(id_cols = `Isolate ID`, names_from = Gene, values_from = flag,
values_fill = 0, values_fn = max, names_prefix = "AMR_")
arg_counts <- res_df %>% group_by(`Isolate ID`) %>% summarise(AMR_count = n_distinct(Gene), .groups = "drop")
###### Plasmid extraction ######
keep_rep <- plasm_df %>% count(Plasmid, sort = TRUE) %>% filter(n >= 3) %>% pull(Plasmid)
plasm_wide <- plasm_df %>% filter(Plasmid %in% keep_rep) %>% mutate(flag = 1) %>%
pivot_wider(id_cols = `Isolate ID`, names_from = Plasmid, values_from = flag,
values_fill = 0, values_fn = max, names_prefix = "Plasmid_")
plasm_counts <- plasm_df %>% group_by(`Isolate ID`) %>%
summarise(Plasmid_count = n_distinct(Plasmid), .groups = "drop")
###### Virulence genes extraction ######
seq_feats <- sum_df %>%
select(`Isolate ID`)
if ("VirulenceFinder" %in% excel_sheets(galaxy_file)) {
vir_df <- read_excel(galaxy_file, sheet = "VirulenceFinder")
keep_vir <- vir_df %>% count(Gene, sort = TRUE) %>% filter(n >= 3) %>% pull(Gene)
vir_wide <- vir_df %>% filter(Gene %in% keep_vir) %>% mutate(flag = 1) %>%
pivot_wider(id_cols = `Isolate ID`, names_from = Gene, values_from = flag, values_fill = 0, values_fn = max, names_prefix = "VIR_")
vir_counts <- vir_df %>% group_by(`Isolate ID`) %>% summarise(Virulence_count = n_distinct(Gene), .groups = "drop")
seq_feats <- seq_feats %>% left_join(vir_wide, by = "Isolate ID") %>%
left_join(vir_counts, by = "Isolate ID")
}
###### Merge all sequencing feature blocks #######
seq_feats <- seq_feats %>%
left_join(arg_wide, by = "Isolate ID") %>%
left_join(arg_counts, by = "Isolate ID") %>%
left_join(plasm_wide, by = "Isolate ID") %>%
left_join(plasm_counts, by = "Isolate ID") %>%
left_join(st_wide, by = "Isolate ID")
####### Patient metadata and antibiotics #########
meta <- read_excel("Merged_metadata.xlsx", sheet = "Entero_IMP")
meta <- meta %>%
left_join(st_df, by = "Isolate ID")
abx_cols <- paste0("antibiotic_", 1:5)
abx_data <- meta %>%
select(`Isolate ID`, all_of(abx_cols))
abx_long <- abx_data %>%
pivot_longer(cols = -`Isolate ID`, names_to = "which_abx", values_to = "drug")
abx_clean <- abx_long %>%
mutate(drug = str_trim(drug)) %>%
filter(!is.na(drug) & drug != "") %>%
mutate(flag = 1)
abx_unique <- abx_clean %>%
distinct(`Isolate ID`, drug, .keep_all = TRUE)
abx_binary <- abx_unique %>%
pivot_wider(id_cols = `Isolate ID`, names_from = drug, values_from = flag, values_fill = 0, names_prefix = "ABX_")
####### Combine metadata and Sequencing data #######
meta_core <- meta %>%
select(`Isolate ID`, outcome, hospital_1m, icu_adm) %>%
mutate(outcome = factor(outcome, levels = c("Died in hospital", "Survived to discharge")))
model_df <- meta_core %>%
left_join(abx_binary, by = "Isolate ID") %>%
left_join(seq_feats, by = "Isolate ID") %>%
mutate(across(where(is.numeric), ~replace_na(., 0)))
model_df <- model_df %>%
select(-`Isolate ID`)
colnames(model_df)
## [1] "outcome" "hospital_1m"
## [3] "icu_adm" "ABX_Ceftriaxone"
## [5] "ABX_Piperacillin/tazobactam" "ABX_Other"
## [7] "ABX_Meropenem" "ABX_Benzylpenicillin"
## [9] "ABX_Ciprofloxacin" "ABX_Trimethoprim-sulphamethoxazole"
## [11] "ABX_Clarithromycin" "ABX_Ceftolozane/tazoabctam"
## [13] "ABX_Caspofungin" "ABX_Amoxicillin"
## [15] "ABX_Vancomycin" "ABX_Cefepime"
## [17] "ABX_Cephazolin" "ABX_Tobramycin"
## [19] "ABX_Flucloxacillin" "ABX_Amoxicillin/clavulanate"
## [21] "ABX_Trimethoprim" "ABX_Cephalexin"
## [23] "ABX_Clindamycin" "ABX_Fluconazole"
## [25] "AMR_ARR-3" "AMR_aac(3)-IId"
## [27] "AMR_aac(6')-Ib-cr" "AMR_aph(3'')-Ib"
## [29] "AMR_aph(6)-Id" "AMR_blaACT-15"
## [31] "AMR_blaIMP-4" "AMR_blaOXA-1"
## [33] "AMR_blaTEM-1B" "AMR_catA2"
## [35] "AMR_catB3" "AMR_dfrA19"
## [37] "AMR_mph(A)" "AMR_qnrB2"
## [39] "AMR_sul1" "AMR_tet(D)"
## [41] "AMR_aac(6')-Ib3" "AMR_blaACT-7"
## [43] "AMR_fosA" "AMR_blaSHV-12"
## [45] "AMR_qnrA1" "AMR_blaACT-16"
## [47] "AMR_aac(6')-Ib-Hangzhou" "AMR_count"
## [49] "Plasmid_Col(pHAD28)" "Plasmid_ColE10"
## [51] "Plasmid_IncFII(pECLA)" "Plasmid_IncHI2"
## [53] "Plasmid_IncHI2A" "Plasmid_IncM2"
## [55] "Plasmid_IncFIB(pB171)" "Plasmid_IncFIB(pECLA)"
## [57] "Plasmid_IncR" "Plasmid_count"
## [59] "ST_90" "ST_415"
## [61] "ST_66" "ST_108"
## [63] "ST_133"
set.seed(345)
splits <- initial_split(model_df, prop = 0.75, strata = outcome)
train_data <- training(splits)
test_data <- testing(splits)
model_recipe <- recipe(outcome ~ ., data = train_data) %>%
step_string2factor(all_nominal_predictors()) %>%
step_unknown(all_nominal_predictors()) %>%
step_zv(all_nominal_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_impute_median(all_numeric_predictors()) %>% # imputes numeric NA
step_impute_mode(all_nominal_predictors()) %>% # imputes factor NA
step_upsample(outcome)
log_reg <- logistic_reg(penalty = tune(), mixture = 1, mode = "classification") %>%
set_engine("glmnet", weights = ifelse(train_data$outcome == "Survived to discharge", 1, 50/7))
wflow <- workflow() %>%
add_recipe(model_recipe) %>%
add_model(log_reg)
set.seed(678)
folds <- vfold_cv(train_data, v = 5, strata = outcome)
grid <- tibble(penalty = 10^seq(-4, -1, length.out = 30))
tuned <- tune_grid(wflow, resamples = folds, grid = grid, metrics = metric_set(roc_auc, sens, spec),
control = control_grid(save_pred = TRUE))
## → A | warning: The argument `weights` cannot be manually modified and was removed.
## There were issues with some computations A: x1There were issues with some computations A: x3There were issues with some computations A: x5There were issues with some computations A: x5
best <- select_best(tuned, metric = "roc_auc")
final_fit <- finalize_workflow(wflow, best) %>% fit(train_data)
## Warning: The argument `weights` cannot be manually modified and was removed.
Model results
test_pred <- bind_cols(test_data %>% select(outcome), predict(final_fit, test_data,
type = "prob") %>% janitor::clean_names(), predict(final_fit, test_data))
metrics <- metric_set(roc_auc, accuracy, sens, spec)(test_pred, truth = outcome,
estimate = .pred_class, pred_survived_to_discharge, event_level = "second")
confusion <- conf_mat(test_pred, truth = outcome, estimate = .pred_class)
print(metrics)
## # A tibble: 4 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.867
## 2 sens binary 0.846
## 3 spec binary 1
## 4 roc_auc binary 0.962
print(confusion)
## Truth
## Prediction Died in hospital Survived to discharge
## Died in hospital 2 2
## Survived to discharge 0 11
tidy(final_fit) %>%
filter(estimate != 0) %>%
arrange(desc(abs(estimate))) %>%
select(term, estimate) %>%
print(n = 30)
## # A tibble: 11 × 2
## term estimate
## <chr> <dbl>
## 1 AMR_blaACT-7 1.34e+ 0
## 2 Plasmid_IncFIB(pECLA) 7.36e- 1
## 3 ST_90 -6.62e- 1
## 4 ABX_Ciprofloxacin -6.29e- 1
## 5 AMR_ARR-3 5.79e- 1
## 6 (Intercept) -3.35e- 1
## 7 ABX_Flucloxacillin 2.12e- 1
## 8 ABX_Vancomycin 1.61e- 1
## 9 AMR_aph(3'')-Ib -1.40e- 1
## 10 AMR_blaSHV-12 -8.45e- 3
## 11 AMR_aph(6)-Id -8.97e-17
autoplot(tuned) +
ggtitle("Penalty vs. cross-validated AUC")
roc_curve(test_pred, outcome, pred_survived_to_discharge) %>%
ggplot(aes(1 - specificity, sensitivity)) +
geom_path(linewidth = 1.2) +
geom_abline(lty = 2, colour = "grey") +
coord_equal() +
labs(title = glue::glue("ROC curve (AUC = {round(metrics %>% filter(.metric == 'roc_auc') %>% pull(.estimate), 3)})"),
x = "1 – Specificity (False-positive rate)", y = "Sensitivity (True-positive rate)")
conf_mat(test_pred, truth = outcome, estimate = .pred_class) %>%
autoplot(type = "heatmap") + scale_fill_gradient(low = "white", high = "hotpink") + theme_minimal()
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
coef_df <- tidy(final_fit) %>%
filter(term != "(Intercept)", estimate != 0) %>%
mutate(direction = ifelse(estimate > 0, "Survival increase", "Survival decrease"))
ggplot(coef_df, aes(x = reorder(term, abs(estimate)), y = estimate, fill = direction)) +
geom_col(width = 0.75) + coord_flip() +
scale_fill_manual(values = c("Survival increase" = "red", "Survival decrease" = "slategray")) +
labs(x = NULL, y = "Log-odds coefficient\n(positive = increases survival)", title = "Features retained by LASSO--Metadata And Sequencing Data") +
theme_minimal() +
theme(
legend.title = element_blank(),
plot.title = element_text(face = "bold"),
axis.title.y = element_text(face = "bold"),
axis.title.x = element_text(face = "bold"),
axis.text = element_text(face = "bold"),
legend.text = element_text(face = "bold")
)
library(ranger)
library(vip) # for variable importance plot
##
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
##
## vi
# Define random forest model
rf_model <- rand_forest(
mtry = tune(),
trees = 1000,
min_n = tune()
) %>%
set_engine("ranger", importance = "impurity") %>%
set_mode("classification")
# Add to workflow
rf_wflow <- workflow() %>%
add_recipe(model_recipe) %>%
add_model(rf_model)
# Tune hyperparameters
rf_grid <- grid_regular(
mtry(range = c(5, 30)),
min_n(range = c(2, 10)),
levels = 5)
set.seed(456)
rf_tuned <- tune_grid(
rf_wflow,
resamples = folds,
grid = rf_grid,
metrics = metric_set(roc_auc, accuracy, sens, spec),
control = control_grid(save_pred = TRUE))
# Select best hyperparameters by AUC
rf_best <- select_best(rf_tuned, metric = "roc_auc")
# Finalize and fit
rf_final <- finalize_workflow(rf_wflow, rf_best) %>%
fit(data = train_data)
# Predict on test set
rf_pred <- bind_cols(test_data %>% select(outcome),
predict(rf_final, test_data, type = "prob") %>% janitor::clean_names(),
predict(rf_final, test_data))
# Metrics
rf_metrics <- metric_set(roc_auc, accuracy, sens, spec)(rf_pred, truth = outcome,
estimate = .pred_class, pred_survived_to_discharge, event_level = "second")
rf_confusion <- conf_mat(rf_pred, truth = outcome, estimate = .pred_class)
# Print results
print(rf_metrics)
## # A tibble: 4 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.8
## 2 sens binary 0.923
## 3 spec binary 0
## 4 roc_auc binary 0.846
print(rf_confusion)
## Truth
## Prediction Died in hospital Survived to discharge
## Died in hospital 0 1
## Survived to discharge 2 12
# ROC Curve
roc_curve(rf_pred, outcome, pred_survived_to_discharge) %>%
ggplot(aes(1 - specificity, sensitivity)) +
geom_path(linewidth = 1.2, color = "darkgreen") +
geom_abline(lty = 2, color = "grey") +
coord_equal() +
labs(title = glue::glue("Random Forest ROC (AUC = {round(rf_metrics %>% filter(.metric == 'roc_auc') %>% pull(.estimate), 3)})"),
x = "1 – Specificity", y = "Sensitivity")
# Variable Importance
rf_final %>%
extract_fit_parsnip() %>%
vip(num_features = 10, geom = "col", aesthetics = list(fill = "darkturquoise")) +
labs(title = "Metadata and Genomic Sequencing Data Top 10 Important Features (Random Forest)") +
theme_minimal() +
theme(
legend.title = element_blank(),
plot.title = element_text(face = "bold", color = "black"),
axis.title.y = element_text(face = "bold"),
axis.title.x = element_text(face = "bold"),
axis.text = element_text(face = "bold"),
legend.text = element_text(face = "bold")
)