library(tidyverse)
library(tidymodels)
library(modeldata) #dataset cells

theme_set(theme_bw())
data(cells)
cell <- cells
cell
# A tibble: 2,019 x 58
   case  class angle_ch_1 area_ch_1 avg_inten_ch_1 avg_inten_ch_2 avg_inten_ch_3
   <fct> <fct>      <dbl>     <int>          <dbl>          <dbl>          <dbl>
 1 Test  PS        143.         185           15.7           4.95           9.55
 2 Train PS        134.         819           31.9         207.            69.9 
 3 Train WS        107.         431           28.0         116.            63.9 
 4 Train PS         69.2        298           19.5         102.            28.2 
 5 Test  PS          2.89       285           24.3         112.            20.5 
 6 Test  WS         40.7        172          326.          654.           129.  
 7 Test  WS        174.         177          260.          596.           124.  
 8 Test  PS        180.         251           18.3           5.73          17.2 
 9 Test  WS         18.9        495           16.1          89.5           13.7 
10 Test  WS        153.         384           17.7          89.9           20.4 
# ... with 2,009 more rows, and 51 more variables: avg_inten_ch_4 <dbl>,
#   convex_hull_area_ratio_ch_1 <dbl>, convex_hull_perim_ratio_ch_1 <dbl>,
#   diff_inten_density_ch_1 <dbl>, diff_inten_density_ch_3 <dbl>,
#   diff_inten_density_ch_4 <dbl>, entropy_inten_ch_1 <dbl>,
#   entropy_inten_ch_3 <dbl>, entropy_inten_ch_4 <dbl>,
#   eq_circ_diam_ch_1 <dbl>, eq_ellipse_lwr_ch_1 <dbl>,
#   eq_ellipse_oblate_vol_ch_1 <dbl>, eq_ellipse_prolate_vol_ch_1 <dbl>,
#   eq_sphere_area_ch_1 <dbl>, eq_sphere_vol_ch_1 <dbl>,
#   fiber_align_2_ch_3 <dbl>, fiber_align_2_ch_4 <dbl>,
#   fiber_length_ch_1 <dbl>, fiber_width_ch_1 <dbl>, inten_cooc_asm_ch_3 <dbl>,
#   inten_cooc_asm_ch_4 <dbl>, inten_cooc_contrast_ch_3 <dbl>,
#   inten_cooc_contrast_ch_4 <dbl>, inten_cooc_entropy_ch_3 <dbl>,
#   inten_cooc_entropy_ch_4 <dbl>, inten_cooc_max_ch_3 <dbl>,
#   inten_cooc_max_ch_4 <dbl>, kurt_inten_ch_1 <dbl>, kurt_inten_ch_3 <dbl>,
#   kurt_inten_ch_4 <dbl>, length_ch_1 <dbl>, neighbor_avg_dist_ch_1 <dbl>,
#   neighbor_min_dist_ch_1 <dbl>, neighbor_var_dist_ch_1 <dbl>,
#   perim_ch_1 <dbl>, shape_bfr_ch_1 <dbl>, shape_lwr_ch_1 <dbl>,
#   shape_p_2_a_ch_1 <dbl>, skew_inten_ch_1 <dbl>, skew_inten_ch_3 <dbl>,
#   skew_inten_ch_4 <dbl>, spot_fiber_count_ch_3 <int>,
#   spot_fiber_count_ch_4 <dbl>, total_inten_ch_1 <int>,
#   total_inten_ch_2 <dbl>, total_inten_ch_3 <int>, total_inten_ch_4 <int>,
#   var_inten_ch_1 <dbl>, var_inten_ch_3 <dbl>, var_inten_ch_4 <dbl>,
#   width_ch_1 <dbl>
glimpse(cell)
Rows: 2,019
Columns: 58
$ case                         <fct> Test, Train, Train, Train, Test, Test, Te~
$ class                        <fct> PS, PS, WS, PS, PS, WS, WS, PS, WS, WS, W~
$ angle_ch_1                   <dbl> 143.247705, 133.752037, 106.646387, 69.15~
$ area_ch_1                    <int> 185, 819, 431, 298, 285, 172, 177, 251, 4~
$ avg_inten_ch_1               <dbl> 15.71186, 31.92327, 28.03883, 19.45614, 2~
$ avg_inten_ch_2               <dbl> 4.954802, 206.878517, 116.315534, 102.294~
$ avg_inten_ch_3               <dbl> 9.548023, 69.916880, 63.941748, 28.217544~
$ avg_inten_ch_4               <dbl> 2.214689, 164.153453, 106.696602, 31.0280~
$ convex_hull_area_ratio_ch_1  <dbl> 1.124509, 1.263158, 1.053310, 1.202625, 1~
$ convex_hull_perim_ratio_ch_1 <dbl> 0.9196827, 0.7970801, 0.9354750, 0.865829~
$ diff_inten_density_ch_1      <dbl> 29.51923, 31.87500, 32.48771, 26.73228, 3~
$ diff_inten_density_ch_3      <dbl> 13.77564, 43.12228, 35.98577, 22.91732, 2~
$ diff_inten_density_ch_4      <dbl> 6.826923, 79.308424, 51.357050, 26.393701~
$ entropy_inten_ch_1           <dbl> 4.969781, 6.087592, 5.883557, 5.420065, 5~
$ entropy_inten_ch_3           <dbl> 4.371017, 6.642761, 6.683000, 5.436732, 5~
$ entropy_inten_ch_4           <dbl> 2.718884, 7.880155, 7.144601, 5.778329, 5~
$ eq_circ_diam_ch_1            <dbl> 15.36954, 32.30558, 23.44892, 19.50279, 1~
$ eq_ellipse_lwr_ch_1          <dbl> 3.060676, 1.558394, 1.375386, 3.391220, 2~
$ eq_ellipse_oblate_vol_ch_1   <dbl> 336.9691, 2232.9055, 802.1945, 724.7143, ~
$ eq_ellipse_prolate_vol_ch_1  <dbl> 110.0963, 1432.8246, 583.2504, 213.7031, ~
$ eq_sphere_area_ch_1          <dbl> 742.1156, 3278.7256, 1727.4104, 1194.9320~
$ eq_sphere_vol_ch_1           <dbl> 1900.996, 17653.525, 6750.985, 3884.084, ~
$ fiber_align_2_ch_3           <dbl> 1.000000, 1.487935, 1.300522, 1.220424, 1~
$ fiber_align_2_ch_4           <dbl> 1.000000, 1.352374, 1.522316, 1.733250, 1~
$ fiber_length_ch_1            <dbl> 26.98132, 64.28230, 21.14115, 43.14112, 3~
$ fiber_width_ch_1             <dbl> 7.410365, 13.167079, 21.141150, 7.404412,~
$ inten_cooc_asm_ch_3          <dbl> 0.011183899, 0.028051061, 0.006862315, 0.~
$ inten_cooc_asm_ch_4          <dbl> 0.050448005, 0.012594975, 0.006141691, 0.~
$ inten_cooc_contrast_ch_3     <dbl> 40.751777, 8.227953, 14.446074, 7.299457,~
$ inten_cooc_contrast_ch_4     <dbl> 13.895439, 6.984046, 16.700843, 13.390884~
$ inten_cooc_entropy_ch_3      <dbl> 7.199458, 6.822138, 7.580100, 6.312641, 6~
$ inten_cooc_entropy_ch_4      <dbl> 5.249744, 7.098988, 7.671478, 7.197026, 5~
$ inten_cooc_max_ch_3          <dbl> 0.07741935, 0.15321477, 0.02835052, 0.162~
$ inten_cooc_max_ch_4          <dbl> 0.17197452, 0.07387141, 0.02319588, 0.077~
$ kurt_inten_ch_1              <dbl> -0.656744087, -0.248769067, -0.293484630,~
$ kurt_inten_ch_3              <dbl> -0.608058268, -0.330783900, 1.051281336, ~
$ kurt_inten_ch_4              <dbl> 0.7258145, -0.2652638, 0.1506140, -0.3472~
$ length_ch_1                  <dbl> 26.20779, 47.21855, 28.14303, 37.85957, 3~
$ neighbor_avg_dist_ch_1       <dbl> 370.4543, 174.4442, 158.4774, 206.3344, 2~
$ neighbor_min_dist_ch_1       <dbl> 99.10349, 30.11114, 34.94477, 33.08030, 2~
$ neighbor_var_dist_ch_1       <dbl> 127.96080, 81.38063, 90.43768, 116.89276,~
$ perim_ch_1                   <dbl> 68.78338, 154.89876, 84.56460, 101.09107,~
$ shape_bfr_ch_1               <dbl> 0.6651480, 0.5397584, 0.7243116, 0.589162~
$ shape_lwr_ch_1               <dbl> 2.462450, 1.468181, 1.328408, 2.826854, 2~
$ shape_p_2_a_ch_1             <dbl> 1.883006, 2.255810, 1.272193, 2.545840, 2~
$ skew_inten_ch_1              <dbl> 0.45450484, 0.39870467, 0.47248709, 0.881~
$ skew_inten_ch_3              <dbl> 0.46039340, 0.61973079, 0.97137879, 0.999~
$ skew_inten_ch_4              <dbl> 1.2327736, 0.5272631, 0.3247065, 0.604439~
$ spot_fiber_count_ch_3        <int> 1, 4, 2, 4, 1, 1, 0, 2, 1, 1, 1, 0, 0, 2,~
$ spot_fiber_count_ch_4        <dbl> 5, 12, 7, 8, 8, 5, 5, 8, 12, 8, 5, 6, 7, ~
$ total_inten_ch_1             <int> 2781, 24964, 11552, 5545, 6603, 53779, 43~
$ total_inten_ch_2             <dbl> 701, 160998, 47511, 28870, 30306, 107681,~
$ total_inten_ch_3             <int> 1690, 54675, 26344, 8042, 5569, 21234, 20~
$ total_inten_ch_4             <int> 392, 128368, 43959, 8843, 11037, 57231, 4~
$ var_inten_ch_1               <dbl> 12.47468, 18.80923, 17.29564, 13.81897, 1~
$ var_inten_ch_3               <dbl> 7.609035, 56.715352, 37.671053, 30.005643~
$ var_inten_ch_4               <dbl> 2.714100, 118.388139, 49.470524, 24.74953~
$ width_ch_1                   <dbl> 10.64297, 32.16126, 21.18553, 13.39283, 1~
colSums(is.na(cell))
                        case                        class 
                           0                            0 
                  angle_ch_1                    area_ch_1 
                           0                            0 
              avg_inten_ch_1               avg_inten_ch_2 
                           0                            0 
              avg_inten_ch_3               avg_inten_ch_4 
                           0                            0 
 convex_hull_area_ratio_ch_1 convex_hull_perim_ratio_ch_1 
                           0                            0 
     diff_inten_density_ch_1      diff_inten_density_ch_3 
                           0                            0 
     diff_inten_density_ch_4           entropy_inten_ch_1 
                           0                            0 
          entropy_inten_ch_3           entropy_inten_ch_4 
                           0                            0 
           eq_circ_diam_ch_1          eq_ellipse_lwr_ch_1 
                           0                            0 
  eq_ellipse_oblate_vol_ch_1  eq_ellipse_prolate_vol_ch_1 
                           0                            0 
         eq_sphere_area_ch_1           eq_sphere_vol_ch_1 
                           0                            0 
          fiber_align_2_ch_3           fiber_align_2_ch_4 
                           0                            0 
           fiber_length_ch_1             fiber_width_ch_1 
                           0                            0 
         inten_cooc_asm_ch_3          inten_cooc_asm_ch_4 
                           0                            0 
    inten_cooc_contrast_ch_3     inten_cooc_contrast_ch_4 
                           0                            0 
     inten_cooc_entropy_ch_3      inten_cooc_entropy_ch_4 
                           0                            0 
         inten_cooc_max_ch_3          inten_cooc_max_ch_4 
                           0                            0 
             kurt_inten_ch_1              kurt_inten_ch_3 
                           0                            0 
             kurt_inten_ch_4                  length_ch_1 
                           0                            0 
      neighbor_avg_dist_ch_1       neighbor_min_dist_ch_1 
                           0                            0 
      neighbor_var_dist_ch_1                   perim_ch_1 
                           0                            0 
              shape_bfr_ch_1               shape_lwr_ch_1 
                           0                            0 
            shape_p_2_a_ch_1              skew_inten_ch_1 
                           0                            0 
             skew_inten_ch_3              skew_inten_ch_4 
                           0                            0 
       spot_fiber_count_ch_3        spot_fiber_count_ch_4 
                           0                            0 
            total_inten_ch_1             total_inten_ch_2 
                           0                            0 
            total_inten_ch_3             total_inten_ch_4 
                           0                            0 
              var_inten_ch_1               var_inten_ch_3 
                           0                            0 
              var_inten_ch_4                   width_ch_1 
                           0                            0 
# predictor is class
# PS: poorly segmented
# WS: well segmented
cell <- cell %>% select(-case)

cell %>% count(class)
# A tibble: 2 x 2
  class     n
  <fct> <int>
1 PS     1300
2 WS      719
set.seed(2021)
cell_split <- initial_split(cell,
                            prop = .75,
                            strata = class)

cell_train <- training(cell_split)
cell_test <- testing(cell_split)
set.seed(2021)
cell_folds <- vfold_cv(cell_train, v = 10)
mlp_nnet_model <-
  mlp(hidden_units = tune(), penalty = tune(), epochs = tune()) %>%
  set_engine('nnet', trace = 0) %>%
  set_mode('classification')
mlp_nnet_rec <- 
  recipe(class ~ ., data = cell_train) %>% 
  step_YeoJohnson(all_numeric_predictors()) %>% 
  step_normalize(all_numeric_predictors()) %>% 
  step_pca(all_numeric_predictors(), num_comp = tune())
mlp_nnet_wf <- 
  workflow() %>% 
  add_recipe(mlp_nnet_rec) %>% 
  add_model(mlp_nnet_model)

mlp_param <- mlp_nnet_wf %>% parameters()
mlp_param %>% pull_dials_object("hidden_units")
# Hidden Units (quantitative)
Range: [1, 10]
mlp_param %>% pull_dials_object("penalty")
Amount of Regularization (quantitative)
Transformer:  log-10 
Range (transformed scale): [-10, 0]
mlp_param %>% pull_dials_object("epochs")
# Epochs (quantitative)
Range: [10, 1000]
mlp_param %>% pull_dials_object("num_comp")
# Components (quantitative)
Range: [1, 4]
crox <- 
  crossing(
    hidden_units = 1:3,
    penaly = c(0.0, 0.1),
    epochs = c(100, 200)
  )

crox %>% View


# expand_grid(
#     hidden_units = 1:3,
#     penalty = c(0.0, 0.1),
#     epochs = c(100, 200)
#   )


grid_regular(mlp_param, levels = 2)
# A tibble: 16 x 4
   hidden_units      penalty epochs num_comp
          <int>        <dbl>  <int>    <int>
 1            1 0.0000000001     10        1
 2           10 0.0000000001     10        1
 3            1 1                10        1
 4           10 1                10        1
 5            1 0.0000000001   1000        1
 6           10 0.0000000001   1000        1
 7            1 1              1000        1
 8           10 1              1000        1
 9            1 0.0000000001     10        4
10           10 0.0000000001     10        4
11            1 1                10        4
12           10 1                10        4
13            1 0.0000000001   1000        4
14           10 0.0000000001   1000        4
15            1 1              1000        4
16           10 1              1000        4
# mlp_param %>%
#   grid_regular(levels = c(hidden_units = 3, penalty = 2, epochs = 2))
library(ggforce)
set.seed(2021)
mlp_param %>% 
  # The 'original = FALSE' option keeps penalty in log10 units
  grid_random(size = 15, original = FALSE) %>% 
  ggplot(aes(x = .panel_x, y = .panel_y)) + 
  geom_point() +
  geom_blank() +
  facet_matrix(vars(hidden_units, penalty, epochs), layer.diag = 2) + 
  labs(title = "Random design with 15 candidates")

set.seed(2021)
mlp_param %>% 
  grid_latin_hypercube(size = 15, original = FALSE) %>% 
  ggplot(aes(x = .panel_x, y = .panel_y)) + 
  geom_point() +
  geom_blank() +
  facet_matrix(vars(hidden_units, penalty, epochs), layer.diag = 2) + 
  labs(title = "Latin Hypercube design with 15 candidates")

mlp_param <- mlp_nnet_wf %>% 
  parameters() %>% 
  update(
    epochs = epochs(c(50, 200)),
    num_comp = num_comp(c(0, 40))
  )
roc_metric <- metric_set(roc_auc)
keep_pred <- control_resamples(save_pred = TRUE, save_workflow = TRUE)

set.seed(2021)

mlp_tuning <- 
  tune_grid(
    mlp_nnet_wf,
    resamples = cell_folds,
    grid = mlp_param %>% grid_regular(levels = 3),
    metrics = roc_metric,
    control = keep_pred
  )

mlp_tuning
# Tuning results
# 10-fold cross-validation 
# A tibble: 10 x 5
   splits            id     .metrics         .notes          .predictions       
   <list>            <chr>  <list>           <list>          <list>             
 1 <split [1362/152~ Fold01 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,312 x ~
 2 <split [1362/152~ Fold02 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,312 x ~
 3 <split [1362/152~ Fold03 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,312 x ~
 4 <split [1362/152~ Fold04 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,312 x ~
 5 <split [1363/151~ Fold05 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,231 x ~
 6 <split [1363/151~ Fold06 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,231 x ~
 7 <split [1363/151~ Fold07 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,231 x ~
 8 <split [1363/151~ Fold08 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,231 x ~
 9 <split [1363/151~ Fold09 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,231 x ~
10 <split [1363/151~ Fold10 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,231 x ~
mlp_tuning %>% collect_metrics()
# A tibble: 81 x 10
   hidden_units   penalty epochs num_comp .metric .estimator  mean     n std_err
          <int>     <dbl>  <int>    <int> <chr>   <chr>      <dbl> <int>   <dbl>
 1            1     1e-10     50        0 roc_auc binary     0.871    10 0.0121 
 2            5     1e-10     50        0 roc_auc binary     0.871    10 0.00573
 3           10     1e-10     50        0 roc_auc binary     0.842    10 0.0102 
 4            1     1e- 5     50        0 roc_auc binary     0.874    10 0.00707
 5            5     1e- 5     50        0 roc_auc binary     0.867    10 0.00909
 6           10     1e- 5     50        0 roc_auc binary     0.849    10 0.00906
 7            1     1e+ 0     50        0 roc_auc binary     0.893    10 0.00535
 8            5     1e+ 0     50        0 roc_auc binary     0.896    10 0.00632
 9           10     1e+ 0     50        0 roc_auc binary     0.890    10 0.00438
10            1     1e-10    125        0 roc_auc binary     0.821    10 0.0112 
# ... with 71 more rows, and 1 more variable: .config <chr>
mlp_tuning %>% collect_predictions()
# A tibble: 122,634 x 10
   id     .pred_PS .pred_WS  .row num_comp hidden_units     penalty epochs class
   <chr>     <dbl>    <dbl> <int>    <int>        <int>       <dbl>  <int> <fct>
 1 Fold01    0.721    0.279    15        0            1       1e-10     50 PS   
 2 Fold01    0.621    0.379    27        0            1       1e-10     50 PS   
 3 Fold01    0.384    0.616    29        0            1       1e-10     50 PS   
 4 Fold01    0.721    0.279    35        0            1       1e-10     50 PS   
 5 Fold01    0.380    0.620    50        0            1       1e-10     50 PS   
 6 Fold01    0.719    0.281    52        0            1       1e-10     50 PS   
 7 Fold01    0.721    0.279    55        0            1       1e-10     50 PS   
 8 Fold01    0.381    0.619    57        0            1       1e-10     50 PS   
 9 Fold01    0.380    0.620    58        0            1       1e-10     50 PS   
10 Fold01    0.721    0.279    60        0            1       1e-10     50 PS   
# ... with 122,624 more rows, and 1 more variable: .config <chr>
mlp_tuning %>% show_best()
# A tibble: 5 x 10
  hidden_units penalty epochs num_comp .metric .estimator  mean     n std_err
         <int>   <dbl>  <int>    <int> <chr>   <chr>      <dbl> <int>   <dbl>
1            5       1     50        0 roc_auc binary     0.896    10 0.00632
2            5       1    125       40 roc_auc binary     0.895    10 0.00762
3            5       1    125        0 roc_auc binary     0.895    10 0.00623
4            5       1     50       20 roc_auc binary     0.895    10 0.00653
5            5       1    200       40 roc_auc binary     0.894    10 0.00569
# ... with 1 more variable: .config <chr>
mlp_tuning %>% select_best()
# A tibble: 1 x 5
  hidden_units penalty epochs num_comp .config              
         <int>   <dbl>  <int>    <int> <chr>                
1            5       1     50        0 Preprocessor1_Model08
autoplot(mlp_tuning) + theme(legend.position = "top")

parallel::detectCores(logical = FALSE)
[1] 8
parallel::detectCores(logical = TRUE)
[1] 16
set.seed(2021)

mlp_tuning2 <- 
  tune_grid(
    mlp_nnet_wf,
    resamples = cell_folds,
    grid = 20,
    param_info = mlp_param,
    metrics = roc_metric,
    control = keep_pred
  )

mlp_tuning2
# Tuning results
# 10-fold cross-validation 
# A tibble: 10 x 5
   splits            id     .metrics         .notes          .predictions       
   <list>            <chr>  <list>           <list>          <list>             
 1 <split [1362/152~ Fold01 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,040 x 9~
 2 <split [1362/152~ Fold02 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,040 x 9~
 3 <split [1362/152~ Fold03 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,040 x 9~
 4 <split [1362/152~ Fold04 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,040 x 9~
 5 <split [1363/151~ Fold05 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,020 x 9~
 6 <split [1363/151~ Fold06 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,020 x 9~
 7 <split [1363/151~ Fold07 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,020 x 9~
 8 <split [1363/151~ Fold08 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,020 x 9~
 9 <split [1363/151~ Fold09 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,020 x 9~
10 <split [1363/151~ Fold10 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,020 x 9~
mlp_tuning2 %>% collect_predictions()
# A tibble: 30,280 x 10
   id     .pred_PS .pred_WS  .row num_comp hidden_units penalty epochs class
   <chr>     <dbl>    <dbl> <int>    <int>        <int>   <dbl>  <int> <fct>
 1 Fold01    0.720    0.280    15       14            6 0.00106    112 PS   
 2 Fold01    0.341    0.659    27       14            6 0.00106    112 PS   
 3 Fold01    0.716    0.284    29       14            6 0.00106    112 PS   
 4 Fold01    0.635    0.365    35       14            6 0.00106    112 PS   
 5 Fold01    0.349    0.651    50       14            6 0.00106    112 PS   
 6 Fold01    0.664    0.336    52       14            6 0.00106    112 PS   
 7 Fold01    0.720    0.280    55       14            6 0.00106    112 PS   
 8 Fold01    0.509    0.491    57       14            6 0.00106    112 PS   
 9 Fold01    0.341    0.659    58       14            6 0.00106    112 PS   
10 Fold01    0.730    0.270    60       14            6 0.00106    112 PS   
# ... with 30,270 more rows, and 1 more variable: .config <chr>
mlp_tuning2 %>% collect_metrics()
# A tibble: 20 x 10
   hidden_units  penalty epochs num_comp .metric .estimator  mean     n std_err
          <int>    <dbl>  <int>    <int> <chr>   <chr>      <dbl> <int>   <dbl>
 1            6 1.06e- 3    112       14 roc_auc binary     0.854    10 0.00960
 2            6 4.31e- 8    135       39 roc_auc binary     0.868    10 0.00675
 3            6 4.80e-10    147       15 roc_auc binary     0.855    10 0.0111 
 4            1 1.37e- 1     77       28 roc_auc binary     0.876    10 0.00559
 5            5 4.81e- 6    105       18 roc_auc binary     0.846    10 0.00993
 6            9 8.70e- 2    164       18 roc_auc binary     0.851    10 0.0101 
 7            5 1.43e- 4    184       33 roc_auc binary     0.856    10 0.00943
 8            9 3.51e- 5    157        5 roc_auc binary     0.860    10 0.00631
 9            7 5.82e- 4     70        7 roc_auc binary     0.866    10 0.00726
10            2 2.14e- 2    176       37 roc_auc binary     0.884    10 0.00694
11           10 1.59e-10     55       35 roc_auc binary     0.836    10 0.0118 
12            4 2.60e- 6    100        3 roc_auc binary     0.870    10 0.00577
13            2 6.04e- 9     81       25 roc_auc binary     0.839    10 0.00736
14            2 1.04e- 8    189       31 roc_auc binary     0.853    10 0.00686
15            4 1.74e- 5     62       26 roc_auc binary     0.855    10 0.00812
16            3 1.95e- 9    199       22 roc_auc binary     0.860    10 0.00562
17            9 3.32e- 7    118       12 roc_auc binary     0.833    10 0.00749
18            8 3.96e- 3    128        1 roc_auc binary     0.769    10 0.00806
19            7 1.15e- 7    148        9 roc_auc binary     0.849    10 0.00965
20            3 6.59e- 1     93       20 roc_auc binary     0.892    10 0.00520
# ... with 1 more variable: .config <chr>
mlp_tuning2 %>% show_best(n= 10)
# A tibble: 10 x 10
   hidden_units  penalty epochs num_comp .metric .estimator  mean     n std_err
          <int>    <dbl>  <int>    <int> <chr>   <chr>      <dbl> <int>   <dbl>
 1            3 6.59e- 1     93       20 roc_auc binary     0.892    10 0.00520
 2            2 2.14e- 2    176       37 roc_auc binary     0.884    10 0.00694
 3            1 1.37e- 1     77       28 roc_auc binary     0.876    10 0.00559
 4            4 2.60e- 6    100        3 roc_auc binary     0.870    10 0.00577
 5            6 4.31e- 8    135       39 roc_auc binary     0.868    10 0.00675
 6            7 5.82e- 4     70        7 roc_auc binary     0.866    10 0.00726
 7            3 1.95e- 9    199       22 roc_auc binary     0.860    10 0.00562
 8            9 3.51e- 5    157        5 roc_auc binary     0.860    10 0.00631
 9            5 1.43e- 4    184       33 roc_auc binary     0.856    10 0.00943
10            6 4.80e-10    147       15 roc_auc binary     0.855    10 0.0111 
# ... with 1 more variable: .config <chr>
mlp_tuning2 %>% select_best(.metric = "roc_auc")
# A tibble: 1 x 5
  hidden_units penalty epochs num_comp .config              
         <int>   <dbl>  <int>    <int> <chr>                
1            3   0.659     93       20 Preprocessor19_Model1
autoplot(mlp_tuning2)

Finalizing the model. I choose mlp_tunning due to a better .metric (roc_auc) = 0.89

best_mlp_model <- 
  mlp_tuning %>% 
  select_best(metric = "roc_auc")

last_mlp_wf <- 
  mlp_nnet_wf %>% 
  finalize_workflow(best_mlp_model)

last_mlp_wf
== Workflow ====================================================================
Preprocessor: Recipe
Model: mlp()

-- Preprocessor ----------------------------------------------------------------
3 Recipe Steps

* step_YeoJohnson()
* step_normalize()
* step_pca()

-- Model -----------------------------------------------------------------------
Single Layer Neural Network Specification (classification)

Main Arguments:
  hidden_units = 5
  penalty = 1
  epochs = 50

Engine-Specific Arguments:
  trace = 0

Computational engine: nnet 
mlp_fit <- 
  last_fit(
    last_mlp_wf,
    split = cell_split
  )

mlp_fit
# Resampling results
# Manual resampling 
# A tibble: 1 x 6
  splits        id           .metrics      .notes      .predictions    .workflow
  <list>        <chr>        <list>        <list>      <list>          <list>   
1 <split [1514~ train/test ~ <tibble [2 x~ <tibble [0~ <tibble [505 x~ <workflo~
mlp_fit %>% collect_metrics()
# A tibble: 2 x 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.798 Preprocessor1_Model1
2 roc_auc  binary         0.879 Preprocessor1_Model1
mlp_pred <- mlp_fit %>% collect_predictions()
mlp_pred
# A tibble: 505 x 7
   id               .pred_PS .pred_WS  .row .pred_class class .config           
   <chr>               <dbl>    <dbl> <int> <fct>       <fct> <chr>             
 1 train/test split    0.393    0.607     3 WS          WS    Preprocessor1_Mod~
 2 train/test split    0.621    0.379     9 PS          WS    Preprocessor1_Mod~
 3 train/test split    0.716    0.284    10 PS          WS    Preprocessor1_Mod~
 4 train/test split    0.308    0.692    11 WS          WS    Preprocessor1_Mod~
 5 train/test split    0.726    0.274    13 PS          PS    Preprocessor1_Mod~
 6 train/test split    0.401    0.599    16 WS          PS    Preprocessor1_Mod~
 7 train/test split    0.460    0.540    22 WS          PS    Preprocessor1_Mod~
 8 train/test split    0.725    0.275    24 PS          PS    Preprocessor1_Mod~
 9 train/test split    0.583    0.417    30 PS          WS    Preprocessor1_Mod~
10 train/test split    0.326    0.674    35 WS          PS    Preprocessor1_Mod~
# ... with 495 more rows
mlp_pred %>% 
  conf_mat(
    truth = class, estimate = .pred_class
    )
          Truth
Prediction  PS  WS
        PS 273  50
        WS  52 130
mlp_pred %>% 
  roc_auc(
    truth = class,
    estimate = .pred_PS
  )
# A tibble: 1 x 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.879
mlp_pred %>% 
  roc_curve(
    truth = class,
    estimate = .pred_PS
  ) %>% 
  autoplot()

library(finetune)

set.seed(2021)

mlp_race <- 
  tune_race_anova(
    mlp_nnet_wf,
    resamples = cell_folds,
    grid = 20,
    param_info = mlp_param,
    metrics = roc_metric,
    control = control_race(verbose_elim = T)
  )
  
mlp_race
# Tuning results
# 10-fold cross-validation 
# A tibble: 10 x 5
   splits             id     .order .metrics          .notes          
   <list>             <chr>   <int> <list>            <list>          
 1 <split [1362/152]> Fold01      3 <tibble [20 x 8]> <tibble [0 x 1]>
 2 <split [1362/152]> Fold04      2 <tibble [20 x 8]> <tibble [0 x 1]>
 3 <split [1363/151]> Fold08      1 <tibble [20 x 8]> <tibble [0 x 1]>
 4 <split [1363/151]> Fold05      4 <tibble [13 x 8]> <tibble [0 x 1]>
 5 <split [1363/151]> Fold07      5 <tibble [4 x 8]>  <tibble [0 x 1]>
 6 <split [1363/151]> Fold06      6 <tibble [3 x 8]>  <tibble [0 x 1]>
 7 <split [1362/152]> Fold03      7 <tibble [3 x 8]>  <tibble [0 x 1]>
 8 <split [1362/152]> Fold02      8 <tibble [3 x 8]>  <tibble [0 x 1]>
 9 <split [1363/151]> Fold09      9 <tibble [3 x 8]>  <tibble [0 x 1]>
10 <split [1363/151]> Fold10     10 <tibble [3 x 8]>  <tibble [0 x 1]>
mlp_race %>% collect_metrics()
# A tibble: 20 x 10
   hidden_units  penalty epochs num_comp .metric .estimator  mean     n std_err
          <int>    <dbl>  <int>    <int> <chr>   <chr>      <dbl> <int>   <dbl>
 1            1 2.07e- 9    101       15 roc_auc binary     0.840     3 0.0215 
 2            5 6.15e- 9    147        3 roc_auc binary     0.843     3 0.00837
 3            8 1.29e- 7    114       13 roc_auc binary     0.848     4 0.00970
 4            9 5.80e- 1    156       34 roc_auc binary     0.866     5 0.00955
 5            4 1.39e-10     95       34 roc_auc binary     0.853     4 0.00626
 6           10 3.34e- 5    122        0 roc_auc binary     0.823     3 0.0150 
 7            3 3.09e- 6     69        5 roc_auc binary     0.878    10 0.00650
 8            7 5.63e- 2    148       19 roc_auc binary     0.854     4 0.00899
 9            3 4.81e- 3    178       36 roc_auc binary     0.835     3 0.0237 
10            7 8.50e- 4     63       29 roc_auc binary     0.800     3 0.0234 
11            8 3.87e- 6    138       39 roc_auc binary     0.847     4 0.0120 
12            8 1.49e- 1    200       10 roc_auc binary     0.873    10 0.00897
13            6 6.68e- 8    105       22 roc_auc binary     0.844     4 0.00980
14            2 1.15e- 2    128        6 roc_auc binary     0.882    10 0.00318
15            9 3.83e- 7     84       23 roc_auc binary     0.835     3 0.0344 
16            2 2.50e- 8     56       25 roc_auc binary     0.850     4 0.00547
17            4 5.57e-10    192       18 roc_auc binary     0.827     3 0.0206 
18            2 1.08e- 4    176       27 roc_auc binary     0.849     4 0.00941
19            5 3.03e- 3    169       31 roc_auc binary     0.849     4 0.0191 
20            6 1.68e- 5     74       11 roc_auc binary     0.854     4 0.0169 
# ... with 1 more variable: .config <chr>
mlp_race %>% show_best(metric = "roc_auc")
# A tibble: 3 x 10
  hidden_units    penalty epochs num_comp .metric .estimator  mean     n std_err
         <int>      <dbl>  <int>    <int> <chr>   <chr>      <dbl> <int>   <dbl>
1            2 0.0115        128        6 roc_auc binary     0.882    10 0.00318
2            3 0.00000309     69        5 roc_auc binary     0.878    10 0.00650
3            8 0.149         200       10 roc_auc binary     0.873    10 0.00897
# ... with 1 more variable: .config <chr>
best_race <- 
  mlp_race %>% select_best(metric = "roc_auc")

last_race_wf <- 
  mlp_nnet_wf %>% 
  finalize_workflow(best_race)

race_fit <- 
  last_fit(
    last_race_wf, 
    split = cell_split
  )

race_fit %>% collect_predictions()
# A tibble: 505 x 7
   id               .pred_PS .pred_WS  .row .pred_class class .config           
   <chr>               <dbl>    <dbl> <int> <fct>       <fct> <chr>             
 1 train/test split    0.441    0.559     3 WS          WS    Preprocessor1_Mod~
 2 train/test split    0.624    0.376     9 PS          WS    Preprocessor1_Mod~
 3 train/test split    0.707    0.293    10 PS          WS    Preprocessor1_Mod~
 4 train/test split    0.328    0.672    11 WS          WS    Preprocessor1_Mod~
 5 train/test split    0.728    0.272    13 PS          PS    Preprocessor1_Mod~
 6 train/test split    0.427    0.573    16 WS          PS    Preprocessor1_Mod~
 7 train/test split    0.477    0.523    22 WS          PS    Preprocessor1_Mod~
 8 train/test split    0.730    0.270    24 PS          PS    Preprocessor1_Mod~
 9 train/test split    0.596    0.404    30 PS          WS    Preprocessor1_Mod~
10 train/test split    0.323    0.677    35 WS          PS    Preprocessor1_Mod~
# ... with 495 more rows
race_fit %>% collect_metrics()
# A tibble: 2 x 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.804 Preprocessor1_Model1
2 roc_auc  binary         0.879 Preprocessor1_Model1