library(pacman)
p_load( tidymodels, tidyverse, plotly, skimr, themis, readxl, janitor, vip, nnet, brulee)
msc <- read_csv("mscph.csv", col_types = cols(), na = "NA")
names(msc)
[1] "pep_id_sex" "pep_id_deliv_type"
[3] "pep_id_gest_age_birth_wks" "pep_id_hiv_serostatus_mo"
[5] "pep_id_birth_weight_grams" "mothers_recent_viral_load"
[7] "first_pcr_test_result" "second_pcr_test"
[9] "final_outcome_result" "mother_received_intervetion_y_n"
[11] "place_site_of_delivery" "brst_feeding_method"
[13] "type_of_avr_prophy_for_baby" "cotrimoxazole_administered_y_n"
[15] "age_mo" "education"
[17] "marital_st"
# Rename columns
msc <- msc |> rename(infant_sex = pep_id_sex,
delivery_type = pep_id_deliv_type,
infant_age_at_birth = pep_id_gest_age_birth_wks,
mother_hiv_serostatus = pep_id_hiv_serostatus_mo,
infant_weight = pep_id_birth_weight_grams,
mother_viral_load = mothers_recent_viral_load,
first_pcr = first_pcr_test_result,
second_pcr = second_pcr_test,
final_outcome = final_outcome_result,
mother_received_intervetion = mother_received_intervetion_y_n,
delivery_place = place_site_of_delivery,
breast_feeding_method = brst_feeding_method,
baby_prophylaxis = type_of_avr_prophy_for_baby,
cotrimaxazole_for_mmother = cotrimoxazole_administered_y_n,
mother_age = age_mo,
education_level = education,
marital_status = marital_st)
# To factors
msc_names <- c("infant_sex", "delivery_type", "mother_hiv_serostatus", "first_pcr", "second_pcr", "final_outcome", "mother_received_intervetion", "delivery_place", "breast_feeding_method", "baby_prophylaxis", "cotrimaxazole_for_mmother", "education_level", "marital_status")
msc <- msc |>
mutate(across(all_of(msc_names), factor))
str(msc)
tibble [600 × 17] (S3: tbl_df/tbl/data.frame)
$ infant_sex : Factor w/ 2 levels "Female","Male": 1 1 1 2 1 2 2 1 1 1 ...
$ delivery_type : Factor w/ 4 levels "Assisted vaginal",..: 2 4 4 4 4 4 4 2 4 4 ...
$ infant_age_at_birth : num [1:600] 37 38 38 38 38 37 38 39 38 38 ...
$ mother_hiv_serostatus : Factor w/ 3 levels "HIV-1 & HIV-2 positive",..: 1 2 2 2 2 2 2 1 2 2 ...
$ infant_weight : num [1:600] 3000 2600 2600 2000 3400 2400 3200 2900 2100 3300 ...
$ mother_viral_load : num [1:600] 20 20 20 20 20 20.3 20 876 20 20 ...
$ first_pcr : Factor w/ 2 levels "Negative","Positive": 1 1 1 1 1 1 1 1 1 1 ...
$ second_pcr : Factor w/ 2 levels "Negative","Positive": 1 1 1 1 1 1 1 1 1 1 ...
$ final_outcome : Factor w/ 2 levels "HIV Infected",..: 2 2 2 2 1 2 2 2 2 1 ...
$ mother_received_intervetion: Factor w/ 2 levels "No","Yes": 2 2 2 2 2 2 2 2 2 2 ...
$ delivery_place : Factor w/ 2 levels "Inside Facility",..: 1 1 1 1 1 1 1 1 1 1 ...
$ breast_feeding_method : Factor w/ 2 levels "Exclusive","Mixed Feeding": 2 2 2 1 1 2 2 1 1 1 ...
$ baby_prophylaxis : Factor w/ 2 levels "AZT + NVP","NVP": 2 2 2 2 2 2 2 2 2 2 ...
$ cotrimaxazole_for_mmother : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 2 2 2 2 2 ...
$ mother_age : num [1:600] 42 46 22 44 38 41 36 33 36 33 ...
$ education_level : Factor w/ 3 levels "Primary","Secondary",..: 3 3 2 2 2 2 2 2 3 2 ...
$ marital_status : Factor w/ 5 levels "Divorced","Married",..: 5 4 2 2 2 2 2 NA 4 2 ...
# Remove columns with NAs
new_msc <- msc |> select(-marital_status)
# Split data
set.seed(111)
data_split <- initial_split(new_msc, prop = .8, strata = final_outcome)
data_split
<Training/Testing/Total>
<480/120/600>
# Training and test data
m_train <- data_split |> training()
m_test <- data_split |> testing()
m_train |> glimpse()
Rows: 480
Columns: 16
$ infant_sex <fct> Female, Female, Male, Female, Female, Fema…
$ delivery_type <fct> Standard vaginal, Standard vaginal, Standa…
$ infant_age_at_birth <dbl> 38, 38, 40, 38, 38, 38, 38, 38, 38, 38, 38…
$ mother_hiv_serostatus <fct> HIV-1 positive, HIV-1 positive, HIV-1 posi…
$ infant_weight <dbl> 3400, 3300, 4000, 2200, 2200, 2500, 3500, …
$ mother_viral_load <dbl> 20, 20, 52, 20, 20, 20, 20, 20, 243, 20, 2…
$ first_pcr <fct> Negative, Negative, Negative, Positive, Po…
$ second_pcr <fct> Negative, Negative, Negative, Positive, Po…
$ final_outcome <fct> HIV Infected, HIV Infected, HIV Infected, …
$ mother_received_intervetion <fct> Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Ye…
$ delivery_place <fct> Inside Facility, Inside Facility, Inside F…
$ breast_feeding_method <fct> Exclusive, Exclusive, Mixed Feeding, Exclu…
$ baby_prophylaxis <fct> NVP, NVP, NVP, NVP, NVP, NVP, NVP, NVP, NV…
$ cotrimaxazole_for_mmother <fct> Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Ye…
$ mother_age <dbl> 38, 33, 23, 47, 44, 33, 34, 34, 40, 32, 40…
$ education_level <fct> Secondary, Secondary, Secondary, Tertiary,…
levels(m_train$final_outcome)
[1] "HIV Infected" "HIV Uninfected"
# Recipe
reci_m <- recipe(final_outcome ~ ., data = m_train) |>
step_log(mother_viral_load, base = 10, offset = 1) |>
step_dummy(all_nominal_predictors()) |>
step_interact(terms = ~ infant_weight:infant_age_at_birth) |>
step_normalize(all_numeric_predictors())
# Create model designt
nn_hiv <- mlp(hidden_units = 50, epochs = 10000, penalty = 0) |>
set_engine("nnet", MaxNWts = 5000) |>
set_mode("classification")
nn_hiv |> print()
Single Layer Neural Network Model Specification (classification)
Main Arguments:
hidden_units = 50
penalty = 0
epochs = 10000
Engine-Specific Arguments:
MaxNWts = 5000
Computational engine: nnet
# Add recipe and model design to workflow and fit
workflow_nn <- workflow() |>
add_model(nn_hiv) |>
add_recipe(reci_m) |>
fit(m_train)
# Predict
nn_hiv_test <- workflow_nn |>
augment(new_data = m_test)
nn_metrics <- metric_set(accuracy, sensitivity, specificity, precision, roc_auc, npv)
nn_hiv_test |>
nn_metrics(truth = final_outcome,
estimate = .pred_class, # For accuracy, sens, spec
`.pred_HIV Infected`)
# A tibble: 6 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy binary 0.758
2 sensitivity binary 0.714
3 specificity binary 0.782
4 precision binary 0.638
5 npv binary 0.836
6 roc_auc binary 0.747
## Create a Model Design
bru_hiv <- mlp(hidden_units = tune(), dropout = tune(),
epochs = 1000, penalty = tune()) |>
set_engine("brulee") |>
set_mode("classification")
# Add the Recipe and the Model Design to a Workflow
workflow_bru <- workflow() |>
add_model(bru_hiv) |>
add_recipe(reci_m)
# Create a Hyper-Parameter Grid
bru_grid <- expand.grid(
hidden_units = c(2, 5, 10, 15), # Narrower, more appropriate range
penalty = c(0.01, 0.1, 1.0), # Higher penalties to counteract small data
dropout = c(0, 0.2, 0.4) # Increased dropout for regularization
)
# Cross validation
# Resamples for cross validation
set.seed(313)
bru_fold <- vfold_cv(m_train, v = 4, strata = final_outcome)
bru_metrics <- metric_set(accuracy, sensitivity, specificity, precision, roc_auc, npv)
# Tune the Workflow and Train All Models
set.seed(313)
bru_tune <- tune_grid(workflow_bru,
resamples = bru_fold,
grid = bru_grid,
metrics = bru_metrics)
→ A | error: Both weight decay and dropout should not be specified.
There were issues with some computations A: x1There were issues with some computations A: x3There were issues with some computations A: x6There were issues with some computations A: x10There were issues with some computations A: x15There were issues with some computations A: x21There were issues with some computations A: x28There were issues with some computations A: x36There were issues with some computations A: x45There were issues with some computations A: x55There were issues with some computations A: x66There were issues with some computations A: x78There were issues with some computations A: x91There were issues with some computations A: x105There were issues with some computations A: x120There were issues with some computations A: x136There were issues with some computations A: x153There were issues with some computations A: x171There were issues with some computations A: x190There were issues with some computations A: x210There were issues with some computations A: x231There were issues with some computations A: x253There were issues with some computations A: x276There were issues with some computations A: x300There were issues with some computations A: x301There were issues with some computations A: x303There were issues with some computations A: x306There were issues with some computations A: x310There were issues with some computations A: x315There were issues with some computations A: x321There were issues with some computations A: x328There were issues with some computations A: x336There were issues with some computations A: x345There were issues with some computations A: x355There were issues with some computations A: x366There were issues with some computations A: x378There were issues with some computations A: x391There were issues with some computations A: x405There were issues with some computations A: x420There were issues with some computations A: x436There were issues with some computations A: x453There were issues with some computations A: x471There were issues with some computations A: x490There were issues with some computations A: x510There were issues with some computations A: x531There were issues with some computations A: x553There were issues with some computations A: x576There were issues with some computations A: x600There were issues with some computations A: x601There were issues with some computations A: x603There were issues with some computations A: x606There were issues with some computations A: x610There were issues with some computations A: x615There were issues with some computations A: x621There were issues with some computations A: x628There were issues with some computations A: x636There were issues with some computations A: x645There were issues with some computations A: x655There were issues with some computations A: x666There were issues with some computations A: x678There were issues with some computations A: x691There were issues with some computations A: x705There were issues with some computations A: x720There were issues with some computations A: x736There were issues with some computations A: x753There were issues with some computations A: x771There were issues with some computations A: x790There were issues with some computations A: x810There were issues with some computations A: x831There were issues with some computations A: x853There were issues with some computations A: x876There were issues with some computations A: x900There were issues with some computations A: x901There were issues with some computations A: x903There were issues with some computations A: x906There were issues with some computations A: x910There were issues with some computations A: x915There were issues with some computations A: x921There were issues with some computations A: x928There were issues with some computations A: x936There were issues with some computations A: x945There were issues with some computations A: x955There were issues with some computations A: x966There were issues with some computations A: x978There were issues with some computations A: x991There were issues with some computations A: x1005There were issues with some computations A: x1020There were issues with some computations A: x1036There were issues with some computations A: x1053There were issues with some computations A: x1071There were issues with some computations A: x1090There were issues with some computations A: x1110There were issues with some computations A: x1131There were issues with some computations A: x1153There were issues with some computations A: x1176There were issues with some computations A: x1200There were issues with some computations A: x1200
# Extract the Best Hyper-Parameter(s)
bru_best <- select_best(bru_tune, metric = "accuracy")
bru_best |> print()
# A tibble: 1 × 4
hidden_units penalty dropout .config
<dbl> <dbl> <dbl> <chr>
1 15 0.1 0 pre0_mod31_post0
# Extract the Best Hyper-Parameter(s)
set.seed(313)
bru_best_fit <- workflow_bru |>
finalize_workflow(bru_best) |>
fit(m_train)
# Assess Prediction
bru_aug <- bru_best_fit |>
augment(new_data = m_test)
bru_aug |>
bru_metrics(truth = final_outcome,
estimate = .pred_class,
`.pred_HIV Infected`)
# A tibble: 6 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy binary 0.733
2 sensitivity binary 0.5
3 specificity binary 0.859
4 precision binary 0.656
5 npv binary 0.761
6 roc_auc binary 0.814