setwd("/Users/cjavitia/Desktop/Internships/CFDE/AustraliaData")
while (!is.null(dev.list())) dev.off()
rm(list = ls())
library(tidymodels)
library(readxl)
library(stringr)
library(themis)
library(glmnet)
library(synthpop)
library(janitor)
library(ggplot2)
Patient metadata and antibiotics
meta <- read_excel("Merged_metadata.xlsx", sheet = "Entero_IMP")
abx_cols <- paste0("antibiotic_", 1:5)
abx_data <- meta %>%
select(`Isolate ID`, all_of(abx_cols))
abx_long <- abx_data %>%
pivot_longer(cols = -`Isolate ID`, names_to = "which_abx", values_to = "drug")
abx_clean <- abx_long %>%
mutate(drug = str_trim(drug)) %>%
filter(!is.na(drug) & drug != "") %>%
mutate(flag = 1)
abx_unique <- abx_clean %>%
distinct(`Isolate ID`, drug, .keep_all = TRUE)
abx_binary <- abx_unique %>%
pivot_wider(id_cols = `Isolate ID`, names_from = drug, values_from = flag, values_fill = 0, names_prefix = "ABX_")
Combine metadata and Sequencing data
meta_core <- meta %>%
select(`Isolate ID`, outcome, hospital_1m, icu_adm) %>%
mutate(outcome = factor(outcome, levels = c("Died in hospital", "Survived to discharge")))
model_df <- meta_core %>%
left_join(abx_binary, by = "Isolate ID") %>%
mutate(across(where(is.numeric), ~replace_na(., 0)))
model_df <- model_df %>%
select(-`Isolate ID`)
colnames(model_df)
## [1] "outcome" "hospital_1m"
## [3] "icu_adm" "ABX_Ceftriaxone"
## [5] "ABX_Piperacillin/tazobactam" "ABX_Other"
## [7] "ABX_Meropenem" "ABX_Benzylpenicillin"
## [9] "ABX_Ciprofloxacin" "ABX_Trimethoprim-sulphamethoxazole"
## [11] "ABX_Clarithromycin" "ABX_Ceftolozane/tazoabctam"
## [13] "ABX_Caspofungin" "ABX_Amoxicillin"
## [15] "ABX_Vancomycin" "ABX_Cefepime"
## [17] "ABX_Cephazolin" "ABX_Tobramycin"
## [19] "ABX_Flucloxacillin" "ABX_Amoxicillin/clavulanate"
## [21] "ABX_Trimethoprim" "ABX_Cephalexin"
## [23] "ABX_Clindamycin" "ABX_Fluconazole"
set.seed(345)
splits <- initial_split(model_df, prop = 0.75, strata = outcome)
train_data <- training(splits)
test_data <- testing(splits)
model_recipe <- recipe(outcome ~ ., data = train_data) %>%
step_string2factor(all_nominal_predictors()) %>%
step_unknown(all_nominal_predictors()) %>%
step_zv(all_nominal_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_impute_median(all_numeric_predictors()) %>% # imputes numeric NA
step_impute_mode(all_nominal_predictors()) %>% # imputes factor NA
step_upsample(outcome)
log_reg <- logistic_reg(penalty = tune(), mixture = 1, mode = "classification") %>%
set_engine("glmnet", weights = ifelse(train_data$outcome == "Survived to discharge", 1, 50/7))
wflow <- workflow() %>%
add_recipe(model_recipe) %>%
add_model(log_reg)
set.seed(678)
folds <- vfold_cv(train_data, v = 5, strata = outcome)
grid <- tibble(penalty = 10^seq(-4, -1, length.out = 30))
tuned <- tune_grid(wflow, resamples = folds, grid = grid, metrics = metric_set(roc_auc, sens, spec),
control = control_grid(save_pred = TRUE))
## → A | warning: The argument `weights` cannot be manually modified and was removed.
## There were issues with some computations A: x1There were issues with some computations A: x4There were issues with some computations A: x5
best <- select_best(tuned, metric = "roc_auc")
final_fit <- finalize_workflow(wflow, best) %>% fit(train_data)
## Warning: The argument `weights` cannot be manually modified and was removed.
Model results
test_pred <- bind_cols(test_data %>% select(outcome), predict(final_fit, test_data,
type = "prob") %>% janitor::clean_names(), predict(final_fit, test_data))
metrics <- metric_set(roc_auc, accuracy, sens, spec)(test_pred, truth = outcome,
estimate = .pred_class, pred_survived_to_discharge, event_level = "second")
confusion <- conf_mat(test_pred, truth = outcome, estimate = .pred_class)
print(metrics)
## # A tibble: 4 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.733
## 2 sens binary 0.692
## 3 spec binary 1
## 4 roc_auc binary 0.808
print(confusion)
## Truth
## Prediction Died in hospital Survived to discharge
## Died in hospital 2 4
## Survived to discharge 0 9
tidy(final_fit) %>%
filter(estimate != 0) %>%
arrange(desc(abs(estimate))) %>%
select(term, estimate) %>%
print(n = 30)
## # A tibble: 17 × 2
## term estimate
## <chr> <dbl>
## 1 ABX_Vancomycin 1.27e+ 1
## 2 ABX_Cephalexin 5.12e+ 0
## 3 ABX_Clindamycin 4.83e+ 0
## 4 ABX_Meropenem 4.69e+ 0
## 5 ABX_Ciprofloxacin -4.36e+ 0
## 6 ABX_Ceftriaxone -3.33e+ 0
## 7 ABX_Flucloxacillin 3.03e+ 0
## 8 ABX_Amoxicillin/clavulanate 2.76e+ 0
## 9 ABX_Other -2.51e+ 0
## 10 ABX_Amoxicillin 2.02e+ 0
## 11 ABX_Ceftolozane/tazoabctam 1.98e+ 0
## 12 hospital_1m_Yes 1.41e+ 0
## 13 ABX_Fluconazole 9.72e- 1
## 14 ABX_Clarithromycin 9.19e- 1
## 15 ABX_Trimethoprim-sulphamethoxazole 9.03e- 1
## 16 (Intercept) -1.77e- 1
## 17 ABX_Caspofungin 8.96e-15
autoplot(tuned) +
ggtitle("Penalty vs. cross-validated AUC")
roc_curve(test_pred, outcome, pred_survived_to_discharge) %>%
ggplot(aes(1 - specificity, sensitivity)) +
geom_path(linewidth = 1.2) +
geom_abline(lty = 2, colour = "grey50") +
coord_equal() +
labs(title = glue::glue("ROC curve (AUC = {round(metrics %>% filter(.metric == 'roc_auc') %>% pull(.estimate), 3)})"),
x = "1 – Specificity (False-positive rate)", y = "Sensitivity (True-positive rate)")
conf_mat(test_pred, truth = outcome, estimate = .pred_class) %>%
autoplot(type = "heatmap") + scale_fill_gradient(low = "white", high = "hotpink") + theme_minimal()
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.
coef_df <- tidy(final_fit) %>%
filter(term != "(Intercept)", estimate != 0) %>%
mutate(direction = ifelse(estimate > 0, "Survival increase", "Survival decrease"))
ggplot(coef_df, aes(x = reorder(term, abs(estimate)), y = estimate, fill = direction)) +
geom_col(width = 0.75) + coord_flip() +
scale_fill_manual(values = c("Survival increase" = "red", "Survival decrease" = "slategray")) +
labs(x = NULL, y = "Log-odds coefficient\n(positive = increases survival)", title = "Features retained by LASSO Metadata Only") +
theme_minimal() +
theme(
legend.title = element_blank(),
plot.title = element_text(face = "bold"),
axis.title.y = element_text(face = "bold"),
axis.title.x = element_text(face = "bold"),
axis.text = element_text(face = "bold"),
legend.text = element_text(face = "bold")
)
library(ranger)
library(vip) # for variable importance plot
##
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
##
## vi
# Define random forest model
rf_model <- rand_forest(
mtry = tune(),
trees = 1000,
min_n = tune()
) %>%
set_engine("ranger", importance = "impurity") %>%
set_mode("classification")
# Add to workflow
rf_wflow <- workflow() %>%
add_recipe(model_recipe) %>%
add_model(rf_model)
# Tune hyperparameters
rf_grid <- grid_regular(
mtry(range = c(5, 30)),
min_n(range = c(2, 10)),
levels = 5)
set.seed(456)
rf_tuned <- tune_grid(
rf_wflow,
resamples = folds,
grid = rf_grid,
metrics = metric_set(roc_auc, accuracy, sens, spec),
control = control_grid(save_pred = TRUE))
## → A | warning: ! 23 columns were requested but there were 18 predictors in the data.
## ℹ 18 predictors will be used.
## There were issues with some computations A: x1 → B | warning: ! 30 columns were requested but there were 18 predictors in the data.
## ℹ 18 predictors will be used.
## There were issues with some computations A: x1There were issues with some computations A: x2 B: x1There were issues with some computations A: x3 B: x2There were issues with some computations A: x4 B: x3There were issues with some computations A: x5 B: x4 → C | warning: ! 23 columns were requested but there were 21 predictors in the data.
## ℹ 21 predictors will be used.
## There were issues with some computations A: x5 B: x4There were issues with some computations A: x5 B: x5 C: x1 → D | warning: ! 30 columns were requested but there were 21 predictors in the data.
## ℹ 21 predictors will be used.
## There were issues with some computations A: x5 B: x5 C: x1There were issues with some computations A: x5 B: x5 C: x2 D: x1There were issues with some computations A: x5 B: x5 C: x2 D: x2There were issues with some computations A: x5 B: x5 C: x3 D: x3There were issues with some computations A: x5 B: x5 C: x4 D: x3There were issues with some computations A: x5 B: x5 C: x5 D: x4There were issues with some computations A: x5 B: x5 C: x6 D: x5There were issues with some computations A: x5 B: x5 C: x7 D: x6There were issues with some computations A: x5 B: x5 C: x8 D: x7There were issues with some computations A: x5 B: x5 C: x9 D: x8There were issues with some computations A: x5 B: x5 C: x10 D: x9 → E | warning: ! 23 columns were requested but there were 19 predictors in the data.
## ℹ 19 predictors will be used.
## There were issues with some computations A: x5 B: x5 C: x10 D: x9There were issues with some computations A: x5 B: x5 C: x10 D: x10 E:… → F | warning: ! 30 columns were requested but there were 19 predictors in the data.
## ℹ 19 predictors will be used.
## There were issues with some computations A: x5 B: x5 C: x10 D: x10 E:…There were issues with some computations A: x5 B: x5 C: x10 D: x10 E:…There were issues with some computations A: x5 B: x5 C: x10 D: x10 E:…There were issues with some computations A: x5 B: x5 C: x10 D: x10 E:…There were issues with some computations A: x5 B: x5 C: x10 D: x10 E:…There were issues with some computations A: x5 B: x5 C: x10 D: x10 E:…There were issues with some computations A: x5 B: x5 C: x11 D: x10 E:…There were issues with some computations A: x5 B: x5 C: x12 D: x11 E:…There were issues with some computations A: x5 B: x5 C: x13 D: x12 E:…There were issues with some computations A: x5 B: x5 C: x14 D: x13 E:…There were issues with some computations A: x5 B: x5 C: x15 D: x14 E:…There were issues with some computations A: x5 B: x5 C: x15 D: x15 E:…
# Select best hyperparameters by AUC
rf_best <- select_best(rf_tuned, metric = "roc_auc")
# Finalize and fit
rf_final <- finalize_workflow(rf_wflow, rf_best) %>%
fit(data = train_data)
## Warning: ! 23 columns were requested but there were 21 predictors in the data.
## ℹ 21 predictors will be used.
# Predict on test set
rf_pred <- bind_cols(test_data %>% select(outcome),
predict(rf_final, test_data, type = "prob") %>% janitor::clean_names(),
predict(rf_final, test_data))
# Metrics
rf_metrics <- metric_set(roc_auc, accuracy, sens, spec)(rf_pred, truth = outcome,
estimate = .pred_class, pred_survived_to_discharge, event_level = "second")
rf_confusion <- conf_mat(rf_pred, truth = outcome, estimate = .pred_class)
# Print results
print(rf_metrics)
## # A tibble: 4 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.8
## 2 sens binary 0.769
## 3 spec binary 1
## 4 roc_auc binary 0.846
print(rf_confusion)
## Truth
## Prediction Died in hospital Survived to discharge
## Died in hospital 2 3
## Survived to discharge 0 10
# ROC Curve
roc_curve(rf_pred, outcome, pred_survived_to_discharge) %>%
ggplot(aes(1 - specificity, sensitivity)) +
geom_path(linewidth = 1.2, color = "darkgreen") +
geom_abline(lty = 2, color = "darkgrey") +
coord_equal() +
labs(title = glue::glue("Random Forest ROC (AUC = {round(rf_metrics %>% filter(.metric == 'roc_auc') %>% pull(.estimate), 3)})"),
x = "1 – Specificity", y = "Sensitivity")
# Variable Importance
rf_final %>%
extract_fit_parsnip() %>%
vip(num_features = 10, geom = "col", aesthetics = list(fill = "darkturquoise")) +
labs(title = "Metadata Only Top 10 Important Features (Random Forest)") +
theme_minimal() +
theme(
legend.title = element_blank(),
plot.title = element_text(face = "bold", color = "black"),
axis.title.y = element_text(face = "bold"),
axis.title.x = element_text(face = "bold"),
axis.text = element_text(face = "bold"),
legend.text = element_text(face = "bold")
)