PREDICTIVE MODEL WITH METADATA AND SYNTHETIC DATA

######### Prepare patient metadata #########
meta <- read_excel("Merged_metadata.xlsx", sheet = "Entero_IMP") %>%
  mutate(`Isolate ID` = Experiment)

abx_cols <- paste0("antibiotic_", 1:5)

abx_data <- meta %>% select(`Isolate ID`, all_of(abx_cols))

abx_long <- abx_data %>%
  pivot_longer(cols = -`Isolate ID`, names_to = "which_abx", values_to = "drug")

abx_clean <- abx_long %>%
  mutate(drug = str_trim(drug)) %>%
  filter(!is.na(drug) & drug != "") %>%
  mutate(flag = 1)

abx_unique <- abx_clean %>%
  distinct(`Isolate ID`, drug, .keep_all = TRUE)

abx_binary <- abx_unique %>%
  pivot_wider(id_cols = `Isolate ID`, names_from = drug, values_from = flag,
              values_fill = 0, names_prefix = "ABX_")

set.seed(123)
meta_core <- meta %>%
  select(`Isolate ID`, outcome, hospital_1m, icu_adm) %>%
  mutate(outcome = factor(outcome, levels = c("Died in hospital", "Survived to discharge")))

model_df <- meta_core %>%
  left_join(abx_binary, by = "Isolate ID") %>%
  mutate(across(where(is.numeric), ~replace_na(., 0))) %>%
  mutate(source = "real") %>%
  select(-`Isolate ID`)  # Remove ID from predictors

glimpse(model_df)
## Rows: 57
## Columns: 25
## $ outcome                              <fct> Survived to discharge, Survived t…
## $ hospital_1m                          <chr> "No", "Yes", "Yes", "Yes", "Yes",…
## $ icu_adm                              <chr> "No", "Yes", "No", "Yes", "Yes", …
## $ ABX_Ceftriaxone                      <dbl> 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ `ABX_Piperacillin/tazobactam`        <dbl> 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, …
## $ ABX_Other                            <dbl> 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, …
## $ ABX_Meropenem                        <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ABX_Benzylpenicillin                 <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ ABX_Ciprofloxacin                    <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ `ABX_Trimethoprim-sulphamethoxazole` <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ ABX_Clarithromycin                   <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ `ABX_Ceftolozane/tazoabctam`         <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ ABX_Caspofungin                      <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ ABX_Amoxicillin                      <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ ABX_Vancomycin                       <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ ABX_Cefepime                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ABX_Cephazolin                       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ABX_Tobramycin                       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ABX_Flucloxacillin                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `ABX_Amoxicillin/clavulanate`        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ABX_Trimethoprim                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ABX_Cephalexin                       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ABX_Clindamycin                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ ABX_Fluconazole                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ source                               <chr> "real", "real", "real", "real", "…
model_df %>% count(outcome)
## # A tibble: 2 × 2
##   outcome                   n
##   <fct>                 <int>
## 1 Died in hospital          7
## 2 Survived to discharge    50
model_df <- model_df %>%
  rename_with(make.names) %>%
  select(-source)


codebook.syn(model_df)
## $tab
##                              variable     class nmiss perctmiss ndistinct
## 1                             outcome    factor     0         0         2
## 2                         hospital_1m character     0         0         2
## 3                             icu_adm character     0         0         2
## 4                     ABX_Ceftriaxone   numeric     0         0         2
## 5         ABX_Piperacillin.tazobactam   numeric     0         0         2
## 6                           ABX_Other   numeric     0         0         2
## 7                       ABX_Meropenem   numeric     0         0         2
## 8                ABX_Benzylpenicillin   numeric     0         0         2
## 9                   ABX_Ciprofloxacin   numeric     0         0         2
## 10 ABX_Trimethoprim.sulphamethoxazole   numeric     0         0         2
## 11                 ABX_Clarithromycin   numeric     0         0         2
## 12         ABX_Ceftolozane.tazoabctam   numeric     0         0         2
## 13                    ABX_Caspofungin   numeric     0         0         2
## 14                    ABX_Amoxicillin   numeric     0         0         2
## 15                     ABX_Vancomycin   numeric     0         0         2
## 16                       ABX_Cefepime   numeric     0         0         2
## 17                     ABX_Cephazolin   numeric     0         0         2
## 18                     ABX_Tobramycin   numeric     0         0         2
## 19                 ABX_Flucloxacillin   numeric     0         0         2
## 20        ABX_Amoxicillin.clavulanate   numeric     0         0         2
## 21                   ABX_Trimethoprim   numeric     0         0         2
## 22                     ABX_Cephalexin   numeric     0         0         2
## 23                    ABX_Clindamycin   numeric     0         0         2
## 24                    ABX_Fluconazole   numeric     0         0         2
##    details
## 1         
## 2         
## 3         
## 4         
## 5         
## 6         
## 7         
## 8         
## 9         
## 10        
## 11        
## 12        
## 13        
## 14        
## 15        
## 16        
## 17        
## 18        
## 19        
## 20        
## 21        
## 22        
## 23        
## 24        
## 
## $labs
## NULL

Create and prepare synthetic data

mysyn <- syn(model_df)
## CAUTION: Your data set has fewer observations (57) than we advise.
## We suggest that there should be at least 340 observations
## (100 + 10 * no. of variables used in modelling the data).
## Please check your synthetic data carefully with functions
## compare(), utility.tab(), and utility.gen().
## 
## 
## Variable(s): hospital_1m, icu_adm have been changed for synthesis from character to factor.
## Warning: In your synthesis there are numeric variables with 5 or fewer levels: ABX_Ceftriaxone, ABX_Piperacillin.tazobactam, ABX_Other, ABX_Meropenem, ABX_Benzylpenicillin, ABX_Ciprofloxacin, ABX_Trimethoprim.sulphamethoxazole, ABX_Clarithromycin, ABX_Ceftolozane.tazoabctam, ABX_Caspofungin, ABX_Amoxicillin, ABX_Vancomycin, ABX_Cefepime, ABX_Cephazolin, ABX_Tobramycin, ABX_Flucloxacillin, ABX_Amoxicillin.clavulanate, ABX_Trimethoprim, ABX_Cephalexin, ABX_Clindamycin, ABX_Fluconazole.
## Consider changing them to factors. You can do it using parameter 'minnumlevels'.
## Variables ABX_Ceftolozane.tazoabctam, ABX_Caspofungin are collinear. Variables later in 'visit.sequence'
## are derived from ABX_Ceftolozane.tazoabctam.
## 
## 
## Synthesis
## -----------
##  outcome hospital_1m icu_adm ABX_Ceftriaxone ABX_Piperacillin.tazobactam ABX_Other ABX_Meropenem ABX_Benzylpenicillin ABX_Ciprofloxacin ABX_Trimethoprim.sulphamethoxazole
##  ABX_Clarithromycin ABX_Ceftolozane.tazoabctam ABX_Caspofungin ABX_Amoxicillin ABX_Vancomycin ABX_Cefepime ABX_Cephazolin ABX_Tobramycin ABX_Flucloxacillin ABX_Amoxicillin.clavulanate
##  ABX_Trimethoprim ABX_Cephalexin ABX_Clindamycin ABX_Fluconazole
synthetic_df <- mysyn$syn %>%
  type.convert(as.is = TRUE) %>%
  rename_with(str_trim) %>%
  rename_with(make.names) %>%
  mutate(source = "synthetic",
         outcome = factor(outcome, levels = c("Died in hospital", "Survived to discharge")))

########## Combine real and synthetic data #############

combined_data <- bind_rows(model_df, synthetic_df)
mysyn <- syn(model_df)
## CAUTION: Your data set has fewer observations (57) than we advise.
## We suggest that there should be at least 340 observations
## (100 + 10 * no. of variables used in modelling the data).
## Please check your synthetic data carefully with functions
## compare(), utility.tab(), and utility.gen().
## 
## 
## Variable(s): hospital_1m, icu_adm have been changed for synthesis from character to factor.
## Warning: In your synthesis there are numeric variables with 5 or fewer levels: ABX_Ceftriaxone, ABX_Piperacillin.tazobactam, ABX_Other, ABX_Meropenem, ABX_Benzylpenicillin, ABX_Ciprofloxacin, ABX_Trimethoprim.sulphamethoxazole, ABX_Clarithromycin, ABX_Ceftolozane.tazoabctam, ABX_Caspofungin, ABX_Amoxicillin, ABX_Vancomycin, ABX_Cefepime, ABX_Cephazolin, ABX_Tobramycin, ABX_Flucloxacillin, ABX_Amoxicillin.clavulanate, ABX_Trimethoprim, ABX_Cephalexin, ABX_Clindamycin, ABX_Fluconazole.
## Consider changing them to factors. You can do it using parameter 'minnumlevels'.
## Variables ABX_Ceftolozane.tazoabctam, ABX_Caspofungin are collinear. Variables later in 'visit.sequence'
## are derived from ABX_Ceftolozane.tazoabctam.
## 
## 
## Synthesis
## -----------
##  outcome hospital_1m icu_adm ABX_Ceftriaxone ABX_Piperacillin.tazobactam ABX_Other ABX_Meropenem ABX_Benzylpenicillin ABX_Ciprofloxacin ABX_Trimethoprim.sulphamethoxazole
##  ABX_Clarithromycin ABX_Ceftolozane.tazoabctam ABX_Caspofungin ABX_Amoxicillin ABX_Vancomycin ABX_Cefepime ABX_Cephazolin ABX_Tobramycin ABX_Flucloxacillin ABX_Amoxicillin.clavulanate
##  ABX_Trimethoprim ABX_Cephalexin ABX_Clindamycin ABX_Fluconazole
summary(mysyn)
## Synthetic object with one synthesis using methods:
##                            outcome                        hospital_1m 
##                           "sample"                             "cart" 
##                            icu_adm                    ABX_Ceftriaxone 
##                             "cart"                             "cart" 
##        ABX_Piperacillin.tazobactam                          ABX_Other 
##                             "cart"                             "cart" 
##                      ABX_Meropenem               ABX_Benzylpenicillin 
##                             "cart"                             "cart" 
##                  ABX_Ciprofloxacin ABX_Trimethoprim.sulphamethoxazole 
##                             "cart"                             "cart" 
##                 ABX_Clarithromycin         ABX_Ceftolozane.tazoabctam 
##                             "cart"                             "cart" 
##                    ABX_Caspofungin                    ABX_Amoxicillin 
##                        "collinear"                             "cart" 
##                     ABX_Vancomycin                       ABX_Cefepime 
##                             "cart"                             "cart" 
##                     ABX_Cephazolin                     ABX_Tobramycin 
##                             "cart"                             "cart" 
##                 ABX_Flucloxacillin        ABX_Amoxicillin.clavulanate 
##                             "cart"                             "cart" 
##                   ABX_Trimethoprim                     ABX_Cephalexin 
##                             "cart"                             "cart" 
##                    ABX_Clindamycin                    ABX_Fluconazole 
##                             "cart"                             "cart" 
## 
##                   outcome   hospital_1m          icu_adm         
##  Died in hospital     : 7   Length:57          Length:57         
##  Survived to discharge:50   Class :character   Class :character  
##                             Mode  :character   Mode  :character  
##                                                                  
##                                                                  
##                                                                  
##  ABX_Ceftriaxone   ABX_Piperacillin.tazobactam   ABX_Other     
##  Min.   :0.00000   Min.   :0.0000              Min.   :0.0000  
##  1st Qu.:0.00000   1st Qu.:0.0000              1st Qu.:0.0000  
##  Median :0.00000   Median :0.0000              Median :0.0000  
##  Mean   :0.07018   Mean   :0.2807              Mean   :0.2807  
##  3rd Qu.:0.00000   3rd Qu.:1.0000              3rd Qu.:1.0000  
##  Max.   :1.00000   Max.   :1.0000              Max.   :1.0000  
##  ABX_Meropenem     ABX_Benzylpenicillin ABX_Ciprofloxacin
##  Min.   :0.00000   Min.   :0.00000      Min.   :0.0000   
##  1st Qu.:0.00000   1st Qu.:0.00000      1st Qu.:0.0000   
##  Median :0.00000   Median :0.00000      Median :0.0000   
##  Mean   :0.05263   Mean   :0.01754      Mean   :0.1228   
##  3rd Qu.:0.00000   3rd Qu.:0.00000      3rd Qu.:0.0000   
##  Max.   :1.00000   Max.   :1.00000      Max.   :1.0000   
##  ABX_Trimethoprim.sulphamethoxazole ABX_Clarithromycin
##  Min.   :0.00000                    Min.   :0.00000   
##  1st Qu.:0.00000                    1st Qu.:0.00000   
##  Median :0.00000                    Median :0.00000   
##  Mean   :0.08772                    Mean   :0.03509   
##  3rd Qu.:0.00000                    3rd Qu.:0.00000   
##  Max.   :1.00000                    Max.   :1.00000   
##  ABX_Ceftolozane.tazoabctam ABX_Caspofungin ABX_Amoxicillin   ABX_Vancomycin  
##  Min.   :0                  Min.   :0       Min.   :0.00000   Min.   :0.0000  
##  1st Qu.:0                  1st Qu.:0       1st Qu.:0.00000   1st Qu.:0.0000  
##  Median :0                  Median :0       Median :0.00000   Median :0.0000  
##  Mean   :0                  Mean   :0       Mean   :0.08772   Mean   :0.1053  
##  3rd Qu.:0                  3rd Qu.:0       3rd Qu.:0.00000   3rd Qu.:0.0000  
##  Max.   :0                  Max.   :0       Max.   :1.00000   Max.   :1.0000  
##   ABX_Cefepime     ABX_Cephazolin    ABX_Tobramycin    ABX_Flucloxacillin
##  Min.   :0.00000   Min.   :0.00000   Min.   :0.00000   Min.   :0.0000    
##  1st Qu.:0.00000   1st Qu.:0.00000   1st Qu.:0.00000   1st Qu.:0.0000    
##  Median :0.00000   Median :0.00000   Median :0.00000   Median :0.0000    
##  Mean   :0.03509   Mean   :0.01754   Mean   :0.05263   Mean   :0.1754    
##  3rd Qu.:0.00000   3rd Qu.:0.00000   3rd Qu.:0.00000   3rd Qu.:0.0000    
##  Max.   :1.00000   Max.   :1.00000   Max.   :1.00000   Max.   :1.0000    
##  ABX_Amoxicillin.clavulanate ABX_Trimethoprim  ABX_Cephalexin  
##  Min.   :0.0000              Min.   :0.00000   Min.   :0.0000  
##  1st Qu.:0.0000              1st Qu.:0.00000   1st Qu.:0.0000  
##  Median :0.0000              Median :0.00000   Median :0.0000  
##  Mean   :0.1404              Mean   :0.05263   Mean   :0.1053  
##  3rd Qu.:0.0000              3rd Qu.:0.00000   3rd Qu.:0.0000  
##  Max.   :1.0000              Max.   :1.00000   Max.   :1.0000  
##  ABX_Clindamycin   ABX_Fluconazole  
##  Min.   :0.00000   Min.   :0.00000  
##  1st Qu.:0.00000   1st Qu.:0.00000  
##  Median :0.00000   Median :0.00000  
##  Mean   :0.03509   Mean   :0.07018  
##  3rd Qu.:0.00000   3rd Qu.:0.00000  
##  Max.   :1.00000   Max.   :1.00000
compare(mysyn, model_df, stat = "counts")
## Calculations done for outcome 
## Calculations done for hospital_1m 
## Calculations done for icu_adm 
## Only 2 groups produced for ABX_Ceftriaxone even after changing method.
## Calculations done for ABX_Ceftriaxone 
## Only 2 groups produced for ABX_Piperacillin.tazobactam even after changing method.
## Calculations done for ABX_Piperacillin.tazobactam 
## Only 2 groups produced for ABX_Other even after changing method.
## Calculations done for ABX_Other 
## Only 2 groups produced for ABX_Meropenem even after changing method.
## Calculations done for ABX_Meropenem 
## Only 2 groups produced for ABX_Benzylpenicillin even after changing method.
## Calculations done for ABX_Benzylpenicillin 
## Only 2 groups produced for ABX_Ciprofloxacin even after changing method.
## Calculations done for ABX_Ciprofloxacin 
## Only 2 groups produced for ABX_Trimethoprim.sulphamethoxazole even after changing method.
## Calculations done for ABX_Trimethoprim.sulphamethoxazole 
## Only 2 groups produced for ABX_Clarithromycin even after changing method.
## Calculations done for ABX_Clarithromycin 
## Only 2 groups produced for ABX_Ceftolozane.tazoabctam even after changing method.
## Calculations done for ABX_Ceftolozane.tazoabctam 
## Only 2 groups produced for ABX_Caspofungin even after changing method.
## Calculations done for ABX_Caspofungin 
## Only 2 groups produced for ABX_Amoxicillin even after changing method.
## Calculations done for ABX_Amoxicillin 
## Only 2 groups produced for ABX_Vancomycin even after changing method.
## Calculations done for ABX_Vancomycin 
## Only 2 groups produced for ABX_Cefepime even after changing method.
## Calculations done for ABX_Cefepime 
## Only 2 groups produced for ABX_Cephazolin even after changing method.
## Calculations done for ABX_Cephazolin 
## Only 2 groups produced for ABX_Tobramycin even after changing method.
## Calculations done for ABX_Tobramycin 
## Only 2 groups produced for ABX_Flucloxacillin even after changing method.
## Calculations done for ABX_Flucloxacillin 
## Only 2 groups produced for ABX_Amoxicillin.clavulanate even after changing method.
## Calculations done for ABX_Amoxicillin.clavulanate 
## Only 2 groups produced for ABX_Trimethoprim even after changing method.
## Calculations done for ABX_Trimethoprim 
## Only 2 groups produced for ABX_Cephalexin even after changing method.
## Calculations done for ABX_Cephalexin 
## Only 2 groups produced for ABX_Clindamycin even after changing method.
## Calculations done for ABX_Clindamycin 
## Only 2 groups produced for ABX_Fluconazole even after changing method.
## Calculations done for ABX_Fluconazole
## 
## Comparing counts observed with synthetic

## Press return for next variable(s):

## Press return for next variable(s):

## Press return for next variable(s):

## Press return for next variable(s):

## Press return for next variable(s):

## 
## Selected utility measures:
##                                        pMSE   S_pMSE df
## outcome                            0.000000 0.000000  1
## hospital_1m                        0.003934 3.587413  1
## icu_adm                            0.000090 0.082459  1
## ABX_Ceftriaxone                    0.000265 0.241270  1
## ABX_Piperacillin.tazobactam        0.000097 0.088613  1
## ABX_Other                          0.000913 0.832454  1
## ABX_Meropenem                      0.000000 0.000000  1
## ABX_Benzylpenicillin               0.000751 0.684685  1
## ABX_Ciprofloxacin                  0.000817 0.745098  1
## ABX_Trimethoprim.sulphamethoxazole 0.000000 0.000000  1
## ABX_Clarithromycin                 0.000000 0.000000  1
## ABX_Ceftolozane.tazoabctam         0.002212 2.017699  1
## ABX_Caspofungin                    0.002212 2.017699  1
## ABX_Amoxicillin                    0.000000 0.000000  1
## ABX_Vancomycin                     0.000190 0.173648  1
## ABX_Cefepime                       0.000751 0.684685  1
## ABX_Cephazolin                     0.002273 2.072727  1
## ABX_Tobramycin                     0.002273 2.072727  1
## ABX_Flucloxacillin                 0.001364 1.244391  1
## ABX_Amoxicillin.clavulanate        0.000168 0.153535  1
## ABX_Trimethoprim                   0.002273 2.072727  1
## ABX_Cephalexin                     0.000221 0.201236  1
## ABX_Clindamycin                    0.000000 0.000000  1
## ABX_Fluconazole                    0.001543 1.407407  1

Modeling

set.seed(345)
splits     <- initial_split(combined_data, prop = 0.75, strata = outcome)
train_data <- training(splits)
test_data  <- testing(splits)
train_data %>% summarise(across(everything(), ~ sum(is.na(.)))) %>% t()
##                                    [,1]
## outcome                               0
## hospital_1m                           0
## icu_adm                               0
## ABX_Ceftriaxone                       0
## ABX_Piperacillin.tazobactam           0
## ABX_Other                             0
## ABX_Meropenem                         0
## ABX_Benzylpenicillin                  0
## ABX_Ciprofloxacin                     0
## ABX_Trimethoprim.sulphamethoxazole    0
## ABX_Clarithromycin                    0
## ABX_Ceftolozane.tazoabctam            0
## ABX_Caspofungin                       0
## ABX_Amoxicillin                       0
## ABX_Vancomycin                        0
## ABX_Cefepime                          0
## ABX_Cephazolin                        0
## ABX_Tobramycin                        0
## ABX_Flucloxacillin                    0
## ABX_Amoxicillin.clavulanate           0
## ABX_Trimethoprim                      0
## ABX_Cephalexin                        0
## ABX_Clindamycin                       0
## ABX_Fluconazole                       0
## source                               45
####### Recipe ########
model_recipe <- recipe(outcome ~ ., data = train_data) %>%
  update_role(source, new_role = "id") %>%
  step_string2factor(all_nominal_predictors()) %>%
  step_unknown(all_nominal_predictors()) %>%
  step_zv(all_nominal_predictors()) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_zv(all_predictors()) %>%
  step_impute_median(all_numeric_predictors()) %>%  # imputes numeric NA
  step_impute_mode(all_nominal_predictors()) %>%  # imputes factor NA
  step_upsample(outcome)

#### Model tunning ######

log_reg <- logistic_reg(penalty = tune(), mixture = 1, mode = "classification") %>%
  set_engine("glmnet")

wflow <- workflow() %>%
  add_recipe(model_recipe) %>%
  add_model(log_reg)

set.seed(678)
folds <- vfold_cv(train_data, v = 5, strata = outcome)
grid  <- tibble(penalty = 10^seq(-4, -1, length.out = 30))

tuned <- tune_grid(wflow, resamples = folds, grid = grid,
                   metrics = metric_set(roc_auc, sens, spec),
                   control = control_grid(save_pred = TRUE))

best <- select_best(tuned, metric = "roc_auc")

final_fit <- finalize_workflow(wflow, best) %>% fit(train_data)

Evaluate on test set

test_pred <- bind_cols(
  test_data %>% select(outcome),
  predict(final_fit, test_data, type = "prob"),
  predict(final_fit, test_data))

names(test_pred)
## [1] "outcome"                     ".pred_Died in hospital"     
## [3] ".pred_Survived to discharge" ".pred_class"
metrics <- metric_set(roc_auc, accuracy, sens, spec)(test_pred,
  truth = outcome, estimate = .pred_class, ".pred_Survived to discharge",
  event_level = "second")

confusion <- conf_mat(test_pred, truth = outcome, estimate = .pred_class)

print(metrics)
## # A tibble: 4 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.966
## 2 sens     binary         1    
## 3 spec     binary         0.667
## 4 roc_auc  binary         0.949
print(confusion)
##                        Truth
## Prediction              Died in hospital Survived to discharge
##   Died in hospital                     2                     0
##   Survived to discharge                1                    26
######### Coefficients and visualization ###########
coef_df <- tidy(final_fit) %>%
  filter(term != "(Intercept)", estimate != 0) %>%
  mutate(direction = ifelse(estimate > 0, "Survival increase", "Survival decrease"))

print(coef_df %>% arrange(desc(abs(estimate))) %>% select(term, estimate), n = 30)
## # A tibble: 14 × 2
##    term                        estimate
##    <chr>                          <dbl>
##  1 ABX_Vancomycin                 5.67 
##  2 ABX_Cefepime                  -4.73 
##  3 ABX_Cephalexin                 4.33 
##  4 ABX_Ciprofloxacin             -3.95 
##  5 ABX_Amoxicillin                3.11 
##  6 ABX_Meropenem                  3.05 
##  7 ABX_Cephazolin                -2.24 
##  8 hospital_1m_Yes                1.46 
##  9 ABX_Ceftriaxone               -1.31 
## 10 ABX_Flucloxacillin             1.20 
## 11 ABX_Amoxicillin.clavulanate    1.09 
## 12 ABX_Fluconazole                0.800
## 13 ABX_Piperacillin.tazobactam    0.313
## 14 icu_adm_Yes                    0.190
# Penalty vs AUC
autoplot(tuned) + ggtitle("Penalty vs AUC")

# ROC curve
roc_curve(test_pred, outcome, ".pred_Survived to discharge") %>%
  ggplot(aes(1 - specificity, sensitivity)) +
  geom_path(linewidth = 1.2) +
  geom_abline(lty = 2) +
  coord_equal() +
  labs(title = glue::glue("ROC curve (AUC = {round(metrics %>% filter(.metric == 'roc_auc') %>% pull(.estimate), 3)})"),
       x = "1 – Specificity (False-positive rate)", y = "Sensitivity (True-positive rate)")

# Confusion-matrix heatmap
conf_mat(test_pred, truth = outcome, estimate = .pred_class) %>%
  autoplot(type = "heatmap") + scale_fill_gradient(low = "white", high = "hotpink") + theme_minimal()
## Scale for fill is already present.
## Adding another scale for fill, which will replace the existing scale.

# Barplot of coefficients
ggplot(coef_df, aes(x = reorder(term, abs(estimate)), y = estimate, fill = direction)) +
  geom_col(width = 0.75) + coord_flip() +
  scale_fill_manual(values = c("Survival increase" = "red", "Survival decrease" = "slategray")) +
  labs(x = NULL, y = "Log-odds coefficient\n(positive = increases survival)",
       title = "Features retained by LASSO -- Metadata with Synthetic data") +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    plot.title = element_text(face = "bold"),
    axis.title.y = element_text(face = "bold"),
    axis.title.x = element_text(face = "bold"),
    axis.text = element_text(face = "bold"),
    legend.text = element_text(face = "bold")
  )

Random Forest

library(ranger)
library(vip)  # for variable importance plot
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi
# Define random forest model
rf_model <- rand_forest(
  mtry = tune(), 
  trees = 1000, 
  min_n = tune()
) %>%
  set_engine("ranger", importance = "impurity") %>%
  set_mode("classification")

# Add to workflow
rf_wflow <- workflow() %>%
  add_recipe(model_recipe) %>%
  add_model(rf_model)

# Tune hyperparameters
rf_grid <- grid_regular(
  mtry(range = c(5, 30)), 
  min_n(range = c(2, 10)), 
  levels = 5
)

set.seed(456)
rf_tuned <- tune_grid(
  rf_wflow, 
  resamples = folds, 
  grid = rf_grid,
  metrics = metric_set(roc_auc, accuracy, sens, spec),
  control = control_grid(save_pred = TRUE))
## → A | warning: ! 23 columns were requested but there were 20 predictors in the data.
##                ℹ 20 predictors will be used.
## There were issues with some computations   A: x1                                                 → B | warning: ! 30 columns were requested but there were 20 predictors in the data.
##                ℹ 20 predictors will be used.
## There were issues with some computations   A: x1There were issues with some computations   A: x1   B: x1There were issues with some computations   A: x2   B: x1There were issues with some computations   A: x3   B: x2There were issues with some computations   A: x3   B: x3There were issues with some computations   A: x4   B: x3There were issues with some computations   A: x5   B: x4There were issues with some computations   A: x5   B: x5                                                         → C | warning: ! 23 columns were requested but there were 22 predictors in the data.
##                ℹ 22 predictors will be used.
## There were issues with some computations   A: x5   B: x5There were issues with some computations   A: x5   B: x5   C: x1                                                                 → D | warning: ! 30 columns were requested but there were 22 predictors in the data.
##                ℹ 22 predictors will be used.
## There were issues with some computations   A: x5   B: x5   C: x1There were issues with some computations   A: x5   B: x5   C: x2   D: x1There were issues with some computations   A: x5   B: x5   C: x3   D: x2There were issues with some computations   A: x5   B: x5   C: x3   D: x3There were issues with some computations   A: x5   B: x5   C: x4   D: x3There were issues with some computations   A: x5   B: x5   C: x5   D: x4There were issues with some computations   A: x5   B: x5   C: x5   D: x5                                                                         → E | warning: ! 23 columns were requested but there were 21 predictors in the data.
##                ℹ 21 predictors will be used.
## There were issues with some computations   A: x5   B: x5   C: x5   D: x5There were issues with some computations   A: x5   B: x5   C: x5   D: x5   E: x1                                                                                 → F | warning: ! 30 columns were requested but there were 21 predictors in the data.
##                ℹ 21 predictors will be used.
## There were issues with some computations   A: x5   B: x5   C: x5   D: x5   E: x1There were issues with some computations   A: x5   B: x5   C: x5   D: x5   E: x…There were issues with some computations   A: x5   B: x5   C: x5   D: x5   E: x…There were issues with some computations   A: x5   B: x5   C: x5   D: x5   E: x…There were issues with some computations   A: x5   B: x5   C: x5   D: x5   E: x…There were issues with some computations   A: x5   B: x5   C: x5   D: x5   E: x…There were issues with some computations   A: x5   B: x5   C: x5   D: x5   E: x…There were issues with some computations   A: x5   B: x5   C: x5   D: x5   E: x…There were issues with some computations   A: x5   B: x5   C: x6   D: x5   E: x…There were issues with some computations   A: x5   B: x5   C: x7   D: x6   E: x…There were issues with some computations   A: x5   B: x5   C: x7   D: x7   E: x…There were issues with some computations   A: x5   B: x5   C: x8   D: x7   E: x…There were issues with some computations   A: x5   B: x5   C: x9   D: x8   E: x…There were issues with some computations   A: x5   B: x5   C: x10   D: x9   E: …There were issues with some computations   A: x5   B: x5   C: x11   D: x10   E:…There were issues with some computations   A: x5   B: x5   C: x12   D: x11   E:…There were issues with some computations   A: x5   B: x5   C: x13   D: x12   E:…There were issues with some computations   A: x5   B: x5   C: x13   D: x13   E:…There were issues with some computations   A: x5   B: x5   C: x14   D: x13   E:…There were issues with some computations   A: x5   B: x5   C: x15   D: x14   E:…There were issues with some computations   A: x5   B: x5   C: x15   D: x15   E:…
# Select best hyperparameters by AUC
rf_best <- select_best(rf_tuned, metric = "roc_auc")

# Finalize and fit
rf_final <- finalize_workflow(rf_wflow, rf_best) %>%
  fit(data = train_data)

# Predict on test set
rf_pred <- bind_cols(
  test_data %>% select(outcome),
  predict(rf_final, test_data, type = "prob") %>% janitor::clean_names(),
  predict(rf_final, test_data))

# Metrics
rf_metrics <- metric_set(roc_auc, accuracy, sens, spec)(
  rf_pred, truth = outcome, 
  estimate = .pred_class, pred_survived_to_discharge, event_level = "second")

rf_confusion <- conf_mat(rf_pred, truth = outcome, estimate = .pred_class)

# Print results
print(rf_metrics)
## # A tibble: 4 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.862
## 2 sens     binary         0.885
## 3 spec     binary         0.667
## 4 roc_auc  binary         0.929
print(rf_confusion)
##                        Truth
## Prediction              Died in hospital Survived to discharge
##   Died in hospital                     2                     3
##   Survived to discharge                1                    23
# ROC Curve
roc_curve(rf_pred, outcome, pred_survived_to_discharge) %>%
  ggplot(aes(1 - specificity, sensitivity)) +
  geom_path(linewidth = 1.2, color = "darkgreen") +
  geom_abline(lty = 2, color = "darkgrey") +
  coord_equal() +
  labs(title = glue::glue("Random Forest ROC (AUC = {round(rf_metrics %>% filter(.metric == 'roc_auc') %>% pull(.estimate), 3)})"),
       x = "1 – Specificity", y = "Sensitivity")

# Variable Importance
rf_final %>%
  extract_fit_parsnip() %>%
  vip(num_features = 10, geom = "col", aesthetics = list(fill = "darkturquoise")) +
  labs(title = "Metadata with Synthetic Data Top 10 Important Features (Random Forest)") +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    plot.title = element_text(face = "bold", color = "black"),
    axis.title.y = element_text(face = "bold"),
    axis.title.x = element_text(face = "bold"),
    axis.text = element_text(face = "bold"),
    legend.text = element_text(face = "bold")
  )